In [2]:
import os
import time
import random
from dataclasses import dataclass

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

import torchvision
import torchvision.transforms as T
from torchvision.models import vgg19

### Configurtion

In [9]:
@dataclass
class Config:
    data_path: str = ""
    batch_size: int = 32
    pin_memory: bool = True
    num_workers: int = 2
    lr: float = 0.0005
    momentum: float = 0.9
    betas: tuple = (.9, .999)
    seed: int = 42
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
def set_seed(seed=Config.seed):
    if Config.seed is not None:
        random.seed(seed)
        np.random.seed(seed)
        torch.random.manual_seed(seed)
        print(f"Manual Seed set to {seed}")
    else:
        print("seed was nulltype so no seed set")        

In [10]:
config = Config()
set_seed()

Manual Seed set to 42


### SRGAN Blocks

In [17]:
class ConvBlock(nn.Module):
    # conv -> BN -> pReLU | LeakyReLU
    def __init__(self, in_channels, out_channels, 
                 discriminator: bool = False, 
                 use_act=True, use_bn=True, **kwargs):
        super().__init__()
        self.cnn = nn.Conv2d(in_channels, out_channels, 
                             **kwargs, bias=(not use_bn))   
        self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
        self.act = (nn.LeakyReLU(0.2, inplace=True) 
                    if discriminator 
                    else nn.PReLU(out_channels)
                    )
        
        self.use_act = use_act
    
    def forward(self, x):
        return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))
    

class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, sf):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, in_channels*sf**2, 3, 1, 1)
        self.ps = nn.PixelShuffle(sf)  # in_ch *4, H, W -> in_ch, H*2, W*2
        self.act = nn.PReLU(in_channels)
        
    def forward(self, x):
        return self.act(self.ps(self.conv(x)))


class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.block1 = ConvBlock(in_channels, in_channels, 
                                kernel_size=3, stride=1, padding=1)
        
        self.block2 = ConvBlock(in_channels, in_channels, 
                                kernel_size=3, stride=1, padding=1, 
                                use_act=False)

    def forward(self,x):
        out = self.block1(x)
        out = self.block2(out)
        return out + x
    
class Generator(nn.Module):
    def __init__(self, in_channels=3, num_channels=64, num_blocks=16):
        super().__init__()
        self.initial = ConvBlock(in_channels, num_channels, 
                                 kernel_size=9, stride=1, 
                                 padding=4, use_bn=False)
        
        self.resnet = nn.Sequential(*[ResidualBlock(num_channels) for _ in range(num_blocks)])
        self.conv = ConvBlock(num_channels, num_channels, 
                              kernel_size=3, stride=1, 
                              padding=1, use_act=False)
        self.upsnet = nn.Sequential(UpsampleBlock(num_channels, sf=2),
                                    UpsampleBlock(num_channels, sf=2))
        self.final = nn.Conv2d(num_channels, in_channels, kernel_size=9, 
                               stride=1, padding=4)
        
    def forward(self, image):
        init = self.initial(image)
        image = self.resnet(init)
        image = self.conv(image) + init
        image = self.upsnet(image)
        
        return torch.tanh(self.final(image))


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, 
                 features=[64, 64, 128, 128, 256, 256, 512, 512]):
        super().__init__()
        
        blocks = []
        for idx, feature in enumerate(features):
            blocks.append(
                ConvBlock(in_channels, feature, discriminator=True,
                          kernel_size=3, stride=1+idx%2,
                          use_act=True,
                          use_bn=False if idx == 0 else True
                         )
            )
            in_channels=feature
            
        self.blocks = nn.Sequential(*blocks)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((6, 6)),  # ensures running for variable i/p size
            nn.Flatten(),
            nn.Linear(512*6*6, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
            # nn.Sigmoid()  ## optional
        )
        
    def forward(self, x):
        x = self.blocks(x)
        return self.classifier(x)    

In [25]:
def test():
    low_res=128
    with torch.cuda.amp.autocast():
        x = torch.randn((5, 3, low_res, low_res))
        gen = Generator()
        gen_out = gen(x)
        disc = Discriminator()
        disc_out = disc(gen_out)
        
        print(gen_out.shape)
        print(disc_out.shape)

In [26]:
start = time.perf_counter()
test()
print(f"time taken: {time.perf_counter() - start:.4f}s")

torch.Size([5, 3, 512, 512])
torch.Size([5, 1])
time taken: 5.8180s


In [None]:
# phi_5,4 5th conv layer before maxpooling but after activation
class VGGLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = vgg19(pretrained=True).features[:36].eval().to(config.device)
        self.loss = nn.MSELoss()

        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, input, target):
        vgg_input_features = self.vgg(input)
        vgg_target_features = self.vgg(target)
        return self.loss(vgg_input_features, vgg_target_features)
