In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
import time
from pathlib import Path
from sklearn.datasets import make_moons
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset
from pytorch_lightning.loggers import TensorBoardLogger
import torch.nn.functional as F
import pytorch_lightning as pl

path_results = Path.cwd().parent / 'results'


In [None]:
data = make_moons(n_samples=10000, noise=0.3, random_state=1)
colors = ['C0' if y == 0 else 'C1' for y in data[1]]

plt.figure()
plt.scatter(data[0][:,0], data[0][:,1], c=colors, alpha=0.2)

In [None]:
class MoonsDataset(Dataset):
    def __init__(self, n_samples=10000, noise=None):
        x, y = make_moons(n_samples=n_samples, noise=noise, random_state=1)
        self.x, self.y = torch.from_numpy(x).float(), torch.from_numpy(y).float()
        # self.x = torch.zeros_like(self.x)

    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]


class MoonsDataModule(pl.LightningDataModule):
    
    def __init__(self, n_samples=10000, noise=None):
        super().__init__()
        self.n_samples = n_samples
        self.noise = noise
        
    def setup(self, stage=None):
        data = MoonsDataset(n_samples=10000, noise=self.noise)
        self.data_train, self.data_val = torch.utils.data.random_split(data, [8000, 2000])
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.data_train, batch_size=32, shuffle=True, num_workers=4)
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.data_val, batch_size=32, shuffle=False, num_workers=4)

    

In [None]:
class Classifier(pl.LightningModule):
    
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 10)
        self.fc2 = nn.Linear(10, 10)
        self.fc3 = nn.Linear(10, 1)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y = y.unsqueeze(1)
        y_hat = self(x)
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        acc = (torch.sigmoid(y_hat).round() == y).sum() / y.shape[0]
        self.log('train_loss', loss)
        self.log('train_acc', acc)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y = y.unsqueeze(1)
        y_hat = self(x)
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        acc = (torch.sigmoid(y_hat).round() == y).sum() / y.shape[0]
        self.log('val_loss', loss)
        self.log('val_acc', acc)
        return loss        


In [None]:
dm = MoonsDataModule(n_samples=10000, noise=0.3)
model = Classifier()

timestamp = time.strftime('%Y-%m-%d_%H%M', time.localtime())
path_results_exp = path_results / 'classifier' / timestamp
if not path_results_exp.exists(): path_results_exp.mkdir(parents=True)
logger = TensorBoardLogger(save_dir=path_results_exp, name='', version='')

trainer = pl.Trainer(accelerator="auto", devices=1, max_epochs=30, logger=logger)
trainer.fit(model, datamodule=dm)

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, data_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 10),
            nn.LeakyReLU(0.2),
            nn.Linear(10, 10),
            nn.LeakyReLU(0.2),
            nn.Linear(10, 10),
            nn.LeakyReLU(0.2),
            nn.Linear(10, data_dim))
    
    def forward(self, z):
        x = self.model(z)
        return x
    
# class Generator(nn.Module):
#     def __init__(self, latent_dim, data_dim):
#         super().__init__()
#         self.mapping = nn.Sequential(
#             nn.Linear(latent_dim, 10),
#             nn.LeakyReLU(0.2),
#             nn.Linear(10, latent_dim),
#             nn.LeakyReLU(0.2))
#         self.synthesis = nn.Sequential(
#             nn.Linear(latent_dim, 10),
#             nn.LeakyReLU(0.2),
#             nn.Linear(10, 10),
#             nn.LeakyReLU(0.2),
#             nn.Linear(10, 10),
#             nn.LeakyReLU(0.2),
#             nn.Linear(10, data_dim))
        
#     def forward(self, z):
#         w = self.mapping(z)
#         x = self.synthesis(w)
#         return x

class Discriminator(nn.Module):
    def __init__(self, data_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(data_dim, 10),
            nn.LeakyReLU(0.2),
            nn.Linear(10, 10),
            nn.LeakyReLU(0.2),
            nn.Linear(10, 1))
    
    def forward(self, x):
        y = self.model(x)
        return y
    
    def _get_gradient_penalty(self, crit, real, fake, epsilon):
        '''
        Return the gradient of the critic's scores with respect to mixes of real and fake images.
        Parameters:
            crit: the critic model
            real: a batch of real images
            fake: a batch of fake images
            epsilon: a vector of the uniformly random proportions of real/fake per mixed image
        Returns:
            gradient: the gradient of the critic's scores, with respect to the mixed image
        '''
        # Mix the images together
        mixed_images = real * epsilon + fake * (1 - epsilon)

        # Calculate the critic's scores on the mixed images
        mixed_scores = crit(mixed_images)
        
        # Take the gradient of the scores with respect to the images
        gradient = torch.autograd.grad(
            # Note: You need to take the gradient of outputs with respect to inputs.
            # This documentation may be useful, but it should not be necessary:
            # https://pytorch.org/docs/stable/autograd.html#torch.autograd.grad
            #### START CODE HERE ####
            inputs=mixed_images,
            outputs=mixed_scores,
            #### END CODE HERE ####
            # These other parameters have to do with the pytorch autograd engine works
            grad_outputs=torch.ones_like(mixed_scores), 
            create_graph=True,
            retain_graph=True,
        )[0]

        gradient = gradient.view(len(gradient), -1)

        # Calculate the magnitude of every row
        gradient_norm = gradient.norm(2, dim=1)
        
        # Penalize the mean squared distance of the gradient norms from 1
        penalty = torch.pow(torch.mean(gradient_norm - 1), 2)

        return penalty



class GAN(pl.LightningModule):
    
    def __init__(self, latent_dim, data_dim, data_module, c_dim=0, loss='standard', lambda_gp=10):
        super().__init__()
        self.save_hyperparameters()
        self.latent_dim = latent_dim
        self.generator = Generator(latent_dim + c_dim, data_dim)
        self.discriminator = Discriminator(data_dim + c_dim)
        self.dm = data_module
        self.c_dim = c_dim # class dimensionality, 0=unconditional
        self.loss = loss
        self.lambda_gp = lambda_gp # gradient penalty weight for wasserstein loss
        
    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy_with_logits(y_hat, y)
    
    def wasserstein_loss_gp(self, mode, pred_fake, pred_real=None, x_real=None, x_fake=None, lambda_gp=None):
        if mode == 'generator':
            loss = -pred_fake.mean()
        elif mode == 'discriminator':
            epsilon = torch.rand(x_real.shape[0], 1, device=self.device, requires_grad=True)
            gp = self.discriminator._get_gradient_penalty(self.discriminator, x_real, x_fake.detach(), epsilon)
            loss = torch.mean(pred_fake - pred_real) + lambda_gp*gp
        return loss
        
    def training_step(self, batch, batch_idx, optimizer_idx):
        x_real, y_real = batch
        z = torch.randn(x_real.shape[0], self.latent_dim, device=self.device)
        if self.c_dim > 0:
            rnd_label = torch.randint(self.c_dim, size=(z.shape[0],), device=self.device)
            c = F.one_hot(rnd_label, num_classes=self.c_dim)
            z = torch.cat([z, c], dim=1)

        # generator
        if optimizer_idx == 0:
            x_fake = self.generator(z)
            if self.c_dim > 0: 
                x_fake = torch.cat([x_fake, c], dim=1)
            pred_fake = self.discriminator(x_fake)
            if self.loss == 'standard':
                loss_gen = self.adversarial_loss(pred_fake, torch.ones_like(pred_fake, device=self.device)) # fakes should be predicted as true
            elif self.loss == 'wasserstein':
                loss_gen = self.wasserstein_loss_gp('generator', pred_fake)
            self.log('loss_gen', loss_gen)
            return loss_gen
        
        # discriminator
        elif optimizer_idx == 1:
            x_fake = self.generator(z).detach()
            if self.c_dim > 0: 
                x_fake = torch.cat([x_fake, c], dim=1)
                c_real = F.one_hot(y_real.long(), num_classes=self.c_dim)
                x_real = torch.cat([x_real, c_real], dim=1)
            pred_real = self.discriminator(x_real)
            pred_fake = self.discriminator(x_fake)
            if self.loss == 'standard':
                loss_disc = 0.5 * (self.adversarial_loss(pred_fake, torch.zeros_like(pred_fake, device=self.device)) # fakes should be predicted as false
                                   + self.adversarial_loss(pred_real, torch.ones_like(pred_real, device=self.device))) # reals should be predicted as true
            elif self.loss == 'wasserstein':
                loss_disc = self.wasserstein_loss_gp('discriminator', pred_fake, pred_real, x_real, x_fake, lambda_gp=self.lambda_gp)
            self.log('loss_disc', loss_disc)
            return loss_disc
        
    def configure_optimizers(self):
        optim_gen = torch.optim.Adam(self.generator.parameters(), lr=0.001, betas=(0.5, 0.999))
        optim_disc = torch.optim.Adam(self.discriminator.parameters(), lr=0.001, betas=(0.5, 0.999))
        return optim_gen, optim_disc
    
    def on_train_epoch_end(self):
        # log images
        z = torch.randn(1000, self.latent_dim, device=self.device)
        if self.c_dim > 0:
            rnd_label = torch.randint(self.c_dim, size=(z.shape[0],), device=self.device)
            c = F.one_hot(rnd_label, num_classes=self.c_dim)
            z = torch.cat([z, c], dim=1)
        x_fake = self.generator(z).detach().cpu().numpy()
        plt.scatter(self.dm.data_train[:1000][0][:, 0], self.dm.data_train[:1000][0][:, 1], alpha=0.5, c=['C0' if y == 0 else 'C1' for y in self.dm.data_train[:1000][1]])
        color = ['C2' if y == 0 else 'C3' for y in rnd_label] if self.c_dim > 0 else 'k'
        plt.scatter(x_fake[:, 0], x_fake[:, 1], c=color)
        tensorboard_logger = self.logger.experiment
        tensorboard_logger.add_figure("generated_images", plt.gcf(), self.current_epoch)


In [None]:
experiment_name = 'standard_GAN_cond_latent10'

dm = MoonsDataModule(n_samples=10000, noise=0.1)
model = GAN(10, 2, dm, c_dim=2, loss='standard', lambda_gp=10)

timestamp = time.strftime('%Y-%m-%d_%H%M', time.localtime())
path_results_exp = path_results / 'GAN' / timestamp
if not path_results_exp.exists(): path_results_exp.mkdir(parents=True)
logger = TensorBoardLogger(save_dir=path_results_exp, name=experiment_name, version='')

trainer = pl.Trainer(accelerator="auto", devices=1, max_epochs=100, logger=logger)
trainer.fit(model, dm)

In [None]:
experiment_name = 'wasserstein_GAN_cond_latent10'

dm = MoonsDataModule(n_samples=10000, noise=0.1)
model = GAN(10, 2, dm, c_dim=2, loss='wasserstein', lambda_gp=10)

timestamp = time.strftime('%Y-%m-%d_%H%M', time.localtime())
path_results_exp = path_results / 'GAN' / timestamp
if not path_results_exp.exists(): path_results_exp.mkdir(parents=True)
logger = TensorBoardLogger(save_dir=path_results_exp, name=experiment_name, version='')

trainer = pl.Trainer(accelerator="auto", devices=1, max_epochs=100, logger=logger)
trainer.fit(model, dm)