In [None]:
import skimage
import sklearn.feature_extraction
import matplotlib.pyplot as plt
import numpy as np
import os
import math
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torchvision

from torch.optim import lr_scheduler

%load_ext tensorboard
from torch.utils.tensorboard import SummaryWriter
from fastai.layers import PixelShuffle_ICNR

In [None]:
dir = '/workspace/data/Dhruv/pytorch/SuperResolution/Data'
writer = SummaryWriter('runs/resunet50decoder')

In [None]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [None]:
train_images = np.zeros((800, 512, 512, 3))
i = 0
for f in os.listdir(dir + '/' + 'Train'):
    if(f.endswith(".png")):
        train_images[i] = (skimage.transform.resize(
            skimage.io.imread(dir + '/Train/' + f),(512,512), mode ='constant'))
        i += 1

In [None]:
dir = '/workspace/data/Dhruv/pytorch/SuperResolution/ImagenetData/content/train'
train_images1 = np.zeros((7276, 300, 300, 3))
i = 0
for f in os.listdir(dir):
    if(f.endswith(".jpg")):
        train_images1[i] = (skimage.transform.resize(
            skimage.io.imread(dir + '/' + f),(300,300), mode ='constant'))
        i += 1
np.random.shuffle(train_images1)
test_images1 = train_images1[5776:]
train_images1 = train_images1[:5776]

In [None]:
# Function to extract 64x64 patches from the images. 20 patches from each image.
def patchExtract(images, patch_size=(128, 128), max_patches=5):
    pe = sklearn.feature_extraction.image.PatchExtractor(patch_size=patch_size, max_patches = max_patches)
    pe_fit = pe.fit(images)
    pe_trans = pe.transform(images)
    return pe_trans

In [None]:
dir = '/workspace/data/Dhruv/pytorch/SuperResolution/Data'
test_images = np.zeros((100, 512, 512, 3))
i = 0
for f in os.listdir(dir + '/' + 'Validation'):
    if(f.endswith(".png")):
        test_images[i] = (skimage.transform.resize(
            skimage.io.imread(dir + '/Validation/' + f),(512,512), mode ='constant'))
        i += 1

In [None]:
train_images1 = patchExtract(train_images1, max_patches=3)
test_images1 = patchExtract(test_images1, max_patches=3)
train_images = patchExtract(train_images)
test_images = patchExtract(test_images)

In [None]:
print(train_images.shape)
print(test_images.shape)
print(train_images1.shape)
print(test_images1.shape)

In [None]:
train_images = np.concatenate((train_images, train_images1), axis=0)
test_images = np.concatenate((test_images, test_images1), axis=0)
np.random.shuffle(train_images)
np.random.shuffle(test_images)

In [None]:
print(train_images.shape)
print(test_images.shape)

In [None]:
def bicubicDownsample(images, scale_factor=0.5):
    out = torch.nn.functional.interpolate(images, scale_factor=scale_factor, mode='bicubic', align_corners=True)
    return out

In [None]:
y_tr = torch.from_numpy(train_images).permute(0,3,1,2)
y_tr = y_tr.float()
y_te = torch.from_numpy(test_images).permute(0,3,1,2)
y_te = y_te.float()

In [None]:
del train_images
del test_images

In [None]:
x_tr = bicubicDownsample(y_tr)
x_tr = x_tr.float()
x_te = bicubicDownsample(y_te)
x_te = x_te.float()

In [None]:
y_tr = y_tr.contiguous()
y_te = y_te.contiguous()

In [None]:
print(x_tr.is_contiguous())
print(x_te.is_contiguous())
print(y_tr.is_contiguous())
print(y_te.is_contiguous())

In [None]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

In [None]:
# Creating custom training dataset
class TrainDataset(Dataset):
    def __init__(self, transform=None):
        self.x = x_tr
        self.y = y_tr
        self.n_samples = self.x.shape[0]
        self.transform = transform
        
    def __getitem__(self, index):
        if self.transform:
            return self.transform(self.x[index]), self.y[index]
        else:
            return self.x[index], self.y[index]
    
    def __len__(self):
        return self.n_samples
    
# Creating custom testing dataset
class TestDataset(Dataset):
    def __init__(self, transform=None):
        self.x = x_te
        self.y = y_te
        self.n_samples = self.x.shape[0]
        self.transform = transform
        
    def __getitem__(self, index):
        if self.transform:
            return self.transform(self.x[index]), self.y[index]
        else:
            return self.x[index], self.y[index]
    
    def __len__(self):
        return self.n_samples

In [None]:
batch_size = 32

In [None]:
train_dataset = TrainDataset()
test_dataset = TestDataset()

# Implementing train loader to split the data into batches
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True, # data reshuffled at every epoch
                          num_workers=2) # Use several subprocesses to load the data

# Implementing train loader to split the data into batches
test_loader = DataLoader(dataset=test_dataset,
                          batch_size=batch_size,
                          shuffle=True, # data reshuffled at every epoch
                          num_workers=2) # Use several subprocesses to load the data

In [None]:
EPOCHS = 100
n_samples = len(train_dataset)
n_iterations = math.ceil(n_samples/batch_size)

## Creating Model

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        if(in_channels != out_channels):
            self.skip = nn.Conv2d(in_channels, out_channels, kernel_size = 1) # Skip connection
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        
    def forward(self, x):
        if(self.in_channels != self.out_channels):
            skip_x = self.skip(x)
        else:
            skip_x = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        added_x = skip_x + x  # Element-wise addition of skip connection filters and residual filters
        return self.relu(added_x)

class DecoderLayer(nn.Module):
    def __init__(self, in_channels, out_channels, n_blocks):
        super().__init__()
        self.upsample = PixelShuffle_ICNR(in_channels, in_channels//2, scale=2)
        self.blocks = self.make_layer(in_channels, out_channels, n_blocks)
    
    def make_layer(self, in_channels, out_channels, n_blocks):
        Blocks = []
        Blocks.append(DecoderBlock(in_channels, in_channels))
        Blocks.append(DecoderBlock(in_channels, out_channels))
        for _ in range(2, n_blocks):
            Blocks.append(DecoderBlock(out_channels, out_channels))
        return nn.Sequential(*Blocks)
            
    def forward(self, x1, x2):
        x1 = self.upsample(x1)
        x = torch.cat((x2,x1), dim=1)
        x = self.blocks(x)
        return x

class FinalDecoderLayer(nn.Module):
    def __init__(self, in_channels, out_channels, n_blocks):
        super().__init__()
        self.blocks = self.make_layer(in_channels, out_channels, n_blocks)
    
    def make_layer(self, in_channels, out_channels, n_blocks):
        Blocks = []
        Blocks.append(DecoderBlock(in_channels, in_channels))
        Blocks.append(DecoderBlock(in_channels, out_channels))
        for _ in range(2, n_blocks):
            Blocks.append(DecoderBlock(out_channels, out_channels))
        return nn.Sequential(*Blocks)
            
    def forward(self, x):
        x = self.blocks(x)
        return x

In [None]:
class ResUNet(nn.Module):
    def __init__(self, out_channels):
        super().__init__()

        self.base_model = torchvision.models.resnet50(pretrained=True, progress=False)
        self.base_layers = list(self.base_model.children())
        
        # Encoder path
        self.in_layer1 = self.base_layers[0]
        self.in_layer2 = nn.Sequential(*self.base_layers[1:4])
        self.layer1 = nn.Sequential(*self.base_layers[4])
        self.layer2 = nn.Sequential(*self.base_layers[5])
        self.layer3 = nn.Sequential(*self.base_layers[6])
        self.layer4 = nn.Sequential(*self.base_layers[7])
        
        # Cross path
        self.down_in1 = nn.Conv2d(64, 128 ,kernel_size=1)
        self.down_up = nn.Conv2d(3, 64, kernel_size=1)
        
        # Decoder path
        self.up1 = DecoderLayer(2048, 1024, 3)
        self.up2 = DecoderLayer(1024, 512, 6)
        self.up3 = DecoderLayer(512, 256, 4)
        self.up4 = DecoderLayer(256, 128, 3)
        self.up5 = DecoderLayer(128, 64, 3)
        
        self.out_layer = FinalDecoderLayer(64, 3, 3)
        
        
    def forward(self, x):
        
        #Encoder path
        x_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        x_in1 = self.in_layer1(x_up)
        x_in2 = self.in_layer2(x_in1) # This is of same size as x_l1 so not used
        x_l1 = self.layer1(x_in2)
        x_l2 = self.layer2(x_l1)
        x_l3 = self.layer3(x_l2)
        x_l4 = self.layer4(x_l3)
        
        # Decoder path
        x = self.up1(x_l4, x_l3)
        x = self.up2(x, x_l2)
        x = self.up3(x, x_l1)
        x_in1 = self.down_in1(x_in1)
        x = self.up4(x, x_in1)
        x_up = self.down_up(x_up)
        x = self.up5(x, x_up)
        x = self.out_layer(x)
            
        return x

## Creating Loss Function (Perceptual Loss)

In [None]:
class VGGPerceptualLoss(nn.Module):
    
    def __init__(self):
        super().__init__()
        model = torchvision.models.vgg16(pretrained=True, progress=False)
        features = model.features
        self.relu1_2 = nn.Sequential()
        self.relu2_2 = nn.Sequential()
        self.relu3_3 = nn.Sequential()
        self.relu4_3 = nn.Sequential()
        for i in range(4):
            self.relu1_2.add_module(name="relu1_2_"+str(i+1), module=features[i])
        for i in range(4, 9):
            self.relu2_2.add_module(name="relu2_2_"+str(i-3), module=features[i])
        for i in range(9, 16):
            self.relu3_3.add_module(name="relu3_3_"+str(i-8), module=features[i])
        for i in range(16, 23):
            self.relu4_3.add_module(name="relu4_3_"+str(i-15), module=features[i])      
        # Setting requires_grad=False to fix the perceptual loss model parameters 
        for param in self.parameters():
            param.requires_grad = False
            
    def forward(self, x):
        out_relu1_2 = self.relu1_2(x)
        out_relu2_2 = self.relu2_2(out_relu1_2)
        out_relu3_3 = self.relu3_3(out_relu2_2)
        out_relu4_3 = self.relu4_3(out_relu3_3)
        return out_relu1_2, out_relu2_2, out_relu3_3, out_relu4_3
    
# Function to calculate Gram matrix
def gram(x):
    (N, C, H, W) = x.shape
    psy = x.view(N, C, H*W)
    psy_T = psy.transpose(1, 2)
    G = torch.bmm(psy, psy_T) / (C*H*W)  # Should we divide by N here? Or does batch matric multiplication do that on it's own? 
    return G

def TVR(x, TV_WEIGHT=1e-8):
    diff_i = torch.sum(torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]))
    diff_j = torch.sum(torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]))
    tv_loss = TV_WEIGHT*(diff_i + diff_j)
    return tv_loss

In [None]:
VGGLoss = VGGPerceptualLoss().to(device)

In [None]:
def PerceptualLoss(x, y, STYLE_WEIGHT=1e-4, CONTENT_WEIGHT=1e-4, PIXEL_WEIGHT=1):
    
    x_features = VGGLoss(x)
    y_features = VGGLoss(y)
    
    # Calculating per-pixel loss
    C = y.shape[1]
    H = y.shape[2]
    W = y.shape[3]
    pixel_loss =  F.l1_loss(x, y, reduction='sum') / (C*H*W)
    #print(pixel_loss)
    
    # Calculating Total variation regularization value
    tvr_loss = TVR(x)
    #print(tvr_loss)
    
    # Calculating content loss
    #weights = [0.25, 0.25, 0.25]
    content_loss = 0.0
    for i in range(2,4):
        C = y_features[i].shape[1]
        H = y_features[i].shape[2]
        W = y_features[i].shape[3]
        content_loss += (F.l1_loss(y_features[i], x_features[i], reduction='sum') / (C*H*W))
        #print('c',content_loss)
    #print(content_loss)
    '''
    # Calculating content loss
    C = y_features[2].shape[1]
    H = y_features[2].shape[2]
    W = y_features[2].shape[3]
    content_loss = F.l1_loss(x_features[2], y_features[2], reduction='sum') / (C*H*W)
    #print(content_loss)
    '''
    # Calculating Style loss
    style_loss = 0.0
    for i in range(4):
        C = y_features[i].shape[1]
        H = y_features[i].shape[2]
        W = y_features[i].shape[3]
        style_loss += F.l1_loss(gram(x_features[i]), gram(y_features[i]), reduction='sum')
        #print('s ',style_loss)
    #print(style_loss)
    total_loss = STYLE_WEIGHT*style_loss + CONTENT_WEIGHT*content_loss + PIXEL_WEIGHT*pixel_loss + tvr_loss
    return total_loss

## Training Loop

In [None]:
# Implementing checkpoints
def save_checkpoint_best(epoch, model):
    print("Saving best model")
    PATH = "/workspace/data/Dhruv/pytorch/SuperResolution/BestModel/best_model_"+str(epoch)+".pt"
    torch.save(model.state_dict(), PATH)

def save_checkpoint(epoch, model, optimizer, loss):  # Saving model in a way so we can load and start training again
    PATH = "/workspace/data/Dhruv/pytorch/SuperResolution/Models/model_"+str(epoch)+".pt"
    print("Saving model")
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': loss,
            }, PATH)

In [None]:
tr_loss_log = []
val_loss_log = []

In [None]:
model = ResUNet(3).to(device)

In [None]:
for name, layer in model._modules.items():
    if name in ['in_layer1', 'in_layer2', 'layer1', 'layer2', 'layer3', 'layer4']:
        for param in layer.parameters():
            param.requires_grad = False

In [None]:
for param in model.parameters():
    print(param.requires_grad)

In [None]:
optimizer = torch.optim.Adadelta(model.parameters())
#loss = PerceptualLoss used directly in training loop

example = iter(train_loader)
example_data, example_target = example.next()
writer.add_graph(model, example_data.to(device))
writer.close()

In [None]:
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=7, verbose=True)

In [None]:
# Training Loop
def train_model():
    least_val_loss = math.inf
    flag = False
    for epoch in range(EPOCHS):   
        '''
        if((epoch != 0) and ((epoch)%20 == 0)):
            if(flag == False):
                flag = True
                print("Setting encoder to trainable")
                for name, layer in model._modules.items():
                    if name in ['in_layer1', 'in_layer2', 'layer1', 'layer2', 'layer3', 'layer4']:
                        for param in layer.parameters():
                            param.requires_grad = True
            elif(flag == True):
                flag = False
                print("Setting encoder to non-trainable")
                for name, layer in model._modules.items():
                    if name in ['in_layer1', 'in_layer2', 'layer1', 'layer2', 'layer3', 'layer4']:
                        for param in layer.parameters():
                            param.requires_grad = False
        '''
        beg_time = time.time() #To calculate time taken for each epoch

        train_loss = 0.0
        val_loss = 0.0

        for i, (x, y) in enumerate(train_loader):
            x = x.to(device)
            y = y.to(device)
            # Will run for 1000 iterations per epoch
            optimizer.zero_grad()
            # Forward pass
            out = model(x)
            #Calculating loss
            loss = PerceptualLoss(out, y)
            # Backward pass
            loss.backward()
            # Update gradients
            optimizer.step()
            # Get training loss
            train_loss += loss.item()
        tr_loss_log.append(train_loss)

        model.eval()
        with torch.no_grad():
            for i, (x, y) in enumerate(test_loader):
                x = x.to(device)
                y = y.to(device)
                out = model(x)
                #Calculating loss
                loss = PerceptualLoss(out, y)
                # Get validation loss
                val_loss += loss.item()
            val_loss_log.append(val_loss)
        model.train()

        #scheduler.step(val_loss)

        # Saving checkpoints
        #save_checkpoint(epoch+1, model, optimizer, val_loss)
        if(val_loss < least_val_loss):
            save_checkpoint_best(epoch+1, model)
            least_val_loss = val_loss

        end_time = time.time()
        print('Epoch: {:.0f}/{:.0f}, Time: {:.0f}m {:.0f}s, Train_Loss: {:.4f}, Val_loss: {:.4f}'.format(
              epoch+1, EPOCHS, (end_time-beg_time)//60, (end_time-beg_time)%60, train_loss, val_loss))
        writer.add_scalar('Training_loss', train_loss, epoch+1)
        writer.add_scalar('Validation_loss', val_loss, epoch+1)
        writer.close()

In [None]:
train_model()

In [None]:
%tensorboard --logdir=runs/resunet50decoder

In [None]:
# Saving the final model
PATH = "/workspace/data/Dhruv/pytorch/SuperResolution/FinalModel/final_trained_model.pt"
print("Saving final model")
torch.save(model.state_dict(), PATH)

In [None]:
# To load the best model (fill in the best model epoch number)
loaded_best_model = ResUNet(3).to(device)
checkpoint = torch.load("/workspace/data/Dhruv/pytorch/SuperResolution/BestModel/best_model_23.pt")
loaded_best_model.load_state_dict(checkpoint)
loaded_best_model.eval()
model = loaded_best_model

In [None]:

# Loading a model with desired epoch number
loaded_model = ResUNet(3).to(device)
checkpoint = torch.load("/workspace/data/Dhruv/pytorch/SuperResolution/Models/model_8.pt")
loaded_model.load_state_dict(checkpoint['model_state_dict'])
loaded_model.eval()
model = loaded_model


## Inference

In [None]:
model.eval()
example = iter(test_loader)
example_data, example_target = example.next()

plt.figure(figsize=(15,15))
# Downsampled
plt.subplot(2,2,1)
plt.title('Downsampled')
plt.imshow(example_data[15].permute(1,2,0))
# Bi-linear upsampled
out_1 = F.interpolate(example_data, scale_factor=2, mode='bilinear', align_corners=True)
plt.subplot(2,2,2)
plt.title('Bi-linear upsampled')
plt.imshow(out_1[15].permute(1,2,0))
# Model prediction
out_2 = model(example_data.to(device))
plt.subplot(2,2,3)
plt.title('Prediction')
plt.imshow(out_2[15].cpu().detach().permute(1,2,0))
# Ground truth
plt.subplot(2,2,4)
plt.title('Ground truth')
plt.imshow(example_target[15].permute(1,2,0))

## Visualizing activations

In [None]:
# Visualize feature maps
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

In [None]:
data, _ = test_dataset[16]
plt.imshow(data.permute(1,2,0))

In [None]:
# Find activations for all the layers of the model
for name, layer in model._modules.items():
  layer.register_forward_hook(get_activation(name))
data, _ = test_dataset[16]
data.unsqueeze_(0)
output = model(data.to(device))

In [None]:
act = activation['up5'].squeeze().cpu()
for idx in range(act.size(0)):
    plt.imshow(act[idx])
    plt.show()