In [None]:
import skimage
import sklearn.feature_extraction
import matplotlib.pyplot as plt
import numpy as np
import os
import math
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torchvision

%load_ext tensorboard
from torch.utils.tensorboard import SummaryWriter

In [None]:
dir = '/workspace/data/Dhruv/pytorch/SuperResolution/Data'
writer = SummaryWriter()

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
train_images = []
for f in os.listdir(dir + '/' + 'Train'):
    if(f.endswith(".png")):
        train_images.append(skimage.transform.resize(
            skimage.io.imread(dir + '/Train/' + f),(2048,1080), mode ='constant'))

In [None]:
test_images = []
for f in os.listdir(dir + '/' + 'Validation'):
    if(f.endswith(".png")):
        test_images.append(skimage.transform.resize(
            skimage.io.imread(dir + '/Validation/' + f),(2048,1080), mode ='constant'))

In [None]:
# Function to extract 64x64 patches from the images. 20 patches from each image.
def patchExtract(images, patch_size=(64, 64), max_patches=20):
    pe = sklearn.feature_extraction.image.PatchExtractor(patch_size=patch_size, max_patches = max_patches)
    pe_fit = pe.fit(images)
    pe_trans = pe.transform(images)
    return pe_trans

In [None]:
train_images1 = np.asarray(train_images[:400], dtype=np.float32)

In [None]:
train_images_patches1 = patchExtract(train_images1)

In [None]:
del train_images1

In [None]:
test_images = np.asarray(test_images, dtype=np.float32)

In [None]:
test_images_patches = patchExtract(test_images)

In [None]:
train_images1 = np.asarray(train_images[400:], dtype=np.float32)
train_images_patches2 = patchExtract(train_images1)
train_images_patches = np.concatenate((train_images_patches1, train_images_patches2), axis=0)

In [None]:
print(train_images_patches.shape)
print(test_images_patches.shape)

In [None]:
def bicubicDownsample(images, scale_factor=0.5):
    out = torch.nn.functional.interpolate(images, scale_factor=scale_factor, mode='bicubic', align_corners=True)
    return out

In [None]:
del train_images_patches1
del train_images_patches2
del train_images
del test_images

In [None]:
y_tr = torch.from_numpy(train_images_patches).permute(0,3,1,2)
y_tr = y_tr.float()
y_te = torch.from_numpy(test_images_patches).permute(0,3,1,2)
y_te = y_te.float()

In [None]:
del train_images_patches
del test_images_patches

In [None]:
x_tr = bicubicDownsample(y_tr)
x_tr = x_tr.float()
x_te = bicubicDownsample(y_te)
x_te = x_te.float()

In [None]:
y_tr = y_tr.contiguous()
y_te = y_te.contiguous()

In [None]:
print(x_tr.is_contiguous())
print(x_te.is_contiguous())
print(y_tr.is_contiguous())
print(y_te.is_contiguous())

In [None]:
# Creating custom training dataset
class TrainDataset(Dataset):
    def __init__(self):
        self.x = x_tr
        self.y = y_tr
        self.n_samples = self.x.shape[0]
        
    def __getitem__(self, index):
        return self.x[index], self.y[index]
    
    def __len__(self):
        return self.n_samples
    
# Creating custom testing dataset
class TestDataset(Dataset):
    def __init__(self):
        self.x = x_te
        self.y = y_te
        self.n_samples = self.x.shape[0]
        
    def __getitem__(self, index):
        return self.x[index], self.y[index]
    
    def __len__(self):
        return self.n_samples

In [None]:
'''
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])])
'''

In [None]:
batch_size = 16

In [None]:
train_dataset = TrainDataset()
test_dataset = TestDataset()

# Implementing train loader to split the data into batches
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True, # data reshuffled at every epoch
                          num_workers=2) # Use several subprocesses to load the data

# Implementing train loader to split the data into batches
test_loader = DataLoader(dataset=test_dataset,
                          batch_size=batch_size,
                          shuffle=True, # data reshuffled at every epoch
                          num_workers=2) # Use several subprocesses to load the data

In [None]:
EPOCHS = 200
n_samples = len(train_dataset)
n_iterations = math.ceil(n_samples/batch_size)

## Creating Model

In [None]:
class DoubleConv(nn.Module):
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.skip = nn.Conv2d(in_channels, out_channels, kernel_size = 1) # Skip connection
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels)
        )
        
    def forward(self, x):
        skip_x = self.skip(x)
        conv_x = self.double_conv(x)
        added_x = skip_x + conv_x  # Element-wise addition of skip connection filters and residual filters
        return F.relu_(added_x) # Inplace functional version of relu
    

class PsUpsample(nn.Module): # Upsampling using pixel shuffle
    
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, in_channels//2, kernel_size=3, padding=1)
    
    def forward(self, x1, x2):
        x1 = F.interpolate(x1, scale_factor=2, mode='nearest')
        x1 = self.conv(x1)
        x = torch.cat((x2, x1), dim=1)
        return x
    
class UpConcatConv(nn.Module):
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.reduce = nn.Conv2d(in_channels, out_channels, kernel_size=1)  # 1x1 convolution to reduce num of channels to half
        self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1, x2):
        x1 = F.interpolate(x1, scale_factor=2, mode='nearest')
        x1 = self.reduce(x1)
        # No need to crop the feature maps from the corresponding contracting layer since we using padding in DoubleConv
        x = torch.cat((x2, x1), dim=1)
        return self.conv(x)

    
class InConv(nn.Module):  # First 9x9 convolution
    
    def __init__(self, in_channels):
        super().__init__()
        self.inconv = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=9, padding=4),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        x = self.inconv(x)
        return x

In [None]:
class SR_UNet(nn.Module):  # Parameters = 20009219
    
    def __init__(self, in_channels):
        super().__init__()
        self.inconv = InConv(in_channels)
        self.dconv1 = DoubleConv(64, 128)
        self.pool1 = nn.MaxPool2d(2)
        self.dconv2 = DoubleConv(128, 256)
        self.pool2 = nn.MaxPool2d(2)
        self.dconv3 = DoubleConv(256, 512)
        self.pool3 = nn.MaxPool2d(2)
        self.dconv4 = DoubleConv(512, 1024)
        self.up1 = UpConcatConv(1024, 512) # Reduction of C by 2^2 i.e. output channels = 256
        self.up2 = UpConcatConv(512, 256) # Output channels = 64
        self.up3 = UpConcatConv(256, 128)
        self.outconvblock = nn.Sequential(                   # Input to this block has 128 channels and image size = input size
            nn.Conv2d(128, 64, kernel_size=3, padding=1),    # This block can be repeated for x4
            nn.ReLU(inplace=True)
        )
        self.outconv = nn.Conv2d(64, 3, kernel_size=9, padding=4)       
            
    
    def forward(self, x):
        x = self.inconv(x)
        x1 = self.dconv1(x)
        x2 = self.pool1(x1)
        x2 = self.dconv2(x2)
        x3 = self.pool2(x2)
        x3 = self.dconv3(x3)
        x4 = self.pool3(x3)
        x4 = self.dconv4(x4)
        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = self.outconvblock(x)
        x = self.outconv(x)
        return x

## Creating Loss Function (Perceptual Loss)

In [None]:
class VGGPerceptualLoss(nn.Module):
    
    def __init__(self):
        super().__init__()
        model = torchvision.models.vgg16(pretrained=True, progress=False)
        features = model.features
        self.relu2_2 = nn.Sequential()
        for i in range(9):
            self.relu2_2.add_module(name="relu2_2_"+str(i+1), module=features[i])    
        # Setting requires_grad=False to fix the perceptual loss model parameters 
        for param in self.parameters():
            param.requires_grad = False
            
    def forward(self, x):
        out_relu2_2 = self.relu2_2(x)
        return out_relu2_2

In [None]:
VGGLoss = VGGPerceptualLoss().to(device)

In [None]:
def PerceptualLoss(x, y):
    
    x_features = VGGLoss(x)
    y_features = VGGLoss(y)
    
    # Calculating feature loss
    C = y_features.shape[1]
    H = y_features.shape[2]
    W = y_features.shape[3]
    feature_loss = F.mse_loss(y_features, x_features, reduction='sum') / (C*H*W) # Here assuming square of Euclidean Norm = MSE Loss
    return feature_loss

## Training Loop

In [None]:
# Implementing checkpoints
def save_checkpoint_best(epoch, model, optimizer, loss):
    print("Saving best model")
    PATH = "/workspace/data/Dhruv/pytorch/SuperResolution/BestModel/best_model_"+str(epoch)+".pt"
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': loss,
        
            }, PATH)

def save_checkpoint(epoch, model, optimizer, loss):
    PATH = "/workspace/data/Dhruv/pytorch/SuperResolution/Models/model_"+str(epoch)+".pt"
    print("Saving model")
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': loss,
            }, PATH)

In [None]:
tr_loss_log = []
val_loss_log = []

In [None]:
model = SR_UNet(3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Add weight decay?
#loss = PerceptualLoss used directly in training loop

example = iter(train_loader)
example_data, example_target = example.next()
writer.add_graph(model, example_data.to(device))
writer.close()

In [None]:
# Training Loop
def train_model():

  least_val_loss = math.inf

  for epoch in range(EPOCHS):
      
      beg_time = time.time() #To calculate time taken for each epoch
      
      train_loss = 0.0
      val_loss = 0.0
      
      for i, (x, y) in enumerate(train_loader):
          x = x.to(device)
          y = y.to(device)
          # Will run for 1000 iterations per epoch
          optimizer.zero_grad()
          # Forward pass
          out = model(x)
          #Calculating loss
          loss = PerceptualLoss(out, y)
          # Backward pass
          loss.backward()
          # Update gradients
          optimizer.step()
          # Get training loss
          train_loss += loss.item()
      tr_loss_log.append(train_loss)
      
      model.eval()
      with torch.no_grad():
          for i, (x, y) in enumerate(test_loader):
              x = x.to(device)
              y = y.to(device)
              out = model(x)
              #Calculating loss
              loss = PerceptualLoss(out, y)
              # Get validation loss
              val_loss += loss.item()
          val_loss_log.append(val_loss)
      model.train()
      
      # Saving checkpoints
      save_checkpoint(epoch, model, optimizer, val_loss)
      if(val_loss < least_val_loss):
          save_checkpoint_best(epoch, model, optimizer, val_loss)
          least_val_loss = val_loss
          
      end_time = time.time()
      print('Epoch: {:.0f}/{:.0f}, Time: {:.0f}m {:.0f}s, Train_Loss: {:.4f}, Val_loss: {:.4f}'.format(
          epoch+1, EPOCHS, (end_time-beg_time)//60, (end_time-beg_time)%60, train_loss, val_loss))
      writer.add_scalar('Training_loss', train_loss, epoch*n_iterations+i)
      writer.add_scalar('Validation_loss', val_loss, epoch*n_iterations+i)

In [None]:
train_model()

In [None]:
FILE = "/workspace/data/Dhruv/pytorch/SuperResolution/FinalModel/final_trained_model.pt"
print("Saving final model")
torch.save(model.state_dict(), FILE)

In [None]:
example = iter(test_loader)
example_data, example_target = example.next()
plt.imshow(example_data[15].permute(1,2,0))

In [None]:
out = model(example_data.to(device))
plt.imshow(out[15].cpu().detach().permute(1,2,0))

In [None]:
plt.imshow(example_target[15].permute(1,2,0))

In [None]:
%tensorboard --logdir=runs

In [None]:
include data preprocessing, reduce model complexity, reduce learning rate or add decay rate, add dropout layers