In [65]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard

import pandas as pd

# Data pre-process

## Read cleaned data

In [66]:
df = pd.read_csv('../OlympicHistory/CleanedData.csv')

## Find continuous data

In [98]:
df_conti = df[['ID', 'Height', 'Weight', 'Year', 'AmountOfSport', 'AmountOfEvent', 'YearOfBirth']]

In [99]:
df_conti

Unnamed: 0,ID,Height,Weight,Year,AmountOfSport,AmountOfEvent,YearOfBirth
0,1,180.0,80.0,1992,1,1,1968
1,2,170.0,60.0,2012,1,1,1989
2,3,175.0,71.0,1920,1,1,1896
3,4,182.0,95.0,1900,1,1,1866
4,5,185.0,82.0,1988,1,2,1967
...,...,...,...,...,...,...,...
188164,135568,171.0,69.0,2016,1,1,1983
188165,135569,179.0,89.0,1976,1,1,1947
188166,135570,176.0,59.0,2014,1,2,1987
188167,135571,185.0,96.0,1998,1,1,1968


## Reshape data to 4d array

In [100]:
input_data = df_conti.to_numpy().flatten()[0:1317120].reshape(-1, 1, 28, 28)
input_data.shape

(1680, 1, 28, 28)

# Build PyTorch Dataset

## Dataset class

In [70]:
class OlympicDataset(Dataset):
    
    def __init__(self, data: pd.DataFrame, transform=None):
        self.data = torch.from_numpy(data).float()
#         self.data = data
#         self.transform = transforms.Compose([transforms.ToTensor()]) 
        
    
    def __len__(self):
        return len(self.data)
    
    
    def __getitem__(self, idx):
        data_content = self.data[idx]
#         data_label = int(1)
#         return (data_content, data_label)
#         return self.transform((data_content, data_label))
#         return self.transform(data_content)
#         return (data_content, data_label)[0]
        return data_content

# Build GAN

## Hyperparameters

In [71]:
# Hyperparameters etc.
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1  # 784
batch_size = 32
num_epochs = 5

In [72]:
fixed_noise = torch.randn((batch_size, z_dim)).to(device)

# transforms = transforms.Compose(
#     [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),]
# )

# transforms = transforms.Compose(
#     [transforms.ToTensor()]
# )

## Dataset and DataLoader

In [73]:
# dataset = OlympicDataset(input_data, transform=transforms)
dataset = OlympicDataset(input_data, transform=None)

loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

## Discriminator

In [74]:
class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)

disc = Discriminator(image_dim).to(device)

## Generator

In [75]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
#             nn.Tanh(),  # normalize inputs to [-1, 1] so make outputs [-1, 1]
        )

    def forward(self, x):
        return self.gen(x)
    
gen = Generator(z_dim, image_dim).to(device)

## Optimizer

In [76]:
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()

## Tensorboard

In [77]:
writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")

# Train/Test

In [78]:
step = 0
for epoch in range(num_epochs):
#     for batch_idx, (real, _) in enumerate(loader):
    for batch_idx, real in enumerate(loader):

        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        # where the second option of maximizing doesn't suffer from
        # saturating gradients
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Fake Data", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Real Data", img_grid_real, global_step=step
                )
                step += 1

Epoch [0/5] Batch 0/53                       Loss D: 0.3637, loss G: 0.6956
Epoch [1/5] Batch 0/53                       Loss D: 0.6598, loss G: 0.3592
Epoch [2/5] Batch 0/53                       Loss D: 0.6658, loss G: 0.3949
Epoch [3/5] Batch 0/53                       Loss D: 0.6398, loss G: 0.4510
Epoch [4/5] Batch 0/53                       Loss D: 0.7173, loss G: 0.4578


In [79]:
fake.flatten().reshape(-1, 7).shape

torch.Size([1792, 7])

In [80]:
fake_df = pd.DataFrame(fake.flatten().reshape(-1, 7).detach().numpy())
fake_df

Unnamed: 0,0,1,2,3,4,5,6
0,4.069705,-1.075363,1.431894,-0.148170,1.552395,1.133369,2.552165
1,6.395461,-0.763557,1.009190,1.603559,-1.231239,-1.104914,1.234785
2,4.867707,0.029303,1.400914,2.448937,1.272824,-0.925300,4.387727
3,6.899322,0.329826,0.421427,4.667288,-0.435276,-1.354734,0.749507
4,4.815613,1.390180,0.846305,1.519486,0.991203,0.508393,4.558118
...,...,...,...,...,...,...,...
1787,2.726169,0.776640,1.178219,0.365224,1.033957,1.257541,3.880444
1788,2.085245,0.826460,-0.616106,0.557458,1.186213,0.301323,0.700825
1789,6.636372,-0.403226,-1.356165,1.431967,-0.699627,0.280025,-0.030878
1790,2.996233,0.693563,0.114380,1.280619,-1.448111,-2.305799,2.851131
