In [1]:
import os
import sys

import pandas as pd

import torch
import torch.nn as nn

sys.path.append('../src')
from model import *

  from .autonotebook import tqdm as notebook_tqdm


### Load data

In [2]:
data_dir = '../data/'

In [3]:
# load GDSC data

gdsc_data_dir = data_dir + 'GDSC/GDSC_gex.csv'
gdsc_data_df = pd.read_csv(gdsc_data_dir, index_col=0)

gdsc_info_dir = data_dir + 'GDSC/GDSC_info.csv'
gdsc_info_df = pd.read_csv(gdsc_info_dir)

gdsc_data_df.shape, gdsc_info_df.shape

((673, 978), (673, 5))

In [4]:
# load TCGA data

tcga_unlabeled_data_dir = data_dir + 'TCGA/TCGA_unlabeled_gex.csv'
tcga_unlabeled_data_df = pd.read_csv(tcga_unlabeled_data_dir, index_col=0)

tcga_unlabeled_info_dir = data_dir + 'TCGA/TCGA_unlabeled_info.csv'
tcga_unlabeled_info_df = pd.read_csv(tcga_unlabeled_info_dir)

tcga_labeled_data_dir = data_dir + 'TCGA/TCGA_labeled_gex.csv'
tcga_labeled_data_df = pd.read_csv(tcga_labeled_data_dir, index_col=0)

tcga_labeled_info_dir = data_dir + 'TCGA/TCGA_labeled_info.csv'
tcga_labeled_info_df = pd.read_csv(tcga_labeled_info_dir)

tcga_unlabeled_data_df.shape, tcga_unlabeled_info_df.shape, tcga_labeled_data_df.shape, tcga_labeled_info_df.shape 

((8042, 978), (8042, 5), (358, 978), (358, 5))

### Load alignment model

In [5]:
dim_latent = 128
device = 'cuda:1'

In [6]:
gdsc_dataset = AlignerDataset(gdsc_data_df, 'gdsc', gdsc_info_df['tissue_label'])
tcga_unlabeled_dataset = AlignerDataset(tcga_unlabeled_data_df, 'tcga', tcga_unlabeled_info_df['tissue_label'])
tcga_labeled_dataset = AlignerDataset(tcga_labeled_data_df, 'tcga', tcga_labeled_info_df['tissue_label'])

In [7]:
source = 'GDSC'
target = 'TCGA'

best_model_name = f'../src/ckpts/THERAPI_aligner_{source}_{target}.pt'

best_gdsc_AE = GDSC_AE(n_genes=gdsc_dataset.n_genes, n_classes=len(gdsc_info_df['tissue_label'].unique()), n_latent=dim_latent)
best_gdsc_AE.load_state_dict(torch.load(best_model_name, map_location=device)['gdsc_AE'])
best_gdsc_AE.to(device)

best_tcga_weightencoder = TCGA_weightencoder(n_genes=tcga_unlabeled_dataset.n_genes, n_latent=dim_latent, n_celines=673)
best_tcga_weightencoder.load_state_dict(torch.load(best_model_name, map_location=device)['tcga_weightencoder'])
best_tcga_weightencoder.to(device)

TCGA_weightencoder(
  (Q): Sequential(
    (0): Linear(in_features=978, out_features=128, bias=True)
    (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=128, bias=True)
  )
  (K): Linear(in_features=128, out_features=128, bias=False)
  (decoder): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Linear(in_features=256, out_features=978, bias=True)
  )
)

### Alignment result save

In [8]:
alignment_result_folder = '../data/Alignment_results/'
os.makedirs(alignment_result_folder, exist_ok=True)

In [9]:
best_gdsc_AE.eval()
best_tcga_weightencoder.eval()

TCGA_weightencoder(
  (Q): Sequential(
    (0): Linear(in_features=978, out_features=128, bias=True)
    (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=128, bias=True)
  )
  (K): Linear(in_features=128, out_features=128, bias=False)
  (decoder): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Linear(in_features=256, out_features=978, bias=True)
  )
)

In [10]:
gdsc_dataset_input = gdsc_dataset.data.to(device)
gdsc_latent, _ = best_gdsc_AE(gdsc_dataset_input)
gdsc_latent_df = pd.DataFrame(gdsc_latent.cpu().detach().numpy(), index=gdsc_data_df.index)
gdsc_latent_df.shape

(673, 128)

In [11]:
gdsc_latent_df.to_csv(alignment_result_folder + 'GDSC_alignment_latent.csv')

In [12]:
tcga_unlabeled_weights, tcga_unlabeled_latent, tcga_unlabeled_wgex, _ = best_tcga_weightencoder(tcga_unlabeled_dataset.data.to(device), gdsc_latent, gdsc_dataset_input)

tcga_unlabeled_latent_df = pd.DataFrame(tcga_unlabeled_latent.cpu().detach().numpy(), index=tcga_unlabeled_data_df.index)

tcga_unlabeled_weights_df = pd.DataFrame(tcga_unlabeled_weights.cpu().detach().numpy(), index=tcga_unlabeled_data_df.index)
tcga_unlabeled_weights_df.columns = gdsc_data_df.index

tcga_unlabeled_wgex_df = pd.DataFrame(tcga_unlabeled_wgex.cpu().detach().numpy(), index=tcga_unlabeled_data_df.index)
tcga_unlabeled_wgex_df.columns = gdsc_data_df.columns

tcga_unlabeled_latent_df.shape, tcga_unlabeled_weights_df.shape, tcga_unlabeled_wgex_df.shape

((8042, 128), (8042, 673), (8042, 978))

In [13]:
tcga_unlabeled_latent_df.to_csv(alignment_result_folder + 'TCGA_unlabeled_alignment_latent.csv')
tcga_unlabeled_weights_df.to_csv(alignment_result_folder + 'TCGA_unlabeled_alignment_weights.csv')
tcga_unlabeled_wgex_df.to_csv(alignment_result_folder + 'TCGA_unlabeled_alignment_wgex.csv')

In [14]:
tcga_labeled_weights, tcga_labeled_latent, tcga_labeled_wgex, _ = best_tcga_weightencoder(tcga_labeled_dataset.data.to(device), gdsc_latent, gdsc_dataset_input)

tcga_labeled_latent_df = pd.DataFrame(tcga_labeled_latent.cpu().detach().numpy(), index=tcga_labeled_data_df.index)

tcga_labeled_weights_df = pd.DataFrame(tcga_labeled_weights.cpu().detach().numpy(), index=tcga_labeled_data_df.index)
tcga_labeled_weights_df.columns = gdsc_data_df.index

tcga_labeled_wgex_df = pd.DataFrame(tcga_labeled_wgex.cpu().detach().numpy(), index=tcga_labeled_data_df.index)
tcga_labeled_wgex_df.columns = gdsc_data_df.columns

tcga_labeled_latent_df.shape, tcga_labeled_weights_df.shape, tcga_labeled_wgex_df.shape

((358, 128), (358, 673), (358, 978))

In [15]:
tcga_labeled_latent_df.to_csv(alignment_result_folder + 'TCGA_labeled_alignment_latent.csv')
tcga_labeled_weights_df.to_csv(alignment_result_folder + 'TCGA_labeled_alignment_weights.csv')
tcga_labeled_wgex_df.to_csv(alignment_result_folder + 'TCGA_labeled_alignment_wgex.csv')