# GANgster
In this notebook we will try to generate mugshots with our generator

In [None]:
import os
import glob
import json

import torch
from torch import nn
from IPython import display

import PIL
import time

import math
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# os.system("cd /content/data && unzip /content/drive/MyDrive/data.zip")

In [None]:
image_size = 96
generator_entry = 50

device = ""
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("running on GPU")
else:
    device = torch.device("cpu")
    print("running on CPU")

transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), transforms.Resize((image_size, image_size)), transforms.Grayscale()]
)

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, images_folder, transform=None):
        self.df = []
        for index, filepath in enumerate(glob.glob(f"{images_folder}/*")):
            if index % 10 == 0:
                filename = os.path.split(filepath)[1]
                image = PIL.Image.open(f"{images_folder}/{filename}")
                if transform is not None:
                    image = transform(image)
                self.df.append((image, filename))

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        return self.df[index]


train_set = CustomDataset(
    images_folder="../data/front/front", transform=transform
)

In [None]:
batch_size = 64
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True
)

In [None]:
real_samples, real_labels = next(iter(train_loader))
for i in range(16):
    ax = plt.subplot(4, 4, i + 1)
    ax.imshow(real_samples[i].reshape(image_size, image_size), cmap="gray")
    ax.set_title(real_labels[i], fontsize=12, pad=1.0)
    ax.set_xticks([])
    ax.set_yticks([])

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(image_size ** 2, 2048),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = x.view(x.size(0), image_size ** 2)
        output = self.model(x)
        return output

In [None]:
discriminator = Discriminator().to(device=device)
discriminator.load_state_dict(torch.load("../models/2000_96_1e-05/discriminator"))

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(generator_entry, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Linear(2048, 4096),
            nn.ReLU(),
            nn.Linear(4096, image_size ** 2),
            nn.Tanh(),
        )

    def forward(self, x):
        output = self.model(x)
        output = output.view(x.size(0), 1, image_size, image_size)
        return output

In [None]:
generator = Generator().to(device=device)
generator.load_state_dict(torch.load("../models/2000_96_1e-05/generator"))

In [None]:
lr = 0.00001
num_epochs = 2000
show_every = 10
loss_function = nn.BCELoss()

print("DATASET SIZE :", len(train_set))

In [None]:
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)

In [None]:
per_x_time = []
losses = {
    "generator": [],
    "discriminator": []
}

show_test_sample = torch.randn(1, generator_entry).to(device=device)

In [None]:
evolution = []

In [None]:
epoch_start = time.time()
for epoch in range(num_epochs):
    loss_discriminator = 0
    loss_generator = 0
    for n, (real_samples, _) in enumerate(train_loader):
        # Data for training the discriminator
        
        real_samples = real_samples.to(device=device)
        real_samples_size = real_samples.shape[0]

        # Set the labels of the real data to one (Using the size of the real samples array)
        real_samples_labels = torch.ones((real_samples_size, 1)).to(device=device)

        # Set the labels of the fake data to zero
        latent_space_samples = torch.randn((real_samples_size, generator_entry)).to(device=device)
        generated_samples = generator(latent_space_samples)
        generated_samples_labels = torch.zeros((real_samples_size, 1)).to(device=device)

        # Concatenate true and fake data
        all_samples = torch.cat((real_samples, generated_samples))
        all_samples_labels = torch.cat((real_samples_labels, generated_samples_labels))
        
        # Training the discriminator
        discriminator.zero_grad()
        output_discriminator = discriminator(all_samples)

        loss_discriminator = loss_function(
            output_discriminator, all_samples_labels
        )
        loss_discriminator.backward()
        optimizer_discriminator.step()

        # Data to train the generator
        latent_space_samples = torch.randn((real_samples_size, generator_entry)).to(
            device=device
        )

        # Training the generator
        generator.zero_grad()
        generated_samples = generator(latent_space_samples)
        output_discriminator_generated = discriminator(generated_samples)
        loss_generator = loss_function(
            output_discriminator_generated, real_samples_labels
        )
        loss_generator.backward()
        optimizer_generator.step()

    # Show loss
    if epoch % show_every == 0 and epoch != 0:
        per_x_time.append(time.time() - epoch_start)
        epoch_start = time.time()
        losses["generator"].append(float(loss_generator))
        losses["discriminator"].append(float(loss_discriminator))

        ld = str(round(float(loss_discriminator), 10)).zfill(12)
        lg = str(round(float(loss_generator), 10)).zfill(12)
        it = str(round(per_x_time[-1], 4)).zfill(9)
        av = str(round(sum(per_x_time)/len(per_x_time), 2)).zfill(6)
        seconds = round(
            ((sum(per_x_time)/len(per_x_time)) * ((num_epochs - epoch)/show_every)),
            2
        )
        m, s = divmod(seconds, 60)
        h, m = divmod(m, 60)
        tl = "%d:%02d:%02d" % (h, m, s)
        print(f"[+] EPOCH : {str(epoch).zfill(7)} | LD.: {ld} | LG.: {lg} | IT : {it}s | AV : {av} | TL : {tl}""")
        
        generated_sample = generator(show_test_sample)
        generated_sample = generated_sample.cpu().detach()
        image = transforms.ToPILImage()(generated_sample[0].reshape(image_size, image_size))
        evolution.append(image)

In [None]:
path = f"/content/{num_epochs}_{image_size}_{lr}"

x = 1
while os.path.isdir(path):
    path = f"/content/{num_epochs}_{image_size}_{lr}_{x}"
    x += 1

os.mkdir(path)

torch.save(generator.state_dict(), f"{path}/generator")
torch.save(discriminator.state_dict(), f"{path}/discriminator")

In [None]:
evolution[0].save(f'{path}/progressive_text.gif', format='GIF',
               append_images=evolution[1:], save_all=True, duration=75, loop=1)

In [None]:
data = {
    "avg_time": round(sum(per_x_time)/len(per_x_time), 2),
    "total_time": sum(per_x_time),
    "losses": losses
}

with open(f"{path}/meta.json", 'w+') as f:
    json.dump(data, f, indent=4)

In [None]:
samples_amount = 150

In [None]:
samples_to_try = torch.randn(samples_amount, generator_entry).to(device=device)
generated_samples = generator(samples_to_try)
discriminated_samples = discriminator(generated_samples)

generated_samples = generated_samples.cpu().detach()
discriminated_samples = discriminated_samples.cpu().detach()

avg = sum(float(x) for x in discriminated_samples)/150
print("Average score :", round(avg * 100, 2), "%")

for i in range(samples_amount):
    image = transforms.ToPILImage()(generated_samples[i].reshape(image_size, image_size))
    display.display(image)
    print(round(float(discriminated_samples[i]) * 100, 2), "%")