# 1.0 VGG 19 Classifier(SR and HR)

In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init
import numpy as np
import matplotlib.pyplot as plt
import os
import PIL

In [None]:
%%capture
try:
    import wandb
    import yaml
    import torchinfo
except:
    %pip install wandb
    %pip install pyyaml
    %pip install torchinfo
    import yaml
    import wandb
    import torchinfo

In [None]:
# Setup device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
device

## 1.1(a) VGG19 Model

In [None]:
import torchvision.models as models
import torch.nn as nn
from torchvision.models import VGG19_Weights

class VGG19Classifier(nn.Module):
    def __init__(self, num_features):
        super(VGG19Classifier, self).__init__()

        # Adjusting to use the new weights parameter
        vgg19 = models.vgg19(weights=VGG19_Weights.IMAGENET1K_V1)

        features_layers = 24
        self.features = nn.Sequential(*list(vgg19.features.children())[:features_layers+1]) 

        # Freeze the features layers
        for param in self.features.parameters():
            param.requires_grad = False
        
        # Replace the classifier
        self.classifier = nn.Sequential(
            nn.Linear(in_features=num_features, out_features=512),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(in_features=512, out_features=64),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(in_features=64, out_features=2)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1) # Flatten the features
        x = self.classifier(x)
        return x
    
    # Load the weights from the pre-trained model present locally
    def load_weights(self, path):
        self.load_state_dict(torch.load(path))



In [None]:
# Set the API key
os.environ['WANDB_API_KEY'] = '0736c590933a18ad9639f49867ed1548495ded1c'

In [None]:
# Define the wandb entity and project
project_name = "SRCNN+VGG"  # replace with your wandb project name
entity_name = "pershadmayank"  # replace with your wandb entity

## 1.2 Model Summary

In [None]:
from torchinfo import summary
classifier = VGG19Classifier(num_features=28*28*512).to(device)
summary(classifier, input_size=[1, 3, 224, 224])

## 1.3 `train_step()`

In [None]:
from typing import Tuple
import torch

def train_step(model: torch.nn.Module,
               dataloader: torch.utils.data.DataLoader,
               loss_fn,
               optimizer: torch.optim.Optimizer,
               device: str,
               max_pixel_value: float = 1.0) -> Tuple[float, float]:
  """
  Performs a single training step including forward pass, loss computation,
  backpropagation, and optimizer step.

  Parameters:
  - model (torch.nn.Module): The neural network model to be trained.
  - dataloader (torch.utils.data.DataLoader): DataLoader for the dataset.
  - loss_fn: Loss function used for training.
  - optimizer (torch.optim.Optimizer): Optimizer used for parameter updates.
  - device (str): Device to run the training on ('cuda' or 'cpu').
  - max_pixel_value (float, optional): The maximum pixel value used in the PSNR calculation. Default is 1.0.

  Returns:
  - train_loss (float): Average training loss for this step.
  - train_psnr (float): Average Peak Signal-to-Noise Ratio (PSNR) for this step.

  Raises:
  - ValueError: If `device` is not 'cuda' or 'cpu'.
  - TypeError: If the provided model, dataloader, loss function, or optimizer are of the wrong type.
  """

  # Validate input parameters for type and value
  if not isinstance(model, torch.nn.Module):
    raise TypeError("model must be an instance of torch.nn.Module")
  if not isinstance(dataloader, torch.utils.data.DataLoader):
    raise TypeError("dataloader must be an instance of torch.utils.data.DataLoader")
  if not callable(loss_fn):
    raise TypeError("loss_fn must be callable")
  if not isinstance(optimizer, torch.optim.Optimizer):
    raise TypeError("optimizer must be an instance of torch.optim.Optimizer")
  if device not in ['cuda', 'cpu']:
    raise ValueError("device must be 'cuda' or 'cpu'")

  # Ensure model is on the correct device
  model.to(device)

  # Put the model in train mode
  model.train()

  # Setup train loss and PSNR
  train_loss = 0.0
  train_acc = 0.0

  # Loop through batches of data
  for _, (X, y) in enumerate(dataloader):
    X, y = X.to(device), y.to(device)

    # Forward pass
    y_pred = model(X)

    # Calculate loss
    loss = loss_fn(y_pred, y)
    train_loss += loss.item()

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Calculate and accumulate metrics across all the batches
    y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
    train_acc = (y_pred_class == y).sum().item() / len(y_pred)
    
  # Adjust metrics to get average loss and PSNR per batch
  train_loss /= len(dataloader)
  train_acc /= len(dataloader)

  return train_loss, train_acc


## 1.4 `test_step()`

In [None]:
from typing import Tuple
import torch

def test_step(model: torch.nn.Module,
              dataloader: torch.utils.data.DataLoader,
              loss_fn,
              device: str,
              max_pixel_value: float = 1.0) -> Tuple[float, float]:
  """
  Performs a single evaluation step, calculating the average loss and PSNR
  over the provided dataloader.

  Parameters:
  - model (torch.nn.Module): The neural network model to be evaluated.
  - dataloader (torch.utils.data.DataLoader): DataLoader for the dataset to evaluate.
  - loss_fn: The loss function used for evaluation.
  - device (str): The device to run the evaluation on ('cuda' or 'cpu').
  - max_pixel_value (float, optional): The maximum pixel value used in the PSNR calculation. Default is 1.0.

  Returns:
  - test_loss (float): The average loss over the dataloader.
  - test_psnr (float): The average Peak Signal-to-Noise Ratio over the dataloader.

  Raises:
  - ValueError: If `device` is not 'cuda' or 'cpu'.
  - TypeError: If the provided model, dataloader, loss function, or device are of the wrong type.
  """

  # Validate input parameters
  if not isinstance(model, torch.nn.Module):
    raise TypeError("model must be an instance of torch.nn.Module")
  if not isinstance(dataloader, torch.utils.data.DataLoader):
    raise TypeError("dataloader must be an instance of torch.utils.data.DataLoader")
  if not callable(loss_fn):
    raise TypeError("loss_fn must be callable")
  if device not in ['cuda', 'cpu']:
    raise ValueError("device must be 'cuda' or 'cpu'")

  # Ensure model is on the correct device
  model.to(device)

  # Put the model in eval mode
  model.eval()

  test_loss = 0.0
  test_psnr = 0.0

  with torch.inference_mode():
    for _, (X, y) in enumerate(dataloader):
      X, y = X.to(device), y.to(device)

      # Forward pass
      test_pred = model(X)

      # Calculate loss
      loss = loss_fn(test_pred, y)
      test_loss += loss.item()

      # Calculate and accumulate acc
      test_pred_labels = torch.argmax(torch.softmax(test_pred, dim=1), dim=1)
      test_acc = (test_pred_labels == y).sum().item() / len(test_pred_labels)

  # Compute average loss and PSNR
  test_loss /= len(dataloader)
  test_acc /= len(dataloader)

  return test_loss, test_acc


## 1.5 Checkpoint Saving

In [None]:
import torch
import os

def save_checkpoint(epoch, model, optimizer, loss, path="checkpoint.pth"):
  """
  Saves a checkpoint of the model and optimizer state.

  Parameters:
  - epoch: The current epoch number.
  - model: The model being trained.
  - optimizer: The optimizer being used for training.
  - loss: The loss value at the checkpoint.
  - path: The path to save the checkpoint to.
  """
  checkpoint = {
    'epoch': epoch + 1,  # Saving such that training can resume from the next epoch
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
  }
  torch.save(checkpoint, path)
  print(f"Checkpoint saved at '{path}'")

In [None]:
import os
import shutil

def delete_checkpoints(directory):
    """
    Deletes all files within the specified directory.

    Parameters:
    - directory: The path to the directory whose files are to be deleted.
    """
    # Check if the directory exists
    if os.path.exists(directory):
        # Iterate through all files in the directory
        for filename in os.listdir(directory):
            file_path = os.path.join(directory, filename)
            try:
                # If it is a file, delete it
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)
                # If it is a directory, delete it and all its contents
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)
            except Exception as e:
                print(f'Failed to delete {file_path}. Reason: {e}')
        print(f"All checkpoints in '{directory}' have been deleted.")
    else:
        print(f"The directory {directory} does not exist.")

## 1.6 Early Stopping and Saving Best Model Parameters

In [None]:
class EarlyStopping:

  def __init__(self, patience=7, verbose=False, delta=0, path='best_model.pth'):
    self.patience = patience
    self.verbose = verbose
    self.delta = delta
    self.best_score = None
    self.early_stop = False
    self.val_loss_min = np.Inf
    self.counter = 0
    self.path = path

  def __call__(self, val_loss, model, optimizer):
    score = -val_loss

    if self.best_score is None:
      self.best_score = score
      self.save_checkpoint(val_loss, model, optimizer)
    elif score < self.best_score + self.delta:
      self.counter += 1
      if self.verbose:
          print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
      if self.counter >= self.patience:
          self.early_stop = True
    else:
      self.best_score = score
      self.save_checkpoint(val_loss, model, optimizer)
      self.counter = 0

  def save_checkpoint(self, val_loss, model, optimizer):
    '''Saves model when validation loss decrease.'''
    if self.verbose:
      print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': val_loss,
    }, self.path)
    self.val_loss_min = val_loss

## 1.7 Combining `train_step()` and `test_step()`

In [None]:
import os
import time
import wandb
import signal
from tqdm.auto import tqdm

def train(model: torch.nn.Module,
          train_dataloader: torch.utils.data.DataLoader,
          test_dataloader: torch.utils.data.DataLoader,
          optimizer: torch.optim.Optimizer,
          loss_fn,
          path: str,
          start_epoch: int = 0,
          end_epoch: int = 200,
          checkpoint_interval=20,
          device: str = device):

  # 1.0 Start the timer
  start_time = time.time()

  # 2.0 Create a empty result dicitonary
  results = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}

  # 3.0 Directory to save the model
  model_dir = "hard-disk-2/users/mpershad/SRCNN_DIV2K/Models/Classifier"
  if not (os.path.exists(model_dir)):
    os.makedirs(model_dir)

  model_path = f"{model_dir}/{path}.pth"

  # 4.0 Calculate Logging Interval(10% of total epochs)
  total_epochs = end_epoch - start_epoch
  log_interval = total_epochs // 10 if total_epochs // 10 > 0 else 1
  
  # 5.0 Initialize the early stopping
  early_stopper = EarlyStopping(patience=int(total_epochs/5), verbose=True, delta=5e-6, path=model_path)

  # --- Signal handling for KeyboardInterrupt ---
  def signal_handler(signal, frame):
    print("Training stopped by user! Saving a checkpoint before exiting...")
    checkpoint_path = f"{model_dir}/INTERRUPTED_{path}.pth"
    save_checkpoint(epoch=epoch, model=model, optimizer=optimizer, loss=val_loss, path=checkpoint_path)
    print(f'Checkpoint saved. Safetly terminated training.')

  # Register the signal handler
  signal.signal(signal.SIGINT, signal_handler)
  
  # Loop through training and testing steps for a number of epochs
  for epoch in tqdm(range(start_epoch, end_epoch)):

    # 6.0 Training step
    train_loss, train_acc = train_step(model=model,
                            dataloader=train_dataloader,
                            loss_fn=loss_fn,
                            optimizer=optimizer,
                            device=device)

    # 7.0 Testing step(validation)
    val_loss, val_acc = test_step(model=model,
                          dataloader=test_dataloader,
                          loss_fn=loss_fn,
                          device=device)

    # 8.0 Update results dictionary
    results["train_loss"].append(train_loss)
    results["val_loss"].append(val_loss)
    results["train_acc"].append(train_acc)
    results["val_acc"].append(val_acc)

    # Log values to wandb
    if wandb.run:
      wandb.log({'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss, 'train_acc': train_acc, 'val_acc': val_acc})

    # Log values every 10% of the total epochs
    if epoch % log_interval == 0 or epoch == end_epoch - 1:
      
      # Print out what's happening
      print(f"Epoch: {epoch}, Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

    # 9.0 Early stopping
    early_stopper(val_loss, model, optimizer)
    if early_stopper.early_stop:
        print("Early stopping")
        break

    checkpoint_dir = f"/hard-disk-2/users/mpershad/SRCNN_DIV2K/Checkpoints/Classifier/{path}"

    # 10.0 Checkpoint Saving
    if epoch % checkpoint_interval == 0 or epoch == end_epoch - 1:

      # Create the directory for saving Checkpoints
      if not (os.path.exists(checkpoint_dir)):
        os.makedirs(checkpoint_dir)

      checkpoint_path = f"{checkpoint_dir}/epoch_{epoch}.pth"

      save_checkpoint(epoch=epoch, model=model, optimizer=optimizer, loss=val_loss, path=checkpoint_path)
      
      # Optionally log checkpoint to wandb if connected
      if wandb.run:
        wandb.save(checkpoint_path, base_path="/hard-disk-2/users/mpershad/SRCNN_DIV2K/Checkpoints/Classifier")

  # 11.0 Use wandb.save() to ensure the model file is saved to W&B if connected
  if wandb.run:
    wandb.save(model_path, base_path=model_dir)

    # 12.0 Log the model as an artifact if connected
    artifact = wandb.Artifact('Classifier', type='model', description="A super-resolution model")
    artifact.add_file(model_path)
    wandb.log_artifact(artifact)

  # 13.0 Calculate and log training duration
  end_time = time.time()
  total_training_time = end_time - start_time
  if wandb.run:
    wandb.log({'total_training_time': total_training_time})

  print(f"Total training time: {total_training_time:.3f} seconds")

  wandb.finish()

  delete_checkpoints(checkpoint_dir)

  return results

## 1.8 Plot Loss Curves

In [None]:
import matplotlib.pyplot as plt

def plot_loss_curves(results):
    """Plot training and validation loss curves and PSNR curves if available."""
    # Setup the figure and axes dynamically based on what needs to be plotted
    plots_needed = sum(key in results for key in ['train_loss', 'val_loss', 'train_psnr', 'val_psnr'])
    if plots_needed == 0:
        print("No data to plot.")
        return
    fig, axes = plt.subplots(1, plots_needed, figsize=(6 * plots_needed, 6))
    if plots_needed == 1:
        axes = [axes]  # Make it iterable

    current_plot = 0

    # Plot Loss if data is available
    if 'train_loss' in results and 'val_loss' in results:
        train_loss = results['train_loss']
        val_loss = results['val_loss']
        min_loss = min(train_loss + val_loss)
        max_loss = max(train_loss + val_loss)
        axes[current_plot].plot(train_loss, label='Training Loss', color='blue')
        axes[current_plot].plot(val_loss, label='Validation Loss', color='red')
        axes[current_plot].set_title('Loss Over Epochs')
        axes[current_plot].set_xlabel('Epoch')
        axes[current_plot].set_ylabel('Loss')
        axes[current_plot].set_ylim([min_loss - 0.05 * (max_loss - min_loss), max_loss + 0.05 * (max_loss - min_loss)])
        axes[current_plot].legend()
        current_plot += 1

    # Plot PSNR if data is available
    if 'train_acc' in results and 'val_acc' in results:
        train_acc = results['train_acc']
        val_acc = results['val_acc']
        min_acc = min(train_acc + val_acc)
        max_acc = max(train_acc + val_acc)
        axes[current_plot].plot(train_acc, label='Training Acc', color='blue')
        axes[current_plot].plot(val_acc, label='Validation Acc', color='red')
        axes[current_plot].set_title('Acc Over Epochs')
        axes[current_plot].set_xlabel('Epoch')
        axes[current_plot].set_ylabel('Acc')
        axes[current_plot].set_ylim([min_acc - 0.05 * (max_acc - min_acc), max_acc + 0.05 * (max_acc - min_acc)])
        axes[current_plot].legend()
        current_plot += 1

    plt.tight_layout()
    plt.show()

# 2.0 Image Quality Metrics

1. Mean Squared Error(MSE) :  MSE loss function is evaluated only by the difference between the central pixels of Xi and the network output.

$$
L(\Theta) = \frac{1}{n} \sum_{i=1}^n \left \| F(Y_i; \Theta) - X_i \right \|^2,
$$

where n is the number of training examples. Using MSE as the loss functiion favors a high PSNR.

Alternative evaluation metrics,

1. Peak Signal To Noise Ratio(PSNR)

2. Structural Similarity Index(SSIM)

## 2.1 Mean Squared Error(MSE)

In [None]:
import numpy as np

def mse(original : np.ndarray, target : np.ndarray) -> float:

  """
  Compute the Mean Squared Error (MSE) between two images.

  Parameters:
  original (np.ndarray): The original image, expected to be a numpy array.
  compressed (np.ndarray): The compressed or modified image, expected to be a numpy array.

  Returns:
  float: The MSE value.
  """

  original_data = original.astype(np.float64)
  target_data = target.astype(np.float64)

  mse = np.mean((original_data - target_data)**2)
  return mse

## 2.2 Peak Signal To Noise Ratio(PSNR)

In [None]:
import numpy as np

def psnr(original : np.ndarray, target : np.ndarray) -> float:
  """
  Compute the Peak Signal to Noise Ratio (PSNR) between two images.

  Parameters:
  original (numpy.ndarray): The original image.
  compressed (numpy.ndarray): The compressed or modified image.

  Returns:
  float: The PSNR value in decibels (dB).
  """

  original_data = original.astype(np.float64)
  target_data = target.astype(np.float64)

  mse = np.mean((original_data - target_data)**2)
  if mse == 0:
    # MSE is zero means no noise is present in the signal.
    # Therefore, PSNR is infinite.
    return float('inf')

  max = 255.0
  psnr_val = 20*np.log10(max / np.sqrt(mse))
  return psnr_val

## 2.3 Structural Similarity Index(SSIM)

In [None]:
from skimage.metrics import structural_similarity as ssim_lib
from skimage.color import rgb2gray
import numpy as np

def ssim(original : np.ndarray, target : np.ndarray) -> float:
  """
  Compute the Structural Similarity Index (SSIM) between two images.

  Parameters:
  original (np.ndarray): The original image.
  compressed (np.ndarray): The compressed or modified image.

  Returns:
  float: The SSIM value.
  """

  # Convert images to grayscale if they are in color because SSIM is often computed in grayscale
  if original.ndim == 3:
    original = rgb2gray(original)
  if target.ndim == 3:
    target = rgb2gray(target)

  # Dynamically determine the data range for SSIM calculation
  # If the images are floating point, assume they are in [0, 1]
  if original.dtype == np.float32 or original.dtype == np.float64:
      data_range = 1
  else:
      # For integer types, use the maximum possible value of the dtype
      data_range = np.iinfo(original.dtype).max
  
  ssim_value, _ = ssim_lib(original, target, data_range=data_range, full=True)
  return ssim_value

## 2.4 Combined Metric

In [None]:
from typing import Tuple
import numpy as np

def combined_metric(original : np.ndarray, target : np.ndarray) -> Tuple:
  """
  Combined metric using MSE, PSNR and SSIM

  Parameters:
  original (np.ndarray): The original image.
  compressed (np.ndarray): The compressed or modified image.

  Return:
  Tuple: (MSE, PSNR, SSIM)
  """

  mse_val = mse(original, target)
  psnr_val = psnr(original, target)
  ssim_val = ssim(original, target)

  return (mse_val, psnr_val, ssim_val)

# 3.0 Transforming Data

To prepare our dataset for training a Super-Resolution CNN (SRCNN), we need to simulate low-resolution (LR) images from our high-resolution (HR) images. The following steps are taken for each image in our HR dataset:




## 3.1 LR Transformations

1. **Apply Gaussian Blur**:
   A Gaussian blur is applied to the HR images. This step is performed to smooth out the images and simulate the loss of detail that occurs in a real-world low-resolution image.

2. **Downscale**:
   The blurred HR images are then downsampled by a specified upscale factor.

3. **Upscale**:
   The downsampled images are then upscaled back to the original dimensions using bicubic interpolation.

4. **Save Processed Images**:
  The processed LR images are saved back to the filesystem for use in training the SRCNN.

In [None]:
# Hyperparameters for Gaussian Blur
kernel_size = (9, 9)
sigma = 1.5
scale = 3

In [None]:
operation = input("Enter the operation that needs to performed on HR Image: ")

# Print message about the processing type
if operation == 'blur':
    print(f"Applying Gaussian Blur with kernel size {kernel_size} and sigma {sigma}.")
elif operation == 'downup':
    print("Applying Down-and-Up sampling.")
elif operation == 'both':
    print(f"Applying both Gaussian Blur with kernel size {kernel_size} and sigma {sigma}, and Down-and-Up sampling.")

In [None]:
import cv2 as cv
import numpy as np
from typing import Tuple

def generate_lr_image(image: np.ndarray, operation: str, scale: int, kernel_size: Tuple[int, int] = (5, 5), sigma: float = 0) -> np.ndarray:
  """
  Processes the image based on the specified operation: Gaussian Blur, Down-and-Up sampling, or both,
  for both grayscale and color images.

  Parameters:
  - image (np.ndarray): The input image to process. Can be grayscale or color.
  - operation (str): The operation to perform - 'blur', 'downup', 'both'.
  - scale (int): The downscaling factor for down-and-up sampling.
  - kernel_size (Tuple[int, int]): The kernel size for Gaussian blur.
  - sigma (float): The sigma value for Gaussian blur.

  Returns:
  - np.ndarray: Low-resolution image.
  """
  # Check if the image is grayscale or color
  if len(image.shape) == 2:
    # Grayscale image
    h, w = image.shape
  else:
    # Color image
    h, w, _ = image.shape

  lr_image = image

  if operation in ['blur', 'both']:
    # Apply Gaussian Blur
    lr_image = cv.GaussianBlur(lr_image, kernel_size, sigma)

  if operation in ['downup', 'both']:
    # Downscale and then upscale
    if len(image.shape) == 2:
      # Grayscale
      lr_image = cv.resize(cv.resize(lr_image, (w // scale, h // scale), interpolation=cv.INTER_CUBIC), (w, h), interpolation=cv.INTER_CUBIC)
    else:
      # Color
      lr_image = cv.resize(cv.resize(lr_image, (w // scale, h // scale), interpolation=cv.INTER_CUBIC), (w, h), interpolation=cv.INTER_CUBIC)

  return lr_image

In [None]:
def modcrop(img, modulo):

  """
  Crop the image to a size that is divisible by the modulo value.

  Parameters:
  - img (np.ndarray): The input image to crop.
  - modulo (int): The value to crop the image size to.

  Returns:
  - np.ndarray: The cropped image.
  """

  # Check the type of the input image(numpy ndarray or Image object)
  if isinstance(img, PIL.Image.Image):
    img = np.array(img)

  if img.ndim == 2:
    sz = img.shape
    sz = sz - np.mod(sz, modulo)
    img_cropped = img[0:sz[0], 0:sz[1]]
  elif img.ndim == 3:
    sz = img.shape[0:2]
    sz = sz - np.mod(sz, modulo)
    img_cropped = img[0:sz[0], 0:sz[1], :]
  else:
    raise ValueError("Unsupported image dimensions")

  return img_cropped

## 3.2 Unpaired Dataset

Note: We are working with only one scale(3x)
* Classification:
    * Training:     DIV2K(Train)
    * Validation:   DIV2K(Validation)


In [None]:
# DIV2K Dataset
div2k_train_path = "/hard-disk-2/users/mpershad/DIV2K_train_HR"
div2k_validation_path = "/hard-disk-2/users/mpershad/DIV2K_valid_HR"

patch_size = 224

In [None]:
train_path = "/hard-disk-2/users/mpershad/SRCNN_DIV2K/Data/Unpaired/Classification/Train"
validation_path = "/hard-disk-2/users/mpershad/SRCNN_DIV2K/Data/Unpaired/Classification/Validation" 

In [None]:
# Parameters for generate_lr_image
lr_image_args = {
    'operation': operation,
    'scale': scale,
    'kernel_size': kernel_size,
    'sigma': sigma
}

In [None]:
# Function to random crops of images

def _get_patch(*args, patch_size=224):
    """
    Get a random patch of the specified size from the provided images.
    """
    # Get the height and width of the image
    h, w = args[0].shape[:2]

    # Randomly select the top left corner of the patch
    ix = np.random.randint(0, w - patch_size)
    iy = np.random.randint(0, h - patch_size)

    # Return the cropped images
    return [img[iy: iy + patch_size, ix: ix + patch_size] for img in args]

In [None]:
import random
def augment(*args, hflip=True, rot=True):

    """
    Augment the images by flipping and rotating them.
    """

    hflip = hflip and random.random() < 0.5
    vflip = rot and random.random() < 0.5
    rot90 = rot and random.random() < 0.5

    def _augment(img):
        if hflip: img = img[:, ::-1, :]
        if vflip: img = img[::-1, :, :]
        if rot90: img = img.transpose(1, 0, 2)
        return img
    return [_augment(img) for img in args]

In [None]:
import os
from pathlib import Path
import numpy as np
from PIL import Image
from tqdm import tqdm
# Function to generate and save the 224x224 patches for LR and HR image classification

def get_patch(dataset_path: str, save_lr_path: str, save_hr_path: str, patch_size: int, num_patches: int, ext: str = "*.png"):

    """
    Generate and save the 224x224 patches for LR and HR image classification.

    Parameters:
    - dataset_path (str): The path to the HR dataset.
    - save_lr_path (str): The path to save the LR patches.
    - save_hr_path (str): The path to save the HR patches.
    - patch_size (int): The size of the patches to generate.
    - num_patches (int): The number of patches to generate.

    Returns:
    - None
    """

    # Create the directories to save the patches
    if not os.path.exists(save_hr_path):
        os.makedirs(save_hr_path)
    if not os.path.exists(save_lr_path):
        os.makedirs(save_lr_path)

    # Get the list of image
    filenames_path_list = sorted(Path(dataset_path).glob(f"{ext}"))

    # Loop through the images
    for filename_path in tqdm(filenames_path_list):
        img = Image.open(filename_path).convert('RGB')

        # Use the modcrop function to crop the image
        img = modcrop(img, 3)

        # Use the _get_patch function to get the patches
        for i in range(num_patches):
            hr_patch, lr_patch = _get_patch(np.array(img), generate_lr_image(np.array(img), **lr_image_args), patch_size=patch_size)

            # Augment the patches
            hr_patch, lr_patch = augment(hr_patch, lr_patch)

            # Save the serialized patches
            hr_patch_path = os.path.join(save_hr_path, f"{filename_path.stem}_{i+1}.png")
            lr_patch_path = os.path.join(save_lr_path, f"{filename_path.stem}_{i+1}.png")

            Image.fromarray(hr_patch).save(hr_patch_path)
            Image.fromarray(lr_patch).save(lr_patch_path)

In [None]:
# Call the get_patch to generate training data
# get_patch(dataset_path=div2k_train_path,
#           save_lr_path=os.path.join(train_path, "LR"),
#           save_hr_path=os.path.join(train_path, "HR"),
#           patch_size=patch_size,
#           num_patches=10)

In [None]:
# Call the get_patch function to generate validation data
# get_patch(dataset_path=div2k_validation_path,
#           save_lr_path=os.path.join(validation_path, "LR"),
#           save_hr_path=os.path.join(validation_path, "HR"),
#           patch_size=patch_size,
#           num_patches=10)

In [None]:
def walk_through_dir(path : str) -> None:
  """
  Walk through Directory returning its content
  """

  for dirpath, dirname, filenames in os.walk(path):
    print(f"There are {len(filenames)} images in the {dirpath}")

In [None]:
image_path = '/hard-disk-2/users/mpershad/SRCNN_DIV2K/Data/Unpaired/Classification'
walk_through_dir(image_path)

# 4.0 Loading Images data

## 4.1 Loading Images using ImageFolder 

### 4.1.1 Transforms

In [None]:
from torchvision import transforms

# Tansform for preforming same preprocessing as VGG19 before extracting features
vgg_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_transform = vgg_transform

validation_transform = vgg_transform

In [None]:
#  Use ImageFolder to create dataset(s)
from torchvision import datasets

train_dataset = datasets.ImageFolder(root=train_path,
                                     transform=train_transform,
                                     target_transform=None) # Transform for the target (our data is in standard classification format, labels are inferred form the directory name)

validation_dataset = datasets.ImageFolder(root=validation_path,
                                        transform=validation_transform,
                                        target_transform=None)

In [None]:
# Get class name as list
class_names = train_dataset.classes
class_names

In [None]:
# Get class names as dict
class_names_dict = train_dataset.class_to_idx
class_names_dict

In [None]:
# Length of the dataset
len(train_dataset), len(validation_dataset)

In [None]:
# Index on the train_data to get a single Image and label
img, label = train_dataset[0]
print(f"Image Shape: {img.shape}")
print(f"Image DataType: {img.dtype}")
print(f"Image Label: {class_names[label]}")
print(f"Label dataType: {type(label)}")

### 4.1.2 Plot a random Image and Label

In [None]:
# plot a radom image
torch.manual_seed(42)
random_idx = torch.randint(0, len(train_dataset), size=[1]).item()
img, label = train_dataset[random_idx]
img = img.permute(1, 2, 0)
print(f"Image Shape: {img.shape}")
print(f"Image DataType: {img.dtype}")
print(f"Image Label: {class_names[label]}")

# plot the image
plt.figure(figsize=(6, 4))
plt.imshow(img)
plt.axis("off")
plt.title(class_names[label], fontsize=12)

### 4.1.3 Turn custom Dataset into DataLoader

In [None]:
import os
from torch.utils.data import DataLoader

# Setup batch size
BATCH_SIZE = 32

# Turn the dataset into iterable(batches)
train_dataloader = DataLoader(dataset=train_dataset,
                              batch_size=BATCH_SIZE,
                              num_workers=os.cpu_count(),
                              shuffle=True)

validation_dataloader = DataLoader(dataset=validation_dataset,
                                     batch_size=BATCH_SIZE,
                                     num_workers=os.cpu_count(),
                                     shuffle=False)

In [None]:
print(f"Length of Training Data: {len(train_dataloader)}")
print(f"Length of Training Data: {len(validation_dataloader)}")

# 5.0 Tracking experiments

## 5.1 Config

In [None]:
config={ 'start_epoch': 0,
         'end_epoch': 200, 
         'architecture': 'VGGClassifier', 
         'checkpoint_interval': 50}

## 5.2 Log In

In [None]:
import wandb

# Try to Log in to your W&B account
try:
    wandb.login()
    print("Successfully logged in to W&B!")

    # Initialize a W&B run
    wandb.init(project='SRCNN+VGG', entity='pershadmayank', config=config)

except Exception as e:
    print("Error during login:", e)

## 5.2 Setup Loss function and optimizer

In [None]:
import torch.nn as nn
import torch

# Initialize the VGG19 model
classifier = VGG19Classifier(num_features=28*28*512)

# Setup the loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)

In [None]:
model_name = str(input("Enter model name:"))

In [None]:
# Train the model using the wandb.config values
model_results = train(model=classifier,
                      train_dataloader=train_dataloader,
                      test_dataloader=validation_dataloader,
                      optimizer=optimizer,
                      loss_fn=loss_fn,
                      path=model_name,
                      start_epoch=config['start_epoch'],
                      end_epoch=config['end_epoch'],
                      device=device,
                      checkpoint_interval=config['checkpoint_interval'])

## 5.3 Save Results Dictionary

In [None]:
import pickle
import os

metrics_path =f'/hard-disk-2/users/mpershad/SRCNN_DIV2K/Metrics/Classifier'

if not os.path.exists(metrics_path):
  os.makedirs(metrics_path)

pickle_file_path = os.path.join(metrics_path, f"metrics_{model_name}.pkl")


with open(pickle_file_path, 'wb') as file:
    pickle.dump(model_results, file)

print(f"Results saved to {pickle_file_path}")

In [None]:
plot_loss_curves(model_results)