In [1]:
import sys
import os
import pandas as pd
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
sys.path.append("../../")
sys.path.append("../")
from utils.evaluation import evaluate
from utils.metrics import Metrics
from models.NNGenerator import AdaptableDenseModel


In [2]:
import neptune
from neptune_pytorch import NeptuneLogger
from neptune.utils import stringify_unsupported
from dotenv import load_dotenv
load_dotenv()

True

In [3]:
import pickle
from git import Repo

# Get the git root directory
repo = Repo(".", search_parent_directories=True)
git_root = repo.git.rev_parse("--show-toplevel")

# Load data
X_Train_pd = pickle.load(open(f"{git_root}/data/splits/train/X_pandas.pck", "rb"))
y_Train_pd = pickle.load(open(f"{git_root}/data/splits/train/y_pandas.pck", "rb"))

X_Val_pd = pickle.load(open(f"{git_root}/data/splits/val/X_pandas.pck", "rb"))
y_Val_pd = pickle.load(open(f"{git_root}/data/splits/val/y_pandas.pck", "rb"))

In [4]:
X_Train = torch.tensor(X_Train_pd.values, dtype=torch.float32)
y_Train = torch.tensor(y_Train_pd.values, dtype=torch.float32)

X_Val = torch.tensor(X_Val_pd.values, dtype=torch.float32)
y_Val = torch.tensor(y_Val_pd.values, dtype=torch.float32)

train_dataset = torch.utils.data.TensorDataset(X_Train, y_Train)
val_dataset = torch.utils.data.TensorDataset(X_Val, y_Val)

In [5]:
def label_from_logits(y_hat: torch.Tensor, threshold = 0.5) -> torch.Tensor:
    with torch.no_grad():
        y_pred_tensor = (torch.sigmoid(y_hat) > threshold).float()
    return y_pred_tensor


def evaluate_from_dataframe(X: pd.DataFrame):
    X_tensor = torch.tensor(X.to_numpy(), dtype=torch.float32)
    
    #model: a pytorch model, which transforms X -> y in torch.Tensor format
    model.eval()
    model.cpu()
    y_pred_tensor = label_from_logits(model(X_tensor))
    
    return pd.DataFrame(y_pred_tensor.numpy())


In [6]:
#Adopted from https://forge.ibisc.univ-evry.fr/alacan/GANs-for-transcriptomics/-/blob/master/src/models/utils.py?ref_type=heads

from metrics.aats import compute_AAts
#from metrics.precision_recall import get_precision_recall

def calc_aat(real_data, fake_data):
    _, _, aat = compute_AAts(real_data, fake_data)
    return aat

# def calc_precision_recall(real_data, fake_data):
#     precision, recall = get_precision_recall(real_data, fake_data)
#     return precision, recall

In [7]:

#Adopted from https://forge.ibisc.univ-evry.fr/alacan/GANs-for-transcriptomics/-/blob/master/src/models/utils.py?ref_type=heads
def wasserstein_loss(y_true: torch.tensor, y_pred: torch.tensor):
    """
    Returns Wasserstein loss (product of real/fake labels and critic scores on real or fake data)
    ----
    Parameters:
        y_true (torch.tensor): true labels (either real or fake)
        y_pred (torch.tensor): critic scores on real or fake data
    Returns:
        (torch.tensor): mean product of real labels and critic scores
    """
    return torch.mean(y_true * y_pred)


def generator_loss(fake_score: torch.tensor):
    """
    Returns generator loss i.e the negative scores of the critic on fake data.
    ----
    Parameters:
        fake_score (torch.tensor): critic scores on fake data
    Returns:
        (torch.tensor): generator loss"""

    return wasserstein_loss(-torch.ones_like(fake_score), fake_score)


def discriminator_loss(real_score: torch.tensor, fake_score: torch.tensor):
    """
    Compute and return the wasserstein loss of critic scores on real and fake data i.e: wassertstein_loss = mean(-score_real) + mean(score_fake)
    ----
    Parameters:
        real_score (torch.tensor): critic scores on real data
        fake_score (torch.tensor): critic scores on fake data
    Returns:
        (torch.tensor): wasserstein loss
    """
    real_loss = wasserstein_loss(-torch.ones_like(real_score), real_score)
    fake_loss = wasserstein_loss(torch.ones_like(fake_score), fake_score)

    return real_loss, fake_loss




In [8]:
def training(D, G, d_optimizer, g_optimizer, train_loader, val_loader, epochs, latentSpaceSize, lambda_gp, iters_critic, device, neptune_logger=None, run = None, trial = None):
# Code adopted from https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
# and https://github.com/Zeleni9/pytorch-wgan/blob/master/models/wgan_gradient_penalty.py
# and https://forge.ibisc.univ-evry.fr/alacan/GANs-for-transcriptomics/-/blob/master/src/models/utils.py?ref_type=heads
    D = D.to(device)
    G = G.to(device)
    aat = 0

    print("Starting Training Loop...")
    # For each epoch
    for epoch in tqdm(range(epochs)):
        errDItem = 0
        errGItem = 0
        Wasserstein_D = 0
        G.train()
        D.train()

        # For each batch in the dataloader
        for X,y in train_loader:

            y = y.to(device).long()
            
            ############################
            # (1) Update D network: maximize D(x|c) - D(G(z|c)) + lambda_gp * gradient_penalty
            ###########################
            for p in D.parameters():
                p.requires_grad = True  

            for p in G.parameters():
                p.requires_grad = False  # to avoid computation

            # Train with all-real batch
            b_size = X.size(0)
            real_data = X.to(device)

            # Train D
            for crit_step in range(iters_critic):
                d_optimizer.zero_grad()
                
                # Generate batch of latent vectors
                z = torch.randn(b_size, latentSpaceSize, device=device)
                # Generate fake image batch with G
                fake_data = G(z, y)
                NB_GENES = fake_data.size(1)

                # Perform random augmentations for stability
                augmentations = torch.distributions.binomial.Binomial(total_count=1, probs=0).sample(torch.tensor([b_size])).to(device)
                fake_data = fake_data + augmentations[:, None] * torch.normal(0, 0.5, size=(NB_GENES,), device=device)
                real_data = real_data + augmentations[:,None] * torch.normal(0, 0.5, size=(b_size,NB_GENES), device=device)

                # Classify batches with D
                errD_real = D(real_data, y)
                errD_fake = D(fake_data, y)
                # Calculate D's loss on the all-fake batch
                real_loss, fake_loss = discriminator_loss(errD_real, errD_fake)

                # Loss for D in
                d_loss = real_loss + fake_loss
                Wasserstein_D = d_loss.item()

                ####### Gradient penalty #######
                BATCH_SIZE = real_data.size(0)

                # Sample alpha from uniform distribution
                alpha = torch.rand(
                    BATCH_SIZE,
                    1,
                    requires_grad=True,
                    device=real_data.device)

                # Interpolation between real data and fake data.
                interpolation = torch.mul(alpha, real_data) + \
                    torch.mul((1 - alpha), fake_data)

                # Get outputs from critic
                disc_outputs = D(interpolation, y)
                grad_outputs = torch.ones_like(
                    disc_outputs,
                    requires_grad=False,
                    device=real_data.device)

                # Retrieve gradients
                gradients = torch.autograd.grad(
                    outputs=disc_outputs,
                    inputs=interpolation,
                    grad_outputs=grad_outputs,
                    create_graph=True,
                    retain_graph=True)[0]

                # Compute gradient penalty
                gradients = gradients.view(BATCH_SIZE, -1)
                grad_norm = gradients.norm(2, dim=1)

                gradient_penalty = torch.mean((grad_norm - 1) ** 2) * lambda_gp
                d_loss += gradient_penalty
                # Calculate gradients for D in backward pass
                d_loss.backward()

                # Compute error of D as sum of the errors on the real and fake batches and the gradient penalty
                errDItem = d_loss.item()
                # Update D
                d_optimizer.step()

            ############################
            # (2) Update G network: maximize D(G(z|c))
            ###########################
            for p in D.parameters():
                p.requires_grad = False  # to avoid computation

            for p in G.parameters():
                p.requires_grad = True

            g_optimizer.zero_grad()
            
            # Generate batch of latent vectors
            z = torch.randn(b_size, latentSpaceSize, device=device)
            # Generate fake image batch with G
            fake_data = G(z, y)

            NB_GENES = fake_data.shape[1]

            # Perform random augmentations for stability
            augmentations = torch.distributions.binomial.Binomial(total_count=1, probs=0).sample(torch.tensor([BATCH_SIZE])).to(device)
            fake_data = fake_data + augmentations[:, None] * torch.normal(0, 0.5, size=(NB_GENES,), device=device)

            # Calculate G's loss based on this output
            errG = D(fake_data, y)
            
            # Calculate gradients for G
            g_loss =  generator_loss(errG)
            g_loss.backward()
            errGItem += g_loss.item()
            # Update G
            g_optimizer.step()

        # Evaluate the model on the validation set
        if epoch % 20 == 0:
            with torch.no_grad():
                G.eval()
                real_datas = np.array([])
                fake_datas = np.array([])
                for X_val, y_val in val_loader:
                    y_val = y_val.to(device).long()
                    X_val = X_val.to(device).to(torch.float16)
                    z = torch.randn(X_val.size(0), latentSpaceSize)
                    fake_data = G(z.to(device), y_val).to(torch.float16)
                    real_datas = torch.vstack((real_datas, X_val)) if real_datas.size else real_data
                    fake_datas = torch.vstack((fake_datas, fake_data)) if fake_datas.size else fake_data
                aat = calc_aat(real_datas.cpu(), fake_datas.cpu())
            #with torch.device(device):
            #    recall, precision = calc_precision_recall(real_datas, fake_datas)

        errDItem /= len(train_loader)
        errGItem /= len(train_loader)
        Wasserstein_D /= len(train_loader)
        # Log the metrics to neptune
        if neptune_logger is not None:
            run[neptune_logger.base_namespace]['D_Loss'].append(errDItem)
            run[neptune_logger.base_namespace]['G_Loss'].append(errGItem)
            run[neptune_logger.base_namespace]['Wasserstein_D'].append(Wasserstein_D)
            run[neptune_logger.base_namespace]['AAT'].append(aat)
            #run[neptune_logger.base_namespace]['Precision'].append(precision)
            #run[neptune_logger.base_namespace]['Recall'].append(recall)

        print(f"Epoch {epoch} D Loss: {errDItem} G Loss: {errGItem} Wasserstein D: {Wasserstein_D}")
            

In [9]:
class Discriminator(nn.Module):
    def __init__(self, input_size, output_size, class_size, embedding_size):
        super().__init__()
        self.model = nn.Sequential(
            #nn.Linear(input_size, 1024),
            nn.Linear(input_size + embedding_size * class_size, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.1),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.1),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, output_size),
        )
        self.embedding = nn.Embedding(class_size, embedding_size)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)

    def forward(self, x, cls):
        cls = self.embedding(cls)
        x = torch.cat((x, cls.flatten(start_dim=1)), 1)
        return self.model(x)
    
class Generator(nn.Module):
    def __init__(self, input_size, output_size, class_size, embedding_size):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size + embedding_size * class_size, 512),
            #nn.Linear(input_size , 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.1),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.1),
            nn.Linear(1024, 2048),
            nn.BatchNorm1d(2048),
            nn.LeakyReLU(0.1),
            nn.Linear(2048, output_size)
        )
        self.embedding = nn.Embedding(class_size, embedding_size)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)

    def forward(self, x, cls):
        cls = self.embedding(cls)
        x = torch.cat((x, cls.flatten(start_dim=1)), 1)
        return self.model(x)

In [10]:
input_dim_data = X_Train.shape[1]
output_dim_data = y_Train.shape[1]
latentSpaceSize = 128

# Define the generator
G = Generator(latentSpaceSize, input_dim_data, output_dim_data, 2)

# Define the discriminator
D = Discriminator(input_dim_data, 16, output_dim_data, 2)





In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
params = {
    "epochs": 800,
    "lambda_gp": 10,
    "latentSpaceSize": latentSpaceSize,
    "device": device,
    "batch_size": 512,
    "lr_g": 0.0001/20,
    "lr_d": 0.001/20,
    "critic_iter": 5,
    "b1": 0.5,
    "b2": 0.999
}




train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params["batch_size"], shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1024, shuffle=True)

d_optimizer = torch.optim.Adam(D.parameters(), lr=params["lr_d"], betas=(params["b1"], params["b2"]))
g_optimizer = torch.optim.Adam(G.parameters(), lr=params["lr_g"], betas=(params["b1"], params["b2"]))



In [12]:
run = neptune.init_run(
    api_token=os.getenv("NEPTUNE_API_KEY"),
    project=os.getenv("NEPTUNE_PROJECT_NAME"),
    name="WGAN-GP - 800 - simple",
)

neptune_logger = NeptuneLogger(run=run, model=G)
                               
run[neptune_logger.base_namespace]["hyperparams"] = stringify_unsupported(params)
run[neptune_logger.base_namespace]["G_Structure"] = str(G)
run[neptune_logger.base_namespace]["D_Structure"] = str(D)



training(D, G, d_optimizer, g_optimizer, train_loader , val_loader, params["epochs"], latentSpaceSize, params["lambda_gp"], params["critic_iter"], device, neptune_logger=neptune_logger, run=run)

neptune_logger.log_model()
run.stop()



[neptune] [info   ] Neptune initialized. Open in the app: https://app.neptune.ai/JPL/rna-sequencing/e/RNAS-196
Starting Training Loop...


  0%|          | 0/800 [00:00<?, ?it/s]

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
100%|██████████| 4/4 [00:17<00:00,  4.38s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 0 D Loss: 0.0021393661732440226 G Loss: 0.14092089402508903 Wasserstein D: -2.142912024384612e-05
Epoch 1 D Loss: 0.0012957395373524485 G Loss: 0.09353192739344977 Wasserstein D: -3.511208545911562e-05
Epoch 2 D Loss: 0.001123168251731179 G Loss: 0.07730406595693602 Wasserstein D: -0.00013854753199990812
Epoch 3 D Loss: 0.0006770808379966896 G Loss: 0.06878841376596398 Wasserstein D: -0.00023808991992390238
Epoch 4 D Loss: -0.0002166120843453841 G Loss: 0.06163631457757283 Wasserstein D: -0.000785356113960693
Epoch 5 D Loss: -0.0007302447945087939 G Loss: 0.059820024882788425 Wasserstein D: -0.0014415035297820618
Epoch 6 D Loss: -0.0011937784981894326 G Loss: 0.11527941250926131 Wasserstein D: -0.00196809368533688
Epoch 7 D Loss: -0.002564546528396073 G Loss: 0.2163058628777524 Wasserstein D: -0.003004715992854192
Epoch 8 D Loss: -0.003123686030194476 G Loss: 0.36325119737978584 Wasserstein D: -0.0037797239276912663
Epoch 9 D Loss: -0.002816650417301205 G Loss: 0.4767960730966154

100%|██████████| 4/4 [00:17<00:00,  4.36s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 20 D Loss: -0.0024069670613829073 G Loss: -0.13730183690525852 Wasserstein D: -0.004706678690610232
Epoch 21 D Loss: -0.004447150897312831 G Loss: -0.46317615646582383 Wasserstein D: -0.005851593884554776
Epoch 22 D Loss: -0.0069376463656658895 G Loss: -0.32624970126402125 Wasserstein D: -0.007859504306233013
Epoch 23 D Loss: -0.002248961191910964 G Loss: -0.05166575771111708 Wasserstein D: -0.005861780860207297
Epoch 24 D Loss: 0.005286205065000308 G Loss: 0.030175709693165093 Wasserstein D: -0.007924851837691727
Epoch 25 D Loss: -0.0029273404108060824 G Loss: -0.37291872798354475 Wasserstein D: -0.005585281582145424
Epoch 26 D Loss: -0.0018617644176616534 G Loss: -0.4676255226343662 Wasserstein D: -0.00619491413756684
Epoch 27 D Loss: -0.0048055215315385294 G Loss: -0.4192444775904809 Wasserstein D: -0.006219801786062601
Epoch 28 D Loss: -0.0041939245237337126 G Loss: -0.22834747396670022 Wasserstein D: -0.006851506816757309
Epoch 29 D Loss: -0.004217569227818843 G Loss: -0.402

100%|██████████| 4/4 [00:18<00:00,  4.60s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 40 D Loss: -0.0011489353813491502 G Loss: -0.8639711249958385 Wasserstein D: -0.007654134210173066
Epoch 41 D Loss: -0.00466138374555361 G Loss: -0.9010141808669884 Wasserstein D: -0.007322299730527651
Epoch 42 D Loss: -0.009095171948412915 G Loss: -0.8080448249836901 Wasserstein D: -0.010628301780540627
Epoch 43 D Loss: -0.0018525111091720475 G Loss: -0.8853536223198151 Wasserstein D: -0.006321156775201117
Epoch 44 D Loss: -0.00648728784147676 G Loss: -1.0111792266785682 Wasserstein D: -0.008047766618795329
Epoch 45 D Loss: -0.009679600075408296 G Loss: -0.8523195410941864 Wasserstein D: -0.011487485645534276
Epoch 46 D Loss: -0.008269731815044697 G Loss: -1.0052254062432509 Wasserstein D: -0.011380919209727041
Epoch 47 D Loss: 0.01794618326467234 G Loss: -0.8262450365753441 Wasserstein D: -0.0106625331865324
Epoch 48 D Loss: -0.002409216407295707 G Loss: -1.362741699585548 Wasserstein D: -0.008396512144929046
Epoch 49 D Loss: 0.003148208964954723 G Loss: -1.0784982897184945 Was

100%|██████████| 4/4 [00:18<00:00,  4.55s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 60 D Loss: 0.011882416851870664 G Loss: -0.83280215563474 Wasserstein D: -0.005106472468876339
Epoch 61 D Loss: -0.0036728348765339884 G Loss: -1.4803626078825731 Wasserstein D: -0.009049342228816105
Epoch 62 D Loss: -0.00876898365420895 G Loss: -1.2279082799291277 Wasserstein D: -0.011014784966315423
Epoch 63 D Loss: -0.002843033183704723 G Loss: -1.2725502490163683 Wasserstein D: -0.006492629751458868
Epoch 64 D Loss: -0.005322767304373787 G Loss: -1.1756672167277837 Wasserstein D: -0.007425489959183273
Epoch 65 D Loss: -0.008197544337986232 G Loss: -1.0866167945461673 Wasserstein D: -0.009541458183235221
Epoch 66 D Loss: -0.010137658852797288 G Loss: -0.9785070790277495 Wasserstein D: -0.011605647894052358
Epoch 67 D Loss: -0.013534818495903815 G Loss: -0.8618783200537408 Wasserstein D: -0.015553059277834591
Epoch 68 D Loss: -0.00458643569812908 G Loss: -1.110693199234409 Wasserstein D: -0.013996117598526961
Epoch 69 D Loss: -0.009945407614007697 G Loss: -1.5040984453854862 Wa

100%|██████████| 4/4 [00:17<00:00,  4.44s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 80 D Loss: -0.011367591110976426 G Loss: -0.7415944571261639 Wasserstein D: -0.013654338729965102
Epoch 81 D Loss: -0.013016448987947477 G Loss: -0.7778571275147524 Wasserstein D: -0.01531299177583281
Epoch 82 D Loss: -0.008050865226692253 G Loss: -1.0084767825119978 Wasserstein D: -0.011545698125879248
Epoch 83 D Loss: -0.010587718103315446 G Loss: -1.1447820763487917 Wasserstein D: -0.015302936513940772
Epoch 84 D Loss: -0.016574481150487087 G Loss: -0.5909836056765977 Wasserstein D: -0.019539426256726673
Epoch 85 D Loss: -0.015369630360103154 G Loss: -0.9328118516848638 Wasserstein D: -0.0178012781209879
Epoch 86 D Loss: -0.012723263327058379 G Loss: -0.3871718449192447 Wasserstein D: -0.021015967522467767
Epoch 87 D Loss: -0.022891823228422578 G Loss: -0.1325850504991058 Wasserstein D: -0.02508129773440061
Epoch 88 D Loss: -0.019992069764570755 G Loss: -0.11803742236347048 Wasserstein D: -0.028849955205317145
Epoch 89 D Loss: -0.03017441542832168 G Loss: -0.04854571973006208 

100%|██████████| 4/4 [00:18<00:00,  4.57s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 100 D Loss: -0.031253851377047025 G Loss: 0.5249403069590355 Wasserstein D: -0.05870729726511282
Epoch 101 D Loss: -0.0259581012325687 G Loss: 1.2214605799921743 Wasserstein D: -0.05805141609031837
Epoch 102 D Loss: -0.06325993838010134 G Loss: 1.7582153682108526 Wasserstein D: -0.067701393074089
Epoch 103 D Loss: -0.059483361410927936 G Loss: 2.388074234649018 Wasserstein D: -0.07178728063623388
Epoch 104 D Loss: -0.06320743960934086 G Loss: 2.571054466954478 Wasserstein D: -0.07755552972113336
Epoch 105 D Loss: -0.07529606185592971 G Loss: 2.6154433380473745 Wasserstein D: -0.07939392036491341
Epoch 106 D Loss: -0.05349902839927407 G Loss: 2.68593915609213 Wasserstein D: -0.082024567610734
Epoch 107 D Loss: -0.051957847355129004 G Loss: 2.7004801660150917 Wasserstein D: -0.0840218150532329
Epoch 108 D Loss: -0.06772270069255695 G Loss: 3.284835475308078 Wasserstein D: -0.09030987666203426
Epoch 109 D Loss: -0.08022026248745151 G Loss: 2.688414928796408 Wasserstein D: -0.0846299

100%|██████████| 4/4 [00:18<00:00,  4.54s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 120 D Loss: -0.10570460766345471 G Loss: 5.0592813225059245 Wasserstein D: -0.11116085852776374
Epoch 121 D Loss: -0.10218951085230688 G Loss: 5.282648823478005 Wasserstein D: -0.11312615454613746
Epoch 122 D Loss: 0.06327881846394572 G Loss: 5.493449171106298 Wasserstein D: -0.08696620781104882
Epoch 123 D Loss: -0.0737147698035607 G Loss: 3.5562130771316847 Wasserstein D: -0.10241836601203971
Epoch 124 D Loss: -0.08047657413082523 G Loss: 4.357551533025461 Wasserstein D: -0.09951312058455461
Epoch 125 D Loss: -0.10361164599865466 G Loss: 4.7545375423831535 Wasserstein D: -0.10753470867663831
Epoch 126 D Loss: -0.04957977374950489 G Loss: 5.0587183678900445 Wasserstein D: -0.11238066133085664
Epoch 127 D Loss: -0.10067366219900704 G Loss: 5.291641098636013 Wasserstein D: -0.11068631552316092
Epoch 128 D Loss: -0.10288758711381392 G Loss: 5.522519535118049 Wasserstein D: -0.11703249791285375
Epoch 129 D Loss: -0.10315377562196104 G Loss: 5.59634301712463 Wasserstein D: -0.1204469

100%|██████████| 4/4 [00:18<00:00,  4.52s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 140 D Loss: -0.08347616662512293 G Loss: 5.85018164294583 Wasserstein D: -0.12188758049811517
Epoch 141 D Loss: -0.06574420662193031 G Loss: 6.118104361153983 Wasserstein D: -0.1311111716957359
Epoch 142 D Loss: -0.08604805286114033 G Loss: 6.269702231133734 Wasserstein D: -0.12780765720180698
Epoch 143 D Loss: -0.07694184363305152 G Loss: 6.506160049171714 Wasserstein D: -0.1166980449969952
Epoch 144 D Loss: -0.08029134790380518 G Loss: 6.259009798089941 Wasserstein D: -0.12351192127574574
Epoch 145 D Loss: -0.010916990000051219 G Loss: 6.208969629727877 Wasserstein D: -0.10572149370100115
Epoch 146 D Loss: -0.056302257351108366 G Loss: 6.109132269879321 Wasserstein D: -0.1256523265705242
Epoch 147 D Loss: -0.1127739086017742 G Loss: 6.177499257601225 Wasserstein D: -0.12084012598424525
Epoch 148 D Loss: 0.253821499697812 G Loss: 6.3887050702021675 Wasserstein D: -0.12963226958588286
Epoch 149 D Loss: -0.002963579618013822 G Loss: 6.561809249691196 Wasserstein D: -0.124765116018

100%|██████████| 4/4 [00:18<00:00,  4.58s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 160 D Loss: 4.1560442731097025 G Loss: -0.9778041423695696 Wasserstein D: -0.024244336815147132
Epoch 161 D Loss: 0.49252959564849214 G Loss: -1.0415555110552928 Wasserstein D: -0.022978615927529502
Epoch 162 D Loss: 0.603136662836675 G Loss: -1.6613944122841309 Wasserstein D: -0.007851321380455177
Epoch 163 D Loss: 0.6325238634656359 G Loss: -2.1045269174175663 Wasserstein D: -0.008278695019808683
Epoch 164 D Loss: 0.6392990859238418 G Loss: -2.398957667650876 Wasserstein D: -0.01091045266264802
Epoch 165 D Loss: 0.08712614499605618 G Loss: -2.9187320845944065 Wasserstein D: -0.007535861088679387
Epoch 166 D Loss: 0.13020125969306573 G Loss: -2.8316095575586067 Wasserstein D: -0.003268115170352109
Epoch 167 D Loss: 0.1804401557762306 G Loss: -2.795582594571414 Wasserstein D: -0.004019715569236062
Epoch 168 D Loss: 0.21230233465874945 G Loss: -2.759952558504118 Wasserstein D: -0.0038467553945688102
Epoch 169 D Loss: 0.06908149986000328 G Loss: -1.878341292167877 Wasserstein D: -0

100%|██████████| 4/4 [00:18<00:00,  4.51s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 180 D Loss: -0.027338134658920182 G Loss: 0.19177450558782874 Wasserstein D: -0.029521998825606768
Epoch 181 D Loss: -0.0299168299961757 G Loss: 0.06630244502680493 Wasserstein D: -0.03276510505409507
Epoch 182 D Loss: -0.03510809945059823 G Loss: -0.3590518906973042 Wasserstein D: -0.03866434097290039
Epoch 183 D Loss: -0.0388967474023779 G Loss: -0.7606660173489497 Wasserstein D: -0.044717855386800696
Epoch 184 D Loss: -0.03966474533081055 G Loss: -1.2344332520778363 Wasserstein D: -0.04766378202638426
Epoch 185 D Loss: -0.036589249030693424 G Loss: -1.5114675418480292 Wasserstein D: -0.049871734805874056
Epoch 186 D Loss: -0.03352658398501523 G Loss: -2.4065933169184865 Wasserstein D: -0.05600050279310533
Epoch 187 D Loss: -0.05151104760336709 G Loss: -3.2116882084132907 Wasserstein D: -0.06262241710316051
Epoch 188 D Loss: -0.0498475995097127 G Loss: -3.2183462806514926 Wasserstein D: -0.06100302476149339
Epoch 189 D Loss: -0.0700100785368806 G Loss: -3.605651416978636 Wasser

100%|██████████| 4/4 [00:18<00:00,  4.54s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 200 D Loss: -0.050758425172392305 G Loss: -7.8745068436736 Wasserstein D: -0.05463692191597465
Epoch 201 D Loss: -0.0575090354972786 G Loss: -8.362173540608866 Wasserstein D: -0.06191073597727956
Epoch 202 D Loss: -0.06268760707828548 G Loss: -8.066839978411481 Wasserstein D: -0.06752548351154461
Epoch 203 D Loss: -0.06722394236317882 G Loss: -8.155716055756683 Wasserstein D: -0.07878446912432051
Epoch 204 D Loss: 0.003156715339713997 G Loss: -10.023569533874939 Wasserstein D: -0.04661848161604021
Epoch 205 D Loss: -0.026230935450200433 G Loss: -12.114214203574441 Wasserstein D: -0.049990140474759616
Epoch 206 D Loss: -0.05136468860652897 G Loss: -12.47672640527045 Wasserstein D: -0.058535282428448014
Epoch 207 D Loss: -0.06598308536556217 G Loss: -10.522477803530393 Wasserstein D: -0.07436250806688428
Epoch 208 D Loss: -0.02833141980471311 G Loss: -11.53078728122311 Wasserstein D: -0.06014251708984375
Epoch 209 D Loss: -0.054941704223205996 G Loss: -12.350418604337252 Wasserstei

100%|██████████| 4/4 [00:18<00:00,  4.51s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 220 D Loss: -0.0642902000800713 G Loss: -14.002759259897513 Wasserstein D: -0.08319441922061094
Epoch 221 D Loss: 0.265903446224186 G Loss: -15.283032904138098 Wasserstein D: -0.032992369645125384
Epoch 222 D Loss: -0.008178020690704559 G Loss: -17.877082331197244 Wasserstein D: -0.049953193931312825
Epoch 223 D Loss: -0.048621871254660866 G Loss: -17.993089222407843 Wasserstein D: -0.05473791802679742
Epoch 224 D Loss: -0.061313735855209245 G Loss: -17.654340170480154 Wasserstein D: -0.06573689067280376
Epoch 225 D Loss: -0.07610010934042764 G Loss: -15.46330221883067 Wasserstein D: -0.0801100030645624
Epoch 226 D Loss: -0.08558894204093026 G Loss: -14.865056851527074 Wasserstein D: -0.08940793417550467
Epoch 227 D Loss: -0.07388525742750901 G Loss: -14.95514271476052 Wasserstein D: -0.08939262203403286
Epoch 228 D Loss: 0.11960659160480633 G Loss: -15.001110877190436 Wasserstein D: -0.08139202144596126
Epoch 229 D Loss: -0.06609381162203275 G Loss: -18.91507150076486 Wasserstei

100%|██████████| 4/4 [00:18<00:00,  4.57s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 240 D Loss: -0.0983684379737694 G Loss: -17.784092429634576 Wasserstein D: -0.10507539602426383
Epoch 241 D Loss: 0.011095793930800645 G Loss: -21.594510938737777 Wasserstein D: -0.07102452791654147
Epoch 242 D Loss: -0.02033135607526019 G Loss: -21.938126050508938 Wasserstein D: -0.07783980469603638
Epoch 243 D Loss: -0.062492163864882676 G Loss: -22.269845402324115 Wasserstein D: -0.07701719057309878
Epoch 244 D Loss: -0.049916100668740436 G Loss: -22.408604028341653 Wasserstein D: -0.08558617438469733
Epoch 245 D Loss: -0.06856486180445531 G Loss: -22.288875979977053 Wasserstein D: -0.09148860477901005
Epoch 246 D Loss: -0.08835003592751244 G Loss: -22.075004257522263 Wasserstein D: -0.0990977920852341
Epoch 247 D Loss: -0.04236506081961252 G Loss: -21.98283352885213 Wasserstein D: -0.1042775374192458
Epoch 248 D Loss: -0.11328970635687555 G Loss: -21.425038597800516 Wasserstein D: -0.1181138978971468
Epoch 249 D Loss: -0.10864261147025582 G Loss: -21.038857973538914 Wasserste

100%|██████████| 4/4 [00:18<00:00,  4.54s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 260 D Loss: 0.038978409933877155 G Loss: -31.54672187858528 Wasserstein D: -0.06341419353351727
Epoch 261 D Loss: -0.05813131132325926 G Loss: -34.474427443284256 Wasserstein D: -0.08935194749098557
Epoch 262 D Loss: 0.6361581228829764 G Loss: -31.293971308461437 Wasserstein D: -0.07433220389839652
Epoch 263 D Loss: -0.11117238931722574 G Loss: -28.82467471302806 Wasserstein D: -0.1352543464073768
Epoch 264 D Loss: -0.1014718275803786 G Loss: -28.92845205827193 Wasserstein D: -0.11711480734231589
Epoch 265 D Loss: -0.08822866586538461 G Loss: -28.643989936455146 Wasserstein D: -0.1245706064717753
Epoch 266 D Loss: -0.11886247221406523 G Loss: -28.898797335324588 Wasserstein D: -0.13697574855564357
Epoch 267 D Loss: -0.14754570114029036 G Loss: -28.299487280678918 Wasserstein D: -0.15430485118519177
Epoch 268 D Loss: -0.1362174107478215 G Loss: -28.526448603276606 Wasserstein D: -0.14824739202752812
Epoch 269 D Loss: -0.14367566408810917 G Loss: -29.759570448548644 Wasserstein D: 

100%|██████████| 4/4 [00:18<00:00,  4.53s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 280 D Loss: -0.04953074955440068 G Loss: -35.60151154844911 Wasserstein D: -0.1504589160839161
Epoch 281 D Loss: -0.13516352726862982 G Loss: -37.52738859270003 Wasserstein D: -0.18539874203555234
Epoch 282 D Loss: -0.18283118401374016 G Loss: -37.904958258142 Wasserstein D: -0.19306593674879807
Epoch 283 D Loss: -0.20026665800934904 G Loss: -37.796471575757 Wasserstein D: -0.21507129802570477
Epoch 284 D Loss: -0.20646467408933838 G Loss: -38.808933791580735 Wasserstein D: -0.21735091309447388
Epoch 285 D Loss: -0.19782667893629807 G Loss: -39.671666098641346 Wasserstein D: -0.23318091972724542
Epoch 286 D Loss: -0.17724134538557146 G Loss: -41.92337700370308 Wasserstein D: -0.1922527393261036
Epoch 287 D Loss: -0.21667372430121148 G Loss: -42.66057522647031 Wasserstein D: -0.23629197874269287
Epoch 288 D Loss: -0.24937446967705146 G Loss: -41.735956098649886 Wasserstein D: -0.2543067932128906
Epoch 289 D Loss: 0.02413683004312582 G Loss: -42.984232549067144 Wasserstein D: -0.16

100%|██████████| 4/4 [00:18<00:00,  4.54s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 300 D Loss: -0.16938456288584464 G Loss: -49.97828484915353 Wasserstein D: -0.25332454868129917
Epoch 301 D Loss: -0.19488596082567336 G Loss: -59.57637928249119 Wasserstein D: -0.20577512087521854
Epoch 302 D Loss: -0.1801435163804701 G Loss: -58.0683725263689 Wasserstein D: -0.19072624686714654
Epoch 303 D Loss: -0.21899488755873034 G Loss: -55.37065004468798 Wasserstein D: -0.2720547922841319
Epoch 304 D Loss: -0.26251383427973396 G Loss: -51.154060763912604 Wasserstein D: -0.29181948575106537
Epoch 305 D Loss: -0.0944535582215636 G Loss: -50.78047281545359 Wasserstein D: -0.2708588180008468
Epoch 306 D Loss: -0.21638354054697745 G Loss: -51.832433020318305 Wasserstein D: -0.2670052001526306
Epoch 307 D Loss: -0.17858826697289526 G Loss: -60.34751243858071 Wasserstein D: -0.2014572036849869
Epoch 308 D Loss: -0.25230839869359156 G Loss: -56.39716739254398 Wasserstein D: -0.29047740589488635
Epoch 309 D Loss: 0.005696863561243444 G Loss: -51.435588463203054 Wasserstein D: -0.16

100%|██████████| 4/4 [00:18<00:00,  4.53s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 320 D Loss: 49.85796820367133 G Loss: -56.146821828988884 Wasserstein D: -0.18061124361478365
Epoch 321 D Loss: -0.2546737777603256 G Loss: -71.82153853836593 Wasserstein D: -0.2700476479697061
Epoch 322 D Loss: -0.32524554379336484 G Loss: -58.41647328196706 Wasserstein D: -0.3371375023901879
Epoch 323 D Loss: -0.28489887797749125 G Loss: -56.368867674073975 Wasserstein D: -0.3429134875744373
Epoch 324 D Loss: -0.06430313803932884 G Loss: -73.03310076840275 Wasserstein D: -0.23191251954832276
Epoch 325 D Loss: -0.23160334233637456 G Loss: -69.95351463264518 Wasserstein D: -0.2469093482811134
Epoch 326 D Loss: -0.24243580211292615 G Loss: -60.673370361328125 Wasserstein D: -0.30460085568728146
Epoch 327 D Loss: -0.16418219613028573 G Loss: -60.827424536218174 Wasserstein D: -0.2914694939459954
Epoch 328 D Loss: -0.232536529327606 G Loss: -59.4264397254357 Wasserstein D: -0.3572307666698536
Epoch 329 D Loss: -0.31811093950605057 G Loss: -59.15098281006713 Wasserstein D: -0.3577503

100%|██████████| 4/4 [00:18<00:00,  4.57s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 340 D Loss: 1.0638393588833042 G Loss: -68.76079746059604 Wasserstein D: -0.188448285723066
Epoch 341 D Loss: -0.14059972429608966 G Loss: -84.52130164299811 Wasserstein D: -0.2543792191085282
Epoch 342 D Loss: -0.14804534645347328 G Loss: -74.8486448701445 Wasserstein D: -0.3092968814022891
Epoch 343 D Loss: -0.24546387145569273 G Loss: -69.61268775779884 Wasserstein D: -0.36222866031673406
Epoch 344 D Loss: -0.24909775740616805 G Loss: -67.9438304767742 Wasserstein D: -0.3489769355400459
Epoch 345 D Loss: 0.6085742336886746 G Loss: -70.22031365241205 Wasserstein D: -0.26056217647098995
Epoch 346 D Loss: -0.30807743205890786 G Loss: -74.74856241933115 Wasserstein D: -0.3321420096017264
Epoch 347 D Loss: -0.30515438693386693 G Loss: -75.20105972823563 Wasserstein D: -0.3454304408360194
Epoch 348 D Loss: -0.34842756578138656 G Loss: -71.82839987161276 Wasserstein D: -0.38366299075680177
Epoch 349 D Loss: -0.34838888528463724 G Loss: -69.6537387020938 Wasserstein D: -0.396458392376

100%|██████████| 4/4 [00:18<00:00,  4.52s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 360 D Loss: -0.42057266768875656 G Loss: -79.35086934549825 Wasserstein D: -0.4356536865234375
Epoch 361 D Loss: -0.38054398223236724 G Loss: -78.92455014315519 Wasserstein D: -0.4169848248675153
Epoch 362 D Loss: -0.38404600770323427 G Loss: -79.11497150767933 Wasserstein D: -0.43886053978980005
Epoch 363 D Loss: -0.4328001329115221 G Loss: -78.46261863441734 Wasserstein D: -0.4687803575208971
Epoch 364 D Loss: 0.20312860128762839 G Loss: -91.392532241928 Wasserstein D: -0.1928676258433949
Epoch 365 D Loss: -0.23909482088955966 G Loss: -108.27182999190751 Wasserstein D: -0.2912818001700448
Epoch 366 D Loss: -0.21750558172906195 G Loss: -107.85188869663052 Wasserstein D: -0.2324931004664281
Epoch 367 D Loss: -0.3813788940856507 G Loss: -97.82782462593559 Wasserstein D: -0.4367014744898656
Epoch 368 D Loss: -0.40802170013214323 G Loss: -84.87214100444234 Wasserstein D: -0.4463877111048132
Epoch 369 D Loss: -0.27910403271655104 G Loss: -83.0768744728782 Wasserstein D: -0.4322714638

100%|██████████| 4/4 [00:18<00:00,  4.57s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 380 D Loss: -0.3990769553017783 G Loss: -94.26345617120916 Wasserstein D: -0.4215572864025623
Epoch 381 D Loss: -0.4337920609054032 G Loss: -91.47387818023041 Wasserstein D: -0.45675813901674495
Epoch 382 D Loss: -0.4983804342630026 G Loss: -90.18259995467179 Wasserstein D: -0.5129473492815778
Epoch 383 D Loss: -0.3042022598373306 G Loss: -89.82871374383673 Wasserstein D: -0.5139758236758358
Epoch 384 D Loss: -0.5337821586982353 G Loss: -90.24954362349077 Wasserstein D: -0.5458722414670291
Epoch 385 D Loss: 2.791188913625437 G Loss: -127.75915180553089 Wasserstein D: -0.16920001690204328
Epoch 386 D Loss: 0.47329978676109047 G Loss: -130.19839376169486 Wasserstein D: -0.2003907957277098
Epoch 387 D Loss: -0.14673726542012674 G Loss: -129.13623409671382 Wasserstein D: -0.2683507745916193
Epoch 388 D Loss: -0.26614523934317635 G Loss: -123.85174411160129 Wasserstein D: -0.3360420707222465
Epoch 389 D Loss: -0.3650242198597301 G Loss: -119.15073432122077 Wasserstein D: -0.4320815293

100%|██████████| 4/4 [00:18<00:00,  4.56s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 400 D Loss: -0.21808871022471182 G Loss: -137.60755018754438 Wasserstein D: -0.3529713237202251
Epoch 401 D Loss: -0.5124682979983883 G Loss: -128.6953130868765 Wasserstein D: -0.542417352849787
Epoch 402 D Loss: -0.5168659770405376 G Loss: -114.25045429576527 Wasserstein D: -0.557783913779092
Epoch 403 D Loss: -0.5751018390788899 G Loss: -112.94977009379781 Wasserstein D: -0.583795240708998
Epoch 404 D Loss: -0.5765340311543925 G Loss: -114.03782611126667 Wasserstein D: -0.5913745906803157
Epoch 405 D Loss: -0.3286699415086866 G Loss: -127.01954864288544 Wasserstein D: -0.47909791319520323
Epoch 406 D Loss: -0.5600785475510818 G Loss: -119.43831565163352 Wasserstein D: -0.5810081641990822
Epoch 407 D Loss: 0.14971971845293378 G Loss: -115.13709056294047 Wasserstein D: -0.39410197651469625
Epoch 408 D Loss: -0.509437507682747 G Loss: -120.6946178516308 Wasserstein D: -0.5563265927188046
Epoch 409 D Loss: -0.4120310003107244 G Loss: -117.30743472226017 Wasserstein D: -0.5990324753

KeyboardInterrupt: 

In [None]:
# Save the model
#torch.save(G.state_dict(), f"{git_root}/experiments/generating/models/G_800_WGAN_GP.pt")
#torch.save(D.state_dict(), f"{git_root}/experiments/generating/models/D_800_WGAN_GP.pt")

In [13]:
G.eval()

Generator(
  (model): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=True)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.1)
    (3): Linear(in_features=512, out_features=1024, bias=True)
    (4): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.1)
    (6): Linear(in_features=1024, out_features=2048, bias=True)
    (7): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.1)
    (9): Linear(in_features=2048, out_features=5045, bias=True)
  )
  (embedding): Embedding(105, 2)
)

In [14]:
G(torch.randn(1, latentSpaceSize, device=device), y_Val[0].long().unsqueeze(dim=0).to(device))

tensor([[-0.1270,  0.0696, -0.2601,  ..., -1.1326, -0.1175,  0.0132]],
       device='cuda:0')

In [15]:
X_Val[0]

tensor([0.0000, 0.0000, 0.0000,  ..., 3.3789, 0.0000, 0.0000])

In [16]:
# Classify the discriminator on the first 50 validation data points
with torch.no_grad():
    D.eval().cpu()
    y_pred = D(X_Val[:50], y_Val[:50].long())
    print(y_pred.mean())



tensor(22.0881)


In [17]:
# Generate 50 samples
with torch.no_grad():
    G.eval().cpu()
    fake_data = G(torch.randn(50, latentSpaceSize), y_Val[:50].long())
    print(fake_data)

# Evaluate the generated data on the discriminator
with torch.no_grad():
    D.eval().cpu()
    y_pred = D(fake_data, y_Val[:50].long())
    print(y_pred.mean())

tensor([[ 2.1170e-01, -7.4901e-01,  2.0443e-01,  ...,  2.1142e+00,
          5.0957e-01,  7.0062e-01],
        [ 8.7831e-03,  1.4698e-01, -9.8590e-02,  ..., -1.6326e+00,
         -1.4369e-02, -3.0847e-01],
        [ 3.7625e-01, -1.4402e-01, -2.8685e-01,  ...,  2.5300e+00,
         -4.7173e-01, -3.8767e-01],
        ...,
        [-9.4485e-02,  3.4533e-01,  3.4342e-01,  ..., -1.6533e+00,
          1.4203e-01, -2.1923e-01],
        [ 4.0288e-01,  9.9806e-02,  2.4194e-01,  ..., -2.2537e-01,
         -3.7465e-02,  1.4169e-03],
        [ 8.6841e-01, -2.4582e-01,  4.3270e-02,  ...,  1.8616e+00,
          1.0579e-01, -5.7536e-02]])
tensor(240.7790)


The discriminator classifies the generated output clearly as fake. Fake is positiv