In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
#from skimage.metrics import structural_similarity as ssim

import cv2
import numpy as np
import math
import os
import matplotlib.pyplot as plt
%matplotlib inline

from PIL import Image, ImageFilter
from model import SRCNN


In [12]:
# Define a function for peak signal-to-noise ratio (PSNR)
def psnr(target, ref):
    # Convert images to torch tensors
    target_data = torch.tensor(target, dtype=torch.float)
    ref_data = torch.tensor(ref, dtype=torch.float)
    
    # Calculate the squared difference between the two images
    diff = ref_data - target_data
    mse = torch.mean(diff ** 2)
    
    # Calculate PSNR
    psnr_value = 20 * torch.log10(255. / torch.sqrt(mse))
    
    return psnr_value.item()

# Define function for mean squared error (MSE)
def mse(target, ref):
    # Convert images to torch tensors
    target_data = torch.tensor(target, dtype=torch.float)
    ref_data = torch.tensor(ref, dtype=torch.float)
    
    # Calculate the squared difference between the two images
    diff = ref_data - target_data
    mse_value = torch.mean(diff ** 2)
    
    return mse_value.item()

# Define function that combines all three image quality metrics
def compare_images(target, ref):
    scores = []
    scores.append(psnr(target, ref))
    scores.append(mse(target, ref))
    
    return scores


#### Super Resolution Convolution Neural Network (SRCNN) Model
1) Explain the achitecture and hyper parameters of the SRCNN network from the original paper
2) 

In [13]:
# Instantiate the model
model = SRCNN()

optimizer = optim.Adam(model.parameters(), lr=0.0003)
criterion = nn.MSELoss()

print(model)


SRCNN(
  (conv1): Conv2d(1, 128, kernel_size=(9, 9), stride=(1, 1))
  (conv2): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 1, kernel_size=(5, 5), stride=(1, 1))
  (relu): ReLU()
)


In [15]:
import h5py
import torch
from torch.utils.data import Dataset, DataLoader

class HDF5Dataset(Dataset):
    def __init__(self, file_path):
        self.file = h5py.File(file_path, 'r')
        self.data = self.file['data']
        self.label = self.file['label']

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx][0]  # Squeeze the batch dimension
        label = self.label[idx][0]  # Squeeze the batch dimension
        return data, label

train_dataset = HDF5Dataset('train_data.h5')
test_dataset = HDF5Dataset('test_data.h5')

# View contents of train_dataset
print("Train Dataset:")
for i in range(len(train_dataset)):
    data, label = train_dataset[i]
    print(f"Sample {i}:")
    print("Data Shape:", data.shape)
    print("Label Shape:", label.shape)
    
    break



Train Dataset:
Sample 0:
Data Shape: (1, 32, 32)
Label Shape: (1, 20, 20)


### Defining the Dataloader

In [16]:
train_loader = DataLoader(train_dataset, shuffle=True)
test_loader = DataLoader(test_dataset, shuffle=False)

# Print the shape of batches in train_loader
for inputs, labels in train_loader:
    print(f"Batch Input Shape: {inputs.shape}, Batch Label Shape: {labels.shape}")
    break

Batch Input Shape: torch.Size([1, 1, 32, 32]), Batch Label Shape: torch.Size([1, 1, 20, 20])


#### Training process

In [None]:
num_epochs = 10
batch_size = 64

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
    # Print the average loss for the epoch
    print('[Epoch %d] Average Loss: %.3f' % (epoch + 1, running_loss / len(train_loader)))
    break


print('Finished Training')
