This notebook performs data augmentation using only patient data, with unconditional Attention GANs.

Reference: https://forge.ibisc.univ-evry.fr/alacan/GANs-for-transcriptomics

In [1]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
import sys

In [2]:
sys.path.append("../src/baselines/")
sys.path.append("../src/metrics/")

In [3]:
from attgan import WGAN_GP_SA

In [4]:
from configs import CONFIG_WGAN as CONFIG

In [5]:
CONFIG["device"] = torch.device("cuda:0")

In [6]:
sample_id = 2

In [7]:
# Patient dataset
tcga_train_df = pd.read_csv(f"../data/diffusion_pretraining/tcga_diffusion_train_sample{sample_id}.csv", index_col=0)
tcga_test_df = pd.read_csv(f"../data/diffusion_pretraining/tcga_diffusion_test_sample{sample_id}.csv", index_col=0)
tcga_train_dataset = TensorDataset(torch.tensor(tcga_train_df.values), torch.tensor(tcga_train_df.values))
tcga_test_dataset = TensorDataset(torch.tensor(tcga_test_df.values), torch.tensor(tcga_test_df.values))
tcga_train_dataloader = DataLoader(tcga_train_dataset, batch_size=256, shuffle=True)
tcga_test_dataloader = DataLoader(tcga_test_dataset, batch_size=256, shuffle=False)

In [8]:
goi_dict_ppi = {}
for i, x in enumerate(tcga_train_df.columns):
    goi_dict_ppi[i] = list(range(0, (tcga_train_df.shape[1])))

In [9]:
CONFIG["goi_dict_ppi"] = goi_dict_ppi
CONFIG["genes_of_interest_ppi"] = list(range(0, tcga_train_df.shape[1]))
CONFIG["goi_dict_corr"] = goi_dict_ppi
CONFIG["ppi_threshold"] = -1
CONFIG["corr_threshold"] = -1
CONFIG["is_gamma_fixed"] = False
CONFIG["pretrain_without_attention"] = True
CONFIG["attention_disc"] = True
CONFIG["attention_threshold"] = -1

In [10]:
attgan_model = WGAN_GP_SA(CONFIG)

1-head attention module
Pretraining model without attention while gradient penalty is above -1.
1-head attention module
Pretraining model without attention while gradient penalty is above -1.


In [11]:
attgan_model.train(
    TrainDataLoader=tcga_train_dataloader,
    ValDataLoader=tcga_test_dataloader,
    z_dim=CONFIG['latent_dim'], 
                epochs=CONFIG['epochs'], 
                categorical=None,
                # iters_critic=CONFIG['iters_critic'], 
                # lambda_penalty=CONFIG['lambda_penalty'], 
                step = CONFIG['step'],
                verbose=True, 
                checkpoint_dir=CONFIG['checkpoint_dir'], 
                log_dir=CONFIG['log_dir'], 
                fig_dir = CONFIG['fig_dir'],
                prob_success=CONFIG['prob_success'], 
                norm_scale=CONFIG['norm_scale'],
                optimizer = CONFIG['optimizer'],
                lr_g = CONFIG['lr_g'],
                lr_d = CONFIG['lr_d'],
               nb_principal_components = CONFIG['nb_principal_components'],
               config=CONFIG,
               hyperparameters_search=False
               )

Directory '../src/baselines/gan/logs/20240417-170511' created
Directory '../src/baselines/gan/checkpoints/20240417-170511' created
Directory '../src/baselines/gan/figures/20240417-170511' created
Time of training: 86.2435 sec = 1.4374 minute(s) = 0.024 hour(s)
--------------------
Discriminator saved at ../src/baselines/gan/checkpoints/20240417-170511/_disc.pt and generator saved at ../src/baselines/gan/checkpoints/20240417-170511/_gen.pt.


In [12]:
x_real, x_gen = attgan_model.real_fake_data(tcga_train_dataloader, z_dim=CONFIG['latent_dim'])
x_gen.shape, x_real.shape

((476, 7776), (476, 7776))

In [13]:
pd.DataFrame(x_gen, columns = tcga_train_df.columns, index=tcga_train_df.index).to_csv(f"/data/ajayago/druid/intermediate/cs6220/baselines/augmented_attgan_tcga_sample{sample_id}.csv")

Inference on cell lines - not relevant

In [12]:
# cl_train_df = pd.read_csv("../data/diffusion_pretraining/cl_diffusion_train_sample0.csv", index_col=0)
# cl_train_df.shape

(1569, 7776)

In [13]:
# cl_test_df = pd.read_csv("../data/diffusion_pretraining/cl_diffusion_test_sample0.csv", index_col=0)
# cl_test_df.shape

(175, 7776)

In [14]:
# cl_train_dataset = TensorDataset(torch.tensor(cl_train_df.values), torch.tensor(cl_train_df.values))
# cl_train_dataloader = DataLoader(cl_train_dataset, batch_size=256, shuffle=True)

In [15]:
# cl_test_dataset = TensorDataset(torch.tensor(cl_test_df.values), torch.tensor(cl_test_df.values))
# cl_test_dataloader = DataLoader(cl_test_dataset, batch_size=256, shuffle=True)

In [17]:
# x_real_cl, x_gen_cl = attgan_model.real_fake_data(cl_train_dataloader, z_dim=CONFIG['latent_dim'])
# x_gen_cl.shape, x_real_cl.shape

((1569, 7776), (1569, 7776))

In [18]:
# x_real_cl_test, x_gen_cl_test = attgan_model.real_fake_data(cl_test_dataloader, z_dim=CONFIG['latent_dim'])
# x_gen_cl_test.shape, x_real_cl_test.shape

((175, 7776), (175, 7776))

In [19]:
# np.concatenate((x_gen_cl, x_gen_cl_test)).shape

(1744, 7776)

In [20]:
# new_idx = list(cl_train_df.index) + list(cl_test_df.index)
# len(new_idx)

1744

In [21]:
# pd.DataFrame(np.concatenate((x_gen_cl, x_gen_cl_test)), columns = cl_train_df.columns, index=new_idx).to_csv("/data/ajayago/druid/intermediate/cs6220/baselines/augmented_attgan.csv")