Image Reconstruction

Choose any image you like. Use Random Fourier Features (RFF) and Linear Regression to learn the mapping from the image coordinates (X, Y) to the pixel colors (R, G, B). Here, (X, Y) represents the coordinates of the pixels, and (R, G, B) represents the color values at those coordinates.

1. **Load Image**: Select any image of your choice.
2. **Random Fourier Features (RFF)**: Implement RFF to map pixel coordinates to color values.
3. **Linear Regression**: Use linear regression to learn the mapping.
4. **Display Results**: Show both the original and reconstructed images.
5. **Metrics**: Calculate the Root Mean Squared Error (MSE) and Peak Signal-to-Noise Ratio (PSNR) between the original and reconstructed images.

**Key Variables**:
- X, Y: Pixel coordinates.
- R, G, B: Pixel color values.


### Latexify

In [None]:
'''
Code snippet copied from [https://github.com/nipunbatra/ml-teaching/blob/e2cd59d3e3358473ebfcf70e71d70361bb4501b4/latexify.py#L9] by Nipun Batra
Date: 14 th August , 2024

Changes :
'text.latex.preamble': '\\usepackage{gensymb}',
'text.usetex': False,

'''


import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib

from math import sqrt
SPINE_COLOR = 'gray'

def latexify(fig_width=None, fig_height=None, columns=1):
    """Set up matplotlib's RC params for LaTeX plotting.
    Call this before plotting a figure.

    Parameters
    ----------
    fig_width : float, optional, inches
    fig_height : float,  optional, inches
    columns : {1, 2}
    """

    # code adapted from http://www.scipy.org/Cookbook/Matplotlib/LaTeX_Examples

    # Width and max height in inches for IEEE journals taken from
    # computer.org/cms/Computer.org/Journal%20templates/transactions_art_guide.pdf

    assert(columns in [1,2])

    if fig_width is None:
        fig_width = 3.39 if columns==1 else 6.9 # width in inches

    if fig_height is None:
        golden_mean = (sqrt(5)-1.0)/2.0    # Aesthetic ratio
        fig_height = fig_width*golden_mean # height in inches

    MAX_HEIGHT_INCHES = 8.0
    if fig_height > MAX_HEIGHT_INCHES:
        print("WARNING: fig_height too large:" + fig_height + 
              "so will reduce to" + MAX_HEIGHT_INCHES + "inches.")
        fig_height = MAX_HEIGHT_INCHES

    params = {'backend': 'ps',
              'text.latex.preamble': '\\usepackage{gensymb}',
              'axes.labelsize': 8, # fontsize for x and y labels (was 10)
              'axes.titlesize': 8,
              'font.size': 8, # was 10
              'legend.fontsize': 8, # was 10
              'xtick.labelsize': 8,
              'ytick.labelsize': 8,
              'text.usetex': False,
              'figure.figsize': [fig_width,fig_height],
              'font.family': 'serif'
    }

    matplotlib.rcParams.update(params)


def format_axes(ax):

    for spine in ['top', 'right']:
        ax.spines[spine].set_visible(False)

    for spine in ['left', 'bottom']:
        ax.spines[spine].set_color(SPINE_COLOR)
        ax.spines[spine].set_linewidth(0.5)

    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')

    for axis in [ax.xaxis, ax.yaxis]:
        axis.set_tick_params(direction='out', color=SPINE_COLOR)

    return ax



## Importing necessary libraries

In [None]:
import torch
print(torch.__version__)

import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F

# Remove all the warnings
import warnings
warnings.filterwarnings('ignore')

# Set env CUDA_LAUNCH_BLOCKING=1
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Retina display
%config InlineBackend.figure_format = 'retina'

try:
    from einops import rearrange
except ImportError:
    %pip install einops

    from einops import rearrange

In [None]:
import torch
print(torch.cuda.get_device_name(0))
print("CUDA available:", torch.cuda.is_available())

## Importing image

In [None]:
import os

# Define the directory path
directory_path = '../assets/images/'

# Check if the directory exists, and create it if it does not
if not os.path.exists(directory_path):
    os.makedirs(directory_path)

# Check if the file exists
if os.path.exists(os.path.join(directory_path, 'dog.jpg')):
    print('dog.jpg exists')
else:
    # Download the file if it does not exist
    !wget https://segment-anything.com/assets/gallery/AdobeStock_94274587_welsh_corgi_pembroke_CD.jpg -O ../assets/images/dog.jpg


In [None]:
# Read in a image from torchvision
img = torchvision.io.read_image("../assets/images/dog.jpg")
print(img.shape)

In [None]:
plt.imshow(rearrange(img, 'c h w -> h w c').numpy())
plt.axis('off')
plt.show()

In [None]:
# # Read in a image from torchvision
# iitgn_img = torchvision.io.read_image("../assets/images/iitgn.jpg")
# print(iitgn_img.shape)
# plt.imshow(rearrange(iitgn_img, 'c h w -> h w c').numpy())


In [None]:
from sklearn import preprocessing

scaler_img = preprocessing.MinMaxScaler().fit(img.reshape(-1, 1))
scaler_img

In [None]:

img_scaled = scaler_img.transform(img.reshape(-1, 1)).reshape(img.shape)
print(img_scaled.shape)

img_scaled = torch.tensor(img_scaled)


In [None]:
img_scaled = img_scaled.to(device)
img_scaled

## Crop the image

In [None]:
crop = torchvision.transforms.functional.crop(img_scaled.cpu(), 600, 800, 300, 300)
crop.shape

In [None]:
plt.imshow(rearrange(crop, 'c h w -> h w c').cpu().numpy())
plt.axis('off')
plt.show()

In [None]:
crop = crop.to(device)

## Create a coordinate map

In [None]:
def create_coordinate_map(img):
    """
    img: torch.Tensor of shape (num_channels, height, width)

    return: tuple of torch.Tensor of shape (height * width, 2) and torch.Tensor of shape (height * width, num_channels)
    """

    num_channels, height, width = img.shape
    print("Number of channels:", num_channels, "\nHeight:", height, "\nWidth:", width)
    # Create a 2D grid of (x,y) coordinates (h, w)
    # width values change faster than height values
    w_coords = torch.arange(width).repeat(height, 1)
    h_coords = torch.arange(height).repeat(width, 1).t()
    w_coords = w_coords.reshape(-1)
    h_coords = h_coords.reshape(-1)

    # Combine the x and y coordinates into a single tensor
    X = torch.stack([h_coords, w_coords], dim=1).float()

    # Move X to GPU if available
    X = X.to(device)
    print("X shape:", X.shape)
    # Reshape the image to (h * w, num_channels)
    Y = rearrange(img, 'c h w -> (h w) c').float()
    print("Y shape:", Y.shape)

    return X, Y

In [None]:
dog_X, dog_Y = create_coordinate_map(crop)

print(dog_X) # (300*300, 2)- coordinates
print(dog_Y) # (300*300, 3)- RGB values

In [None]:
# MinMaxScaler from -1 to 1
scaler_X = preprocessing.MinMaxScaler(feature_range=(-1, 1)).fit(dog_X.cpu())

# Scale the X coordinates
dog_X_scaled = scaler_X.transform(dog_X.cpu())

# Move the scaled X coordinates to the GPU
dog_X_scaled = torch.tensor(dog_X_scaled).to(device)

# Set to dtype float32
dog_X_scaled = dog_X_scaled.float()

### Functions to Calculate RMSE and PSNR

In [None]:
# Functions to calculate RMSE and PSNR
def calculate_rmse(original_image, reconstructed_image):
    """Calculate the RMSE between the original and reconstructed images."""
    mse = torch.mean((original_image - reconstructed_image) ** 2)
    rmse = torch.sqrt(mse)
    return rmse.item()

def calculate_psnr(original_image, reconstructed_image, max_pixel_value=1.0):
    """Calculate the PSNR between the original and reconstructed images."""
    mse = torch.mean((original_image - reconstructed_image) ** 2)
    if mse == 0:  # MSE is zero means images are identical
        return float('inf')
    psnr = 20 * torch.log10(max_pixel_value / torch.sqrt(mse))
    return psnr.item()

### Setting up device

In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset
from torch.cuda.amp import autocast, GradScaler
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

# Set the random seed for reproducibility
torch.manual_seed(42)

# Detect GPU and enable multi-GPU usage
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")


### Defining Linear Model

In [None]:
# Define the linear model
class LinearModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearModel, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        return self.fc(x)

## Reconstructing using Linear Model

In [None]:
# Instantiate and move the model to the device (support multiple GPUs)
net = LinearModel(2, 3).to(device)
if torch.cuda.device_count() > 1:
    net = nn.DataParallel(net)

print(net)
print("Weights:", net.module.linear.weight if torch.cuda.device_count() > 1 else net.linear.weight)
print("Bias:", net.module.linear.bias if torch.cuda.device_count() > 1 else net.linear.bias)

# Training function with mixed precision and memory optimization
def train(net, lr, X, Y, epochs, batch_size=512, verbose=True):
    """
    net: torch.nn.Module
    lr: float
    X: torch.Tensor of shape (num_samples, 2)
    Y: torch.Tensor of shape (num_samples, 3)
    """
    losses = []
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    scaler = GradScaler()  # Use mixed precision training for efficiency

    # DataLoader to handle batch processing and shuffle the data
    dataset = TensorDataset(X, Y)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    for epoch in range(1, epochs+1):
        epoch_loss = 0.0
        for batch_X, batch_Y in data_loader:
            batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)

            optimizer.zero_grad()

            # Mixed precision training
            with autocast():
                outputs = net(batch_X)
                loss = criterion(outputs, batch_Y)

            # Scale loss and backpropagate
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            epoch_loss += loss.item()

        # Track loss per epoch
        losses.append(epoch_loss / len(data_loader))

        if verbose and epoch % 200 == 0:
            print(f"Epoch {epoch} loss: {epoch_loss / len(data_loader):.6f}")

        # Clear unused memory after each epoch
        torch.cuda.empty_cache()

    return losses

# Training on dog_X_scaled and dog_Y (Assuming dog_X_scaled and dog_Y are available)
train_loss, training_losses = train(net, 0.01, dog_X_scaled, dog_Y, 1000)

# Plot training loss graph
plt.figure(figsize=(10, 5))
plt.plot(training_losses)
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()


# Assuming original_images and reconstructed_clipped_images are available
original_image = torch.tensor(original_images[-1]).to(device)
reconstructed_clipped_image = torch.tensor(reconstructed_clipped_images[-1]).to(device)

# Calculate RMSE and PSNR
rmse_value = calculate_rmse(original_image, reconstructed_clipped_image)
print(f"RMSE: {rmse_value}")

psnr_value = calculate_psnr(original_image, reconstructed_clipped_image)
print(f"PSNR: {psnr_value}")

# Plot function for reconstructed and original images
def plot_reconstructed_and_original_image(original_img, net, X, title=""):
    """
    net: torch.nn.Module
    X: torch.Tensor of shape (num_samples, 2)
    Y: torch.Tensor of shape (num_samples, 3)
    """
    net.eval()
    with torch.no_grad():
        outputs = net(X).cpu().numpy()

    num_channels, height, width = original_img.shape
    outputs = outputs.reshape(height, width, num_channels)

    fig = plt.figure(figsize=(6, 4))
    gs = gridspec.GridSpec(1, 2, width_ratios=[1, 1])

    ax0 = plt.subplot(gs[0])
    ax1 = plt.subplot(gs[1])

    ax0.imshow(outputs)
    ax0.set_title("Reconstructed Image")

    ax1.imshow(original_img.cpu().permute(1, 2, 0))
    ax1.set_title("Original Image")

    for a in [ax0, ax1]:
        a.axis("off")

    fig.suptitle(title, y=0.9)
    plt.tight_layout()

# Assuming crop and dog_X_scaled are available
plot_reconstructed_and_original_image(crop, net, dog_X_scaled, title="Reconstructed Image")


## Using Polynomial Basis Functions

In [None]:
def normalize_image(image):
    """Normalize image to [0, 1] range for float images."""
    image_min = image.min()
    image_max = image.max()
    return (image - image_min) / (image_max - image_min)

def sigmoid(x):
    return 1 / (1 + torch.exp(-x))

def apply_sigmoid(image_tensor):
    """Apply sigmoid to image tensor to normalize it to [0, 1] range."""
    return sigmoid(image_tensor)

def clip(image_tensor):
    """Clip image tensor to [0, 1] range."""
    return torch.clamp(image_tensor, 0, 1)



In [None]:
# # Use polynomial features of degree "d"

# def poly_features(X, degree):
#     """
#     X: torch.Tensor of shape (num_samples, 2)
#     degree: int

#     return: torch.Tensor of shape (num_samples, degree * (degree + 1) / 2)
#     """
#     X1 = X[:, 0]
#     X2 = X[:, 1]
#     X1 = X1.unsqueeze(1)
#     X2 = X2.unsqueeze(1)
#     X = torch.cat([X1, X2], dim=1)
#     poly = preprocessing.PolynomialFeatures(degree=degree)
#     X = poly.fit_transform(X.cpu())
#     return torch.tensor(X, dtype=torch.float32).to(device)



# # Define polynomial degrees to test
# degrees = [5, 10, 50, 100]

# # Initialize lists to store training losses and images for each degree
# training_losses = []
# original_images = []
# reconstructed_clipped_images = []

# for degree in degrees:
#     # Create polynomial features
#     dog_X_scaled_poly = poly_features(dog_X_scaled, degree)
#     print(f"Degree {degree}: {dog_X_scaled_poly.dtype}, {dog_X_scaled_poly.shape}, {dog_Y.shape}, {dog_Y.dtype}")

#     # Initialize and train the model
#     net = LinearModel(dog_X_scaled_poly.shape[1], 3)
#     net.to(device)

#     # Train the model
#     train_poly_loss, losses = train(net, 0.005, dog_X_scaled_poly, dog_Y, 1500)
#     training_losses.append(losses)

#     # Generate reconstructed image
#     with torch.no_grad():
#         output = net(dog_X_scaled_poly)

#     # Reshape output and apply transformations
#     reconstructed_image = output.cpu().reshape(crop.shape[1], crop.shape[2], -1)
#     clipped_image = clip(reconstructed_image)

#     # Append images to lists with correct shape
#     original_images.append(rearrange(crop, 'c h w -> h w c').cpu().numpy())
#     reconstructed_clipped_images.append(clipped_image.numpy())

In [None]:
# Define polynomial degrees to test
degrees = [5, 10, 50, 100]


def poly_features(X, degree, batch_size=1000):
    """
    X: torch.Tensor of shape (num_samples, 2)
    degree: int
    batch_size: int (split processing into smaller batches to reduce memory usage)
    
    return: torch.Tensor of shape (num_samples, degree * (degree + 1) / 2)
    """
    poly = preprocessing.PolynomialFeatures(degree=degree)

    # Process in batches to avoid excessive RAM usage
    X_batches = []
    for i in range(0, X.shape[0], batch_size):
        X_batch = X[i:i+batch_size].cpu()  # Move to CPU for the transformation
        X_poly = poly.fit_transform(X_batch)  # Use PolynomialFeatures
        X_batches.append(torch.tensor(X_poly, dtype=torch.float32).to(device))

    return torch.cat(X_batches, dim=0)  # Concatenate all batches

# Adjusting the train loop to minimize memory usage
def run():

    # Initialize lists to store training losses and images for each degree
    training_losses = []
    original_images = []
    reconstructed_clipped_images = []

    for degree in degrees:
        # Create polynomial features
        dog_X_scaled_poly = poly_features(dog_X_scaled, degree, batch_size=1000)  # Use batch processing
        print(f"Degree {degree}: {dog_X_scaled_poly.dtype}, {dog_X_scaled_poly.shape}, {dog_Y.shape}, {dog_Y.dtype}")

        # Initialize and train the model
        net = LinearModel(dog_X_scaled_poly.shape[1], 3).to(device)

        # Train the model
        train_poly_loss, losses = train(net, 0.005, dog_X_scaled_poly, dog_Y, 1500)
        training_losses.append(losses)

        # Generate reconstructed image with no gradient tracking
        with torch.no_grad():
            output = net(dog_X_scaled_poly)

        # Reshape output and apply transformations
        reconstructed_image = output.cpu().reshape(crop.shape[1], crop.shape[2], -1)
        clipped_image = clip(reconstructed_image)

        # Append images to lists with correct shape
        original_images.append(rearrange(crop, 'c h w -> h w c').cpu().numpy())
        reconstructed_clipped_images.append(clipped_image.numpy())

        # Explicitly delete unused tensors and clear cache
        del dog_X_scaled_poly, output, reconstructed_image, clipped_image
        torch.cuda.empty_cache()

    print("Original Image Shape:",original_images[0].shape)
    print("Reconstructed Image Shape:",reconstructed_clipped_images[0].shape)
    print(reconstructed_clipped_images[0][0, :5, :])

# Run the experiment
run()


In [None]:
# Create a figure with subplots for loss and images
fig, axs = plt.subplots(len(degrees),3, figsize=(10, len(degrees) * 3))

latexify()


for i, degree in enumerate(degrees):
    # Plot training loss
    axs[i, 0].plot(training_losses[i])
    axs[i, 0].set_title(f"Training Loss for Degree {degree}")
    axs[i, 0].set_xlabel("Epoch")
    axs[i, 0].set_ylabel("Loss")
    format_axes(axs[i, 0])

    # Plot original image
    axs[i, 1].imshow(original_images[0])
    axs[i, 1].set_title(f"Original Image")
    axs[i, 1].axis('off')

    # Plot reconstructed clipped image
    axs[i, 2].imshow(reconstructed_clipped_images[i])
    axs[i, 2].set_title(f"Reconstructed Clipped Image (Degree {degree})")
    axs[i, 2].axis('off')

plt.tight_layout()
plt.show()


# Plot all training losses in a single plot
plt.figure(figsize=(12, 6))
for i, degree in enumerate(degrees):
    plt.plot(training_losses[i], label=f'Degree {degree}')
plt.title("Training Loss for Different Polynomial Degrees")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

#### Reconstructing using Random Fourier Features

In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset
from torch.cuda.amp import autocast, GradScaler
import matplotlib.pyplot as plt
from sklearn.kernel_approximation import RBFSampler

# Clear unused memory after each iteration
def clear_memory():
    torch.cuda.empty_cache()

# Function to create RFF features
def create_rff_features(X, num_features, sigma):
    rff = RBFSampler(n_components=num_features, gamma=1/(2 * sigma**2))
    X = X.cpu().numpy()
    X = rff.fit_transform(X)
    return torch.tensor(X, dtype=torch.float32).to(device)


# Training function with mixed precision
def train(net, learning_rate, X_data, Y_data, epochs, batch_size=256):
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()
    scaler = GradScaler()  # Mixed precision scaler

    # Create DataLoader for batch processing
    dataset = TensorDataset(X_data, Y_data)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    losses = []
    for epoch in range(epochs):
        epoch_loss = 0
        for batch_X, batch_Y in data_loader:
            optimizer.zero_grad()
            
            # Forward pass with mixed precision
            with autocast():
                output = net(batch_X)
                loss = criterion(output, batch_Y)
            
            # Backward pass and update gradients
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            epoch_loss += loss.item()

        losses.append(epoch_loss / len(data_loader))

        if epoch % 100 == 0:
            print(f"Epoch {epoch} loss: {epoch_loss / len(data_loader)}")

    return losses

# Function to plot reconstructed and original images
def plot_reconstructed_and_original_image(crop, net, X_data, title):
    with torch.no_grad():
        reconstructed_image = net(X_data).cpu().numpy().reshape(crop.shape[1], crop.shape[2], -1)
    
    plt.figure(figsize=(10, 5))
    plt.imshow(reconstructed_image)
    plt.title(title)
    plt.axis('off')
    plt.show()

# Function to train and return the loss and the trained network
def train_model(X_data, features, sigma, learning_rate=0.005, epochs=2500):
    # Generate random Fourier features
    X_rff = create_rff_features(X_data, features, sigma)
    
    # Initialize and train the model
    net = LinearModel(X_rff.shape[1], 3)
    
    # If using multiple GPUs
    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)

    net.to(device)
    train_rff_losses = train(net, learning_rate, X_rff, dog_Y, epochs)
    
    return net, train_rff_losses, X_rff

# Function to plot the training loss
def plot_training_loss(losses, features, sigma):
    plt.figure(figsize=(10, 5))
    plt.plot(losses)
    plt.title(f"Training Loss (Features: {features}, Sigma: {sigma})")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.show()

# Function to plot the reconstructed image
def plot_reconstructed_image(crop, net, X_rff, features, sigma):
    plot_reconstructed_and_original_image(crop, net, X_rff, title=f"Reconstructed Image (Features: {features}, Sigma: {sigma})")

# Main function to test multiple combinations of features and sigma
def test_feature_sigma_combinations(X_data, crop, combinations, learning_rate=0.005, epochs=2500):
    for comb in combinations:
        features = comb["features"]
        sigma = comb["sigma"]
        
        # Train the model and get losses
        net, train_rff_losses, X_rff = train_model(X_data, features, sigma, learning_rate, epochs)
        
        # Plot the training loss
        plot_training_loss(train_rff_losses, features, sigma)
        
        # Plot the reconstructed image
        plot_reconstructed_image(crop, net, X_rff, features, sigma)
        
        # Clear memory after each iteration
        clear_memory()

# Define different feature and sigma combinations
feature_sigma_combinations = [
    {"features": 5000, "sigma": 0.05},
    {"features": 5000, "sigma": 0.1},
    {"features": 15000, "sigma": 0.05},
    {"features": 15000, "sigma": 0.1},
    {"features": 37500, "sigma": 0.008},
]

# Run the test for the defined combinations
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Assuming X_data and crop are already defined
test_feature_sigma_combinations(X_data, crop, feature_sigma_combinations)
