# IMPORTS

In [None]:
import os
import numpy as np
import tarfile

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
from torchvision.utils import save_image
from torch.distributions.normal import Normal
from pytorch_model_summary import summary

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import matplotlib.pyplot as plt
%matplotlib inline

# GENERAL FUNCTIONS

In [None]:
# FUNCTIONS FOR VAE

def criterionVAE(x, x_hat, z, mu, log_std):
    RE = F.mse_loss(x, x_hat, reduction='sum')
    
    std = log_std.exp()
    KL = -(z.shape[0]/2) 
    #KL -= (z.shape[0]/2)*torch.log(1/1)
    KL -= (1/2)*torch.sum(torch.log(std.pow(2)))
    KL += (1/2)*torch.sum(mu.pow(2) + std.pow(2))

    return (RE + KL) / z.shape[0]


def trainVAE(x):
    model.zero_grad()
    mu, log_std = model.encoder(x)
    z = model.reparameterize(mu, log_std)
    x_hat = model.decoder(z)
    loss = criterionVAE(x, x_hat, z, mu, log_std)
    
    loss.backward()
    optimizer.step()
    
    return loss.data.item()

In [None]:
# FUNCTIONS FOR GAN

def D_train(x):
    D.zero_grad()

    x_real, y_real = x.view(-1, mnist_dim).to(device), torch.ones(x.shape[0], 1).to(device)
    D_output = D(x_real)
    D_real_loss = criterionGAN(D_output, y_real)

    z = torch.randn(bs, z_dim).to(device)
    x_fake, y_fake = G(z), torch.zeros(bs, 1).to(device)
    D_output = D(x_fake)
    D_fake_loss = criterionGAN(D_output, y_fake)

    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()
        
    return  D_loss.data.item()


def G_train(x):
    G.zero_grad()

    z = torch.randn(bs, z_dim).to(device)
    y = torch.ones(bs, 1).to(device)

    G_output = G(z)
    D_output = D(G_output)
    G_loss = criterionGAN(D_output, y)

    G_loss.backward()
    G_optimizer.step()
        
    return G_loss.data.item()


criterionGAN = nn.BCELoss() 

# EXPERIMENT MNIST

In [None]:
bs = 64

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))])

trainset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
mnist_loader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True)

mnist_dim = trainset.train_data.size(1) * trainset.train_data.size(2)

del trainset

## VAE MNIST

In [None]:
class VAE(nn.Module):
    def __init__(self, D, M):
        super(VAE, self).__init__()
        self.D = D
        self.M = M
        
        self.enc1 = nn.Linear(self.D, 1024)
        self.enc2 = nn.Linear(self.enc1.out_features, self.enc1.out_features//2)
        self.enc3 = nn.Linear(self.enc2.out_features, self.enc2.out_features//2)
        self.enc4 = nn.Linear(self.enc3.out_features, self.M*2)
        
        self.dec1 = nn.Linear(self.M, 256)
        self.dec2 = nn.Linear(self.dec1.out_features, self.dec1.out_features*2)
        self.dec3 = nn.Linear(self.dec2.out_features, self.dec2.out_features*2)
        self.dec4 = nn.Linear(self.dec3.out_features, self.D)
        
        
    def encoder(self, x):
        x = F.leaky_relu(self.enc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.enc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.enc3(x), 0.2)
        x = F.dropout(x, 0.3)
        x = self.enc4(x).view(-1, 2, self.M)
        return x[:, 0, :], x[:, 1, :]

    
    def reparameterize(self, mu, log_std):
        std = log_std.exp()
        eps = torch.randn_like(std)
        return mu + (eps*std)  # z
    
    
    def decoder(self, z):
        x = F.leaky_relu(self.dec1(z), 0.2)
        x = F.leaky_relu(self.dec2(x), 0.2)
        x = F.leaky_relu(self.dec3(x), 0.2)
        return torch.tanh(self.dec4(x))

In [None]:
model = VAE(mnist_dim, 20)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr = 0.0002)

n_weights = 0
for parameter in model.parameters():
    n_weights += torch.tensor(parameter.shape).sum()

print(f'There are {n_weights} weights in the VAE.')

In [None]:
n_epoch = 50
lossesVAE = []
for epoch in range(1, n_epoch+1):           
    epochloss = 0
    for batch_idx, (x, _) in enumerate(mnist_loader):
        x = x.view(x.shape[0], -1)
        batchloss = trainVAE(x.to(device))
        epochloss += batchloss
        lossesVAE.append(batchloss)
    
    mu, log_std = model.encoder(x[:8].to(device))
    z = model.reparameterize(mu, log_std)
    x_hat = model.decoder(z)
    
    save_image(torch.cat((x[:8].view(8, 1, 28, 28), 
                           x_hat.view(8, 1, 28, 28).cpu())), 
                         f"MNIST/VAE_outputs/output_{epoch}.png", nrow=8)

    print('[%d/%d]: loss: %.3f' % (
            (epoch), n_epoch, epochloss/len(mnist_loader)))

In [None]:
#torch.save(model, 'MNIST/VAE')

model = torch.load('MNIST/VAE')

In [None]:
with torch.no_grad():
    for j in range(10):
        z = np.random.randn(20, 2)
        steps = (z[:, 0] - z[:, 1]) / 8

        z_interpolate = np.zeros((10, 20))
        z_interpolate[ 0, :] = z[:, 0]
        z_interpolate[-1, :] = z[:, 1]

        for i in range(z_interpolate.shape[0]):
            if i != 0 or i != 9:
                z_interpolate[i, :] = z_interpolate[0, :] - steps*i

        z_interpolate = torch.tensor(z_interpolate, dtype=torch.float).to(device)

        interp_VAE = model.decoder(z_interpolate).view(10, 1, 28, 28).cpu()
        
        if "result_" in dir():
            result_ = torch.cat((result_, interp_VAE))
        else:
            result_ = interp_VAE
        
    save_image(result_, f"MNIST/VAE_outputs/interpolation.png", nrow=10)
    
with torch.no_grad():
    z_test = torch.randn(64, 20).to(device)
    test_VAE = model.decoder(z_test)
    save_image(test_VAE.view(64, 1, 28, 28).cpu(),
               f"MNIST/VAE_outputs/test.png", nrow=8)

In [None]:
plt.plot(lossesVAE)
plt.grid()
plt.xlabel('Batch')
plt.ylabel('Loss')
plt.savefig('MNIST/VAE_MNIST')

## GAN MNIST

In [None]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()       
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)
    

    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))
    
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)
    

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [None]:
z_dim = 100

G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
D = Discriminator(mnist_dim).to(device)

lr = 0.0002 
G_optimizer = optim.Adam(G.parameters(), lr = lr)
D_optimizer = optim.Adam(D.parameters(), lr = lr)

n_weights = 0
for parameter in G.parameters():
    n_weights += torch.tensor(parameter.shape).sum()

print(f'There are {n_weights} weights in the GAN Generator.')

n_weights = 0
for parameter in D.parameters():
    n_weights += torch.tensor(parameter.shape).sum()

print(f'There are {n_weights} weights in the GAN Discriminator.')

In [None]:
n_epoch = 200
lossesGAN_D = []
lossesGAN_G = []

for epoch in range(1, n_epoch+1):           
    D_epochloss, G_epochloss = 0, 0
    for batch_idx, (x, _) in enumerate(mnist_loader):
        D_loss, G_loss = D_train(x), G_train(x)
        D_epochloss += D_loss
        G_epochloss += G_loss
        lossesGAN_D.append(D_loss)
        lossesGAN_G.append(G_loss)
    
    z = torch.randn(16, z_dim).to(device)
    G_output = G(z)
    save_image(G_output.reshape(16, 1, 28, 28).cpu(), f"MNIST/GAN_outputs/output_{epoch}.png", nrow=8)

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, D_epochloss/len(mnist_loader), G_epochloss/len(mnist_loader)))

In [None]:
#torch.save(G, 'MNIST/Generator')
#torch.save(D, 'MNIST/Discriminator')

G = torch.load('MNIST/Generator')
D = torch.load('MNIST/Discriminator')

In [None]:
with torch.no_grad():
    z_test = torch.randn(64, 100).to(device)
    test_GAN = G(z_test)
    save_image(test_GAN.view(64, 1, 28, 28).cpu(), 
               f"MNIST/GAN_outputs/test.png", nrow=8)

with torch.no_grad():
    for j in range(10):
        z = np.random.randn(100, 2)
        steps = (z[:, 0] - z[:, 1]) / 8

        z_interpolate = np.zeros((10, 100))
        z_interpolate[ 0, :] = z[:, 0]
        z_interpolate[-1, :] = z[:, 1]

        for i in range(z_interpolate.shape[0]):
            if i != 0 or i != 9:
                z_interpolate[i, :] = z_interpolate[0, :] - steps*i

        z_interpolate = torch.tensor(z_interpolate, dtype=torch.float).to(device)

        interp_GAN = G(z_interpolate).view(10, 1, 28, 28).cpu()
        
        if "result__" in dir():
            result__ = torch.cat((result__, interp_GAN))
        else:
            result__ = interp_GAN
        
    save_image(result__, f"MNIST/GAN_outputs/interpolation.png", nrow=10)

In [None]:
plt.plot(lossesGAN_D, label='Discriminator')
plt.plot(lossesGAN_G, label='Generator')
plt.grid()
plt.legend()
plt.xlabel('Batch')
plt.ylabel('Loss')
plt.savefig('MNIST/GAN_MNIST')

# EXPERIMENT IMAGENETTE

In [None]:
bs = 48

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.CenterCrop(160)])#,
    #transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])

trainset = torchvision.datasets.ImageFolder(root="data/imagenette2-160/train/", 
                                                            transform=transform)

imag_loader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True)
imag_dim = torch.tensor(trainset[0][0].shape)

del trainset

## VAE IMAG

In [None]:
class VAE(nn.Module):
    def __init__(self, D, M):
        super(VAE, self).__init__()
        self.D = D
        self.M = M
        
        self.enc1 = nn.Conv2d(3, 16, kernel_size=3)
        self.enc2 = nn.Conv2d(16, 32, kernel_size=3)
        self.enc3 = nn.Conv2d(32, 64, kernel_size=3)
        self.enc4 = nn.Conv2d(64, 128, kernel_size=3)
        self.enc5 = nn.Linear(128*8*8, self.M*2)
        
        self.dec1 = nn.Linear(self.M, 256*8*8)
        self.dec2 = nn.ConvTranspose2d(256, 128, 4, stride=2)
        self.norm2 = nn.BatchNorm2d(128)
        self.dec3 = nn.ConvTranspose2d(128, 64, 4, stride=2)
        self.norm3 = nn.BatchNorm2d(64)
        self.dec4 = nn.ConvTranspose2d(64, 32, 4, stride=2)
        self.norm4 = nn.BatchNorm2d(32)
        self.dec5 = nn.ConvTranspose2d(32, 16, 4, stride=2)
        self.norm5 = nn.BatchNorm2d(16)
        self.dec6 = nn.ConvTranspose2d(16, 8, 3)
        self.norm6 = nn.BatchNorm2d(8)
        self.dec7 = nn.ConvTranspose2d(8, 3, 3)
        
        
    def encoder(self, x):
        x = F.max_pool2d(F.relu(self.enc1(x)), 2)
        x = F.max_pool2d(F.relu(self.enc2(x)), 2)
        x = F.max_pool2d(F.relu(self.enc3(x)), 2)
        x = F.max_pool2d(F.relu(self.enc4(x)), 2)
        x = torch.flatten(x, 1)
        x = self.enc5(x).view(-1, 2, self.M)
        return x[:, 0, :], x[:, 1, :]

    
    def reparameterize(self, mu, log_std):
        std = log_std.exp()
        eps = torch.randn_like(std)
        return mu + (eps*std)  # z
    
    
    def decoder(self, z):
        x = F.leaky_relu(self.dec1(z)).view(-1, 256, 8, 8)
        x = self.norm2(F.leaky_relu(self.dec2(x)))
        x = self.norm3(F.leaky_relu(self.dec3(x)))
        x = self.norm4(F.leaky_relu(self.dec4(x)))
        x = self.norm5(F.leaky_relu(self.dec5(x)))
        x = self.norm6(F.leaky_relu(self.dec6(x)))
        x = self.dec7(x)
        x = F.interpolate(x, size=(160, 160), mode='nearest')
        return x

In [None]:
model = VAE(imag_dim, 100)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr = 0.0002)

n_weights = 0
for parameter in model.parameters():
    n_weights += torch.tensor(parameter.shape).sum()

print(f'There are {n_weights} weights in the VAE.')

In [None]:
n_epoch = 100
losses = []
for epoch in range(1, n_epoch+1):           
    epochloss = 0
    for batch_idx, (x, _) in enumerate(imag_loader):
        loss = trainVAE(x.to(device))
        epochloss += loss
        losses.append(loss)
        print ("\r Batch: [{}/{}]".format(batch_idx+1, len(imag_loader)), end="")
    
    mu, logvar = model.encoder(x[:8].to(device))
    z = model.reparameterize(mu, logvar)
    x_hat = model.decoder(z)
    save_image(torch.cat((x[:8].view(8, 3, 160, 160), 
                          x_hat.view(8, 3, 160, 160).cpu())), 
               f"IMAG/VAE_outputs/output_{epoch}.png", nrow=8)

    print('[%d/%d]: loss: %.3f' % (
            (epoch), n_epoch, epochloss/len(imag_loader)))

In [None]:
torch.save(model, 'IMAG/VAE')

model = torch.load('IMAG/VAE')

In [None]:
with torch.no_grad():
    z_test = torch.randn(64, 100).to(device)
    test_VAE = model.decoder(z_test)
    save_image(test_VAE.view(64, 3, 160, 160).cpu(),
               f"IMAG/VAE_outputs/test.png", nrow=8)
    

with torch.no_grad():
    for j in range(10):
        z = np.random.randn(100, 2)
        steps = (z[:, 0] - z[:, 1]) / 8

        z_interpolate = np.zeros((10, 100))
        z_interpolate[ 0, :] = z[:, 0]
        z_interpolate[-1, :] = z[:, 1]

        for i in range(z_interpolate.shape[0]):
            if i != 0 or i != 9:
                z_interpolate[i, :] = z_interpolate[0, :] - steps*i

        z_interpolate = torch.tensor(z_interpolate, dtype=torch.float).to(device)

        interp_VAE = model.decoder(z_interpolate).view(10, 3, 160, 160).cpu()
        
        if "result_" in dir():
            result_ = torch.cat((result_, interp_VAE))
        else:
            result_ = interp_VAE
        
    save_image(result_, f"IMAG/VAE_outputs/interpolation.png", nrow=10)

In [None]:
interp_VAE.shape

In [None]:
del result_


for batch_idx, (x, _) in enumerate(imag_loader):
    break

with torch.no_grad():
    for j in range(10):
        instances = np.random.choice(x.shape[0], 2, replace=False)
        mu_start, log_std_start = model.encoder(x[instances[0], :, :, :].unsqueeze(0).to(device))
        mu_end, log_std_end = model.encoder(x[instances[1], :, :, :].unsqueeze(0).to(device))
        z_start = model.reparameterize(mu_start, log_std_start).cpu()
        z_end = model.reparameterize(mu_end, log_std_end).cpu()
        
        steps = (z_start - z_end) / 8

        z_interpolate = np.zeros((10, 100))
        z_interpolate[ 0, :] = z_start
        z_interpolate[-1, :] = z_end

        for i in range(z_interpolate.shape[0]):
            if i != 0 or i != 9:
                z_interpolate[i, :] = z_interpolate[0, :] - np.array(steps.squeeze()*i)

        z_interpolate = torch.tensor(z_interpolate, dtype=torch.float).to(device)

        interp_VAE = model.decoder(z_interpolate).view(10, 3, 160, 160).cpu()
        interp_VAE = torch.cat((x[instances[0], :, :, :].unsqueeze(0), interp_VAE, x[instances[1], :, :, :].unsqueeze(0)))
        
        if "result_" in dir():
            result_ = torch.cat((result_, interp_VAE))
        else:
            result_ = interp_VAE
        
    save_image(result_, f"IMAG/VAE_outputs/interpolation2.png", nrow=12)

In [None]:
losses_p1.extend(losses)

In [None]:
plt.plot(losses_p1)
plt.grid()
plt.xlabel('Batch')
plt.ylabel('Loss')
plt.savefig('IMAG/VAE_IMAG')

In [None]:
losses_p1 = losses

## GAN IMAG

In [None]:
def D_train(x):
    D.zero_grad()

    x_real, y_real = x.to(device), torch.tensor([0.95]*x.shape[0]).unsqueeze(1).to(device)  #torch.ones(x.shape[0], 1).to(device) #
    D_output = D(x_real)
    D_real_loss = criterionGAN(D_output, y_real)

    z = torch.randn(x.shape[0], z_dim).to(device)
    x_fake, y_fake = G(z), torch.zeros(x.shape[0], 1).to(device)
    D_output = D(x_fake)
    D_fake_loss = criterionGAN(D_output, y_fake)

    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()
        
    return  D_loss.data.item()

#def G_train(x, real):
#    G.zero_grad()
#
#    z = torch.randn(x.shape[0], z_dim).to(device)
#    y = torch.ones(x.shape[0], 1).to(device)
#
#    G_output = G(z)
#    D_output = D(G_output, gen=True)
#    
#    D_output, real = torch.mean(D_output, 0), torch.mean(real, 0)
#    G_loss = F.mse_loss(D_output, real)
#
#    G_loss.backward()
#    G_optimizer.step()
#        
#    return G_loss.data.item()

In [None]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()       
        #self.gen1 = nn.Linear(g_input_dim, 256*8*8)
        self.gen2 = nn.ConvTranspose2d(256, 128, 4, stride=2)
        self.norm2 = nn.BatchNorm2d(128)
        self.gen3 = nn.ConvTranspose2d(128, 64, 4, stride=2)
        self.norm3 = nn.BatchNorm2d(64)
        self.gen4 = nn.ConvTranspose2d(64, 32, 4, stride=2)
        self.norm4 = nn.BatchNorm2d(32)
        self.gen5 = nn.ConvTranspose2d(32, 16, 4, stride=2)
        self.norm5 = nn.BatchNorm2d(16)
        self.gen6 = nn.ConvTranspose2d(16, 8, 3)
        self.norm6 = nn.BatchNorm2d(8)
        self.gen7 = nn.ConvTranspose2d(8, 3, 3)
    

    def forward(self, z): 
        #x = F.leaky_relu(self.gen1(z)).view(-1, 256, 8, 8)
        x = z.view(-1, 256, 8, 8)
        x = self.norm2(F.leaky_relu(self.gen2(x)))
        x = self.norm3(F.leaky_relu(self.gen3(x)))
        x = self.norm4(F.leaky_relu(self.gen4(x)))
        x = self.norm5(F.leaky_relu(self.gen5(x)))
        x = self.norm6(F.leaky_relu(self.gen6(x)))
        x = self.gen7(x)
        x = F.interpolate(x, size=(160, 160), mode='nearest')
        return x

    
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.dis1 = nn.Conv2d(3, 16, kernel_size=3)
        self.dis2 = nn.Conv2d(16, 32, kernel_size=3)
        self.dis3 = nn.Conv2d(32, 64, kernel_size=3)
        self.dis4 = nn.Conv2d(64, 128, kernel_size=3)
        self.dis5 = nn.Linear(128*8*8, 1)
    

    def forward(self, x, gen=False):
        x = F.max_pool2d(F.leaky_relu(self.dis1(x)), 2)
        x = F.max_pool2d(F.leaky_relu(self.dis2(x)), 2)
        x = F.max_pool2d(F.leaky_relu(self.dis3(x)), 2)
        x = F.max_pool2d(F.leaky_relu(self.dis4(x)), 2)
        x = torch.flatten(x, 1)
        return torch.sigmoid(self.dis5(x))

In [None]:
z_dim = 256*8*8

G = Generator(g_input_dim = z_dim, g_output_dim = imag_dim).to(device)
D = Discriminator(imag_dim).to(device)

lr_G = 0.00001
lr_D = 0.00005
G_optimizer = optim.Adam(G.parameters(), lr = lr_G)
D_optimizer = optim.Adam(D.parameters(), lr = lr_D)

n_weights = 0
for parameter in G.parameters():
    n_weights += torch.tensor(parameter.shape).sum()

print(f'There are {n_weights} weights in the GAN Generator.')

n_weights = 0
for parameter in D.parameters():
    n_weights += torch.tensor(parameter.shape).sum()

print(f'There are {n_weights} weights in the GAN Discriminator.')

In [None]:
n_epoch = 200
lossesGAN_D = []
lossesGAN_G = []

for epoch in range(1, n_epoch+1):           
    D_epochloss, G_epochloss = 0, 0
    for batch_idx, (x, _) in enumerate(imag_loader):
        D_loss = D_train(x)
        G_loss = G_train(x)
        D_epochloss += D_loss
        G_epochloss += G_loss
        lossesGAN_D.append(D_loss)
        lossesGAN_G.append(G_loss)
        print ("\r Batch: [{}/{}]".format(batch_idx+1, len(imag_loader)), end="")
    
    z = torch.randn(16, z_dim).to(device)
    G_output = G(z)
    save_image(G_output.reshape(16, 3, 160, 160).cpu(), f"IMAG/GAN_outputs/output_{epoch}.png", nrow=8)

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, D_epochloss/len(imag_loader), G_epochloss/len(imag_loader)))

In [None]:
torch.save(G, 'IMAG/Generator')
torch.save(D, 'IMAG/Discriminator')

#G = torch.load('IMAG/Generator')
#D = torch.load('IMAG/Discriminator')

In [None]:
with torch.no_grad():
    z_test = torch.randn(64, z_dim).to(device)
    test_GAN = G(z_test)
    save_image(test_GAN.view(64, 3, 160, 160).cpu(), 
               f"IMAG/GAN_outputs/test.png", nrow=8)

with torch.no_grad():
    for j in range(10):
        z = np.random.randn(z_dim, 2)
        steps = (z[:, 0] - z[:, 1]) / 8

        z_interpolate = np.zeros((10, z_dim))
        z_interpolate[ 0, :] = z[:, 0]
        z_interpolate[-1, :] = z[:, 1]

        for i in range(z_interpolate.shape[0]):
            if i != 0 or i != 9:
                z_interpolate[i, :] = z_interpolate[0, :] - steps*i

        z_interpolate = torch.tensor(z_interpolate, dtype=torch.float).to(device)

        interp_GAN = G(z_interpolate).view(10, 3, 160, 160)[:, :, :150, :150].cpu()
        
        if "result__" in dir():
            result__ = torch.cat((result__, interp_GAN))
        else:
            result__ = interp_GAN
        
    save_image(result__, f"IMAG/GAN_outputs/interpolation.png", nrow=10)

In [None]:
plt.plot(lossesGAN_D, label='Discriminator')
plt.plot(lossesGAN_G, label='Generator')
plt.grid()
plt.legend()
plt.xlabel('Batch')
plt.ylabel('Loss')
plt.savefig('IMAG/GAN_IMAG')