### Libraries

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

import datetime

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Hyperparameters

In [None]:
# Train a new network or continue training a previously trained network:
continueTraining = True;

# learning rate
lr = 0.0002

# number of epochs
num_epochs = 100
batch_size = 25

hidden_dim = 100

### Get Data

In [None]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),  # create PyTorch Tensor | shape: (channels, height, width)
        transforms.Normalize(
            [0.5], [0.5]
        ),  # convert values to [-1, 1]
    ]
)

# create folder structure if it does not exist
folder_path = "data"
if not os.path.exists(folder_path):
    os.makedirs(folder_path)

# create datasets
dataset_train = datasets.MNIST(
    root=folder_path, train=True, transform=transform, download=True
)
dataset_test = datasets.MNIST(
    root=folder_path, train=False, transform=transform, download=False
)

# create dataloaders
loader_train = torch.utils.data.DataLoader(
    dataset=dataset_train, batch_size=batch_size, shuffle=True
)
loader_test = torch.utils.data.DataLoader(
    dataset=dataset_test, batch_size=batch_size, shuffle=False
)

### Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_hidden_dim, g_output_dim):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(g_input_dim, g_hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(g_hidden_dim, g_hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(g_hidden_dim, g_hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(g_hidden_dim, g_output_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

### Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, d_input_dim, d_hidden_dim):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(d_input_dim, d_hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(d_hidden_dim, d_hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(d_hidden_dim, d_hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(d_hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

### Network

In [None]:
z_dim = 100
data_dim = 1
# automatically calculate the dimension
for dimension in range(1, dataset_train.data.ndim):
    data_dim *= dataset_train.data.size(dimension)

Gen = Generator(g_input_dim= z_dim, g_hidden_dim= hidden_dim, g_output_dim= data_dim,).to(device)
Dis = Discriminator(d_input_dim= data_dim,  d_hidden_dim= hidden_dim).to(device)

lossFunction = nn.BCELoss()

# Optimizers

Gen_optimizer = optim.Adam(Gen.parameters(), lr=lr)
Dis_optimizer = optim.Adam(Dis.parameters(), lr=lr)

### Training

In [None]:
def Dis_train(x):
    Dis.zero_grad()

    # real data
    x_real, y_real = x.view(-1, data_dim), torch.ones(batch_size, 1)
    x_real, y_real = x_real.to(device), y_real.to(device)

    D_output = Dis(x_real)
    D_real_loss = lossFunction(D_output, y_real)

    # fake data
    z =  torch.randn(batch_size, z_dim).to(device)
    x_fake, y_fake = Gen(z), torch.zeros(batch_size, 1).to(device)

    D_output = Dis(x_fake)
    D_fake_loss = lossFunction(D_output, y_fake)

    # loss
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    Dis_optimizer.step()

    return D_loss.data.item()

In [None]:
def Gen_train(x):
    Gen.zero_grad()

    z = torch.randn(batch_size, z_dim).to(device)
    y = torch.ones(batch_size, 1).to(device)

    G_output = Gen(z)
    D_output = Dis(G_output)
    G_loss = lossFunction(D_output, y)

    # loss
    G_loss.backward()
    Gen_optimizer.step()

    return G_loss.data.item()

In [None]:
# Load model for further training

folder_path = "load"
discriminator_file = "discriminator.pth"
generator_file = "generator.pth"

# Check if directory load and files for discriminator and generator exist
if os.path.exists(folder_path) and os.path.isfile(os.path.join(folder_path, discriminator_file)) and os.path.isfile(os.path.join(folder_path, generator_file)) and continueTraining:
    Gen = Generator(g_input_dim = z_dim, g_hidden_dim=hidden_dim, g_output_dim = data_dim).to(device)
    Gen.load_state_dict(torch.load('load/generator.pth', map_location=device))

    Dis = Discriminator(d_input_dim= data_dim, d_hidden_dim= hidden_dim).to(device)
    Dis.load_state_dict(torch.load('load/discriminator.pth', map_location=device))

In [None]:
# create folder structure if it does not exist
current_time = datetime.datetime.now()
formatted_time = current_time.strftime("%d%m-%H%M")

folder_path = 'output/' + formatted_time
if not os.path.exists(folder_path):
    os.makedirs(folder_path)

# Create subfolders if they don't exist
pictures_folder = os.path.join(folder_path, "pic")
model_folder = os.path.join(folder_path, "model")

if not os.path.exists(pictures_folder):
    os.makedirs(pictures_folder)

if not os.path.exists(model_folder):
    os.makedirs(model_folder)

# Training loop
for epoch in range(1, num_epochs+1):
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(loader_train):
        D_losses.append(Dis_train(x))
        G_losses.append(Gen_train(x))

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), num_epochs, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
    
    with torch.no_grad():
        if epoch % 10 == 0:
            test_z = Variable(torch.randn(batch_size, z_dim).to(device))
            generated = Gen(test_z)

            # format output string
            
            formatted_number = "{:0{}}".format(epoch, len(str(num_epochs)))

            save_image(generated.view(generated.size(0), 1, 28, 28), pictures_folder + '/' + formatted_number + '.png')

In [None]:
with torch.no_grad():
    test_z = torch.randn(batch_size, z_dim).to(device)
    generated = Gen(test_z)

    save_image(generated.view(generated.size(0), 1, 28, 28), pictures_folder + '/' + 'final.png')

### Save Models

In [None]:
torch.save(Gen.state_dict(), model_folder + '/' + 'generator.pth')
torch.save(Dis.state_dict(), model_folder + '/' + 'discriminator.pth')