In [1]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.optim import Adam
import numpy as np
from sklearn.model_selection import train_test_split
import os
import xarray as xr
import matplotlib.pyplot as plt
import copy
import random

In [16]:
class Generator(nn.Module):
    def __init__(self, input_channels = 2, output_channels = 2): 
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 64, kernel_size = 4, stride = 2, padding = 1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size = 4, stride = 2, padding = 1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size = 4, stride = 2, padding = 1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size = 4, stride = 2, padding = 1)
        self.dropout = nn.Dropout(0.2)
        self.fc_in = nn.Linear(2304, 64)
        self.fc_out = nn.Linear(64, 2304)
        self.tconv1 = nn.ConvTranspose2d(256, 128, kernel_size = 4, stride = 2, padding = 1)
        self.tconv2 = nn.ConvTranspose2d(128, 64, kernel_size = 4, stride = 2, padding = 1)
        self.tconv3 = nn.ConvTranspose2d(64, 64, kernel_size = 4, stride = 2, padding = 1)
        self.tconv4 = nn.ConvTranspose2d(64, output_channels, kernel_size = 4, stride = 2, padding = 1)
        self.relu = nn.ReLU()

    def forward(self, x): 
        batch = x.size(0)
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        x = x.view(batch,-1)
        x = self.dropout(self.relu(self.fc_in(x)))
        x = self.dropout(self.relu(self.fc_out(x)))
        x = torch.reshape(x, (batch, 256,3,3))
        x = self.relu(self.tconv1(x))
        x = self.relu(self.tconv2(x))
        x = self.relu(self.tconv3(x))
        x = self.relu(self.tconv4(x))
        return x

class Discriminator(nn.Module):
    def __init__(self, input_channels = 2): 
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 128, kernel_size = 4, stride = 2, padding = 1)
        self.conv2 = nn.Conv2d(128, 256, kernel_size = 4, stride = 2, padding = 1)
        self.conv3 = nn.Conv2d(256, 512, kernel_size = 4, stride = 1, padding = 'same')
        self.fc = nn.Linear(73728, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x): 
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x)) 
        x = x.view(x.size(0),-1)
        x = self.sigmoid(self.fc(x))
        return x

tensor([[0.4952],
        [0.4821],
        [0.4922],
        [0.4869],
        [0.4916],
        [0.4962],
        [0.4887],
        [0.4958]], grad_fn=<SigmoidBackward0>)