In [None]:
!rm -r ./reduce-mode-collapse-in-gan
!git clone -b synthetic_experiment https://github.com/ThViviani/reduce-mode-collapse-in-gan.git

import sys; sys.path.append('./reduce-mode-collapse-in-gan')

In [None]:
!pip install -r ./reduce-mode-collapse-in-gan/requirements.txt

In [None]:
!rm -r ./gan
!git clone https://github.com/tntrung/gan.git

sys.path.append('./gan')

# Prepare 2D data from dist-gan
https://github.com/tntrung/gan/tree/master/distgan_toy2d  
I have problems when i try import the distgan_toy2d module, so i copied some code in this cells.

In [None]:
import numpy as np
import itertools
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import lightning as L
import wandb
import pandas as pd

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import LearningRateMonitor
from datetime import datetime

In [None]:
VAR = 0.1
BATCH_SIZE = 128
MAX_EPOCHS = 500
SEED = 42
WANDB_TOKEN = '' # input your token

seed_everything(SEED)

In [None]:
def read_toydata(toyfile):
    fid = open(toyfile,'r')
    lines = fid.readlines()
    data = []
    for line in lines:
        line = line.replace('[', '')
        line = line.replace(']', '')
        data.append([float(curr_num) for curr_num in line.split()])
    fid.close()
    return np.array(data)

In [None]:
def maxabs(a, axis=None):
    amax = a.max(axis)
    amin = a.min(axis)
    return np.where(-amin > amax, amin, amax)

def normalize_toydata(toydata, centroids, var):
    centroids = (centroids/maxabs(np.float32(toydata))+1)/2
    var = (var/maxabs(toydata))/np.sqrt(2)
    toydata = (toydata/maxabs(toydata)+1)/2
    toydata_size = len(toydata)
    return toydata, centroids, var 

In [None]:
grid_centroids = np.array([np.array([i, j]) for i, j in itertools.product(range(-4, 5, 2),
                                                                    range(-4, 5, 2))])
toydata = read_toydata('gan/distgan_toy2d/toy_data/toydatav2.txt')
toydata, grid_centroids, VAR = normalize_toydata(toydata, grid_centroids, VAR)

In [None]:
plt.scatter(toydata[:,0], toydata[:,1], color='b')
plt.scatter(grid_centroids[:,0], grid_centroids[:,1], marker='x', color='r')

In [None]:
upper_right_corner_mode = toydata[(np.linalg.norm(toydata - grid_centroids[-1], axis=1) <= VAR)]

plt.scatter(toydata[:,0], toydata[:,1], color='b')
plt.scatter(upper_right_corner_mode[:,0], upper_right_corner_mode[:,1], color='g')
plt.scatter(grid_centroids[:,0], grid_centroids[:,1], marker='x', color='r')

In [None]:
def evaluate_mode_covered(data, centroids, var):
    mode_covered = [0 for _ in range(len(centroids))]
    for i in range(len(centroids)):
        subdata = data - centroids[i]
        distance = np.linalg.norm(subdata,axis=1)
        point_in_mode = (distance<=var).sum()
        mode_covered[i] = point_in_mode
    return np.array(mode_covered)

In [None]:
(evaluate_mode_covered(toydata, grid_centroids, VAR) >= 20).sum()

In [None]:
grid_centroids

In [None]:
dummy_labels = torch.ones(toydata.shape[0])
train_dataset = TensorDataset(torch.FloatTensor(toydata), dummy_labels)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

Визуализация батча

In [None]:
b = next(iter(train_loader))[0]
plt.scatter(b[:,0], b[:,0])
plt.scatter(grid_centroids[:,0], grid_centroids[:,1], marker='x', color='r');

# Experiments

In [None]:
from utils.train_options import TrainOptions


train_options = TrainOptions(
    latent_dim=2,
    batch_size=BATCH_SIZE,
    lr=1e-3,
    betas=(0.8, 0.999)
)

In [None]:
from trainers.synthetic_adversarial_trainer import *
from models.discriminators import Critic2D
from models.utils import MLP
from models.generators import Generator2D


def create_model_fn(model_class, centroids, var, opt, use_encoder=False, use_r1r2=False):
    def wrapper():
        if model_class.__name__.find('Rp') != -1:
            return model_class(
                critic=Critic2D(output_dim=1),
                generator=Generator2D(),
                encoder=MLP() if use_encoder else None,
                prior_type='uniform',
                use_r1r2_penalty=use_r1r2,
                centroids=centroids,
                var=var,
                opt=opt
            )
        else:
            return model_class(
                critic=Critic2D(output_dim=1),
                generator=Generator2D(),
                encoder=MLP() if use_encoder else None,
                prior_type='uniform',
                centroids=centroids,
                var=var,
                opt=opt
            )
    return wrapper


EXPERIMENTS = {
    'StandardGAN': create_model_fn(
        SyntheticVanilaGAN, centroids=grid_centroids, var=VAR, opt=train_options
    ),
    'NEVanilaGAN': create_model_fn(
        SynthNEVanilaGAN, use_encoder=True, centroids=grid_centroids, var=VAR, opt=train_options
    ),
    'DistVanilaGAN': create_model_fn(SynthDistVanilaGAN, use_encoder=True, centroids=grid_centroids, var=VAR, opt=train_options),
    'DpVanilaGAN': create_model_fn(SynthDpVanilaGAN, centroids=grid_centroids, var=VAR, opt=train_options),
    'RpGAN': create_model_fn(SyntheticRpGAN, centroids=grid_centroids, var=VAR, opt=train_options),
    'NERpGAN': create_model_fn(SynthNERpGAN, use_encoder=True, centroids=grid_centroids, var=VAR, opt=train_options),    
    'DistRpGAN': create_model_fn(SynthDistRpGAN, use_encoder=True, centroids=grid_centroids, var=VAR, opt=train_options),
    'DpRpGAN': create_model_fn(SynthDpRpGAN, centroids=grid_centroids, var=VAR, opt=train_options),
    'RpGAN_R1R2': create_model_fn(SyntheticRpGAN, use_r1r2=True, centroids=grid_centroids, var=VAR, opt=train_options),
    'DistRpGAN+R1R2': create_model_fn(SynthDistRpGAN, use_encoder=True, use_r1r2=True, centroids=grid_centroids, var=VAR, opt=train_options),
    'NERpGAN+R1R2': create_model_fn(SynthNERpGAN, use_encoder=True, use_r1r2=True, centroids=grid_centroids, var=VAR, opt=train_options),
    'DpRpGAN+R1R2': create_model_fn(SynthDpRpGAN, use_r1r2=True, centroids=grid_centroids, var=VAR, opt=train_options),
    'NEVanilaGAN_hat': create_model_fn(SynthNEhatVanilaGAN, centroids=grid_centroids, var=VAR, opt=train_options),
    'NERpGAN_hat': create_model_fn(SynthNEhatRpGAN, centroids=grid_centroids, var=VAR, opt=train_options),
    'NERpGAN_hat+R1R2': create_model_fn(SynthNEhatRpGAN, use_r1r2=True, centroids=grid_centroids, var=VAR, opt=train_options),
}

In [None]:
wandb.login(key=WANDB_TOKEN)

In [None]:
results = pd.DataFrame(columns=[
    'registered modes', 
    'registered samples', 
])

In [None]:
for name, model_fn in EXPERIMENTS.items():
    print(f"Running {name}")

    wandb_logger = WandbLogger(
        project='Synthetic2D_vkr',
        save_dir='',
        log_model=True,
        name=name + "_" + str(datetime.now())
    )

    trainer = Trainer(
        max_epochs=MAX_EPOCHS,
        logger=wandb_logger,
        deterministic=True,
        callbacks=[LearningRateMonitor(logging_interval='epoch')]
    )

    model = model_fn()
    trainer.fit(model=model, train_dataloaders=train_loader)
    results.loc[name] = model.compute_modes_covered()

    wandb.finish()

In [None]:
results