In [None]:
import skimage
import sklearn.feature_extraction
import matplotlib.pyplot as plt
import numpy as np
import os
import math
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torchvision

%load_ext tensorboard
from torch.utils.tensorboard import SummaryWriter
from fastai.layers import PixelShuffle_ICNR

In [None]:
dir = '/workspace/data/Dhruv/pytorch/SuperResolution/Data'
writer = SummaryWriter('runs/2_2_L1_Res')

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
train_images = np.zeros((800, 2048, 1080, 3))
i = 0
for f in os.listdir(dir + '/' + 'Train'):
    if(f.endswith(".png")):
        train_images[i] = (skimage.transform.resize(
            skimage.io.imread(dir + '/Train/' + f),(2048,1080), mode ='constant'))
        i += 1

In [None]:
# Function to extract 64x64 patches from the images. 20 patches from each image.
def patchExtract(images, patch_size=(224, 224), max_patches=5):
    pe = sklearn.feature_extraction.image.PatchExtractor(patch_size=patch_size, max_patches = max_patches)
    pe_fit = pe.fit(images)
    pe_trans = pe.transform(images)
    return pe_trans

In [None]:
train_images = patchExtract(train_images)

In [None]:
test_images = np.zeros((100, 2048, 1080, 3))
i = 0
for f in os.listdir(dir + '/' + 'Validation'):
    if(f.endswith(".png")):
        test_images[i] = (skimage.transform.resize(
            skimage.io.imread(dir + '/Validation/' + f),(2048,1080), mode ='constant'))
        i += 1

In [None]:
'''
train_images1 = np.asarray(train_images[:400], dtype=np.float32)
train_images_patches1 = patchExtract(train_images1)
del train_images1

test_images = np.asarray(test_images, dtype=np.float32)
test_images_patches = patchExtract(test_images)

train_images1 = np.asarray(train_images[400:], dtype=np.float32)
train_images_patches2 = patchExtract(train_images1)
train_images_patches = np.concatenate((train_images_patches1, train_images_patches2), axis=0)
'''

In [None]:
test_images = patchExtract(test_images)

In [None]:
print(train_images.shape)
print(test_images.shape)

In [None]:
def bicubicDownsample(images, scale_factor=0.5):
    out = torch.nn.functional.interpolate(images, scale_factor=scale_factor, mode='bicubic', align_corners=True)
    return out

In [None]:
y_tr = torch.from_numpy(train_images).permute(0,3,1,2)
y_tr = y_tr.float()
y_te = torch.from_numpy(test_images).permute(0,3,1,2)
y_te = y_te.float()

In [None]:
del train_images
del test_images

In [None]:
x_tr = bicubicDownsample(y_tr)
x_tr = x_tr.float()
x_te = bicubicDownsample(y_te)
x_te = x_te.float()

In [None]:
y_tr = y_tr.contiguous()
y_te = y_te.contiguous()

In [None]:
print(x_tr.is_contiguous())
print(x_te.is_contiguous())
print(y_tr.is_contiguous())
print(y_te.is_contiguous())

In [None]:
# Creating custom training dataset
class TrainDataset(Dataset):
    def __init__(self):
        self.x = x_tr
        self.y = y_tr
        self.n_samples = self.x.shape[0]
        
    def __getitem__(self, index):
        return self.x[index], self.y[index]
    
    def __len__(self):
        return self.n_samples
    
# Creating custom testing dataset
class TestDataset(Dataset):
    def __init__(self):
        self.x = x_te
        self.y = y_te
        self.n_samples = self.x.shape[0]
        
    def __getitem__(self, index):
        return self.x[index], self.y[index]
    
    def __len__(self):
        return self.n_samples

In [None]:
'''
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])])
'''

In [None]:
batch_size = 32

In [None]:
train_dataset = TrainDataset()
test_dataset = TestDataset()

# Implementing train loader to split the data into batches
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True, # data reshuffled at every epoch
                          num_workers=2) # Use several subprocesses to load the data

# Implementing train loader to split the data into batches
test_loader = DataLoader(dataset=test_dataset,
                          batch_size=batch_size,
                          shuffle=True, # data reshuffled at every epoch
                          num_workers=2) # Use several subprocesses to load the data

In [None]:
EPOCHS = 200
n_samples = len(train_dataset)
n_iterations = math.ceil(n_samples/batch_size)

## Creating Model

In [None]:
class DoubleConv(nn.Module):
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.skip = nn.Conv2d(in_channels, out_channels, kernel_size = 1) # Skip connection
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels)
        )
        
    def forward(self, x):
        skip_x = self.skip(x)
        conv_x = self.double_conv(x)
        added_x = skip_x + conv_x  # Element-wise addition of skip connection filters and residual filters
        return F.relu_(added_x) # Inplace functional version of relu
    

class PsUpsample(nn.Module): # Upsampling using pixel shuffle
    
    def __init__(self, in_channels):
        super().__init__()
        self.upsample = PixelShuffle_ICNR(in_channels, in_channels//2) # Reduction of C by 2^2
        self.dconv = DoubleConv(in_channels, in_channels//2)
    
    def forward(self, x1, x2):
        x1 = self.upsample(x1)
        x = torch.cat((x2, x1), dim=1)
        x = self.dconv(x)
        return x

In [None]:
resnet_model = torchvision.models.resnet34(pretrained=True, progress=False) # Pretrained Resnet34 for transfer learning

In [None]:
from fastai.model import DynamicUnet

In [None]:
mymodel = DynamicUnet(resnet_)

In [None]:
class Res_Unet(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Encoder path
        '''
        self.in_conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, padding=3, bias=False),
            nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU()
        )
        '''
        self.in_conv = nn.Sequential(
            resnet_model.conv1,
            resnet_model.bn1,
            resnet_model.relu,
            resnet_model.maxpool
        )
        
        self.layer1_0 = resnet_model.layer1[0]
        self.layer1_1 = resnet_model.layer1[1]
        self.layer1_2 = resnet_model.layer1[2]
        
        self.layer2_0 = nn.Sequential(
            resnet_model.layer2[0].conv1,
            resnet_model.layer2[0].bn1,
            resnet_model.layer2[0].relu,
            resnet_model.layer2[0].conv2,
            resnet_model.layer2[0].bn2
        )
        self.layer2_1 = resnet_model.layer2[1]
        self.layer2_2 = resnet_model.layer2[2]
        self.layer2_3 = resnet_model.layer2[3]
        
        self.il2 = resnet_model.layer2[0].downsample
        
        self.layer3_0 = nn.Sequential(
            resnet_model.layer3[0].conv1,
            resnet_model.layer3[0].bn1,
            resnet_model.layer3[0].relu,
            resnet_model.layer3[0].conv2,
            resnet_model.layer3[0].bn2
        )
        self.layer3_1 = resnet_model.layer3[1]
        self.layer3_2 = resnet_model.layer3[2]
        self.layer3_3 = resnet_model.layer3[3]
        self.layer3_4 = resnet_model.layer3[4]
        self.layer3_5 = resnet_model.layer3[5]
        
        self.il3 = resnet_model.layer3[0].downsample
        
        self.layer4_0 = nn.Sequential(
            resnet_model.layer4[0].conv1,
            resnet_model.layer4[0].bn1,
            resnet_model.layer4[0].relu,
            resnet_model.layer4[0].conv2,
            resnet_model.layer4[0].bn2
        )
        self.layer4_1 = resnet_model.layer4[1]
        self.layer4_2 = resnet_model.layer4[2]
        
        self.il4 = resnet_model.layer4[0].downsample
        
        # Decoder path
        self.up1 = PsUpsample(512) # Reduction of C by 2^2 followed by 1x1 conv, concat and DoubleConv i.e. output channels = 512
        self.up2 = PsUpsample(256)
        self.up3 = PsUpsample(128)
        self.outconvblock = nn.Sequential(                   # Input to this block has 128 channels and image size = input size
            PixelShuffle_ICNR(64, 64),          # This block is like the one in SRGAN    # This block can be repeated for x4
            nn.ReLU(inplace=True),                # Remove this ReLU?
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            PixelShuffle_ICNR(64, 64),
            nn.ReLU(inplace=True),
        )
        self.outconv = nn.Sequential(
            nn.Conv2d(64, 3, kernel_size=7, padding=3),
            nn.ReLU(inplace=True)
        )
             
        
    def forward(self, x):
        
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        
        # Encoder Path
        x = self.in_conv(x)
        
        # Layer 1
        identity = x
        x = self.layer1_0(x)
        x = x + identity
        x = F.relu(x, inplace=True)
        
        identity = x
        x = self.layer1_1(x)
        x = x + identity
        x = F.relu(x, inplace=True)
        
        identity = x
        x = self.layer1_2(x)
        x = x + identity
        x = F.relu(x, inplace=True)
        x1 = x
        
        # Layer 2
        identity = x
        x = self.layer2_0(x)
        identity = self.il2(identity)
        x = x + identity
        x = F.relu(x, inplace=True)
        
        identity = x
        x = self.layer2_1(x)
        x = x + identity
        x = F.relu(x, inplace=True)
        
        identity = x
        x = self.layer2_2(x)
        x = x + identity
        x = F.relu(x, inplace=True)
        
        identity = x
        x = self.layer2_3(x)
        x = x + identity
        x = F.relu(x, inplace=True)
        x2 = x
        
        # Layer 3
        identity = x
        x = self.layer3_0(x)
        identity = self.il3(identity)
        x = x + identity
        x = F.relu(x, inplace=True)
        
        identity = x
        x = self.layer3_1(x)
        x = x + identity
        x = F.relu(x, inplace=True)
        
        identity = x
        x = self.layer3_2(x)
        x = x + identity
        x = F.relu(x, inplace=True)
        
        identity = x
        x = self.layer3_3(x)
        x = x + identity
        x = F.relu(x, inplace=True)
        
        identity = x
        x = self.layer3_4(x)
        x = x + identity
        x = F.relu(x, inplace=True)
        
        identity = x
        x = self.layer3_5(x)
        x = x + identity
        x = F.relu(x, inplace=True)
        x3 = x
        
        # Layer 4
        identity = x
        x = self.layer4_0(x)
        identity = self.il4(identity)
        x = x + identity
        x = F.relu(x, inplace=True)
        
        identity = x
        x = self.layer4_1(x)
        x = x + identity
        x = F.relu(x, inplace=True)
        
        identity = x
        x = self.layer4_2(x)
        x = x + identity
        x = F.relu(x, inplace=True)
        
        # Decoder Path
        x = self.up1(x, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x = self.outconvblock(x)
        x = self.outconv(x)

        return x

In [None]:
model = Res_Unet()

In [None]:
for param in model.parameters():    # set requires_grad true for in_conv, outconv, outconvblock, up1, up2, up3
    param.requires_grad=False
#for param in model.in_conv.parameters():    # set requires_grad true for in_conv, outconv, outconvblock, up1, up2, up3
    #param.requires_grad=True
for param in model.up1.parameters():    # set requires_grad true for in_conv, outconv, outconvblock, up1, up2, up3
    param.requires_grad=True
for param in model.up2.parameters():    # set requires_grad true for in_conv, outconv, outconvblock, up1, up2, up3
    param.requires_grad=True
for param in model.up3.parameters():    # set requires_grad true for in_conv, outconv, outconvblock, up1, up2, up3
    param.requires_grad=True
for param in model.outconv.parameters():    # set requires_grad true for in_conv, outconv, outconvblock, up1, up2, up3
    param.requires_grad=True
for param in model.outconvblock.parameters():    # set requires_grad true for in_conv, outconv, outconvblock, up1, up2, up3
    param.requires_grad=True

In [None]:
for param in model.parameters():    # set requires_grad true for in_conv, outconv, outconvblock, up1, up2, up3
    print(param.requires_grad)

In [None]:
model

## Creating Loss Function (Perceptual Loss)

In [None]:
class VGGPerceptualLoss(nn.Module):
    
    def __init__(self):
        super().__init__()
        model = torchvision.models.vgg19(pretrained=True, progress=False)
        features = model.features
        self.relu2_2 = nn.Sequential()
        for i in range(9):
            self.relu2_2.add_module(name="relu2_2_"+str(i+1), module=features[i])    
        # Setting requires_grad=False to fix the perceptual loss model parameters 
        for param in self.parameters():
            param.requires_grad = False
            
    def forward(self, x):
        out_relu2_2 = self.relu2_2(x)
        return out_relu2_2

In [None]:
VGGLoss = VGGPerceptualLoss().to(device)

In [None]:
def PerceptualLoss(x, y):
    
    x_features = VGGLoss(x)
    y_features = VGGLoss(y)
    
    # Calculating feature loss
    C = y_features.shape[1]
    H = y_features.shape[2]
    W = y_features.shape[3]
    feature_loss = (F.l1_loss(y_features, x_features, reduction='sum') / (C*H*W)) + (F.l1_loss(x, y) / (C*H*W)) # Here assuming square of Euclidean Norm = MSE Loss
    return feature_loss

## Training Loop

In [None]:
# Implementing checkpoints
def save_checkpoint_best(epoch, model):
    print("Saving best model")
    PATH = "/workspace/data/Dhruv/pytorch/SuperResolution/BestModel/best_model_"+str(epoch)+".pt"
    torch.save(model.state_dict(), PATH)

def save_checkpoint(epoch, model, optimizer, loss):  # Saving model in a way so we can load and start training again
    PATH = "/workspace/data/Dhruv/pytorch/SuperResolution/Models/model_"+str(epoch)+".pt"
    print("Saving model")
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': loss,
            }, PATH)

In [None]:
tr_loss_log = []
val_loss_log = []

In [None]:
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
#loss = PerceptualLoss used directly in training loop

example = iter(train_loader)
example_data, example_target = example.next()
writer.add_graph(model, example_data.to(device))
writer.close()

In [None]:
# Training Loop
def train_model():

  least_val_loss = math.inf

  for epoch in range(EPOCHS):
      
      beg_time = time.time() #To calculate time taken for each epoch
      
      train_loss = 0.0
      val_loss = 0.0
      
      for i, (x, y) in enumerate(train_loader):
          x = x.to(device)
          y = y.to(device)
          # Will run for 1000 iterations per epoch
          optimizer.zero_grad()
          # Forward pass
          out = model(x)
          #Calculating loss
          loss = PerceptualLoss(out, y)
          # Backward pass
          loss.backward()
          # Update gradients
          optimizer.step()
          # Get training loss
          train_loss += loss.item()
      tr_loss_log.append(train_loss)
      
      model.eval()
      with torch.no_grad():
          for i, (x, y) in enumerate(test_loader):
              x = x.to(device)
              y = y.to(device)
              out = model(x)
              #Calculating loss
              loss = PerceptualLoss(out, y)
              # Get validation loss
              val_loss += loss.item()
          val_loss_log.append(val_loss)
      model.train()
      
      # Saving checkpoints
      save_checkpoint(epoch+1, model, optimizer, val_loss)
      if(val_loss < least_val_loss):
          save_checkpoint_best(epoch+1, model)
          least_val_loss = val_loss
          
      end_time = time.time()
      print('Epoch: {:.0f}/{:.0f}, Time: {:.0f}m {:.0f}s, Train_Loss: {:.4f}, Val_loss: {:.4f}'.format(
          epoch+1, EPOCHS, (end_time-beg_time)//60, (end_time-beg_time)%60, train_loss, val_loss))
      writer.add_scalar('Training_loss', train_loss, epoch*n_iterations+i)
      writer.add_scalar('Validation_loss', val_loss, epoch*n_iterations+i)
      writer.close()

In [None]:
train_model()

In [None]:
# Saving the final model
PATH = "/workspace/data/Dhruv/pytorch/SuperResolution/FinalModel/final_trained_model.pt"
print("Saving final model")
torch.save(model.state_dict(), PATH)

In [None]:
'''
# To load the best model (fill in the best model epoch number)
loaded_best_model = SR_UNet(3).to(device)
checkpoint = torch.load("/workspace/data/Dhruv/pytorch/SuperResolution/BestModel/best_model_74.pt")
loaded_best_model.load_state_dict(checkpoint['model_state_dict'])
loaded_best_model.eval()
#model = loaded_best_model
'''

In [None]:
# Loading a model with desired epoch number
loaded_model = Res_Unet().to(device)
checkpoint = torch.load("/workspace/data/Dhruv/pytorch/SuperResolution/Models/model_1.pt")
loaded_model.load_state_dict(checkpoint['model_state_dict'])
loaded_model.eval()
#model = loaded_model

## Inference

In [None]:
model.eval()
example = iter(test_loader)
example_data, example_target = example.next()
plt.imshow(example_data[15].permute(1,2,0))

In [None]:
out = model(example_data.to(device))
plt.imshow(out[15].cpu().detach().permute(1,2,0))

In [None]:
plt.imshow(example_target[15].permute(1,2,0))

In [None]:
%tensorboard --logdir=runs/2_2_L1_Res

In [None]:
include data preprocessing, reduce model complexity, reduce learning rate or add decay rate, add dropout layers