In [None]:
import skimage
import sklearn.feature_extraction
import matplotlib.pyplot as plt
import numpy as np
import os
import math
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torchvision

from torch.optim import lr_scheduler

%load_ext tensorboard
from torch.utils.tensorboard import SummaryWriter
from fastai.layers import PixelShuffle_ICNR

In [None]:
dir = '/workspace/data/Dhruv/pytorch/SuperResolution/tiny-imagenet-200'
writer = SummaryWriter('runs/resunet34')

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
train_images = np.zeros((100000-1821, 64, 64, 3)) # 1821 images have dimension (64x64) and not (64x64x3)
# Total good train images = 98179
i = 0
for f1 in os.listdir(dir + '/' + 'train'):
    for f2 in os.listdir(dir + '/' + 'train/' + f1 + '/images/'):
        if(f2.endswith(".JPEG")):
            img = skimage.io.imread(dir + '/' + 'train/' + f1 + '/images/' + f2)
            if img.ndim ==3:
                train_images[i] = img
                i += 1

In [None]:
train_images_np = np.zeros((98179, 64, 64, 3))
for i in range(98179):
    train_images_np[i] = train_images[i]/255
train_images = train_images_np
del train_images_np

In [None]:
test_images = np.zeros((9832, 64, 64, 3)) # 9832 validation images
i = 0
for f in os.listdir(dir + '/' + 'val/images'):
    if(f.endswith(".JPEG")):
        img = skimage.io.imread(dir + '/' + 'val/images/' + f)
        if img.ndim ==3:
            test_images[i] = img
            i += 1

In [None]:
test_images_np = np.zeros((9832, 64, 64, 3))
for i in range(9832):
    test_images_np[i] = test_images[i]/255
test_images = test_images_np
del test_images_np

In [None]:
def bicubicDownsample(images, scale_factor=0.5):
    out = torch.nn.functional.interpolate(images, scale_factor=scale_factor, mode='bicubic', align_corners=True)
    return out

In [None]:
y_tr = torch.from_numpy(train_images).permute(0,3,1,2)
y_tr = y_tr.float()
y_te = torch.from_numpy(test_images).permute(0,3,1,2)
y_te = y_te.float()

In [None]:
del train_images
del test_images

In [None]:
x_tr = bicubicDownsample(y_tr)
x_tr = x_tr.float()
x_te = bicubicDownsample(y_te)
x_te = x_te.float()

In [None]:
y_tr = y_tr.contiguous()
y_te = y_te.contiguous()

In [None]:
print(x_tr.is_contiguous())
print(x_te.is_contiguous())
print(y_tr.is_contiguous())
print(y_te.is_contiguous())

In [None]:
'''
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
'''

In [None]:
# Creating custom training dataset
class TrainDataset(Dataset):
    def __init__(self, transform=None):
        self.x = x_tr
        self.y = y_tr
        self.n_samples = self.x.shape[0]
        self.transform = transform
        
    def __getitem__(self, index):
        if self.transform:
            return self.transform(self.x[index]), self.y[index]
        else:
            return self.x[index], self.y[index]
    
    def __len__(self):
        return self.n_samples
    
# Creating custom testing dataset
class TestDataset(Dataset):
    def __init__(self, transform=None):
        self.x = x_te
        self.y = y_te
        self.n_samples = self.x.shape[0]
        self.transform = transform
        
    def __getitem__(self, index):
        if self.transform:
            return self.transform(self.x[index]), self.y[index]
        else:
            return self.x[index], self.y[index]
    
    def __len__(self):
        return self.n_samples

In [None]:
batch_size = 32

In [None]:
train_dataset = TrainDataset()
test_dataset = TestDataset()

# Implementing train loader to split the data into batches
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True, # data reshuffled at every epoch
                          num_workers=2) # Use several subprocesses to load the data

# Implementing train loader to split the data into batches
test_loader = DataLoader(dataset=test_dataset,
                          batch_size=batch_size,
                          shuffle=True, # data reshuffled at every epoch
                          num_workers=2) # Use several subprocesses to load the data

In [None]:
EPOCHS = 200
n_samples = len(train_dataset)
n_iterations = math.ceil(n_samples/batch_size)

## Creating Model

In [None]:
class DoubleConv(nn.Module):
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.skip = nn.Conv2d(in_channels, out_channels, kernel_size = 1) # Skip connection
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels)
        )
        torch.nn.init.xavier_uniform_(self.skip.weight)
        self.skip.bias.data.fill_(0.1)
        
        for i in [0,3]:
            torch.nn.init.xavier_uniform_(self.double_conv[i].weight)
            self.double_conv[i].bias.data.fill_(0.1)
        
        
    def forward(self, x):
        skip_x = self.skip(x)
        conv_x = self.double_conv(x)
        added_x = skip_x + conv_x  # Element-wise addition of skip connection filters and residual filters
        return F.relu_(added_x) # Inplace functional version of relu


class PsUpsample(nn.Module): # Upsampling using pixel shuffle
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upsample = PixelShuffle_ICNR(in_channels, in_channels//2, scale=2)
        self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1, x2):
        x1 = self.upsample(x1)
        x = torch.cat((x2, x1), dim=1)
        x = self.conv(x)
        return x

In [None]:
class ResUNet(nn.Module):
    def __init__(self, out_channels):
        super().__init__()

        self.base_model = torchvision.models.resnet34(pretrained=True, progress=False)
        self.base_layers = list(self.base_model.children())
        
        # Encoder path
        self.in_layer1 = self.base_layers[0]
        self.in_layer2 = nn.Sequential(*self.base_layers[1:4])
        self.layer1 = nn.Sequential(*self.base_layers[4])
        self.layer2 = nn.Sequential(*self.base_layers[5])
        self.layer3 = nn.Sequential(*self.base_layers[6])
        self.layer4 = nn.Sequential(*self.base_layers[7])
        
        # Cross path
        self.down_in1 = nn.Conv2d(64, 32 ,kernel_size=1)
        self.down_up = nn.Conv2d(3, 16, kernel_size=1)
        
        # Decoder path
        self.up1 = PsUpsample(512, 256)
        self.up2 = PsUpsample(256, 128)
        self.up3 = PsUpsample(128, 64)
        self.up4 = PsUpsample(64, 32)
        self.up5 = PsUpsample(32, 16)
        
        self.out_layer = DoubleConv(16,3)
        
        
    def forward(self, x):
        
        #Encoder path
        x_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        x_in1 = self.in_layer1(x_up)
        x_in2 = self.in_layer2(x_in1) # This is of same size as x_l1 so not used
        x_l1 = self.layer1(x_in2)
        x_l2 = self.layer2(x_l1)
        x_l3 = self.layer3(x_l2)
        x_l4 = self.layer4(x_l3)
        
        # Decoder path
        x = self.up1(x_l4, x_l3)
        x = self.up2(x, x_l2)
        x = self.up3(x, x_l1)
        x_in1 = self.down_in1(x_in1)
        x = self.up4(x, x_in1)
        x_up = self.down_up(x_up)
        x = self.up5(x, x_up)
        x = self.out_layer(x)
            
        return x

## Creating Loss Function (Perceptual Loss)

In [None]:
class VGGPerceptualLoss(nn.Module):
    
    def __init__(self):
        super().__init__()
        model = torchvision.models.vgg19(pretrained=True, progress=False)
        features = model.features
        self.relu2_2 = nn.Sequential()
        for i in range(9):
            self.relu2_2.add_module(name="relu2_2_"+str(i+1), module=features[i])    
        # Setting requires_grad=False to fix the perceptual loss model parameters 
        for param in self.parameters():
            param.requires_grad = False
            
    def forward(self, x):
        out_relu2_2 = self.relu2_2(x)
        return out_relu2_2

In [None]:
VGGLoss = VGGPerceptualLoss().to(device)

In [None]:
def PerceptualLoss(x, y):
    
    x_features = VGGLoss(x)
    y_features = VGGLoss(y)
    
    # Calculating feature loss
    C = y_features.shape[1]
    H = y_features.shape[2]
    W = y_features.shape[3]
    feature_loss = (F.l1_loss(y_features, x_features, reduction='sum') / (C*H*W)) + (F.l1_loss(x, y) / (C*H*W)) # Here assuming square of Euclidean Norm = MSE Loss
    return feature_loss

## Training Loop

In [None]:
# Implementing checkpoints
def save_checkpoint_best(epoch, model):
    print("Saving best model")
    PATH = "/workspace/data/Dhruv/pytorch/SuperResolution/BestModel/best_model_"+str(epoch)+".pt"
    torch.save(model.state_dict(), PATH)

def save_checkpoint(epoch, model, optimizer, loss):  # Saving model in a way so we can load and start training again
    PATH = "/workspace/data/Dhruv/pytorch/SuperResolution/Models/model_"+str(epoch)+".pt"
    print("Saving model")
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': loss,
            }, PATH)

In [None]:
tr_loss_log = []
val_loss_log = []

In [None]:
model = ResUNet(3).to(device)

In [None]:
# Set requires_grad of parameters of base_model = False
for param in model.base_model.parameters():
    param.requires_grad = False

In [None]:
for param in model.parameters():
    print(param.requires_grad)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
#loss = PerceptualLoss used directly in training loop

example = iter(train_loader)
example_data, example_target = example.next()
writer.add_graph(model, example_data.to(device))
writer.close()

In [None]:
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=6, verbose=True)

In [None]:
# Training Loop
def train_model():

  least_val_loss = math.inf

  for epoch in range(EPOCHS):
      
      beg_time = time.time() #To calculate time taken for each epoch
      
      train_loss = 0.0
      val_loss = 0.0
      
      for i, (x, y) in enumerate(train_loader):
          x = x.to(device)
          y = y.to(device)
          # Will run for 1000 iterations per epoch
          optimizer.zero_grad()
          # Forward pass
          out = model(x)
          #Calculating loss
          loss = PerceptualLoss(out, y)
          # Backward pass
          loss.backward()
          # Update gradients
          optimizer.step()
          # Get training loss
          train_loss += loss.item()
      tr_loss_log.append(train_loss)
      
      model.eval()
      with torch.no_grad():
          for i, (x, y) in enumerate(test_loader):
              x = x.to(device)
              y = y.to(device)
              out = model(x)
              #Calculating loss
              loss = PerceptualLoss(out, y)
              # Get validation loss
              val_loss += loss.item()
          val_loss_log.append(val_loss)
      model.train()
    
      scheduler.step(val_loss)
      
      # Saving checkpoints
      save_checkpoint(epoch+1, model, optimizer, val_loss)
      if(val_loss < least_val_loss):
          save_checkpoint_best(epoch+1, model)
          least_val_loss = val_loss
          
      end_time = time.time()
      print('Epoch: {:.0f}/{:.0f}, Time: {:.0f}m {:.0f}s, Train_Loss: {:.4f}, Val_loss: {:.4f}'.format(
          epoch+1, EPOCHS, (end_time-beg_time)//60, (end_time-beg_time)%60, train_loss, val_loss))
      writer.add_scalar('Training_loss', train_loss, epoch*n_iterations+i)
      writer.add_scalar('Validation_loss', val_loss, epoch*n_iterations+i)
      writer.close()

In [None]:
train_model()

In [None]:
%tensorboard --logdir=runs/resunet34

In [None]:
# Saving the final model
PATH = "/workspace/data/Dhruv/pytorch/SuperResolution/FinalModel/final_trained_model.pt"
print("Saving final model")
torch.save(model.state_dict(), PATH)

In [None]:

# To load the best model (fill in the best model epoch number)
loaded_best_model = ResUNet(3).to(device)
checkpoint = torch.load("/workspace/data/Dhruv/pytorch/SuperResolution/BestModel/best_model_15.pt")
loaded_best_model.load_state_dict(checkpoint)
loaded_best_model.eval()
model = loaded_best_model

In [None]:

# Loading a model with desired epoch number
loaded_model = ResUNet(3).to(device)
checkpoint = torch.load("/workspace/data/Dhruv/pytorch/SuperResolution/Models/model_8.pt")
loaded_model.load_state_dict(checkpoint['model_state_dict'])
loaded_model.eval()
model = loaded_model


## Inference

In [None]:
model.eval()
example = iter(test_loader)
example_data, example_target = example.next()
plt.imshow(example_data[15].permute(1,2,0))

In [None]:
out = model(example_data.to(device))
plt.imshow(out[15].cpu().detach().permute(1,2,0))

In [None]:
plt.imshow(example_target[15].permute(1,2,0))

In [None]:
# Visualize feature maps
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

In [None]:
data, _ = test_dataset[30]
plt.imshow(data.permute(1,2,0))

In [None]:
model.in_layer2.register_forward_hook(get_activation('in_layer2'))
data, _ = test_dataset[30]
data.unsqueeze_(0)
output = model(data.to(device))

act = activation['in_layer2'].squeeze().cpu()
#fig, axarr = plt.subplots(act.size(0))
for idx in range(act.size(0)):
    plt.imshow(act[idx])
    plt.show()