<a href="https://colab.research.google.com/github/abhinav9629/JPEGUP/blob/main/Jpeg2SRJpeg.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [48]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [49]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [50]:
class C_Block(nn.Module):
  def __init__(self, in_ch, out_ch, kernel_size, stride, padding, disc = False, use_activation = True, use_batchnorm = True ):
     super().__init__()
     self.convnet = nn.Conv2d(in_ch, out_ch, kernel_size = kernel_size, stride = stride, padding = padding, bias = not use_batchnorm)
     self.bn = nn.BatchNorm2d(out_ch) if use_batchnorm else nn.Identity()
     self.act = nn.LeakyReLU(0.2, inplace= True) if disc else nn.PReLU(num_parameters = out_ch)
     self.use_activation = use_activation

  def forward(self, input):
    x = self.convnet(input)
    x = self.bn(x)
    return self.act(x) if self.use_activation else x

In [51]:
class US_Block(nn.Module):
  def __init__(self, in_ch, scale):
    super().__init__()
    self.convnet = nn.Conv2d(in_ch, in_ch*scale**2, kernel_size = 3, stride = 1, padding=1)
    self.ps = nn.PixelShuffle(scale)
    self.act = nn.PReLU(num_parameters = in_ch)

  def forward(self, input):
    x = self.convnet(input)
    x = self.ps(x)
    return self.act(x)

In [52]:
class ResidualNet_Block(nn.Module):
  def __init__(self, in_ch):
    super().__init__()
    self.block1 = C_Block(in_ch, in_ch, kernel_size = 3, stride = 1, padding = 1 )
    self.block2 = C_Block(in_ch, in_ch, kernel_size = 3, stride = 1, padding = 1, use_activation = False)

  def forward(self, input):
    x = self.block1(input)
    out = self.block2(x)
    return out + input

In [53]:
class Generator(nn.Module):
  def __init__(self, in_ch = 3, num_ch = 64, num_blocks = 16):
    super().__init__()
    self.init_convnet = C_Block(in_ch, num_ch, kernel_size = 9, stride = 1, padding = 4, use_batchnorm=False)
    self.residual = nn.Sequential(
        ResidualNet_Block(num_ch),
        ResidualNet_Block(num_ch),
        ResidualNet_Block(num_ch),
        ResidualNet_Block(num_ch),
        ResidualNet_Block(num_ch),
        ResidualNet_Block(num_ch),
        ResidualNet_Block(num_ch),
        ResidualNet_Block(num_ch),
        ResidualNet_Block(num_ch),
        ResidualNet_Block(num_ch),
        ResidualNet_Block(num_ch),
        ResidualNet_Block(num_ch),
        ResidualNet_Block(num_ch),
        ResidualNet_Block(num_ch),
        ResidualNet_Block(num_ch),
        ResidualNet_Block(num_ch),
    )
    self.convnet = C_Block(num_ch, num_ch, kernel_size = 3, stride = 1, padding = 1, use_activation = False)
    self.upsamples = nn.Sequential(
        US_Block(num_ch,scale = 2),
        US_Block(num_ch,scale = 2),
    )
    self.final_convnet = nn.Conv2d(num_ch, in_ch, kernel_size = 9, stride = 1, padding = 4)

  def forward(self, input):
    initial = self.init_convnet(input)
    x = self.residual(initial)
    x = self.convnet(x) + initial
    x = self.upsamples(x)
    out = self.final_convnet(x)
    return torch.tanh(out)


In [54]:
class Discriminator(nn.Module):
  def __init__(self, in_ch = 3, features = 64):
    super().__init__()
    self.convnet = nn.Sequential(
        C_Block(in_ch, features, kernel_size = 3, stride = 1, padding = 1, disc = True, use_activation = True, use_batchnorm = False ),
        C_Block(features, features, kernel_size = 3, stride = 2, padding = 1, disc = True, use_activation = True, use_batchnorm = True ),
        C_Block(features, features*2, kernel_size = 3, stride = 1, padding = 1, disc = True, use_activation = True, use_batchnorm = True ),
        C_Block(features*2, features*2, kernel_size = 3, stride = 2, padding = 1, disc = True, use_activation = True, use_batchnorm = True ),
        C_Block(features*2, features*4, kernel_size = 3, stride = 1, padding = 1, disc = True, use_activation = True, use_batchnorm = True ),
        C_Block(features*4, features*4, kernel_size = 3, stride = 2, padding = 1, disc = True, use_activation = True, use_batchnorm = True ),
        C_Block(features*4, features*8, kernel_size = 3, stride = 1, padding = 1, disc = True, use_activation = True, use_batchnorm = True ),
        C_Block(features*8, features*8, kernel_size = 3, stride = 2, padding = 1, disc = True, use_activation = True, use_batchnorm = True ),
    )
    self.dense = nn.Sequential(
        nn.AdaptiveAvgPool2d((6,6)),
        nn.Flatten(),
        nn.Linear(512*6*6, 1024),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Linear(1024,1),
    )

  def forward(self, input):
    x = self.convnet(input)
    out = self.dense(x)
    return out

In [55]:
def test():
  low_res = 24
  with torch.cuda.amp.autocast():
    x = torch.randn((5, 3, low_res,low_res))
    gen = Generator()
    disc = Discriminator()
    gen_out = gen(x)
    disc_out = disc(gen_out)
    print(gen_out.shape)
    print(disc_out.shape)
test()

torch.Size([5, 3, 96, 96])
torch.Size([5, 1])
