In [1]:
import torch as t
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset

In [None]:
class DownSample(nn.Module):
    def __init__(self, input_channels, output_channels) -> None:
        super(DownSample,self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(in_channels=input_channels, out_channels=output_channels, kernel_size=5, stride=2), nn.ReflectionPad2d(2))
        self.upscale = nn.PixelShuffle(2)
        self.nonlinearity = nn.GLU(dim=1)

    def forward(self, x):
        normFunc = nn.InstanceNorm2d(x.shape[1])
        conv = self.conv(x)
        normalize = normFunc(conv)
        upscale = self.upscale(normalize)
        finalOutput = self.nonlinearity(upscale)
        return finalOutput

In [None]:
class UpSample(nn.Module):
    def __init__(self, input_channels, output_channels) -> None:
        super(UpSample,self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(in_channels=input_channels, out_channels=output_channels, kernel_size=5, stride=1), nn.ReflectionPad2d(2))
        self.nonlinearity = nn.GLU(dim=1)

    def forward(self, x):
        normFunc = nn.InstanceNorm2d(x.shape[1])
        conv = self.conv(x)
        normalize = normFunc(conv)
        finalOutput = self.nonlinearity(normalize)
        return finalOutput

In [None]:
class ResBlock(nn.Module):
    def __init__(self) -> None: #input image should be a 64x96 I THINK IM NOT SURE
        super(ResBlock, self).__init__()
        self.nonlinearity = nn.GLU(dim=1)
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(1,3)), nn.ZeroPad2d((0,0,1,1)))
        self.conv2 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(1,3)), nn.ZeroPad2d((0,0,1,1)))

    
    def forward(self, x):
        normFunc = nn.InstanceNorm2d(x.shape[1])
        firstConv = self.conv1(x)
        firstNorm = self.norm(firstConv)
        firstNonLinearity = self.nonlinearity(firstNorm)

        secondConv = self.conv2(firstNonLinearity)
        finalOutput = self.norm(secondConv)
        return finalOutput + x 

         

In [None]:
class Discriminator(nn.Module):
    def __init__(self, width, height) -> None:
        super().__init__(Discriminator, self)
        self.nonlinearity = nn.GLU(dim=1)
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=128, kernel_size=3), nn.ReflectionPad2d(1))
        self.downsample1 = DownSample(input_channels=64, output_channels=256)
        self.downsample2 = DownSample(input_channels=128, output_channels=512)
        self.downsample3 = DownSample(input_channels=256, output_channels=1024)
        self.conv2 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=(1,5)), nn.ReflectionPad2d((0,0,2,2)))
        self.conv3 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=(1,3)), nn.ReflectionPad2d((0,0,1,1)))
        self.fc = nn.Linear(((width+7)/8).item() * ((height+7)/8).item(), 1)

    def forward(self, x):
        normFunc = nn.InstanceNorm2d(x.shape[1])
        firstConv = self.conv1(x)
        firstNonlinearity = self.nonlinearity(firstConv)
        downs1 = self.downsample1(firstNonlinearity)
        downs2 = self.downsample2(downs1)
        downs3 = self.downsample3(downs2)

        secondConv = self.conv2(downs3)
        firstNorm = normFunc(secondConv)
        secondNonlinearity = self.nonlinearity(firstNorm)

        thirdConv = self.conv3(secondNonlinearity)
        thirdConv = thirdConv.view(x.shape[0], -1)
        finalOutput = self.fc(thirdConv)
        return nn.Sigmoid()(finalOutput)
