In [1]:
import sys
import os
import pandas as pd
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
sys.path.append("../../")
sys.path.append("../")
from utils.evaluation import evaluate
from utils.metrics import Metrics
from models.NNGenerator import AdaptableDenseModel


In [2]:
import neptune
from neptune_pytorch import NeptuneLogger
from neptune.utils import stringify_unsupported
from dotenv import load_dotenv
load_dotenv()

True

In [3]:
import pickle
from git import Repo

# Get the git root directory
repo = Repo(".", search_parent_directories=True)
git_root = repo.git.rev_parse("--show-toplevel")

# Load data
X_Train_pd = pickle.load(open(f"{git_root}/data/splits/train/X_pandas.pck", "rb"))
y_Train_pd = pickle.load(open(f"{git_root}/data/splits/train/y_pandas.pck", "rb"))

X_Val_pd = pickle.load(open(f"{git_root}/data/splits/val/X_pandas.pck", "rb"))
y_Val_pd = pickle.load(open(f"{git_root}/data/splits/val/y_pandas.pck", "rb"))

In [4]:
X_Train = torch.tensor(X_Train_pd.values, dtype=torch.float32)
y_Train = torch.tensor(y_Train_pd.values, dtype=torch.float32)

X_Val = torch.tensor(X_Val_pd.values, dtype=torch.float32)
y_Val = torch.tensor(y_Val_pd.values, dtype=torch.float32)

train_dataset = torch.utils.data.TensorDataset(X_Train, y_Train)
val_dataset = torch.utils.data.TensorDataset(X_Val, y_Val)

In [5]:
def label_from_logits(y_hat: torch.Tensor, threshold = 0.5) -> torch.Tensor:
    with torch.no_grad():
        y_pred_tensor = (torch.sigmoid(y_hat) > threshold).float()
    return y_pred_tensor


def evaluate_from_dataframe(X: pd.DataFrame):
    X_tensor = torch.tensor(X.to_numpy(), dtype=torch.float32)
    
    #model: a pytorch model, which transforms X -> y in torch.Tensor format
    model.eval()
    model.cpu()
    y_pred_tensor = label_from_logits(model(X_tensor))
    
    return pd.DataFrame(y_pred_tensor.numpy())


In [6]:
from metrics.aats import compute_AAts
#from metrics.precision_recall import get_precision_recall

def calc_aat(real_data, fake_data):
    _, _, aat = compute_AAts(real_data, fake_data)
    return aat

# def calc_precision_recall(real_data, fake_data):
#     precision, recall = get_precision_recall(real_data, fake_data)
#     return precision, recall

In [None]:

#Adapted from https://forge.ibisc.univ-evry.fr/alacan/GANs-for-transcriptomics/-/blob/master/src/models/utils.py?ref_type=heads
def wasserstein_loss(y_true: torch.tensor, y_pred: torch.tensor):
    """
    Returns Wasserstein loss (product of real/fake labels and critic scores on real or fake data)
    ----
    Parameters:
        y_true (torch.tensor): true labels (either real or fake)
        y_pred (torch.tensor): critic scores on real or fake data
    Returns:
        (torch.tensor): mean product of real labels and critic scores
    """
    return torch.mean(y_true * y_pred)


def generator_loss(fake_score: torch.tensor):
    """
    Returns generator loss i.e the negative scores of the critic on fake data.
    ----
    Parameters:
        fake_score (torch.tensor): critic scores on fake data
    Returns:
        (torch.tensor): generator loss"""

    return wasserstein_loss(-torch.ones_like(fake_score), fake_score)


def discriminator_loss(real_score: torch.tensor, fake_score: torch.tensor):
    """
    Compute and return the wasserstein loss of critic scores on real and fake data i.e: wassertstein_loss = mean(-score_real) + mean(score_fake)
    ----
    Parameters:
        real_score (torch.tensor): critic scores on real data
        fake_score (torch.tensor): critic scores on fake data
    Returns:
        (torch.tensor): wasserstein loss
    """
    real_loss = wasserstein_loss(-torch.ones_like(real_score), real_score)
    fake_loss = wasserstein_loss(torch.ones_like(fake_score), fake_score)

    return real_loss, fake_loss


In [7]:
def training(D, G, d_optimizer, g_optimizer, train_loader, val_loader, epochs, latentSpaceSize, lambda_gp, iters_critic, device, neptune_logger=None, run = None, trial = None):
# Code adopted from https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
# and https://github.com/Zeleni9/pytorch-wgan/blob/master/models/wgan_gradient_penalty.py
    one = torch.tensor(1, dtype=torch.float).to(device)
    mone = (one * -1).to(device)
    D = D.to(device)
    G = G.to(device)
    aat = 0

    print("Starting Training Loop...")
    # For each epoch
    for epoch in tqdm(range(epochs)):
        errDItem = 0
        errGItem = 0
        Wasserstein_D = 0
        G.train()
        D.train()

        # For each batch in the dataloader
        for X,y in train_loader:

            y = y.to(device).long()
            
            ############################
            # (1) Update D network: maximize D(x|c) - D(G(z|c)) + lambda_gp * gradient_penalty
            ###########################
            ## Train with all-real batch
            for p in D.parameters():
                p.requires_grad = True  

            for p in G.parameters():
                p.requires_grad = False  # to avoid computation

            d_optimizer.zero_grad()
            b_size = X.size(0)

            real_data = X.to(device)

            for crit_step in range(iters_critic):
                # Forward pass real batch through D
                
                errD_real = D(real_data, y)
                # Calculate loss on all-real batch
                errD_real = errD_real.mean()
                # Calculate gradients for D in backward pass
                errD_real.backward(mone)

                ## Train with all-fake batch
                # Generate batch of latent vectors
                z = torch.randn(b_size, latentSpaceSize)
                # Generate fake image batch with G
                fake_data = G(z.to(device), y)
                # Classify all fake batch with D
                #print(fake.shape)
                errD_fake = D(fake_data.detach(), y)
                # Calculate D's loss on the all-fake batch
                errD_fake = errD_fake.mean()
                # Calculate the gradients for this batch, accumulated (summed) with previous gradients
                errD_fake.backward(one)
                
                # Gradient penalty
                alpha = torch.rand(real_data.size(0), 1, device=device, requires_grad=True)
                alpha = alpha.expand(real_data.size())
                interpolates = alpha * real_data + ((1 - alpha) * fake_data.detach())
                interpolates.requires_grad_(True)

                d_interpolates = D(interpolates, y)
                gradients = torch.autograd.grad(
                    outputs=d_interpolates,
                    inputs=interpolates,
                    grad_outputs=torch.ones(d_interpolates.size(), device=device),
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True,
                )[0]

                gradients = gradients.view(gradients.size(0), -1)
                gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambda_gp
                gradient_penalty.backward()

                # Compute error of D as sum of the errors on the real and fake batches and the gradient penalty
                errD = errD_real - errD_fake + gradient_penalty
                errDItem += errD.item()
                Wasserstein_D += errD_real.item() - errD_fake.item()
                # Update D
                d_optimizer.step()

            Wasserstein_D /= iters_critic
            errDItem /= iters_critic
            ############################
            # (2) Update G network: maximize D(G(z|c))
            ###########################
            for p in D.parameters():
                p.requires_grad = False  # to avoid computation

            for p in G.parameters():
                p.requires_grad = True

            g_optimizer.zero_grad()

            z = torch.randn(b_size, latentSpaceSize)
            # Generate fake image batch with G
            fake_data = G(z.to(device), y)
            # Calculate G's loss based on this output
            errG = D(fake_data, y)
            errG = errG.mean()
            # Calculate gradients for G
            errG.backward(mone)
            errGItem += errG.item()
            # Update G
            g_optimizer.step()

        with torch.no_grad():
            G.eval()
            real_datas = np.array([])
            fake_datas = np.array([])
            for X_val, y_val in val_loader:
                y_val = y_val.to(device).long()
                X_val = X_val.to(device).to(torch.float16)
                z = torch.randn(X_val.size(0), latentSpaceSize)
                fake_data = G(z.to(device), y_val).to(torch.float16)
                real_datas = torch.vstack((real_datas, X_val)) if real_datas.size else real_data
                fake_datas = torch.vstack((fake_datas, fake_data)) if fake_datas.size else fake_data
            aat = calc_aat(real_datas.cpu(), fake_datas.cpu())
            #with torch.device(device):
            #    recall, precision = calc_precision_recall(real_datas, fake_datas)

        errDItem /= len(train_loader)
        errGItem /= len(train_loader)
        Wasserstein_D /= len(train_loader)
        if neptune_logger is not None:
            run[neptune_logger.base_namespace]['D_Loss'].append(errDItem)
            run[neptune_logger.base_namespace]['G_Loss'].append(errGItem)
            run[neptune_logger.base_namespace]['Wasserstein_D'].append(Wasserstein_D)
            run[neptune_logger.base_namespace]['AAT'].append(aat)
            #run[neptune_logger.base_namespace]['Precision'].append(precision)
            #run[neptune_logger.base_namespace]['Recall'].append(recall)

        print(f"Epoch {epoch} D Loss: {errDItem} G Loss: {errGItem} Wasserstein D: {Wasserstein_D}")
            

In [8]:
class Discriminator(nn.Module):
    def __init__(self, input_size, output_size, class_size, embedding_size):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size + embedding_size * class_size, 3000),
            nn.ReLU(),
            nn.BatchNorm1d(3000),
            nn.Dropout(0.4),
            nn.Linear(3000, 1000),
            nn.ReLU(),
            nn.BatchNorm1d(1000),
            nn.Dropout(0.4),
            nn.Linear(1000, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.4),
            nn.Linear(128, output_size),
            nn.Sigmoid()
        )
        self.embedding = nn.Embedding(class_size, embedding_size)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)

    def forward(self, x, cls):
        cls = self.embedding(cls)
        x = torch.cat((x, cls.flatten(start_dim=1)), 1)
        return self.model(x)
    
class Generator(nn.Module):
    def __init__(self, input_size, output_size, class_size, embedding_size):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size + embedding_size * class_size, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Linear(2048, 4096),
            nn.ReLU(),
            nn.Linear(4096, output_size)
        )
        self.embedding = nn.Embedding(class_size, embedding_size)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)

    def forward(self, x, cls):
        cls = self.embedding(cls)
        x = torch.cat((x, cls.flatten(start_dim=1)), 1)
        return self.model(x)

In [9]:
input_dim_data = X_Train.shape[1]
output_dim_data = y_Train.shape[1]
latentSpaceSize = 256

# Define the generator

G = Generator(latentSpaceSize, input_dim_data, output_dim_data, 2)

# Define the discriminator
D = Discriminator(input_dim_data, 1, output_dim_data, 2)





In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
params = {
    "epochs": 20,
    "lambda_gp": 10,
    "latentSpaceSize": latentSpaceSize,
    "device": device,
    "batch_size": 64,
    "lr": 0.0001,
    "critic_iter": 5,
    "b1": 0.5,
    "b2": 0.999
}




train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params["batch_size"], shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1024, shuffle=True)

d_optimizer = torch.optim.Adam(D.parameters(), lr=params["lr"], betas=(params["b1"], params["b2"]))
g_optimizer = torch.optim.Adam(G.parameters(), lr=params["lr"], betas=(params["b1"], params["b2"]))



In [11]:
run = neptune.init_run(
    api_token=os.getenv("NEPTUNE_API_KEY"),
    project=os.getenv("NEPTUNE_PROJECT_NAME"),
    name="WGAN-GP",
)

neptune_logger = NeptuneLogger(run=run, model=G)
                               
run[neptune_logger.base_namespace]["hyperparams"] = stringify_unsupported(params)
run[neptune_logger.base_namespace]["G_Structure"] = str(G)
run[neptune_logger.base_namespace]["D_Structure"] = str(D)



training(D, G, d_optimizer, g_optimizer, train_loader , val_loader, params["epochs"], latentSpaceSize, params["lambda_gp"], params["critic_iter"], device, neptune_logger=neptune_logger, run=run)

neptune_logger.log_model()
run.stop()



[neptune] [info   ] Neptune initialized. Open in the app: https://app.neptune.ai/JPL/rna-sequencing/e/RNAS-171
Starting Training Loop...


  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 4/4 [00:17<00:00,  4.31s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 0 D Loss: 0.016260985444895043 G Loss: 0.49264545841868984 Wasserstein D: 7.4185301643845385e-06


100%|██████████| 4/4 [00:17<00:00,  4.44s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 1 D Loss: 0.009524042271045054 G Loss: 0.4778437015521328 Wasserstein D: 0.0003423478946925875


100%|██████████| 4/4 [00:17<00:00,  4.26s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 2 D Loss: 0.009743568447708408 G Loss: 0.4442377671351253 Wasserstein D: 0.00011703987946582822


100%|██████████| 4/4 [00:17<00:00,  4.41s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 3 D Loss: 0.011068590348177082 G Loss: 0.3294036224758698 Wasserstein D: 0.00039415535688908576


100%|██████████| 4/4 [00:17<00:00,  4.48s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 4 D Loss: 0.01122750897506513 G Loss: 0.3036306252561046 Wasserstein D: 0.0004635800357487995


100%|██████████| 4/4 [00:17<00:00,  4.43s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 5 D Loss: 0.01089974301769699 G Loss: 0.28204929740490564 Wasserstein D: 0.0001774547632853972


100%|██████████| 4/4 [00:17<00:00,  4.48s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 6 D Loss: 0.011323772300848904 G Loss: 0.4786574091177881 Wasserstein D: 0.00043236208285085366


100%|██████████| 4/4 [00:17<00:00,  4.33s/it]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 7 D Loss: 0.011311288658963713 G Loss: 0.26760345080983317 Wasserstein D: 0.00036800679893960424


 50%|█████     | 2/4 [00:17<00:17,  8.60s/it]


KeyboardInterrupt: 

In [12]:
G.eval()

Generator(
  (model): Sequential(
    (0): Linear(in_features=653, out_features=256, bias=True)
    (1): ReLU()
    (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Linear(in_features=256, out_features=1024, bias=True)
    (4): ReLU()
    (5): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Linear(in_features=1024, out_features=2048, bias=True)
    (7): ReLU()
    (8): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): Linear(in_features=2048, out_features=4096, bias=True)
    (10): ReLU()
    (11): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): Linear(in_features=4096, out_features=5045, bias=True)
  )
  (embedding): Embedding(105, 5)
)

In [21]:
G(torch.randn(1, latentSpaceSize, device=device), y_Val[0].long().unsqueeze(dim=0).to(device))

tensor([[-26.4828,  17.0079,   0.3647,  ...,  -7.2669,  -2.4989,  13.0806]],
       device='cuda:0', grad_fn=<AddmmBackward0>)

In [22]:
X_Val[0]

tensor([0.0000, 0.0000, 0.0000,  ..., 3.3789, 0.0000, 0.0000])