In [1]:
import sys
import os
import pandas as pd
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
sys.path.append("../../")
sys.path.append("../")
from utils.evaluation import evaluate
from utils.metrics import Metrics
from models.NNGenerator import AdaptableDenseModel


In [2]:
import neptune
from neptune_pytorch import NeptuneLogger
from neptune.utils import stringify_unsupported
from dotenv import load_dotenv
load_dotenv()

True

In [3]:
import pickle
from git import Repo

# Get the git root directory
repo = Repo(".", search_parent_directories=True)
git_root = repo.git.rev_parse("--show-toplevel")

# Load data
X_Train_pd = pickle.load(open(f"{git_root}/data/splits/train/X_pandas.pck", "rb"))
y_Train_pd = pickle.load(open(f"{git_root}/data/splits/train/y_pandas.pck", "rb"))

X_Val_pd = pickle.load(open(f"{git_root}/data/splits/val/X_pandas.pck", "rb"))
y_Val_pd = pickle.load(open(f"{git_root}/data/splits/val/y_pandas.pck", "rb"))

In [4]:
X_Train = torch.tensor(X_Train_pd.values, dtype=torch.float32)
y_Train = torch.tensor(y_Train_pd.values, dtype=torch.float32)

X_Val = torch.tensor(X_Val_pd.values, dtype=torch.float32)
y_Val = torch.tensor(y_Val_pd.values, dtype=torch.float32)

train_dataset = torch.utils.data.TensorDataset(X_Train, y_Train)
val_dataset = torch.utils.data.TensorDataset(X_Val, y_Val)

In [5]:
def label_from_logits(y_hat: torch.Tensor, threshold = 0.5) -> torch.Tensor:
    with torch.no_grad():
        y_pred_tensor = (torch.sigmoid(y_hat) > threshold).float()
    return y_pred_tensor


def evaluate_from_dataframe(X: pd.DataFrame):
    X_tensor = torch.tensor(X.to_numpy(), dtype=torch.float32)
    
    #model: a pytorch model, which transforms X -> y in torch.Tensor format
    model.eval()
    model.cpu()
    y_pred_tensor = label_from_logits(model(X_tensor))
    
    return pd.DataFrame(y_pred_tensor.numpy())


In [6]:
from metrics.aats import compute_AAts
#from metrics.precision_recall import get_precision_recall

def calc_aat(real_data, fake_data):
    _, _, aat = compute_AAts(real_data, fake_data)
    return aat

# def calc_precision_recall(real_data, fake_data):
#     precision, recall = get_precision_recall(real_data, fake_data)
#     return precision, recall

In [7]:

#Adapted from https://forge.ibisc.univ-evry.fr/alacan/GANs-for-transcriptomics/-/blob/master/src/models/utils.py?ref_type=heads
def wasserstein_loss(y_true: torch.tensor, y_pred: torch.tensor):
    """
    Returns Wasserstein loss (product of real/fake labels and critic scores on real or fake data)
    ----
    Parameters:
        y_true (torch.tensor): true labels (either real or fake)
        y_pred (torch.tensor): critic scores on real or fake data
    Returns:
        (torch.tensor): mean product of real labels and critic scores
    """
    return torch.mean(y_true * y_pred)


def generator_loss(fake_score: torch.tensor):
    """
    Returns generator loss i.e the negative scores of the critic on fake data.
    ----
    Parameters:
        fake_score (torch.tensor): critic scores on fake data
    Returns:
        (torch.tensor): generator loss"""

    return wasserstein_loss(-torch.ones_like(fake_score), fake_score)


def discriminator_loss(real_score: torch.tensor, fake_score: torch.tensor):
    """
    Compute and return the wasserstein loss of critic scores on real and fake data i.e: wassertstein_loss = mean(-score_real) + mean(score_fake)
    ----
    Parameters:
        real_score (torch.tensor): critic scores on real data
        fake_score (torch.tensor): critic scores on fake data
    Returns:
        (torch.tensor): wasserstein loss
    """
    real_loss = wasserstein_loss(-torch.ones_like(real_score), real_score)
    fake_loss = wasserstein_loss(torch.ones_like(fake_score), fake_score)

    return real_loss, fake_loss


In [8]:
def training(D, G, d_optimizer, g_optimizer, train_loader, val_loader, epochs, latentSpaceSize, lambda_gp, iters_critic, device, neptune_logger=None, run = None, trial = None):
# Code adopted from https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
# and https://github.com/Zeleni9/pytorch-wgan/blob/master/models/wgan_gradient_penalty.py
    D = D.to(device)
    G = G.to(device)
    aat = 0

    print("Starting Training Loop...")
    # For each epoch
    for epoch in tqdm(range(epochs)):
        errDItem = 0
        errGItem = 0
        Wasserstein_D = 0
        G.train()
        D.train()

        # For each batch in the dataloader
        for X,y in train_loader:

            y = y.to(device).long()
            
            ############################
            # (1) Update D network: maximize D(x|c) - D(G(z|c)) + lambda_gp * gradient_penalty
            ###########################
            ## Train with all-real batch
            for p in D.parameters():
                p.requires_grad = True  

            for p in G.parameters():
                p.requires_grad = False  # to avoid computation

            d_optimizer.zero_grad()
            b_size = X.size(0)

            real_data = X.to(device)

            for crit_step in range(iters_critic):
                # Forward pass real batch through D
                
                errD_real = D(real_data, y)

                ## Train with all-fake batch
                # Generate batch of latent vectors
                z = torch.randn(b_size, latentSpaceSize, device=device)
                # Generate fake image batch with G
                fake_data = G(z, y)
                # Classify all fake batch with D
                errD_fake = D(fake_data.detach(), y)
                # Calculate D's loss on the all-fake batch
                real_loss, fake_loss = discriminator_loss(errD_real, errD_fake)

                # Loss for D in
                d_loss = real_loss + fake_loss
                Wasserstein_D += d_loss.item()

                # Fixed batch size
                BATCH_SIZE = real_data.size(0)

                # Sample alpha from uniform distribution
                alpha = torch.rand(
                    BATCH_SIZE,
                    1,
                    requires_grad=True,
                    device=real_data.device)

                # Interpolation between real data and fake data.
                interpolation = torch.mul(alpha, real_data) + \
                    torch.mul((1 - alpha), fake_data)

                # Get outputs from critic
                disc_outputs = D(interpolation, y)
                grad_outputs = torch.ones_like(
                    disc_outputs,
                    requires_grad=False,
                    device=real_data.device)

                # Retrieve gradients
                gradients = torch.autograd.grad(
                    outputs=disc_outputs,
                    inputs=interpolation,
                    grad_outputs=grad_outputs,
                    create_graph=True,
                    retain_graph=True)[0]

                # Compute gradient penalty
                gradients = gradients.view(BATCH_SIZE, -1)
                grad_norm = gradients.norm(2, dim=1)

                gradient_penalty = torch.mean((grad_norm - 1) ** 2) * lambda_gp
                
                d_loss += gradient_penalty
                # Calculate gradients for D in backward pass
                d_loss.backward()

                # Compute error of D as sum of the errors on the real and fake batches and the gradient penalty
                errDItem += d_loss.item()
                # Update D
                d_optimizer.step()

            Wasserstein_D /= iters_critic
            errDItem /= iters_critic
            ############################
            # (2) Update G network: maximize D(G(z|c))
            ###########################
            for p in D.parameters():
                p.requires_grad = False  # to avoid computation

            for p in G.parameters():
                p.requires_grad = True

            g_optimizer.zero_grad()

            z = torch.randn(b_size, latentSpaceSize, device=device)
            # Generate fake image batch with G
            fake_data = G(z, y)
            # Calculate G's loss based on this output
            errG = D(fake_data, y)
            
            # Calculate gradients for G
            g_loss = generator_loss(errG)
            g_loss.backward()
            errGItem += g_loss.item()
            # Update G
            g_optimizer.step()

        if epoch % 20 == 0:
            with torch.no_grad():
                G.eval()
                real_datas = np.array([])
                fake_datas = np.array([])
                for X_val, y_val in val_loader:
                    y_val = y_val.to(device).long()
                    X_val = X_val.to(device).to(torch.float16)
                    z = torch.randn(X_val.size(0), latentSpaceSize)
                    fake_data = G(z.to(device), y_val).to(torch.float16)
                    real_datas = torch.vstack((real_datas, X_val)) if real_datas.size else real_data
                    fake_datas = torch.vstack((fake_datas, fake_data)) if fake_datas.size else fake_data
                aat = calc_aat(real_datas.cpu(), fake_datas.cpu())
            #with torch.device(device):
            #    recall, precision = calc_precision_recall(real_datas, fake_datas)

        errDItem /= len(train_loader)
        errGItem /= len(train_loader)
        Wasserstein_D /= len(train_loader)
        if neptune_logger is not None:
            run[neptune_logger.base_namespace]['D_Loss'].append(errDItem)
            run[neptune_logger.base_namespace]['G_Loss'].append(errGItem)
            run[neptune_logger.base_namespace]['Wasserstein_D'].append(Wasserstein_D)
            run[neptune_logger.base_namespace]['AAT'].append(aat)
            #run[neptune_logger.base_namespace]['Precision'].append(precision)
            #run[neptune_logger.base_namespace]['Recall'].append(recall)

        print(f"Epoch {epoch} D Loss: {errDItem} G Loss: {errGItem} Wasserstein D: {Wasserstein_D}")
            

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_size, output_size, class_size, embedding_size):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, 1024),
            #nn.Linear(input_size + embedding_size * class_size, 512),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.2),
            nn.Linear(1024, 400),
            nn.BatchNorm1d(400),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.2),
            nn.Linear(400, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.2),
            nn.Linear(256, output_size),
        )
        self.embedding = nn.Embedding(class_size, embedding_size)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)

    def forward(self, x, cls):
        #cls = self.embedding(cls)
        #x = torch.cat((x, cls.flatten(start_dim=1)), 1)
        return self.model(x)
    
class Generator(nn.Module):
    def __init__(self, input_size, output_size, class_size, embedding_size):
        super().__init__()
        self.model = nn.Sequential(
            #nn.Linear(input_size + embedding_size * class_size, 512),
            nn.Linear(input_size , 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.1),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.1),
            nn.Linear(1024, 2048),
            nn.BatchNorm1d(2048),
            nn.LeakyReLU(0.1),
            nn.Linear(2048, output_size)
        )
        self.embedding = nn.Embedding(class_size, embedding_size)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)

    def forward(self, x, cls):
        #cls = self.embedding(cls)
        #x = torch.cat((x, cls.flatten(start_dim=1)), 1)
        return self.model(x)

In [10]:
input_dim_data = X_Train.shape[1]
output_dim_data = y_Train.shape[1]
latentSpaceSize = 128

# Define the generator

G = Generator(latentSpaceSize, input_dim_data, output_dim_data, 2)

# Define the discriminator
D = Discriminator(input_dim_data, 16, output_dim_data, 2)





In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
params = {
    "epochs": 800,
    "lambda_gp": 10,
    "latentSpaceSize": latentSpaceSize,
    "device": device,
    "batch_size": 512,
    "lr_g": 0.0001,
    "lr_d": 0.0008,
    "critic_iter": 6,
    "b1": 0.5,
    "b2": 0.999
}




train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params["batch_size"], shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1024, shuffle=True)

d_optimizer = torch.optim.Adam(D.parameters(), lr=params["lr_d"], betas=(params["b1"], params["b2"]))
g_optimizer = torch.optim.Adam(G.parameters(), lr=params["lr_g"], betas=(params["b1"], params["b2"]))



In [12]:
run = neptune.init_run(
    api_token=os.getenv("NEPTUNE_API_KEY"),
    project=os.getenv("NEPTUNE_PROJECT_NAME"),
    name="WGAN-GP - 800 - simple",
)

neptune_logger = NeptuneLogger(run=run, model=G)
                               
run[neptune_logger.base_namespace]["hyperparams"] = stringify_unsupported(params)
run[neptune_logger.base_namespace]["G_Structure"] = str(G)
run[neptune_logger.base_namespace]["D_Structure"] = str(D)



training(D, G, d_optimizer, g_optimizer, train_loader , val_loader, params["epochs"], latentSpaceSize, params["lambda_gp"], params["critic_iter"], device, neptune_logger=neptune_logger, run=run)

neptune_logger.log_model()
run.stop()



[neptune] [info   ] Neptune initialized. Open in the app: https://app.neptune.ai/JPL/rna-sequencing/e/RNAS-183
Starting Training Loop...


  0%|          | 0/800 [00:00<?, ?it/s]

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
100%|██████████| 4/4 [00:17<00:00,  4.42s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 0 D Loss: 0.035091672862384383 G Loss: -0.05762551410996414 Wasserstein D: 6.72552447685773e-05
Epoch 1 D Loss: 0.04260598517985928 G Loss: -0.44061562898275736 Wasserstein D: -0.0002594679312296642
Epoch 2 D Loss: 0.04062837096103206 G Loss: -0.3771143027327277 Wasserstein D: -0.00018668166625630432
Epoch 3 D Loss: 0.024141028851231438 G Loss: -0.2404989467321576 Wasserstein D: -0.00021485700903980349
Epoch 4 D Loss: 0.017893668101375205 G Loss: -0.31839795275167987 Wasserstein D: -0.0002252476338433242
Epoch 5 D Loss: 0.009113774584697325 G Loss: -0.496205041041741 Wasserstein D: -0.0003227011369971801
Epoch 6 D Loss: 0.00934959159122619 G Loss: -0.6267793524515378 Wasserstein D: -0.0008520398970363615
Epoch 7 D Loss: 0.006029940568200679 G Loss: -0.758686542927802 Wasserstein D: -0.002049036654661386
Epoch 8 D Loss: 0.0025165940816169663 G Loss: -0.8926777906351157 Wasserstein D: -0.004235582835052569
Epoch 9 D Loss: 0.0011501017664687155 G Loss: -0.7369100380610752 Wasserstei

100%|██████████| 4/4 [00:18<00:00,  4.56s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 20 D Loss: -0.2785913818757723 G Loss: 9.752666086583705 Wasserstein D: -0.31950820943194114
Epoch 21 D Loss: -0.36850917222694857 G Loss: 7.561513130481426 Wasserstein D: -0.46777252396729957


KeyboardInterrupt: 

In [None]:
# Save the model
#torch.save(G.state_dict(), f"{git_root}/experiments/generating/models/G_800_WGAN_GP.pt")
#torch.save(D.state_dict(), f"{git_root}/experiments/generating/models/D_800_WGAN_GP.pt")

In [13]:
G.eval()

Generator(
  (model): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=True)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.1)
    (3): Linear(in_features=512, out_features=1024, bias=True)
    (4): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.1)
    (6): Linear(in_features=1024, out_features=2048, bias=True)
    (7): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.1)
    (9): Linear(in_features=2048, out_features=5045, bias=True)
  )
  (embedding): Embedding(105, 2)
)

In [None]:
G(torch.randn(1, latentSpaceSize, device=device), y_Val[0].long().unsqueeze(dim=0).to(device))

tensor([[ 0.3123,  0.3945, -0.2576,  ...,  3.7707,  0.0848, -0.4399]],
       device='cuda:0')

In [17]:
X_Val[0]

tensor([0.0000, 0.0000, 0.0000,  ..., 3.3789, 0.0000, 0.0000])

In [None]:
# Classify the discriminator on the first 50 validation data points
with torch.no_grad():
    D.eval().cpu()
    y_pred = D(X_Val[:50], y_Val[:50].long())
    print(y_pred.mean())



tensor(116918.2500)


In [22]:
# Generate 50 samples
with torch.no_grad():
    G.eval().cpu()
    fake_data = G(torch.randn(50, latentSpaceSize), y_Val[:50].long())
    print(fake_data)

# Evaluate the generated data on the discriminator
with torch.no_grad():
    D.eval().cpu()
    y_pred = D(fake_data, y_Val[:50].long())
    print(y_pred.mean())

tensor([[ 1082.1344,   354.4251,   904.1985,  ...,  -389.1379,   386.4493,
         -1525.4122],
        [  319.9232,  1458.5280,    78.4604,  ...,  -986.5058,  1140.0000,
           306.9229],
        [ 1623.1976,  1710.8981,  1674.1731,  ...,   239.8209,  2076.3311,
           552.0113],
        ...,
        [ 1203.5183,   879.1621,   824.4224,  ..., -1108.3838,   656.3746,
         -1935.3901],
        [  -81.6840,   884.2855,  -201.2603,  ...,  -539.1361,   656.3450,
           598.7426],
        [  528.3951,  1721.8861,   215.2771,  ..., -1230.1609,  1348.0797,
            78.7755]])
tensor(96606.2500)
