In [1]:
import torch as t
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset

In [None]:
class DownSample(nn.Module):
    def __init__(self, input_channels, output_channels, input_dims) -> None:
        super(DownSample,self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(in_channels=input_channels, out_channels=output_channels, kernel_size=5, stride=2), nn.ZeroPad2d(2))
        self.upscale = nn.PixelShuffle(2)
        self.nonlinearity = nn.GLU(dim=1)
        self.norm = nn.InstanceNorm2d(input_dims)

    def forward(self, x):
        conv = self.conv(x)
        normalize = self.norm(conv)
        upscale = self.upscale(normalize)
        finalOutput = self.nonlinearity(upscale)
        return finalOutput

In [None]:
class UpSample(nn.Module):
    def __init__(self, input_channels, output_channels, input_dims) -> None:
        super(UpSample,self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(in_channels=input_channels, out_channels=output_channels, kernel_size=5, stride=1), nn.ZeroPad2d(2))
        self.nonlinearity = nn.GLU(dim=1)
        self.norm = nn.InstanceNorm2d(input_dims)

    def forward(self, x):
        conv = self.conv(x)
        normalize = self.norm(conv)
        finalOutput = self.nonlinearity(normalize)
        return finalOutput

In [None]:
class ResBlock(nn.Module):
    def __init__(self, input_dims) -> None: #input image should be a 64x64 I THINK IM NOT SURE
        super(ResBlock, self).__init__()
        self.nonlinearity = nn.GLU(dim=1)
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(1,3)), nn.ZeroPad2d((0,0,1,1)))
        self.conv2 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(1,3)), nn.ZeroPad2d((0,0,1,1)))
        self.norm = nn.InstanceNorm2d(input_dims)

    
    def forward(self, x):
        firstConv = self.conv1(x)
        firstNorm = self.norm(firstConv)
        firstNonLinearity = self.nonlinearity(firstNorm)

        secondConv = self.conv2(firstNonLinearity)
        finalOutput = self.norm(secondConv)
        return finalOutput + x 

         