Image Reconstruction

Choose any image you like. Use Random Fourier Features (RFF) and Linear Regression to learn the mapping from the image coordinates (X, Y) to the pixel colors (R, G, B). Here, (X, Y) represents the coordinates of the pixels, and (R, G, B) represents the color values at those coordinates.

1. **Load Image**: Select any image of your choice.
2. **Random Fourier Features (RFF)**: Implement RFF to map pixel coordinates to color values.
3. **Linear Regression**: Use linear regression to learn the mapping.
4. **Display Results**: Show both the original and reconstructed images.
5. **Metrics**: Calculate the Root Mean Squared Error (MSE) and Peak Signal-to-Noise Ratio (PSNR) between the original and reconstructed images.

**Key Variables**:
- X, Y: Pixel coordinates.
- R, G, B: Pixel color values.


#### Importing necessary libraries

In [None]:
import torch
print(torch.__version__)

import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F

# Remove all the warnings
import warnings
warnings.filterwarnings('ignore')

# Set env CUDA_LAUNCH_BLOCKING=1
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Retina display
%config InlineBackend.figure_format = 'retina'

try:
    from einops import rearrange
except ImportError:
    %pip install einops

    from einops import rearrange

In [None]:
import torch
print(torch.cuda.get_device_name(0))
print("CUDA available:", torch.cuda.is_available())

#### Importing image

In [None]:
import os

# Define the directory path
directory_path = '../assets/images/'

# Check if the directory exists, and create it if it does not
if not os.path.exists(directory_path):
    os.makedirs(directory_path)

# Check if the file exists
if os.path.exists(os.path.join(directory_path, 'dog.jpg')):
    print('dog.jpg exists')
else:
    # Download the file if it does not exist
    !wget https://segment-anything.com/assets/gallery/AdobeStock_94274587_welsh_corgi_pembroke_CD.jpg -O ../assets/images/dog.jpg


In [None]:
# Read in a image from torchvision
img = torchvision.io.read_image("../assets/images/dog.jpg")
print(img.shape)

In [None]:
plt.imshow(rearrange(img, 'c h w -> h w c').numpy())
plt.axis('off')
plt.show()

In [None]:
# # Read in a image from torchvision
# iitgn_img = torchvision.io.read_image("../assets/images/iitgn.jpg")
# print(iitgn_img.shape)
# plt.imshow(rearrange(iitgn_img, 'c h w -> h w c').numpy())


In [None]:
from sklearn import preprocessing

scaler_img = preprocessing.MinMaxScaler().fit(img.reshape(-1, 1))
scaler_img

In [None]:

img_scaled = scaler_img.transform(img.reshape(-1, 1)).reshape(img.shape)
print(img_scaled.shape)

img_scaled = torch.tensor(img_scaled)


In [None]:
img_scaled = img_scaled.to(device)
img_scaled

#### Crop the image

In [None]:
crop = torchvision.transforms.functional.crop(img_scaled.cpu(), 600, 800, 300, 300)
crop.shape

In [None]:
plt.imshow(rearrange(crop, 'c h w -> h w c').cpu().numpy())
plt.axis('off')
plt.show()

In [None]:
crop = crop.to(device)

#### Create a coordinate map

In [None]:
def create_coordinate_map(img):
    """
    img: torch.Tensor of shape (num_channels, height, width)

    return: tuple of torch.Tensor of shape (height * width, 2) and torch.Tensor of shape (height * width, num_channels)
    """

    num_channels, height, width = img.shape
    print("Number of channels:", num_channels, "\nHeight:", height, "\nWidth:", width)
    # Create a 2D grid of (x,y) coordinates (h, w)
    # width values change faster than height values
    w_coords = torch.arange(width).repeat(height, 1)
    h_coords = torch.arange(height).repeat(width, 1).t()
    w_coords = w_coords.reshape(-1)
    h_coords = h_coords.reshape(-1)

    # Combine the x and y coordinates into a single tensor
    X = torch.stack([h_coords, w_coords], dim=1).float()

    # Move X to GPU if available
    X = X.to(device)
    print("X shape:", X.shape)
    # Reshape the image to (h * w, num_channels)
    Y = rearrange(img, 'c h w -> (h w) c').float()
    print("Y shape:", Y.shape)

    return X, Y

In [None]:
dog_X, dog_Y = create_coordinate_map(crop)

print(dog_X) # (300*300, 2)- coordinates
print(dog_Y) # (300*300, 3)- RGB values

In [None]:
# MinMaxScaler from -1 to 1
scaler_X = preprocessing.MinMaxScaler(feature_range=(-1, 1)).fit(dog_X.cpu())

# Scale the X coordinates
dog_X_scaled = scaler_X.transform(dog_X.cpu())

# Move the scaled X coordinates to the GPU
dog_X_scaled = torch.tensor(dog_X_scaled).to(device)

# Set to dtype float32
dog_X_scaled = dog_X_scaled.float()

#### Reconstructing using Linear Model

In [None]:
torch.manual_seed(42) # Set seed for reproducibility

class LinearModel(nn.Module):
    def __init__(self, in_features, out_features):
        super(LinearModel, self).__init__()
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, x):
        return self.linear(x)


In [None]:
net = LinearModel(2, 3)
net.to(device)
print(net)


In [None]:
print("Weights:", net.linear.weight)
print("Bias:", net.linear.bias)

In [None]:
def train(net, lr, X, Y, epochs, verbose=True):
    """
    net: torch.nn.Module
    lr: float
    X: torch.Tensor of shape (num_samples, 2)
    Y: torch.Tensor of shape (num_samples, 3)
    """
    losses = []
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    for epoch in range(1,epochs+1):
        optimizer.zero_grad()
        outputs = net(X)
        loss = criterion(outputs, Y)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
        if verbose and epoch % 100 == 0:
            print(f"Epoch {epoch} loss: {loss.item():.6f}")
    return loss.item(),losses

In [None]:
train_loss,training_losses = train(net, 0.01, dog_X_scaled, dog_Y, 1000)

In [None]:
'''
Code snippet copied from [https://github.com/nipunbatra/ml-teaching/blob/e2cd59d3e3358473ebfcf70e71d70361bb4501b4/latexify.py#L9] by Nipun Batra
Date: 14 th August , 2024

Changes :
'text.latex.preamble': '\\usepackage{gensymb}',
'text.usetex': False,

'''


import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib

from math import sqrt
SPINE_COLOR = 'gray'

def latexify(fig_width=None, fig_height=None, columns=1):
    """Set up matplotlib's RC params for LaTeX plotting.
    Call this before plotting a figure.

    Parameters
    ----------
    fig_width : float, optional, inches
    fig_height : float,  optional, inches
    columns : {1, 2}
    """

    # code adapted from http://www.scipy.org/Cookbook/Matplotlib/LaTeX_Examples

    # Width and max height in inches for IEEE journals taken from
    # computer.org/cms/Computer.org/Journal%20templates/transactions_art_guide.pdf

    assert(columns in [1,2])

    if fig_width is None:
        fig_width = 3.39 if columns==1 else 6.9 # width in inches

    if fig_height is None:
        golden_mean = (sqrt(5)-1.0)/2.0    # Aesthetic ratio
        fig_height = fig_width*golden_mean # height in inches

    MAX_HEIGHT_INCHES = 8.0
    if fig_height > MAX_HEIGHT_INCHES:
        print("WARNING: fig_height too large:" + fig_height + 
              "so will reduce to" + MAX_HEIGHT_INCHES + "inches.")
        fig_height = MAX_HEIGHT_INCHES

    params = {'backend': 'ps',
              'text.latex.preamble': '\\usepackage{gensymb}',
              'axes.labelsize': 8, # fontsize for x and y labels (was 10)
              'axes.titlesize': 8,
              'font.size': 8, # was 10
              'legend.fontsize': 8, # was 10
              'xtick.labelsize': 8,
              'ytick.labelsize': 8,
              'text.usetex': False,
              'figure.figsize': [fig_width,fig_height],
              'font.family': 'serif'
    }

    matplotlib.rcParams.update(params)


def format_axes(ax):

    for spine in ['top', 'right']:
        ax.spines[spine].set_visible(False)

    for spine in ['left', 'bottom']:
        ax.spines[spine].set_color(SPINE_COLOR)
        ax.spines[spine].set_linewidth(0.5)

    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')

    for axis in [ax.xaxis, ax.yaxis]:
        axis.set_tick_params(direction='out', color=SPINE_COLOR)

    return ax



In [None]:
# plot training loss graph
plt.figure(figsize=(10, 5))
latexify()
format_axes(plt.gca())
plt.plot(training_losses)
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()



In [None]:
def plot_reconstructed_and_original_image(original_img, net, X, title=""):
    """
    net: torch.nn.Module
    X: torch.Tensor of shape (num_samples, 2)
    Y: torch.Tensor of shape (num_samples, 3)
    """
    num_channels, height, width = original_img.shape
    net.eval()
    with torch.no_grad():
        outputs = net(X)
        outputs = outputs.reshape(height, width, num_channels)
        #outputs = outputs.permute(1, 2, 0)
    fig = plt.figure(figsize=(6, 4))
    gs = gridspec.GridSpec(1, 2, width_ratios=[1, 1])

    ax0 = plt.subplot(gs[0])
    ax1 = plt.subplot(gs[1])

    ax0.imshow(outputs.cpu())
    ax0.set_title("Reconstructed Image")


    ax1.imshow(original_img.cpu().permute(1, 2, 0))
    ax1.set_title("Original Image")

    for a in [ax0, ax1]:
        a.axis("off")


    fig.suptitle(title, y=0.9)
    plt.tight_layout()

In [None]:
plot_reconstructed_and_original_image(crop, net, dog_X_scaled, title="Reconstructed Image")

#### Using Polynomial Basis Functions

In [None]:
# Use polynomial features of degree "d"

def poly_features(X, degree):
    """
    X: torch.Tensor of shape (num_samples, 2)
    degree: int

    return: torch.Tensor of shape (num_samples, degree * (degree + 1) / 2)
    """
    X1 = X[:, 0]
    X2 = X[:, 1]
    X1 = X1.unsqueeze(1)
    X2 = X2.unsqueeze(1)
    X = torch.cat([X1, X2], dim=1)
    poly = preprocessing.PolynomialFeatures(degree=degree)
    X = poly.fit_transform(X.cpu())
    return torch.tensor(X, dtype=torch.float32).to(device)

In [None]:
def normalize_image(image):
    """Normalize image to [0, 1] range for float images."""
    image_min = image.min()
    image_max = image.max()
    return (image - image_min) / (image_max - image_min)

def sigmoid(x):
    return 1 / (1 + torch.exp(-x))

def apply_sigmoid(image_tensor):
    """Apply sigmoid to image tensor to normalize it to [0, 1] range."""
    return sigmoid(image_tensor)

def clip(image_tensor):
    """Clip image tensor to [0, 1] range."""
    return torch.clamp(image_tensor, 0, 1)



In [None]:
# Define polynomial degrees to test
degrees = [5, 10, 50, 100]

# Initialize lists to store training losses and images for each degree
training_losses = []
original_images = []
reconstructed_clipped_images = []

for degree in degrees:
    # Create polynomial features
    dog_X_scaled_poly = poly_features(dog_X_scaled, degree)
    print(f"Degree {degree}: {dog_X_scaled_poly.dtype}, {dog_X_scaled_poly.shape}, {dog_Y.shape}, {dog_Y.dtype}")

    # Initialize and train the model
    net = LinearModel(dog_X_scaled_poly.shape[1], 3)
    net.to(device)

    # Train the model
    train_poly_loss, losses = train(net, 0.005, dog_X_scaled_poly, dog_Y, 1500)
    training_losses.append(losses)

    # Generate reconstructed image
    with torch.no_grad():
        output = net(dog_X_scaled_poly)

    # Reshape output and apply transformations
    reconstructed_image = output.cpu().reshape(crop.shape[1], crop.shape[2], -1)
    clipped_image = clip(reconstructed_image)

    # Append images to lists with correct shape
    original_images.append(rearrange(crop, 'c h w -> h w c').cpu().numpy())
    reconstructed_clipped_images.append(clipped_image.numpy())

In [None]:
print(original_images[0].shape)
print(reconstructed_clipped_images[0].shape)

In [None]:
print(reconstructed_clipped_images[0][0,:5,:])

In [None]:
# Create a figure with subplots for loss and images
fig, axs = plt.subplots(len(degrees),3, figsize=(10, len(degrees) * 3))

latexify()


for i, degree in enumerate(degrees):
    # Plot training loss
    axs[i, 0].plot(training_losses[i])
    axs[i, 0].set_title(f"Training Loss for Degree {degree}")
    axs[i, 0].set_xlabel("Epoch")
    axs[i, 0].set_ylabel("Loss")
    format_axes(axs[i, 0])

    # Plot original image
    axs[i, 1].imshow(original_images[0])
    axs[i, 1].set_title(f"Original Image")
    axs[i, 1].axis('off')

    # Plot reconstructed clipped image
    axs[i, 2].imshow(reconstructed_clipped_images[i])
    axs[i, 2].set_title(f"Reconstructed Clipped Image (Degree {degree})")
    axs[i, 2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Plot all training losses in a single plot
plt.figure(figsize=(12, 6))
for i, degree in enumerate(degrees):
    plt.plot(training_losses[i], label=f'Degree {degree}')
plt.title("Training Loss for Different Polynomial Degrees")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

#### Reconstructing using Random Fourier Features

In [None]:
# create RFF features
def create_rff_features(X, num_features, sigma):
    from sklearn.kernel_approximation import RBFSampler
    rff = RBFSampler(n_components=num_features, gamma=1/(2 * sigma**2))
    X = X.cpu().numpy()
    X = rff.fit_transform(X)
    return torch.tensor(X, dtype=torch.float32).to(device)


In [None]:
n_features = 37500
sigma = 0.008
X_rff = create_rff_features(dog_X_scaled, n_features, sigma)
print(X_rff.shape)

In [None]:
net = LinearModel(X_rff.shape[1], 3)
# Move model to GPUs if available
if torch.cuda.device_count() > 1:
    net = nn.DataParallel(net)  # Wrap model to use multiple GPUs
net.to(device)

train_rff_loss , train_rff_losses = train(net, 0.005, X_rff, dog_Y, 2500)

In [None]:
# plot training loss graph
plt.figure(figsize=(10, 5))
latexify()
format_axes(plt.gca())
plt.plot(train_rff_losses)
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()



In [None]:
plot_reconstructed_and_original_image(crop, net, X_rff, title="Reconstructed Image with RFF Features")

In [None]:
# Define a function to train the model and return the loss and the trained network
def train_model(X_data, features, sigma, learning_rate=0.005, epochs=2500):
    # Generate random Fourier features
    X_rff = generate_rff(X_data, features, sigma)
    
    # Initialize and train the model
    net = LinearModel(X_rff.shape[1], 3)
    net.to(device)
    train_rff_loss, train_rff_losses = train(net, learning_rate, X_rff, dog_Y, epochs)
    
    return net, train_rff_losses, X_rff

# Define a function to plot the training loss
def plot_training_loss(losses, features, sigma):
    plt.figure(figsize=(10, 5))
    latexify()
    format_axes(plt.gca())
    plt.plot(losses)
    plt.title(f"Training Loss (Features: {features}, Sigma: {sigma})")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.show()

# Define a function to plot the reconstructed image
def plot_reconstructed_image(crop, net, X_rff, features, sigma):
    plot_reconstructed_and_original_image(crop, net, X_rff, title=f"Reconstructed Image (Features: {features}, Sigma: {sigma})")

# Main function to test multiple combinations
def test_feature_sigma_combinations(X_data, crop, combinations, learning_rate=0.005, epochs=2500):
    for comb in combinations:
        features = comb["features"]
        sigma = comb["sigma"]
        
        # Train the model and get losses
        net, train_rff_losses, X_rff = train_model(X_data, features, sigma, learning_rate, epochs)
        
        # Plot the training loss
        plot_training_loss(train_rff_losses, features, sigma)
        
        # Plot the reconstructed image
        plot_reconstructed_image(crop, net, X_rff, features, sigma)

# Define different feature and sigma combinations
feature_sigma_combinations = [
    {"features": 5000, "sigma": 0.05},
    {"features": 5000, "sigma": 0.1},
    {"features": 15000, "sigma": 0.05},
    {"features": 15000, "sigma": 0.1},
]

# Run the test for the defined combinations
test_feature_sigma_combinations(X_data, crop, feature_sigma_combinations)
