In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# cd drive/My \Drive/.....

In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from tqdm.notebook import tqdm
from torch import autograd

x_train = np.load('./Data/train_HR.npy')
x_train = x_train.astype(np.float32)
print(x_train.shape)

class Encoder(nn.Module):

    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 16, 7, stride=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 7, stride=3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 7)
        self.flat = nn.Flatten()
        self.linear = nn.Linear(5184, 1000)

    def forward(self, x):
        
        convolution1 = F.relu(self.conv1(x))
        convolution2 = F.relu(self.conv2(convolution1))
        convolution3 = F.relu(self.conv3(convolution2))
        Flattened = self.flat(convolution3)
        z = self.linear(Flattened)

        return z
        
class Decoder(nn.Module):

    def __init__(self):
        super().__init__()

        self.linear = nn.Linear(1000, 5184)
        self.conv4 = nn.ConvTranspose2d(64, 32, 7)
        self.conv5 = nn.ConvTranspose2d(32, 16, 7, stride=3, padding=1, output_padding=2)
        self.conv6 = nn.ConvTranspose2d(16, 1, 6, stride=3, padding=1, output_padding=2)

    def forward(self, x):

        hidden = self.linear(x)
        Reshaped = hidden.reshape(-1,64,9,9)
        convolution4 = F.relu(self.conv4(Reshaped))
        convolution5 = F.relu(self.conv5(convolution4))
        predicted = torch.tanh(self.conv6(convolution5))

        return predicted

class Discriminator(nn.Module):

    def __init__(self, dim_z=1000 , dim_h=256):
        super(Discriminator,self).__init__()
        self.dim_z = dim_z
        self.dim_h = dim_h
        self.network = []
        self.network.extend([
            nn.Linear(self.dim_z, self.dim_h),
            nn.ReLU(),
            nn.Linear(self.dim_h, self.dim_h),
            nn.ReLU(),
            nn.Linear(self.dim_h,1),
            nn.Sigmoid(),
        ])
        self.network = nn.Sequential(*self.network)

    def forward(self, z):
        disc = self.network(z)
        return disc

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = Encoder().to(device)
decoder = Decoder().to(device)
Disc = Discriminator().to(device)

optim_encoder = torch.optim.Adam(encoder.parameters(), lr=0.001)
optim_decoder = torch.optim.Adam(decoder.parameters(), lr=0.001)
optim_D = torch.optim.Adam(Disc.parameters(), lr=0.001)
optim_encoder_reg = torch.optim.Adam(encoder.parameters(), lr=0.0001)

EPS = 1e-15
ae_criterion = nn.MSELoss()
n_epochs = 300
loss_array = []
pbar = tqdm(range(1, n_epochs+1))
for epoch in pbar:
    total_rec_loss = 0
    total_disc_loss = 0
    total_gen_loss = 0
    
    for i in range(x_train.shape[0]):

        data = torch.from_numpy(x_train[i])
        if torch.cuda.is_available():
          data = data.cuda()

        ### Encoder/Decoder
        encoding = encoder(data)
        fake = decoder(encoding)
        ae_loss = ae_criterion(fake, data)
        total_rec_loss += ae_loss.item()*data.size(0)
        
        optim_encoder.zero_grad()
        optim_decoder.zero_grad()
        ae_loss.backward()
        optim_encoder.step()
        optim_decoder.step()

        ### Discriminator
        z_real_gauss = autograd.Variable(torch.randn(100, 1000) * 5.).to(device)
        D_real_gauss = Disc(z_real_gauss)

        z_fake_gauss = encoder(data)
        D_fake_gauss = Disc(z_fake_gauss)

        D_loss = -torch.mean(torch.log(D_real_gauss + EPS) + torch.log(1 - D_fake_gauss + EPS))
        total_disc_loss += D_loss.item()*data.size(0)

        optim_D.zero_grad()
        D_loss.backward()
        optim_D.step()

        ### Generator
        z_fake_gauss = encoder(data)
        D_fake_gauss = Disc(z_fake_gauss)

        G_loss = -torch.mean(torch.log(D_fake_gauss + EPS))
        total_gen_loss += G_loss.item()*data.size(0)

        optim_encoder_reg.zero_grad()
        G_loss.backward()
        optim_encoder_reg.step()

    train_loss = total_rec_loss/x_train.shape[0]
    loss_array.append(train_loss)

    torch.save(encoder, './Weights/AAE/AAE_Enc.pth')
    torch.save(decoder, './Weights/AAE/AAE_Dec.pth')

    pbar.set_postfix({ 'Recon Loss': train_loss })
    np.save('Results/AAE_loss.npy', loss_array)