In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F

# Remove all the warnings
import warnings
warnings.filterwarnings('ignore')

# Set env CUDA_LAUNCH_BLOCKING=1
import os
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')

# Retina display
%config InlineBackend.figure_format = 'retina'

try:
    from einops import rearrange
except ImportError:
    %pip install einops
    from einops import rearrange

In [None]:
if os.path.exists('dog.jpg'):
    print('dog.jpg exists')
else:
    !wget https://segment-anything.com/assets/gallery/AdobeStock_94274587_welsh_corgi_pembroke_CD.jpg -O dog.jpg


In [None]:
from sklearn import preprocessing
img = torchvision.io.read_image("dog.jpg")

scaler_img = preprocessing.MinMaxScaler().fit(img.reshape(-1, 1))
scaler_img

In [None]:
img_scaled = scaler_img.transform(img.reshape(-1, 1)).reshape(img.shape)
img_scaled = torch.tensor(img_scaled)
img_scaled = img_scaled.to(device)

In [None]:
crop = torchvision.transforms.functional.crop(img_scaled.cpu(), 600, 800, 300, 300)
crop.shape

In [None]:
plt.imshow(rearrange(crop, 'c h w -> h w c').cpu().numpy())
print(crop.shape)

In [None]:
crop = crop.to(device)

In [None]:
num_channels, height, width = crop.shape

In [None]:
def create_coordinate_map(img, scale=1):
    """
    img: torch.Tensor of shape (num_channels, height, width)

    return: tuple of torch.Tensor of shape (height * width, 2) and torch.Tensor of shape (height * width, num_channels)
    """

    num_channels, height, width = img.shape

    # Create a 2D grid of (x,y) coordinates (h, w)
    # width values change faster than height values
    w_coords = torch.arange(0, width,  1/scale).repeat(int(height*scale), 1)
    h_coords = torch.arange(0, height, 1/scale).repeat(int(width*scale), 1).t()
    w_coords = w_coords.reshape(-1)
    h_coords = h_coords.reshape(-1)

    # Combine the x and y coordinates into a single tensor
    X = torch.stack([h_coords, w_coords], dim=1).float()

    # Move X to GPU if available
    X = X.to(device)

    # Reshape the image to (h * w, num_channels)
    Y = rearrange(img, 'c h w -> (h w) c').float()
    return X, Y

In [None]:
dog_X, dog_Y = create_coordinate_map(crop, scale=1)

dog_X.shape, dog_Y.shape

In [None]:
# MinMaxScaler from -1 to 1
scaler_X = preprocessing.MinMaxScaler(feature_range=(-1, 1)).fit(dog_X.cpu())

# Scale the X coordinates
dog_X_scaled = scaler_X.transform(dog_X.cpu())

# Move the scaled X coordinates to the GPU
dog_X_scaled = torch.tensor(dog_X_scaled).to(device)

# Set to dtype float32
dog_X_scaled = dog_X_scaled.float()

In [None]:
class LinearModel(nn.Module):
    def __init__(self, in_features, out_features):
        super(LinearModel, self).__init__()
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, x):
        return self.linear(x)


In [None]:
def train(net, lr, X, Y, epochs, verbose=True, stopping_criteria=0.00001):
    """
    net: torch.nn.Module
    lr: float
    X: torch.Tensor of shape (num_samples, 2)
    Y: torch.Tensor of shape (num_samples, 3)
    """

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    prev_loss = float('inf')

    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = net(X)

        loss = criterion(outputs, Y)
        loss.backward()
        optimizer.step()

        if verbose and epoch % 100 == 0:
            print(f"Epoch {epoch} loss: {loss.item():.6f}")

        if prev_loss - loss.item() <= stopping_criteria:
            break

        prev_loss = loss.item()

    return loss.item()


In [None]:
def plot_reconstructed_and_original_image(gt, original_img, net, X, title="" ):
    """
    net: torch.nn.Module
    X: torch.Tensor of shape (num_samples, 2)
    Y: torch.Tensor of shape (num_samples, 3)
    """
    num_channels, height, width = original_img.shape
    net.eval()
    with torch.no_grad():
        outputs = net(X)
        outputs_reshaped = outputs.reshape(height, width, num_channels)
        #outputs = outputs.permute(1, 2, 0)
    fig = plt.figure(figsize=(6, 4))
    gs = gridspec.GridSpec(1, 3, width_ratios=[1, 1,1])

    ax0 = plt.subplot(gs[0])
    ax1 = plt.subplot(gs[1])
    ax2 = plt.subplot(gs[2])

    ax0.imshow(outputs_reshaped.cpu())
    ax0.set_title("Reconstructed Image")


    ax1.imshow(original_img.cpu().permute(1, 2, 0))
    ax1.set_title("Original Image")

    ax2.imshow(gt.cpu().permute(1, 2, 0))
    ax2.set_title("GT Image")

    for a in [ax0, ax1]:
        a.axis("off")


    fig.suptitle(title, y=0.9)
    plt.tight_layout()
    return outputs

In [None]:
def create_rff_features(X, num_features, sigma):
    from sklearn.kernel_approximation import RBFSampler
    X_numpy = X.cpu().numpy()
    rff = RBFSampler(n_components=num_features, gamma=1 / (2 * sigma ** 2))
    X_transformed = rff.fit_transform(X_numpy)
    return torch.tensor(X_transformed, dtype=torch.float32)

In [None]:
def calculate_rmse(predicted, ground_truth):
    rmse = torch.sqrt(F.mse_loss(predicted, ground_truth))
    return rmse.item()


In [None]:
def calculate_psnr(predicted, ground_truth, max_val=255):
    mse = F.mse_loss(predicted, ground_truth)
    psnr = 20 * torch.log10(max_val / torch.sqrt(mse))
    return psnr.item()

In [None]:
p_value = 1
mask = torch.rand(height*width) > p_value

X_rff = create_rff_features(dog_X_scaled, 37500, 0.008)
X_rff = X_rff.to(device)
X_rff_mask = X_rff[mask]
X_rff.shape, X_rff_mask.shape

dog_Y_mask = dog_Y[mask]
dog_Y_mask = dog_Y_mask.to(device)
dog_Y_mask.shape

net = LinearModel(X_rff_mask.shape[1], 3)
net.to(device)

train(net, 0.005, X_rff_mask, dog_Y_mask, 5000)

crop = crop.to(device)
mask = mask.to(device)

outputs = plot_reconstructed_and_original_image( crop, crop*mask.reshape(300,300), net, X_rff, title="Reconstructed Image with RFF Features")

rmse_value = calculate_rmse(outputs, crop.reshape(height*width, num_channels))
psnr_value = calculate_psnr(outputs, crop.reshape(height*width, num_channels))

print(f"RMSE: {rmse_value}")
print(f"PSNR: {psnr_value}")

### 10% data missing

In [None]:
p_value = 0.1
mask = torch.rand(height*width) > p_value

X_rff = create_rff_features(dog_X_scaled, 37500, 0.008)
X_rff = X_rff.to(device)
X_rff_mask = X_rff[mask]
X_rff.shape, X_rff_mask.shape

dog_Y_mask = dog_Y[mask]
dog_Y_mask = dog_Y_mask.to(device)
dog_Y_mask.shape

net = LinearModel(X_rff_mask.shape[1], 3)
net.to(device)

train(net, 0.005, X_rff_mask, dog_Y_mask, 5000)

crop = crop.to(device)
mask = mask.to(device)

outputs = plot_reconstructed_and_original_image( crop, crop*mask.reshape(300,300), net, X_rff, title="Reconstructed Image with RFF Features with 10% data missing")

rmse_value = calculate_rmse(outputs, crop.reshape(height*width, num_channels))
psnr_value = calculate_psnr(outputs, crop.reshape(height*width, num_channels))

print(f"RMSE: {rmse_value}")
print(f"PSNR: {psnr_value}")

### 20% data missing

In [None]:
p_value = 0.2
mask = torch.rand(height*width) > p_value

X_rff = create_rff_features(dog_X_scaled, 37500, 0.008)
X_rff = X_rff.to(device)
# Remove elements where the mask is False
X_rff_mask = X_rff[mask]
X_rff.shape, X_rff_mask.shape

dog_Y_mask = dog_Y[mask]
dog_Y_mask = dog_Y_mask.to(device)
dog_Y_mask.shape

net = LinearModel(X_rff_mask.shape[1], 3)
net.to(device)

train(net, 0.005, X_rff_mask, dog_Y_mask, 5000)

crop = crop.to(device)
mask = mask.to(device)


outputs = plot_reconstructed_and_original_image( crop, crop*mask.reshape(300,300), net, X_rff, title="Reconstructed Image with RFF Features with 20% data missing")

rmse_value = calculate_rmse(outputs, crop.reshape(height*width, num_channels))
psnr_value = calculate_psnr(outputs, crop.reshape(height*width, num_channels))

print(f"RMSE: {rmse_value}")
print(f"PSNR: {psnr_value}")

### 30% data missing

In [None]:
p_value = 0.3
mask = torch.rand(height*width) > p_value

X_rff = create_rff_features(dog_X_scaled, 37500, 0.008)
X_rff = X_rff.to(device)
X_rff_mask = X_rff[mask]
X_rff.shape, X_rff_mask.shape

dog_Y_mask = dog_Y[mask]
dog_Y_mask = dog_Y_mask.to(device)
dog_Y_mask.shape

net = LinearModel(X_rff_mask.shape[1], 3)
net.to(device)

train(net, 0.005, X_rff_mask, dog_Y_mask, 5000)

crop = crop.to(device)
mask = mask.to(device)

outputs = plot_reconstructed_and_original_image( crop, crop*mask.reshape(300,300), net, X_rff, title="Reconstructed Image with RFF Features with 30% data missing")

rmse_value = calculate_rmse(outputs, crop.reshape(height*width, num_channels))
psnr_value = calculate_psnr(outputs, crop.reshape(height*width, num_channels))

print(f"RMSE: {rmse_value}")
print(f"PSNR: {psnr_value}")

### 40% data missing

In [None]:
p_value = 0.4
mask = torch.rand(height*width) > p_value

X_rff = create_rff_features(dog_X_scaled, 37500, 0.008)
X_rff = X_rff.to(device)
X_rff_mask = X_rff[mask]
X_rff.shape, X_rff_mask.shape

dog_Y_mask = dog_Y[mask]
dog_Y_mask = dog_Y_mask.to(device)
dog_Y_mask.shape

net = LinearModel(X_rff_mask.shape[1], 3)
net.to(device)

train(net, 0.005, X_rff_mask, dog_Y_mask, 5000)

crop = crop.to(device)
mask = mask.to(device)

outputs = plot_reconstructed_and_original_image( crop, crop*mask.reshape(300,300), net, X_rff, title="Reconstructed Image with RFF Features with 40% data missing")

rmse_value = calculate_rmse(outputs, crop.reshape(height*width, num_channels))
psnr_value = calculate_psnr(outputs, crop.reshape(height*width, num_channels))

print(f"RMSE: {rmse_value}")
print(f"PSNR: {psnr_value}")

### 50% data missing

In [None]:
p_value = 0.5
mask = torch.rand(height*width) > p_value

X_rff = create_rff_features(dog_X_scaled, 37500, 0.008)
X_rff = X_rff.to(device)
X_rff_mask = X_rff[mask]
X_rff.shape, X_rff_mask.shape

dog_Y_mask = dog_Y[mask]
dog_Y_mask = dog_Y_mask.to(device)
dog_Y_mask.shape

net = LinearModel(X_rff_mask.shape[1], 3)
net.to(device)

train(net, 0.005, X_rff_mask, dog_Y_mask, 5000)

crop = crop.to(device)
mask = mask.to(device)

outputs = plot_reconstructed_and_original_image( crop, crop*mask.reshape(300,300), net, X_rff, title="Reconstructed Image with RFF Features with 50% data missing")

rmse_value = calculate_rmse(outputs, crop.reshape(height*width, num_channels))
psnr_value = calculate_psnr(outputs, crop.reshape(height*width, num_channels))

print(f"RMSE: {rmse_value}")
print(f"PSNR: {psnr_value}")

### 60% data missing

In [None]:
p_value = 0.6
mask = torch.rand(height*width) > p_value

X_rff = create_rff_features(dog_X_scaled, 37500, 0.008)
X_rff = X_rff.to(device)
X_rff_mask = X_rff[mask]
X_rff.shape, X_rff_mask.shape

dog_Y_mask = dog_Y[mask]
dog_Y_mask = dog_Y_mask.to(device)
dog_Y_mask.shape

net = LinearModel(X_rff_mask.shape[1], 3)
net.to(device)

train(net, 0.005, X_rff_mask, dog_Y_mask, 5000)

crop = crop.to(device)
mask = mask.to(device)

outputs = plot_reconstructed_and_original_image( crop, crop*mask.reshape(300,300), net, X_rff, title="Reconstructed Image with RFF Features with 60% data missing")

rmse_value = calculate_rmse(outputs, crop.reshape(height*width, num_channels))
psnr_value = calculate_psnr(outputs, crop.reshape(height*width, num_channels))

print(f"RMSE: {rmse_value}")
print(f"PSNR: {psnr_value}")

### 70% data missing

In [None]:
p_value = 0.7
mask = torch.rand(height*width) > p_value

X_rff = create_rff_features(dog_X_scaled, 37500, 0.008)
X_rff = X_rff.to(device)
X_rff_mask = X_rff[mask]
X_rff.shape, X_rff_mask.shape

dog_Y_mask = dog_Y[mask]
dog_Y_mask = dog_Y_mask.to(device)
dog_Y_mask.shape

net = LinearModel(X_rff_mask.shape[1], 3)
net.to(device)

train(net, 0.005, X_rff_mask, dog_Y_mask, 5000)

crop = crop.to(device)
mask = mask.to(device)

outputs = plot_reconstructed_and_original_image( crop, crop*mask.reshape(300,300), net, X_rff, title="Reconstructed Image with RFF Features with 70% data missing")

rmse_value = calculate_rmse(outputs, crop.reshape(height*width, num_channels))
psnr_value = calculate_psnr(outputs, crop.reshape(height*width, num_channels))

print(f"RMSE: {rmse_value}")
print(f"PSNR: {psnr_value}")

### 80% data missing

In [None]:
p_value = 0.8
mask = torch.rand(height*width) > p_value

X_rff = create_rff_features(dog_X_scaled, 37500, 0.008)
X_rff = X_rff.to(device)
X_rff_mask = X_rff[mask]
X_rff.shape, X_rff_mask.shape

dog_Y_mask = dog_Y[mask]
dog_Y_mask = dog_Y_mask.to(device)
dog_Y_mask.shape

net = LinearModel(X_rff_mask.shape[1], 3)
net.to(device)

train(net, 0.005, X_rff_mask, dog_Y_mask, 5000)

crop = crop.to(device)
mask = mask.to(device)

outputs = plot_reconstructed_and_original_image( crop, crop*mask.reshape(300,300), net, X_rff, title="Reconstructed Image with RFF Features with 80% data missing")

rmse_value = calculate_rmse(outputs, crop.reshape(height*width, num_channels))
psnr_value = calculate_psnr(outputs, crop.reshape(height*width, num_channels))

print(f"RMSE: {rmse_value}")
print(f"PSNR: {psnr_value}")

### 90% data missing

In [None]:
p_value = 0.9
mask = torch.rand(height*width) > p_value

X_rff = create_rff_features(dog_X_scaled, 37500, 0.008)
X_rff = X_rff.to(device)
X_rff_mask = X_rff[mask]
X_rff.shape, X_rff_mask.shape

dog_Y_mask = dog_Y[mask]
dog_Y_mask = dog_Y_mask.to(device)
dog_Y_mask.shape

net = LinearModel(X_rff_mask.shape[1], 3)
net.to(device)

train(net, 0.005, X_rff_mask, dog_Y_mask, 5000)

crop = crop.to(device)
mask = mask.to(device)

outputs = plot_reconstructed_and_original_image( crop, crop*mask.reshape(300,300), net, X_rff, title="Reconstructed Image with RFF Features with 90% data missing")

rmse_value = calculate_rmse(outputs, crop.reshape(height*width, num_channels))
psnr_value = calculate_psnr(outputs, crop.reshape(height*width, num_channels))

print(f"RMSE: {rmse_value}")
print(f"PSNR: {psnr_value}")