In [None]:
import skimage
import sklearn.feature_extraction
import matplotlib.pyplot as plt
import numpy as np
import os
import math
import time
from skimage import io
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torchvision

from torch.optim import lr_scheduler

%load_ext tensorboard
from torch.utils.tensorboard import SummaryWriter
from fastai.layers import PixelShuffle_ICNR
from skimage.transform import rotate

In [None]:
writer = SummaryWriter('runs/x2_attention/64-128-SA-avgpool4')

In [None]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [None]:
dir = '/workspace/data/Dhruv/pytorch/SuperResolution/Data'
train_images = np.zeros((800*4, 128, 128, 3))
i = 0
for f in os.listdir(dir + '/' + 'Train'):
    if(f.endswith(".png")):
        train_images[i] = (skimage.transform.resize(
            io.imread(dir + '/Train/' + f),(128, 128), mode ='constant'))
        train_images[i+1] = rotate(train_images[i], angle=90)
        train_images[i+2] = rotate(train_images[i], angle=180)
        train_images[i+3] = rotate(train_images[i], angle=270)
        i += 4     

In [None]:
dir = '/workspace/data/Dhruv/pytorch/SuperResolution/Flickr2K'
train_images1 = np.zeros((2400, 512, 512, 3))
test_images1 = np.zeros((250, 512, 512, 3))
i = 0
for f in os.listdir(dir):
    if(f.endswith(".png")):
        if(i<2400):
            train_images1[i] = (skimage.transform.resize(
                skimage.io.imread(dir + '/' + f),(512,512), mode ='constant'))
            i += 1
        elif(i>=2400):
            test_images1[i-2400] = (skimage.transform.resize(
            skimage.io.imread(dir + '/' + f),(512,512), mode ='constant'))
            i += 1

In [None]:
# Function to extract 64x64 patches from the images. 20 patches from each image.
def patchExtract(images, patch_size=(128, 128), max_patches=2):
    pe = sklearn.feature_extraction.image.PatchExtractor(patch_size=patch_size, max_patches = max_patches)
    pe_fit = pe.fit(images)
    pe_trans = pe.transform(images)
    return pe_trans

In [None]:
dir = '/workspace/data/Dhruv/pytorch/SuperResolution/Data'
test_images = np.zeros((100*4, 128, 128, 3))
i = 0
for f in os.listdir(dir + '/' + 'Validation'):
    if(f.endswith(".png")):
        test_images[i] = (skimage.transform.resize(
            io.imread(dir + '/Validation/' + f),(128, 128), mode ='constant'))
        test_images[i+1] = rotate(test_images[i], angle=90)
        test_images[i+2] = rotate(test_images[i], angle=180)
        test_images[i+3] = rotate(test_images[i], angle=270)
        i += 4

In [None]:
#train_images = np.concatenate((train_images, train_images1), axis=0)
#test_images = np.concatenate((test_images, test_images1), axis=0)
#train_images = patchExtract(train_images)
#test_images = patchExtract(test_images)
np.random.shuffle(train_images)
np.random.shuffle(test_images)

In [None]:
print(train_images.shape)
print(test_images.shape)

In [None]:
def bicubicDownsample(images, scale_factor=0.5):
    out = torch.nn.functional.interpolate(images, scale_factor=scale_factor, mode='bicubic', align_corners=False)
    return out

In [None]:
y_tr = torch.from_numpy(train_images).permute(0,3,1,2)
y_tr = y_tr.float()
y_te = torch.from_numpy(test_images).permute(0,3,1,2)
y_te = y_te.float()

In [None]:
del train_images
del test_images

In [None]:
x_tr = bicubicDownsample(y_tr)
x_tr = x_tr.float()
x_te = bicubicDownsample(y_te)
x_te = x_te.float()

In [None]:
y_tr = y_tr.contiguous()
y_te = y_te.contiguous()
print(x_tr.is_contiguous())
print(x_te.is_contiguous())
print(y_tr.is_contiguous())
print(y_te.is_contiguous())

In [None]:
# Creating custom training dataset
class TrainDataset(Dataset):
    def __init__(self, transform=None):
        self.x = x_tr
        self.y = y_tr
        self.n_samples = self.x.shape[0]
        self.transform = transform
        
    def __getitem__(self, index):
        if self.transform:
            return self.transform(self.x[index]), self.transform(self.y[index])
        else:
            return self.x[index], self.y[index]
    
    def __len__(self):
        return self.n_samples
    
# Creating custom testing dataset
class TestDataset(Dataset):
    def __init__(self, transform=None):
        self.x = x_te
        self.y = y_te
        self.n_samples = self.x.shape[0]
        self.transform = transform
        
    def __getitem__(self, index):
        if self.transform:
            return self.transform(self.x[index]), self.transform(self.y[index])
        else:
            return self.x[index], self.y[index]
    
    def __len__(self):
        return self.n_samples

In [None]:
batch_size = 1

In [None]:
train_dataset = TrainDataset()
test_dataset = TestDataset()

# Implementing train loader to split the data into batches
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True, # data reshuffled at every epoch
                          num_workers=2) # Use several subprocesses to load the data

# Implementing train loader to split the data into batches
test_loader = DataLoader(dataset=test_dataset,
                          batch_size=batch_size,
                          shuffle=True, # data reshuffled at every epoch
                          num_workers=2) # Use several subprocesses to load the data

In [None]:
EPOCHS = 100
n_samples = len(train_dataset)
n_iterations = math.ceil(n_samples/batch_size)

## Creating Model

In [None]:
from torch.nn.utils import spectral_norm

def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
    return spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                   stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias))

class Self_Attn(nn.Module):
    """ Self attention Layer"""

    def __init__(self, in_channels):
        super(Self_Attn, self).__init__()
        self.in_channels = in_channels
        if(in_channels>64):
            self.out_dim = 16
        else:
            self.out_dim = 8
        self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=self.out_dim, kernel_size=1, stride=1, padding=0)
        self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=self.out_dim, kernel_size=1, stride=1, padding=0)
        self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=self.out_dim, kernel_size=1, stride=1, padding=0)
        self.snconv1x1_attn = snconv2d(in_channels=self.out_dim, out_channels=in_channels, kernel_size=1, stride=1, padding=0)
        self.maxpool = nn.AvgPool2d(4, stride=4, padding=0)
        self.softmax  = nn.Softmax(dim=-1)
        self.sigma = nn.Parameter(torch.zeros(1))
        

    def forward(self, x):
        """
            inputs :
                x : input feature maps(B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        bs, ch, h, w = x.size()
        # Theta path
        theta = self.snconv1x1_theta(x)
        theta = theta.view(bs, self.out_dim, h*w)
        # Phi path
        phi = self.snconv1x1_phi(x)
        phi = self.maxpool(phi)
        phi = phi.view(bs, self.out_dim, h*w//16)
        # Attn map
        attn = torch.bmm(theta.permute(0, 2, 1), phi)
        attn = self.softmax(attn)
        # g path
        g = self.snconv1x1_g(x)
        g = self.maxpool(g)
        g = g.view(bs, self.out_dim, h*w//16)
        # Attn_g
        attn_g = torch.bmm(g, attn.permute(0, 2, 1))
        attn_g = attn_g.view(bs, self.out_dim, h, w)
        attn_g = self.snconv1x1_attn(attn_g)
        # Out
        out = x + self.sigma*attn_g
        return out


In [None]:
'''
class Self_Attn(nn.Module):
    """ Self attention Layer"""
    def __init__(self,in_dim):
        super(Self_Attn,self).__init__()
        self.chanel_in = in_dim
        #self.activation = activation
        
        if(in_dim>64):
            out_dim = 16
        else:
            out_dim = 8
        
        self.query_conv = snconv2d(in_channels = in_dim , out_channels = out_dim , kernel_size= 1)
        self.key_conv = snconv2d(in_channels = in_dim , out_channels = out_dim , kernel_size= 1)
        self.value_conv = snconv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax  = nn.Softmax(dim=-1)
        
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N) 
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N

        out = torch.bmm(proj_value,attention.permute(0,2,1) )
        out = out.view(m_batchsize,C,width,height)
        
        out = self.gamma*out + x
        return out
'''

In [None]:
## Channel Attention (CA) Layer
class CALayer(nn.Module):
    def __init__(self, channel):
        super(CALayer, self).__init__()
        # global average pooling: feature --> point
        if(channel>64):
            reduction=16
        else:
            reduction=8
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # feature channel downscale and upscale --> channel weight
        self.conv_du = nn.Sequential(
                snconv2d(channel, channel // reduction, 1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                snconv2d(channel // reduction, channel, 1, padding=0, bias=True),
                nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv_du(y)
        return x * y


class DecoderBlockAttn(nn.Module):
    def __init__(self, nf, bias=True):
        super().__init__()
        # gc: growth channel, i.e. intermediate channels
        self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=bias)
        self.bn1 = nn.InstanceNorm2d(nf)
        self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=bias)
        self.bn2 = nn.InstanceNorm2d(nf)
        self.conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=bias)
        self.bn3 = nn.InstanceNorm2d(nf)
        self.lrelu = nn.ReLU(inplace=True)
        self.att_sa = Self_Attn(nf)
        self.att_ca = CALayer(nf)

        torch.nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu')
        torch.nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu')
        torch.nn.init.kaiming_uniform_(self.conv3.weight, nonlinearity='relu')

    def forward(self, x):
        x1 = self.bn1(self.lrelu(self.conv1(x)))
        x2 = self.bn2(self.lrelu(self.conv2(x+x1)))
        x3 = self.bn3(self.lrelu(self.conv3(x+x1+x2)))
        x_att1 = self.att_sa(x+x3)
        x_att2 = self.att_ca(x+x3+x_att1)
        return x_att2+x

class DecoderBlock(nn.Module):
    def __init__(self, nf, bias=True):
        super().__init__()
        # gc: growth channel, i.e. intermediate channels
        self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=bias)
        self.bn1 = nn.InstanceNorm2d(nf)
        self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=bias)
        self.bn2 = nn.InstanceNorm2d(nf)
        self.conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=bias)
        self.bn3 = nn.InstanceNorm2d(nf)
        self.lrelu = nn.ReLU(inplace=True)
        
        torch.nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu')
        torch.nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu')
        torch.nn.init.kaiming_uniform_(self.conv3.weight, nonlinearity='relu')

    def forward(self, x):
        x1 = self.bn1(self.lrelu(self.conv1(x)))
        x2 = self.bn2(self.lrelu(self.conv2(x+x1)))
        x3 = self.bn3(self.lrelu(self.conv3(x+x1+x2)))
        return x3 + x

    
class DecoderLayer(nn.Module):
    def __init__(self, in_channels, out_channels, n_blocks):
        super().__init__()
        self.upsample = nn.Sequential(
            nn.Conv2d(in_channels, in_channels*2, kernel_size=3, stride=1, padding=1, bias=False),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, inplace=True))
        self.reduce_channels = nn.Conv2d(in_channels, out_channels, kernel_size = 1)
        self.blocks = self.make_layer(in_channels, out_channels, n_blocks)
    
    def make_layer(self, in_channels, out_channels, n_blocks):
        Blocks = []
        Blocks.append(DecoderBlock(nf=in_channels))
        Blocks.append(self.reduce_channels)
        Blocks.append(DecoderBlock(nf=out_channels))
        Blocks.append(DecoderBlockAttn(nf=out_channels))
        return nn.Sequential(*Blocks)
        
            
    def forward(self, x1, x2):
        x1 = self.upsample(x1)
        x = torch.cat((x2,x1), dim=1)
        x = self.blocks(x)
        return x

class DecoderLayerNoReduce(nn.Module):
    def __init__(self, in_channels, n_blocks):
        super().__init__()
        self.upsample = nn.Sequential(
            nn.Conv2d(in_channels, in_channels*2, kernel_size=3, stride=1, padding=1, bias=False),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, inplace=True))
        self.blocks = self.make_layer(in_channels, n_blocks)
    
    def make_layer(self, in_channels, n_blocks):
        Blocks = []
        Blocks.append(DecoderBlock(nf=in_channels))
        Blocks.append(DecoderBlock(nf=in_channels))
        Blocks.append(DecoderBlockAttn(nf=in_channels))
        return nn.Sequential(*Blocks)
        
            
    def forward(self, x1, x2):
        x1 = self.upsample(x1)
        x = torch.cat((x2,x1), dim=1)
        x = self.blocks(x)
        return x

'''
class FinalDecoderLayer(nn.Module):
    def __init__(self, in_channels, n_blocks):
        super().__init__()
        self.upsample = PixelShuffle_ICNR(in_channels, in_channels, scale=2)
        self.blocks = self.make_layer(in_channels)
    
    def make_layer(self, in_channels):
        Blocks = []
        Blocks.append(DecoderBlock(nf=in_channels))
        Blocks.append(DecoderBlock(nf=in_channels))
        Blocks.append(DecoderBlockAttn(nf=in_channels))
        return nn.Sequential(*Blocks)
            
    def forward(self, x):
        x = self.upsample(x)
        x = self.blocks(x)
        return x
'''   
class ResUNet(nn.Module):
    def __init__(self, out_channels):
        super().__init__()

        self.base_model = torchvision.models.resnet34(pretrained=True, progress=False)
        self.base_layers = list(self.base_model.children())
        
        # Encoder path
        self.in_layer = nn.Sequential(*self.base_layers[0:3])
        self.layer1 = nn.Sequential(*self.base_layers[4])
        self.layer2 = nn.Sequential(*self.base_layers[5])
        self.layer3 = nn.Sequential(*self.base_layers[6])
        self.layer4 = nn.Sequential(*self.base_layers[7])
        
        # Cross path
        self.cross = nn.Conv2d(3, 32 ,kernel_size=1)
        self.cross_upsample = nn.Conv2d(3, 32 ,kernel_size=1)
        
        # Decoder path
        self.up1 = DecoderLayer(512, 256, 3)
        self.up2 = DecoderLayer(256, 128, 3) # should be 6
        self.up3 = DecoderLayer(128, 64, 3) # should be 4
        self.up4 = DecoderLayerNoReduce(64, 3)
        self.up5 = DecoderLayerNoReduce(64, 3)
        
        self.final_conv = nn.Conv2d(64, 3, kernel_size=9, padding=4)
      
        
    def forward(self, x):
        #Encoder path
        x_inp = x
        x_in = self.in_layer(x)
        x_l1 = self.layer1(x_in)
        x_l2 = self.layer2(x_l1)
        x_l3 = self.layer3(x_l2)
        x_l4 = self.layer4(x_l3)
        
        # Decoder path
        x = self.up1(x_l4, x_l3)
        x = self.up2(x, x_l2)
        x = self.up3(x, x_l1)
        x_upsample = F.interpolate(x_inp, scale_factor=2, mode='bicubic', align_corners=False)
        x_inp = self.cross(x_inp)
        x = self.up4(x, x_inp)
        x_upsample = self.cross_upsample(x_upsample) 
        x = self.up5(x, x_upsample)
        x = self.final_conv(x) 
        return x

## Creating Loss Function (Perceptual Loss)

In [None]:
class VGGPerceptualLoss(nn.Module):
    
    def __init__(self):
        super().__init__()
        model = torchvision.models.vgg16(pretrained=True, progress=False)
        features = model.features
        self.relu1_2 = nn.Sequential()
        self.relu2_2 = nn.Sequential()
        self.relu3_3 = nn.Sequential()
        self.relu4_3 = nn.Sequential()
        for i in range(4):
            self.relu1_2.add_module(name="relu1_2_"+str(i+1), module=features[i])
        for i in range(4, 9):
            self.relu2_2.add_module(name="relu2_2_"+str(i-3), module=features[i])
        for i in range(9, 16):
            self.relu3_3.add_module(name="relu3_3_"+str(i-8), module=features[i])
        for i in range(16, 23):
            self.relu4_3.add_module(name="relu4_3_"+str(i-15), module=features[i])      
        # Setting requires_grad=False to fix the perceptual loss model parameters 
        for param in self.parameters():
            param.requires_grad = False
            
    def forward(self, x):
        out_relu1_2 = self.relu1_2(x)
        out_relu2_2 = self.relu2_2(out_relu1_2)
        out_relu3_3 = self.relu3_3(out_relu2_2)
        out_relu4_3 = self.relu4_3(out_relu3_3)
        return out_relu1_2, out_relu2_2, out_relu3_3, out_relu4_3
    
# Function to calculate Gram matrix
def gram(x):
    (N, C, H, W) = x.shape
    psy = x.view(N, C, H*W)
    psy_T = psy.transpose(1, 2)
    G = torch.bmm(psy, psy_T) / (C*H*W)  # Should we divide by N here? Or does batch matric multiplication do that on it's own? 
    return G

def TVR(x, TV_WEIGHT=1e-6):
    diff_i = torch.sum(torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]))
    diff_j = torch.sum(torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]))
    tv_loss = TV_WEIGHT*(diff_i + diff_j)
    return tv_loss

In [None]:
VGGLoss = VGGPerceptualLoss().to(device)

In [None]:
def PerceptualLoss(x, y, STYLE_WEIGHT=1e-3, CONTENT_WEIGHT=1e-1, PIXEL_WEIGHT=1):
    
    x_features = VGGLoss(x)
    y_features = VGGLoss(y)
    
    # Calculating per-pixel loss
    C = y.shape[1]
    H = y.shape[2]
    W = y.shape[3]
    pixel_loss =  F.l1_loss(x, y, reduction='sum') / (C*H*W)
    #print("pixel loss= ",pixel_loss)
    
    # Calculating Total variation regularization value
    tvr_loss = TVR(x)
    #print("tvr loss= ", tvr_loss)
    
    # Calculating content loss
    #weights = [0.25, 0.25, 0.25]
    content_loss = 0.0
    for i in range(2,4):
        C = y_features[i].shape[1]
        H = y_features[i].shape[2]
        W = y_features[i].shape[3]
        content_loss += (F.l1_loss(y_features[i], x_features[i], reduction='sum') / (C*H*W))
        #print('c',content_loss)
    #print("content loss= ", content_loss)
    '''
    # Calculating content loss
    C = y_features[2].shape[1]
    H = y_features[2].shape[2]
    W = y_features[2].shape[3]
    content_loss = F.l1_loss(x_features[2], y_features[2], reduction='sum') / (C*H*W)
    #print(content_loss)
    '''
    # Calculating Style loss
    style_loss = 0.0
    for i in range(4):
        C = y_features[i].shape[1]
        H = y_features[i].shape[2]
        W = y_features[i].shape[3]
        style_loss += F.l1_loss(gram(x_features[i]), gram(y_features[i]), reduction='sum')
        #print('s ',style_loss)
    #print("style loss= ", style_loss)
    total_loss = STYLE_WEIGHT*style_loss + CONTENT_WEIGHT*content_loss + PIXEL_WEIGHT*pixel_loss + tvr_loss
    return total_loss

In [None]:
'''
example = iter(train_loader)
example_data, example_target = example.next()
example_data = F.interpolate(example_data, scale_factor=4, mode='bilinear', align_corners=False)
loss = PerceptualLoss(example_data.to(device), example_target.to(device))
del example
del example_data
del example_target
'''

## Training Loop

In [None]:
# Implementing checkpoints
def save_checkpoint_best(epoch, model):
    print("Saving best model")
    PATH = "/workspace/data/Dhruv/pytorch/SuperResolution/BestModel/best_model_"+str(epoch)+".pt"
    torch.save(model.state_dict(), PATH)

def save_checkpoint(epoch, model, optimizer, loss):  # Saving model in a way so we can load and start training again
    PATH = "/workspace/data/Dhruv/pytorch/SuperResolution/Models/model_"+str(epoch)+".pt"
    print("Saving model")
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': loss,
            }, PATH)

In [None]:
#del model
tr_loss_log = []
val_loss_log = []
model = ResUNet(3).to(device)

In [None]:
for name, layer in model._modules.items():
    if name in ['in_layer', 'layer1', 'layer2', 'layer3', 'layer4']:
        for param in layer.parameters():
            param.requires_grad = False
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
encoder=[]
decoder=[]
for name, layer in model._modules.items():
    if name in ['in_layer', 'layer1', 'layer2', 'layer3', 'layer4']:
        print(name)
        for param in layer.parameters():
            param.requires_grad = True
            encoder.append(param)
    elif name in ['cross', 'cross_upsample', 'up1', 'up2', 'up3', 'up4', 'up5', 'final_conv']:
        print(name)
        for param in layer.parameters():
            param.requires_grad = True
            decoder.append(param)

optimizer = torch.optim.Adam([
                {'params': decoder},
                {'params': encoder, 'lr': 0.00005}
            ], lr=0.0001)

In [None]:
for param in model.parameters():
    print(param.requires_grad)

In [None]:
example = iter(train_loader)
example_data, example_target = example.next()
writer.add_graph(model, example_data.to(device))
writer.close()
del example
del example_data
del example_target

In [None]:
# Training Loop
def train_model():
    least_val_loss = math.inf

    for epoch in range(EPOCHS): 
        
        beg_time = time.time() #To calculate time taken for each epoch

        train_loss = 0.0
        val_loss = 0.0

        for i, (x, y) in enumerate(train_loader):
            x = x.to(device)
            y = y.to(device)
            # Will run for 1000 iterations per epoch
            optimizer.zero_grad()
            # Forward pass
            out = model(x)
            #Calculating loss
            loss = PerceptualLoss(out, y)
            # Backward pass
            loss.backward()
            # Update gradients
            optimizer.step()
            # Get training loss
            train_loss += loss.item()
        tr_loss_log.append(train_loss)

        model.eval()
        with torch.no_grad():
            for i, (x, y) in enumerate(test_loader):
                x = x.to(device)
                y = y.to(device)
                out = model(x)
                #Calculating loss
                loss = PerceptualLoss(out, y)
                # Get validation loss
                val_loss += loss.item()
            val_loss_log.append(val_loss)
        model.train()


        # Saving checkpoints
        #save_checkpoint(epoch+1, model, optimizer, val_loss)
        if(val_loss < least_val_loss):
            save_checkpoint_best(epoch+1, model)
            least_val_loss = val_loss

        end_time = time.time()
        print('Epoch: {:.0f}/{:.0f}, Time: {:.0f}m {:.0f}s, Train_Loss: {:.4f}, Val_loss: {:.4f}'.format(
              epoch+1, EPOCHS, (end_time-beg_time)//60, (end_time-beg_time)%60, train_loss, val_loss))
        writer.add_scalar('Training_loss', train_loss, epoch+1)
        writer.add_scalar('Validation_loss', val_loss, epoch+1)
        writer.close()

In [None]:
train_model()

In [None]:
%tensorboard --logdir=runs/x2_attention    # To beat 267, 36

In [None]:
# Saving the final model
PATH = "/workspace/data/Dhruv/pytorch/SuperResolution/FinalModel/attn_x2/128-256-avgpool4.pt"
print("Saving final model")
torch.save(model.state_dict(), PATH)

In [None]:
# Saving just model
PATH = "/workspace/data/Dhruv/pytorch/SuperResolution/JustModels/final_trained_model_inorm128.pt"
torch.save(model, PATH)

In [None]:
# Loading model
model = ResUNet(3).to(device)
model.load_state_dict(torch.load("/workspace/data/Dhruv/pytorch/SuperResolution/FinalModel/36433/model-128-256.pt"))

In [None]:
# To load the final model (fill in the final model epoch number)
loaded_final_model = ResUNet(3).to(device)
checkpoint = torch.load("/workspace/data/Dhruv/pytorch/SuperResolution/FinalModel/modified-res-dense-inorm-big128.pt")
loaded_final_model.load_state_dict(checkpoint)
#loaded_final_model.eval()
model = loaded_final_model
del loaded_final_model

In [None]:
# To load the best model (fill in the best model epoch number)
del model
model = ResUNet(3).to(device)
model.load_state_dict(torch.load("/workspace/data/Dhruv/pytorch/SuperResolution/BestModel/best_model_74.pt")) #3

In [None]:
# Loading a model with desired epoch number
loaded_model = ResUNet(3).to(device)
checkpoint = torch.load("/workspace/data/Dhruv/pytorch/SuperResolution/Models/model_8.pt")
loaded_model.load_state_dict(checkpoint['model_state_dict'])
loaded_model.eval()
model = loaded_model

## Inference

In [None]:
std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
def unorm(data):
    img = data.clone().numpy()
    img = ((img * std + mean).transpose(1, 2, 0)*255.0).clip(0, 255).astype("uint8")
    return img

In [None]:
model.eval()
example = iter(test_loader)
example_data, example_target = example.next()

plt.figure(figsize=(10,10))
# Downsampled
plt.subplot(2,2,1)
plt.title('Downsampled')
plt.imshow(example_data[0].permute(1,2,0))
# Bi-linear upsampled
out_1 = F.interpolate(example_data, scale_factor=2, mode='bilinear', align_corners=True)
plt.subplot(2,2,2)
plt.title('Bi-linear upsampled')
plt.imshow(out_1[0].permute(1,2,0))
# Model prediction
out_2 = model(example_data.to(device))
plt.subplot(2,2,3)
plt.title('Prediction')
plt.imshow(out_2[0].cpu().detach().permute(1,2,0))
# Ground truth
plt.subplot(2,2,4)
plt.title('Ground truth')
plt.imshow(example_target[0].permute(1,2,0))

In [None]:
model.eval()
model1.eval()
#example = iter(test_loader)
#example_data, example_target = example.next()

# old is only with 2 in perceptual, new is with 2 and 3 in perceptual

plt.figure(figsize=(20,20))
plt.subplot(1,2,1)
out_2 = model(example_data.to(device))
plt.title('Prediction model')
plt.imshow(out_2[15].cpu().detach().permute(1,2,0))

plt.subplot(1,2,2)
plt.title('Prediction model1')
out_2 = model1(example_data.to(device))
plt.imshow(out_2[15].cpu().detach().permute(1,2,0))

In [None]:
# For comparision
import matplotlib
directory = 'compare/butterfly.png'
img = skimage.io.imread(directory)
img = img/255.0
#img = skimage.transform.resize(img,(64,64))
img = np.asarray(img)
img = np.expand_dims(img, 0)
img = torch.from_numpy(img).permute(0,3,1,2).float()
out = model(img.to(device))
out_img = out[0].cpu().detach().permute(1,2,0).numpy()
plt.imshow(out_img)
#matplotlib.image.imsave('compare/my_butterfly_4x.png', np.clip(out_img,0,1))

## Visualizing activations

In [None]:
# Visualize feature maps
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

In [None]:
data, _ = test_dataset[32]
plt.imshow(data.permute(1,2,0))

In [None]:
# Find activations for all the layers of the model
for name, layer in model._modules.items():
  layer.register_forward_hook(get_activation(name))
data, _ = test_dataset[32]
data.unsqueeze_(0)
output = model(data.to(device))

In [None]:
act = activation['up5'].squeeze().cpu()
for idx in range(act.size(0)):
    plt.imshow(act[idx])
    plt.show()

In [None]:
model.layer1

In [None]:
base_model = torchvision.models.resnet34(pretrained=True, progress=False)

In [None]:
base_model.layer3[0].conv1.weight.grad

In [None]:
model.up3.blocks[0].conv1.weight