# Run Project
This is a notebook such that models can be ran. Data not included.

In [None]:
DATASET_PATH = ''
# Expects a df stored in pickle file (.pkl)
MODEL_PATH = ''
# Where to save model


In [None]:
import os
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd

## Dataset Definition

In [None]:
class myDataset(Dataset):
    def __init__(self, dataframe:pd.DataFrame, mode:str = 'train', train_test_split = [9, 1]):
        """
        mode : valid values is 'test' or 'train'
        """
        self.dataset = dataframe
        self.mode = mode
        self.train_test_split = train_test_split
        self.last_input = None
        self.last_target = None

    def __len__(self):
        # train 90 test 10 split
        if self.mode == 'train':
            return self.dataset.shape[0] // 10 * self.train_test_split[0]
        else:
            return self.dataset.shape[0] // 10 * self.train_test_split[1]

    def __getitem__(self, index):
        # input, target
        input_img = self.dataset.iloc[index]['img_l']
        target_img = self.dataset.iloc[index]['img_h']
        input_img = input_img[np.newaxis, :]
        target_img = target_img[np.newaxis, :]
        # print(type(input_img))

        # return input_img, target_img
        self.last_input
        self.last_target
        return input_img.astype(np.float32), target_img.astype(np.float32)

## Model Definition

In [None]:
"""
Following this paper model, SRCNN: https://arxiv.org/pdf/1501.00092
Assumptions: 
  n is the number of output channels
  c is the number of input channels
  k is ther kernal size aka f_1 or filter size
SRCNN is just that, a simple CNN, 
  NO attention
  NO skip
  NO pooling
  NO none ReLU functions
Initial implementaiton has ouput image size smaller than input image size, they simply compared the center of input image to output image
Difference in setup and loss calculation:
  They take an image, apply blur, then the output image is smaller than the input image. This means that they take a subset of the input image for testing. For our class project set up, we take a subset of the target image instead!
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class patch(nn.Module):
    def __init__(self, in_channels, out_channels, k_size):
        super(patch, self).__init__()
        self.c1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=k_size,
        )
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.c1(x)
        out = self.relu(out)
        return out

class mapping(nn.Module):
    def __init__(self, in_channels, out_channels, k_size):
        super(mapping, self).__init__()
        self.c1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=k_size,
        )
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.c1(x)
        out = self.relu(out)
        return out

class reconstruction(nn.Module):
    def __init__(self, in_channels, out_channels, k_size):
        super(reconstruction, self).__init__()
        self.c1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=k_size,
        )
    def forward(self, x):
        out = self.c1(x)
        return out

class SRCNN(nn.Module):
    def __init__(self, f1=9, f2=1, f3=5, n1=64, n2=32, img_channel=1):
        super(SRCNN, self).__init__()
        self.part1 = patch(img_channel, n1, f1)
        self.part2 = mapping(n1, n2, f2)
        self.part3 = reconstruction(n2, img_channel, f3)

    def forward(self, x):
        out = F.interpolate(x, size=(179, 221), mode='bicubic')
        out = self.part1(out)
        out = self.part2(out)
        out = self.part3(out)
        return out

    def __call__(self, *args, **kwds):
        return self.forward(*args, **kwds)


## Train / Test Definition

In [None]:
class SRCNN_loss():
    def __init__(self):
        super(SRCNN_loss, self).__init__()
    
    def forward(self, x, y):
        h_diff = abs(x.shape[2] - y.shape[2])
        w_diff = abs(x.shape[3] - y.shape[3])
        w_margin = w_diff // 2
        h_margin = h_diff // 2
        y_sub_img = y[:, :, w_margin:(y.shape[2] - w_margin) , h_margin:(y.shape[3] - h_margin)]
        return torch.square(y_sub_img - x).sum()
    
    def __call__(self, *args, **kwds):
        return self.forward(*args, **kwds)

def train_loop(dataloader, device, model, loss_fn, optimizer, batch_size):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        pred = model(X.to(device))
        # loss = loss_fn(pred, y.to(device))
        loss = loss_fn.forward(pred, y.to(device))

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if batch % 2 == 0: 
            # Keep in mind how many batchs (number of enumerate) is tied to batch size in dataloader
            loss, current = loss.item(), batch * batch_size + X.shape[0]
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test_loop(dataloader, device, model, loss_fn):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X.to(device))
            test_loss += loss_fn(pred, y.to(device)).item()
            # correct += (pred.argmax(1) == y.to(device)).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Avg loss: {test_loss:>8f} \n")

## Test / Train Run

In [None]:
# Create model
# model = ImageUpscaler(scale_factor=2, num_channels=1, num_residual_blocks=8, base_channels=64)
# model = UpsampleCNN()
model = SRCNN()

# Get Data
dataset = pd.read_pickle(DATASET_PATH)
input_image = dataset.loc[0, 'img_l']
train_dataset = myDataset(dataset)
test_dataset = myDataset(dataset, 'test')

batch_size = 50
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# Training setup 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Loss function (common choices for super-resolution)
# criterion = nn.MSELoss()
criterion = SRCNN_loss() #SRCNN Loss as in paper

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_loader, device, model, criterion, optimizer, batch_size)
    test_loop(test_loader, device, model, criterion)
print("Done!")
torch.save(model.state_dict(), os.path.join(MODEL_PATH, 'model1.pt'))