## Super Resolution Models

### Imports and Utils

In [1]:
"""
Importing necessary libraries
"""
import os
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt
from sklearn import preprocessing
from sklearn.kernel_approximation import RBFSampler

# Remove all the warnings
import warnings
warnings.filterwarnings('ignore')

# Set env CUDA_LAUNCH_BLOCKING=1
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
    from einops import rearrange
except ImportError:
    %pip install einops
    from einops import rearrange

In [None]:
# Create Coordinate Map
def create_coordinate_map(img, scale):
    """
    img: torch.Tensor of shape (num_channels, height, width)
    scale: int, the scale factor for the image
    
    Return: tuple of torch.Tensor of shape (height * width * scale**2, 2) and torch.Tensor of shape (height * width * scale**2, num_channels)
    """
    
    # Upscale the image
    num_channels, height, width = img.shape
    
    # Create a 2D grid of (x,y) coordinates (h, w)
    w_coords = torch.arange(0, width, 1/scale).repeat(int(height*scale), 1)
    h_coords = torch.arange(0, height, 1/scale).repeat(int(width*scale), 1).t()
    w_coords = w_coords.reshape(-1)
    h_coords = h_coords.reshape(-1)

    # Combine the x and y coordinates into a single tensor
    X = torch.stack([h_coords, w_coords], dim=1).float()

    # Reshape the image to (h * w, num_channels)
    Y = torch.tensor(rearrange(img, 'c h w -> (h w) c').float())

    # Move X to GPU if available
    X = X.to(device)
    Y = Y.to(device)
    return X, Y

In [2]:
"""
Making Functions for Plotting and Comparing
"""

def plot_compare_two_images(img1, img2, title1='Image 1', title2='Image 2', main_title='Comparison'):
    """
    Plots a comparison between two images in a subplot.
    """

    fig, ax = plt.subplots(1, 2, figsize=(15, 10))

    ax[0].imshow(rearrange(img1, 'c h w -> h w c').numpy())
    ax[0].set_title(title1)
    ax[0].axis('off')

    ax[1].imshow(rearrange(img2, 'c h w -> h w c').numpy())
    ax[1].set_title(title2)
    ax[1].axis('off')

    fig.suptitle(main_title)
    plt.show()

def plot_compare_three_images(img1, img2, img3, title1='Image 1', title2='Image 2', title3='Image 3', main_title='Comparison'):
    """
    Plots a comparison between two images in a subplot.
    """
    
    fig, ax = plt.subplots(1, 3, figsize=(20, 7))

    ax[0].imshow(rearrange(img1, 'c h w -> h w c').numpy())
    ax[0].set_title(title1)
    ax[0].axis('off')

    ax[1].imshow(rearrange(img2, 'c h w -> h w c').numpy())
    ax[1].set_title(title2)
    ax[1].axis('off')

    ax[2].imshow(rearrange(img3, 'c h w -> h w c').numpy())
    ax[2].set_title(title3)
    ax[2].axis('off')

    fig.suptitle(main_title)
    plt.show()

### Dataset Creation and Preprocessing

In [None]:
# Set the path to the image
path = './Dataset/dog.jpg'

# Load the image
if not os.path.exists(path):
    !wget https://segment-anything.com/assets/gallery/AdobeStock_94274587_welsh_corgi_pembroke_CD.jpg -O {path}

# Read in a image from torchvision
img = torchvision.io.read_image(path)
print(f"Original Image Shape: {img.shape}")
plt.imshow(rearrange(img, 'c h w -> h w c').numpy())
plt.title('Original Image')
plt.show()

In [None]:
# Normalize the image
scaler_img = preprocessing.MinMaxScaler().fit(img.reshape(-1, 1))
img = torch.tensor(scaler_img.transform(img.reshape(-1, 1)).reshape(img.shape))
img = img.to(device)
print(f'Image shape: {img.shape}')

In [None]:
# Crop the image
img_cropped = torchvision.transforms.functional.crop(img.cpu(), 600, 800, 400, 400)
img_cropped = img_cropped.to(device)
print(f'Image cropped shape: {img_cropped.shape}')
plt.imshow(rearrange(img_cropped, 'c h w -> h w c').cpu().numpy())
plt.title('Cropped Image')
plt.show()

In [12]:
# Create the coordinate map
X, Y = create_coordinate_map(img_cropped, scale=1)
print(f'X shape: {X.shape}, Y shape: {Y.shape}')

# MinMaxScaler
minmax = preprocessing.MinMaxScaler(feature_range=(-1, 1)).fit(X.cpu())
X_scaled = minmax.transform(X.cpu())

# Move the scaled X coordinates to the GPU
X_scaled = torch.tensor(X_scaled).to(device).float()

X_train shape:  torch.Size([256, 42])
X_test shape:  torch.Size([64, 42])
Y_train shape:  torch.Size([256, 4])
Y_test shape:  torch.Size([64, 4])


In [None]:
num_features = 15000
sigma = 0.008

# Transform the X coordinates to the new map space
RFF = RBFSampler(n_components=num_features, gamma=1/(2 * sigma**2))
X_RFF = RFF.fit_transform(X.cpu().numpy())
X_RFF = torch.tensor(X_RFF, dtype=torch.float32).to(device)

### Model Creation and Training

In [3]:
class SuperResModel(nn.Module):
    """
    A Linear Regression Model for Super Resolution
    """
    def __init__(self, in_features, out_features):
        super(SuperResModel, self).__init__()
        """
        
        """
        self.linear = nn.Linear(in_features, out_features)
        
    def forward(self, x):
        return self.linear(x)
    
    def fit(self, X, Y, epochs=10000, lr=0.01):
        """
        X: torch.Tensor of shape (n_samples, n_features)
        Y: torch.Tensor of shape (n_samples, n_channels)
        epochs: int, the number of epochs
        """
        criteria = nn.MSELoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        
        for epoch in range(epochs):
            # Forward pass
            preds = self.forward(X)
            loss = criteria(preds, Y)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Print the loss
            if (epoch+1) % 1000 == 0:
                print(f'Epoch {epoch+1} Loss: {loss.item()}')
                print("\n----------------------------------------------------\n")

    def save_model(self, file_path):
        """
        Save the model to a file.
        """
        torch.save(self.state_dict(), file_path)

    def load_model(self, file_path):
        """
        Load the model from a file.
        """
        self.load_state_dict(torch.load(file_path))

In [26]:
# Initialize the model
model = SuperResModel(X_RFF.shape[1], Y.shape[1]).to(device)

In [27]:
# Train the model
model.fit(X_RFF, Y, epochs=50, lr=0.01)

Epoch 10 Loss: 1.132049322128296

----------------------------------------------------

Epoch 20 Loss: 0.8226261138916016

----------------------------------------------------

Epoch 30 Loss: 0.7468557953834534

----------------------------------------------------

Epoch 40 Loss: 0.7439714074134827

----------------------------------------------------

Epoch 50 Loss: 0.7437458634376526

----------------------------------------------------

Training Accuracy: 100.0%


### Testing and Saving

In [28]:
# Evaluate the model
model.eval()

with torch.no_grad():
    Y_pred = model(X_test)
    Y_pred = torch.argmax(Y_pred, axis=1)

    accuracy_test = (Y_pred == torch.argmax(Y_test, axis=1)).sum().item() / Y_test.shape[0]

print(f'Testing Accuracy: {accuracy_test * 100}%')

Testing Accuracy: 100.0%


In [30]:
# Models directory
directory = "Models"
os.makedirs(directory, exist_ok=True)

# Save the model
file_path = "Models/HGR_Model.pth"
model.save_model(file_path)

In [31]:
# Initialize the model
model = HGRModel(X_train.shape[1], Y_train.shape[1]).to(device)

# Load the model
file_path = "Models/HGR_Model.pth"
model.load_model(file_path)

In [4]:
# Load the model
model_path = "Models/HGR_Model.pth"
model = HGRModel(42, 4)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()

# Initialize MediaPipe Hands
HandLandmarker = mp_hands.Hands(
    static_image_mode=False,
    max_num_hands=1,
    min_detection_confidence=0.8,
    min_tracking_confidence=0.5
)
classes = {
    0: "Right hand open",
    1: "Left hand open",
    2: "Right hand close",
    3: "Left hand close"
}

# Initialize the webcam
cap = cv2.VideoCapture(0)

with HandLandmarker as landmarker:
    while cap.isOpened():
        success, frame = cap.read()
        if not success:
            print("Ignoring empty camera frame.")
            continue

        # Convert the BGR image to RGB.
        frame = cv2.flip(frame, 1)
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # Detect the hand landmarks
        frame.flags.writeable = False
        results = landmarker.process(frame)

        # Draw the hand annotations on the image.
        frame.flags.writeable = True
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        key = cv2.waitKey(5) & 0xFF
        if results.multi_hand_landmarks:
            for hand_landmarks in results.multi_hand_landmarks:
                mp_drawings.draw_landmarks(
                    frame,
                    hand_landmarks,
                    mp_hands.HAND_CONNECTIONS,
                    mp_drawings.DrawingSpec(color=(97, 137, 48), thickness=2, circle_radius=4),
                    mp_drawings.DrawingSpec(color=(255, 255, 255), thickness=2, circle_radius=2),
                )

                # Convert the landmarks to a list wrt the image
                landmarks = landmarks_to_list(frame, results.multi_hand_landmarks)

                # Normalize the landmarks
                landmarks = normalize_landmarks(landmarks).reshape(1, -1)
                Y_pred = model(landmarks)
                pred_class = torch.argmax(Y_pred, axis=1).item()

                # Annotate the predicted class on the screen
                cv2.putText(frame, classes[pred_class], (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)

        cv2.imshow('MediaPipe Hands', frame)
        if key == 27: # ESC
            break

cap.release()
cv2.destroyAllWindows()