In [1]:
import torch as t
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset

In [2]:
class UpSample(nn.Module):
    def __init__(self, input_channels, output_channels) -> None:
        super(UpSample,self).__init__()
        self.conv = nn.Sequential(nn.ReflectionPad2d(2), nn.Conv2d(in_channels=input_channels, out_channels=output_channels, kernel_size=5, stride=1))
        self.upscale = nn.PixelShuffle(2)
        self.nonlinearity = nn.GLU(dim=1)
        self.normFunc = nn.InstanceNorm2d(num_features=output_channels//4, affine=True)

    def forward(self, x):
        conv = self.conv(x)
        upscale = self.upscale(conv)
        normalize = self.normFunc(upscale)
        finalOutput = self.nonlinearity(normalize)
        return finalOutput

In [3]:
class DownSample(nn.Module):
    def __init__(self, input_channels, output_channels) -> None:
        super(DownSample,self).__init__()
        self.conv = nn.Conv1d(in_channels=input_channels, out_channels=output_channels, kernel_size=5, stride=2, padding=2, padding_mode='reflect')
        self.nonlinearity = nn.GLU(dim=1)
        self.normFunc = nn.InstanceNorm2d(num_features=output_channels, affine=True)
    def forward(self, x):
        conv = self.conv(x)
        normalize = self.normFunc(conv)
        finalOutput = self.nonlinearity(normalize)
        return finalOutput

In [4]:
class ResBlock(nn.Module):
    def __init__(self) -> None: #input image should be a 64x94 I THINK IM NOT SURE
        super(ResBlock, self).__init__()
        self.nonlinearity = nn.GLU(dim=1)
        self.conv1 = nn.Sequential(nn.ZeroPad2d((0,0,1,1)), nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(1,3)))
        self.conv2 = nn.Sequential(nn.ZeroPad2d((0,0,1,1)), nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(1,3)))

    
    def forward(self, x):
        firstConv = self.conv1(x)
        normFunc = nn.InstanceNorm2d(firstConv.shape[1])
        firstNorm = normFunc(firstConv)
        firstNonLinearity = self.nonlinearity(firstNorm)

        secondConv = self.conv2(firstNonLinearity)
        normFunc = nn.InstanceNorm2d(secondConv.shape[1])
        finalOutput = normFunc(secondConv)
        return finalOutput + x 

         

In [None]:
class Generator(nn.Module):
    def __init__(self, features) -> None:
        super(Generator, self).__init__()
        self.nonlinearity = nn.GLU(dim=1)
        self.conv1 = nn.Sequential(nn.ReflectionPad2d(2,2,7,7), nn.Conv2d(in_channels=1, out_channels=128, kernel_size=(5, 15)))
        self.downsample1 = DownSample(input_channels=64, output_channels=256)
        self.downsample2 = DownSample(input_channels=128, output_channels=512)
        self.conv2 = nn.Conv2d(in_channels=( ((features+1)//2 + 1)//2*256), out_channels=256)
        for i in range(6):
            self.add_module(f"resblock{i+1}", ResBlock())
        self.conv3 = nn.Conv2d(in_channels=( ((features+1)//2 + 1)//2*256), out_channels=256)
        self.upsample1 = UpSample(input_channels=256, output_channels=1024)
        self.upsample2 = UpSample(input_channels=128, output_channels=512)
        self.conv4 = nn.Sequential(nn.ReflectionPad2d(2,2,7,7), nn.Conv2d(in_channels=64, out_channels=features, kernel_size=(5, 15)))
        self.conv5 = nn.Sequential(nn.ReflectionPad2d(2,2,7,7), nn.Conv2d(in_channels= ((features+1)//2 + 1)//2*4, kernel_size=(5,15), out_channels=1))

    def forward(self, x):
        firstConv = self.conv1(x)
        firstNonLinearity = self.nonlinearity(firstConv)
        firstDown = self.downsample1(firstNonLinearity)
        secondDown = self.downsample2(firstDown)
        secondDown = secondDown.reshape((secondDown.shape[0], -1,((secondDown.shape[2]+1)//2 + 1)//2, 1))

        secondConv = self.conv2(secondDown)
        normFunc = nn.InstanceNorm2d(secondConv.shape[1])
        secondNorm = normFunc(secondConv)
        for i in range(6):
            secondNorm = self.__getattr__(f"resblock{i+1}")(secondNorm)
        
        thirdConv = self.conv3(secondNorm)
        normFunc = nn.InstanceNorm2d(thirdConv.shape[1])
        thirdNorm = normFunc(thirdConv)
        thirdNorm = thirdNorm.reshape(thirdNorm.shape[0], -1, ((thirdNorm.shape[2]+1)//2 + 1)//2, ((thirdNorm.shape[3]+1)//2 + 1)//2)
        
        ups1 = self.upsample1(thirdNorm)
        ups2 = self.upsample2(ups1)

        fourthConv = self.conv4(ups2)
        fourthConv = fourthConv.reshape(fourthConv.shape[0],fourthConv.shape[3],fourthConv.shape[2],fourthConv.shape[1])

        return self.conv5(fourthConv)



In [None]:
class Discriminator(nn.Module):
    def __init__(self, width, height) -> None:
        super(Discriminator, self).__init__()
        self.nonlinearity = nn.GLU(dim=1)
        self.conv1 = nn.Sequential(nn.ReflectionPad2d(1), nn.Conv2d(in_channels=1, out_channels=128, kernel_size=3))
        self.downsample1 = DownSample(input_channels=64, output_channels=256)
        self.downsample2 = DownSample(input_channels=128, output_channels=512)
        self.downsample3 = DownSample(input_channels=256, output_channels=1024)
        self.conv2 = nn.Sequential(nn.ReflectionPad2d((0,0,2,2)), nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=(1,5)))
        self.conv3 = nn.Sequential(nn.ReflectionPad2d((0,0,1,1)), nn.Conv2d(in_channels=512, out_channels=1, kernel_size=(1,3)))
        self.fc = nn.Linear(((width+7)//8) * ((height+7)//8), 1)

    def forward(self, x):
        firstConv = self.conv1(x)
        firstNonlinearity = self.nonlinearity(firstConv)
        downs1 = self.downsample1(firstNonlinearity)
        downs2 = self.downsample2(downs1)
        downs3 = self.downsample3(downs2)

        secondConv = self.conv2(downs3)
        normFunc = nn.InstanceNorm2d(secondConv.shape[1])
        firstNorm = normFunc(secondConv)
        secondNonlinearity = self.nonlinearity(firstNorm)

        thirdConv = self.conv3(secondNonlinearity)
        thirdConv = thirdConv.view(thirdConv.shape[0], -1)
        finalOutput = self.fc(thirdConv)
        return nn.Sigmoid()(finalOutput)
