### Libraries

In [14]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

import datetime

# custom imports
sys.path.append('../src')
from utils.dataset import load_datasets
from utils.foldergen import generate_folder
from classes.gan.gan import Generator, Discriminator
from classes.gan.ganCNN import GeneratorCNN, DiscriminatorCNN

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Hyperparameters

In [15]:
datasettype = "MNIST"

# Train a new network or continue training a previously trained network:
continueTraining = False;

# learning rate
lr = 0.0002

# number of epochs
num_epochs = 100
batch_size = 25

hidden_dim = 100

### Get Data

In [16]:
if datasettype == "MNIST":
    transform = transforms.Compose(
        [
            transforms.ToTensor(),  # create PyTorch Tensor | shape: (channels, height, width)
            transforms.Normalize(
                [0.5], [0.5]
            ),  # convert values to [-1, 1]
        ]
    )
elif datasettype == "CIFAR10":
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

# create folder structure
data_folder = "data"
load_folder = "load"
output_folder = "output"
folders = [data_folder, load_folder, output_folder]
generate_folder(folders)

dataset_train, loader_train = load_datasets(datasettype, transform, batch_size)
dataset_test, loader_test = load_datasets(datasettype, transform, batch_size, train=False, download=False)

### Network

In [17]:
z_dim = 100
# automatically calculate the dimension -> does not work with CIFAR10
# data_dim = 1
# for dimension in range(1, dataset_train.data.ndim):
#     data_dim *= dataset_train.data.size(dimension)

# set image_channels depending on datasettype
if datasettype == "MNIST":
    image_channels = 0
    data_dim = 28 * 28
elif datasettype == "CIFAR10":
    image_channels = 3
    data_dim = 32 * 32 * 3

Gen = Generator(g_input_dim= z_dim, g_hidden_dim= hidden_dim, g_output_dim= data_dim,).to(device)
Dis = Discriminator(d_input_dim= data_dim,  d_hidden_dim= hidden_dim, data_dim=data_dim).to(device)

# Gen = GeneratorCNN(latent_dim = z_dim, img_channels = image_channels, img_size = data_dim)
# Dis = DiscriminatorCNN(img_channels = image_channels, img_size = data_dim)

lossFunction = nn.BCELoss()

# Optimizers
Gen_optimizer = optim.Adam(Gen.parameters(), lr=lr)
Dis_optimizer = optim.Adam(Dis.parameters(), lr=lr)

KeyboardInterrupt: 

### Training

In [None]:
def Dis_train(x):
    Dis.zero_grad()

    # real data
    y_real = torch.ones(batch_size, 1)
    x_real, y_real = x.to(device), y_real.to(device)

    D_output = Dis(x_real)
    D_real_loss = lossFunction(D_output, y_real)

    # fake data
    z =  torch.randn(batch_size, z_dim).to(device)
    x_fake, y_fake = Gen(z), torch.zeros(batch_size, 1).to(device)

    D_output = Dis(x_fake)
    D_fake_loss = lossFunction(D_output, y_fake)

    # loss
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    Dis_optimizer.step()

    return D_loss.data.item()

In [None]:
def Gen_train(x):
    Gen.zero_grad()

    z = torch.randn(batch_size, z_dim).to(device)
    y = torch.ones(batch_size, 1).to(device)

    G_output = Gen(z)
    D_output = Dis(G_output)
    G_loss = lossFunction(D_output, y)

    # loss
    G_loss.backward()
    Gen_optimizer.step()

    return G_loss.data.item()

In [None]:
# Load model for further training

discriminator_file = "discriminator.pth"
generator_file = "generator.pth"

# Check if directory load and files for discriminator and generator exist
if os.path.exists(load_folder) and os.path.isfile(os.path.join(load_folder, discriminator_file)) and os.path.isfile(os.path.join(load_folder, generator_file)) and continueTraining:
    Gen = Generator(g_input_dim = z_dim, g_hidden_dim=hidden_dim, g_output_dim = data_dim).to(device)
    Gen.load_state_dict(torch.load(os.path.join(load_folder, generator_file), map_location=device))

    Dis = Discriminator(d_input_dim= data_dim, d_hidden_dim= hidden_dim).to(device)
    Dis.load_state_dict(torch.load(os.path.join(load_folder, discriminator_file), map_location=device))

In [None]:
# create folder structure if it does not exist
current_time = datetime.datetime.now()
formatted_time = current_time.strftime("%d%m-%H%M")

# Create subfolders if they don't exist
pictures_folder = os.path.join(output_folder, formatted_time, "pic")
model_folder = os.path.join(output_folder, formatted_time, "model")

if not os.path.exists(pictures_folder):
    os.makedirs(pictures_folder)

if not os.path.exists(model_folder):
    os.makedirs(model_folder)

# Training loop
for epoch in range(1, num_epochs+1):
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(loader_train):
        D_losses.append(Dis_train(x))
        G_losses.append(Gen_train(x))

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), num_epochs, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
    
    with torch.no_grad():
        if epoch % 10 == 0:
            test_z = Variable(torch.randn(batch_size, z_dim).to(device))
            generated = Gen(test_z)

            # format output string
            
            formatted_number = "{:0{}}".format(epoch, len(str(num_epochs)))

            save_image(generated.view(generated.size(0), 1, 28, 28), pictures_folder + '/' + formatted_number + '.png')

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [25, 784]

In [None]:
with torch.no_grad():
    test_z = torch.randn(batch_size, z_dim).to(device)
    generated = Gen(test_z)

    save_image(generated.view(generated.size(0), 1, 28, 28), pictures_folder + '/' + 'final.png')

### Save Models

In [None]:
torch.save(Gen.state_dict(), model_folder + '/' + 'generator.pth')
torch.save(Dis.state_dict(), model_folder + '/' + 'discriminator.pth')