In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import torch.optim as optim

import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import os
from einops import rearrange
from sklearn import preprocessing

Creating Dataset

In [130]:
img = torchvision.io.read_image('Datasets/dog.jpg')
scaler_img = preprocessing.MinMaxScaler().fit(img.reshape(-1, 1))
img_scaled = scaler_img.transform(img.reshape(-1, 1)).reshape(img.shape)
img_scaled = torch.tensor(img_scaled)
cropped_img = torchvision.transforms.functional.crop(img_scaled,600,800,300,300)
cropped_img = rearrange(cropped_img,'c h w -> h w c')
cropped_img = torch.tensor(cropped_img,dtype = torch.float32)
original_img = cropped_img.clone()

Matrix Factorization

In [3]:
def mask_image_patch(img,x,y,z,patch_size):
    img_copy = img.clone()
    for i in range(patch_size):
        for j in range(patch_size):
                for k in range(z):
                    img_copy[x+i][y+j][k] = torch.nan
    return img_copy

def factorize_matrix(A, r):
    mask = ~torch.isnan(A)
    m,n = A.shape
    W = torch.rand(m,r,requires_grad=True)
    H = torch.rand(r,n,requires_grad=True)
    optimizer = optim.Adam([W,H],lr = 0.01)

    max_epochs = 1000
    for i in range(max_epochs):
        loss = torch.norm((A - torch.mm(W,H))[mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return W,H, loss

def image_reconstrunction_matrix_factorization(original_img,masked_img):
    W,H,loss = factorize_matrix(rearrange(masked_img, 'h w c -> h (w c)'), 50)
    reconstructed_img = torch.mm(W,H).detach()
    reconstructed_img = reconstructed_img.reshape(masked_img.shape[0],masked_img.shape[1],masked_img.shape[2])

    # scaler_img = preprocessing.MinMaxScaler().fit(reconstructed_img.reshape(-1, 1))
    # reconstructed_img = scaler_img.transform(reconstructed_img.reshape(-1, 1)).reshape(masked_img.shape)
    # reconstructed_img = torch.tensor(reconstructed_img)
    
    fig,axs = plt.subplots(1,2)
    axs[0][0].imshow(masked_img)
    axs[0][1].imshow(reconstructed_img)
    # axs[1][0].imshow(cropped_img[x_y_s[i][0]:x_y_s[i][0]+patch_sizes[i],x_y_s[i][1]:x_y_s[i][1]+patch_sizes[i],:])
    # axs[1][1].imshow(reconstructed_img[x_y_s[i][0]:x_y_s[i][0]+patch_sizes[i],x_y_s[i][1]:x_y_s[i][1]+patch_sizes[i],:])

Linear Regression + RFF

In [138]:
def plot_reconstructed_and_original_image(original_img, masked_img, net, X, title=""):
    """
    net: torch.nn.Module
    X: torch.Tensor of shape (num_samples, 2)
    Y: torch.Tensor of shape (num_samples, 3)
    """
    height, width, num_channels = original_img.shape
    net.eval()
    with torch.no_grad():
        outputs = net(X).reshape(height,width,num_channels)
        #outputs = outputs.permute(1, 2, 0)
    fig = plt.figure(figsize=(6, 4))
    gs = gridspec.GridSpec(1, 3, width_ratios=[1, 1, 1])

    ax0 = plt.subplot(gs[0])
    ax1 = plt.subplot(gs[1])
    ax2 = plt.subplot(gs[2])

    ax0.imshow(outputs)
    ax0.set_title("Reconstructed Image")
    

    ax1.imshow(original_img.cpu())
    ax1.set_title("Original Image")

    ax2.imshow(masked_img.cpu())
    ax2.set_title("Masked Image")

    
    for a in [ax0, ax1, ax2]:
        a.axis("off")

def create_rff_features(X, num_features, sigma):
    from sklearn.kernel_approximation import RBFSampler
    rff = RBFSampler(n_components=num_features, gamma=1/(2 * sigma**2))
    X = X.cpu().numpy()
    X = rff.fit_transform(X)
    return torch.tensor(X, dtype=torch.float32)

def train(net, lr, X, Y, epochs, verbose=True):
    """
    net: torch.nn.Module
    lr: float
    X: torch.Tensor of shape (known_pixels, 2) // (x,y)
    Y: torch.Tensor of shape (known_pixels, 3) // (r,g,b)
    """

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    outputs = net(X)
    loss = criterion(outputs, Y)
    if verbose :
            print(f"Initial loss: {loss.item():.6f}")
            verbose = 2
    
    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = net(X)
        loss = criterion(outputs, Y)
        loss.backward()
        optimizer.step()
    if verbose :
        print(f"Final loss: {loss.item():.6f}")
    return loss.item()

def mask_image_patch(img,x,y,z,patch_size):
    img_copy = img.clone()
    for i in range(patch_size):
        for j in range(patch_size):
                for k in range(z):
                    img_copy[x+i][y+j][k] = torch.nan
    return img_copy

def create_mask(t,x,y,patch_size):
    mask = torch.full(t.shape,True)
    z = t.shape[2]
    for i in range(patch_size):
        for j in range(patch_size):
                for k in range(z):
                    mask[x+i][y+j][k] = False
    return mask

class LinearModel(nn.Module):
    def __init__(self, in_features, out_features):
        super(LinearModel, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        
    def forward(self, x):
        return self.linear(x)
    
def create_coordinate_map(img):
    """
    img: torch.Tensor of shape (num_channels, height, width)
    
    return: tuple of torch.Tensor of shape (height * width, 2) and torch.Tensor of shape (height * width, num_channels)
    """
    
    height, width, num_channels = img.shape
    X = torch.empty((img.shape[0],img.shape[1],2))
    
    for i in range(height):
        for j in range(width):
            X[i][j][0] = i
            X[i][j][1] = j
    return X.reshape(-1,2)

def stack_itself(t, n):
    # Expand the tensor along a new dimension
    stacked_t = t.unsqueeze(1).expand(-1, n, -1)
    return stacked_t.squeeze()

def scale(img):
    # MinMaxScaler from -1 to 1
    scaler_X = preprocessing.MinMaxScaler(feature_range=(-1, 1)).fit(img)
    img_scaled = scaler_X.transform(img)
    img_scaled = torch.tensor(img_scaled)
    img_scaled = img_scaled.float()
    return img_scaled

def image_reconstrunction_linear_rff(original_img,masked_img):
    
    y = original_img.clone().reshape(-1,3)
    X = create_coordinate_map(original_img)
    mask = ~torch.isnan(masked_img).reshape(-1,3)
    X_train = X[mask[:,0:2]].reshape(-1,2)
    y_train = y[mask].reshape(-1,3)

    # Without RFF
    # net = LinearModel(2,3)
    # train(net,0.01,X_train,y_train,1000)
    # plot_reconstructed_and_original_image(original_img, net, X, title="Reconstructed Image")

    # With RFF
    X_rff = create_rff_features(X, 100, 0.008)
    mask_rff = stack_itself(mask[:,0].unsqueeze(1),100)
    X_rff_train = X_rff[mask_rff].reshape(-1,100)
    y_rff_train = y[mask].reshape(-1,3)

    netrff = LinearModel(X_rff_train.shape[1], 3)
    train(netrff, 0.005, X_rff_train, y_rff_train, 1000)
    plot_reconstructed_and_original_image(original_img, masked_img, netrff, X_rff, title="Reconstructed Image with RFF Features")

Reconstructions

In [None]:
patch_size = 30
x_y_s = [[10,10,3],[90,150,3],[140,60,3]]

for i in range(len(x_y_s)):
    masked_img = mask_image_patch(original_img,x_y_s[0],x_y_s[1],3,patch_size)
    image_reconstrunction_matrix_factorization()

In [None]:
x_start = 50
y_start = 50
# n1 = [30]
patch_sizes = [20,40,60,80,100]
for i in range(len(patch_sizes)):
    masked_img = mask_image_patch(original_img,x_start,y_start,3,patch_sizes[i])
    image_reconstrunction_matrix_factorization(original_img,masked_img)

In [None]:
x_start = 50
y_start = 50
# n1 = [30]
patch_sizes = [20,40,60,80,100]
for i in range(len(patch_sizes)):
    masked_img = mask_image_patch(original_img,x_start,y_start,3,patch_sizes[i])
    image_reconstrunction_linear_rff(original_img,masked_img)

In [None]:
patch_size = 30
x_y_s = [[10,10],[90,150],[140,60]]
for i in range(len(x_y_s)):
    masked_img = mask_image_patch(original_img,x_y_s[0],x_y_s[1],3,patch_size)
    image_reconstrunction_linear_rff(original_img,masked_img)