<a href="https://colab.research.google.com/github/Costub/SRGAN/blob/master/SRGANs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Imports
import torch
import torch.nn as nn
from torchvision.models import vgg16

In [None]:
# discriminator.py
class Discriminator(nn.Module):
    def __init__(self, im_shape=256):
        super(Discriminator, self).__init__()

        x = int(im_shape * im_shape * 2)

        self.disc = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Flatten(),

            nn.Linear(x, 1024),
            nn.LeakyReLU(0.2),

            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.disc(x)


In [None]:
# generator.py
class ResBlock(nn.Module):
  def __init__(self):
    super(ResBlock, self).__init__()
    self.block = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm2d(64),
            nn.PReLU(),

            nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm2d(64),
            nn.PReLU()
        )
  def forward(self,x):
    return self.block(x) + x


class Generator(nn.Module):
  def __init__(self,resblocks):
    super(Generator,self).__init__()
    self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3,out_channels=64,kernel_size=9,padding=4,stride=1),
                               nn.PReLU())
    self.resblocks = nn.Sequential(*[ResBlock() for _ in range(resblocks)])
    self.conv2 = nn.Sequential(nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,padding=1,stride=1),
                               nn.BatchNorm2d(64))
    self.conv3 = nn.Sequential(nn.Conv2d(in_channels=64,out_channels=256,kernel_size=3,padding=1,stride=1),
                               nn.PixelShuffle(2),
                               nn.PReLU(),
                               nn.Conv2d(in_channels=64,out_channels=256,kernel_size=3,padding=1,stride=1),
                               nn.PixelShuffle(2),
                               nn.PReLU())
    self.conv4 = nn.Conv2d(in_channels=64,out_channels=3,kernel_size=9,padding=4,stride=1)

  def forward(self,x):
    out = self.conv1(x)
    out2 = self.resblocks(out)
    out3 = self.conv2(out2)
    out3 = out3 + out
    out3 = self.conv3(out3)
    out3 = self.conv4(out3)
    return out3
    

In [None]:
class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss,self).__init__()
        VGG16 = vgg16(pretrained=True)
        loss_network = nn.Sequential(*list(VGG16.features)[:18]).eval()
        for param in loss_network.parameters():
            param.requires_grad = False
        self.loss_network = loss_network
        self.mse_loss = nn.MSELoss()
        self.bce_loss = nn.BCELoss()

    def forward(self,Gen,Disc,lr,hr):
        hr_out = Gen(lr)
        content_loss = self.mse_loss(self.loss_network(hr_out),self.loss_network(hr))
        adversarial_loss = self.bce_loss(Disc(hr).reshape(-1),torch.ones_like(Disc(hr)))
        return content_loss + 0.001*adversarial_loss

class DiscriminatorLoss(nn.Module):
    def __init__(self):
        super(DiscriminatorLoss,self).__init__()
        self.bce_loss = nn.BCELoss()

    def forward(self,Gen,Disc,lr,hr):
        hr_out = Gen(lr)
        D_real = Disc(hr)
        D_fake = Disc(hr_out)
        lossD_real = self.bce_loss(D_real.reshape(-1),torch.ones_like(D_real))
        lossD_fake = self.bce_loss(D_fake.reshape(-1),torch.zeros_like(D_fake))
        lossD = (lossD_real + lossD_fake)/2
        return lossD

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# Testing models and loss

hr = torch.rand(64,3,256,256).to(device)
lr = torch.rand(64,3,64,64).to(device)
G = Generator(5).to(device)
D = Discriminator().to(device)
lossG = GeneratorLoss().to(device)
lossD = DiscriminatorLoss().to(device)

# hr_out = G(lr)
# print(hr_out.shape)

# print(lossG(G,D,lr,hr))

print(lossD(G,D,lr,hr))

RuntimeError: ignored