In [1]:
import os
import sys

import pandas as pd

import torch
import torch.nn as nn

sys.path.append('../src')
from model import *

  from .autonotebook import tqdm as notebook_tqdm


### Load data

In [2]:
data_dir = '../data/'

In [3]:
# load GDSC data

gdsc_data_dir = data_dir + 'GDSC/GDSC_gex.csv'
gdsc_data_df = pd.read_csv(gdsc_data_dir, index_col=0)

gdsc_info_dir = data_dir + 'GDSC/GDSC_info.csv'
gdsc_info_df = pd.read_csv(gdsc_info_dir)

gdsc_data_df.shape, gdsc_info_df.shape

((673, 978), (673, 5))

In [4]:
# load external data

external_unlabeled_data_dir = data_dir + 'External/External_unlabeled_gex.csv'
external_unlabeled_data_df = pd.read_csv(external_unlabeled_data_dir, index_col=0)

external_unlabeled_info_dir = data_dir + 'External/External_unlabeled_info.csv'
external_unlabeled_info_df = pd.read_csv(external_unlabeled_info_dir)

external_labeled_data_dir = data_dir + 'External/External_labeled_gex.csv'
external_labeled_data_df = pd.read_csv(external_labeled_data_dir, index_col=0)

external_labeled_info_dir = data_dir + 'External/External_labeled_info.csv'
external_labeled_info_df = pd.read_csv(external_labeled_info_dir)

external_unlabeled_data_df.shape, external_unlabeled_info_df.shape, external_labeled_data_df.shape, external_labeled_info_df.shape

((807, 978), (807, 3), (179, 978), (179, 3))

### Tissue-specific data generation

In [5]:
tissue = 'breast'

In [6]:
# tissue-specific GDSC data

tissue_gdsc_info_df = gdsc_info_df[gdsc_info_df['tissue_type']==tissue].reset_index(drop=True)
tissue_gdsc_data_df = gdsc_data_df.loc[tissue_gdsc_info_df['ID']]

tissue_gdsc_data_df.shape, tissue_gdsc_info_df.shape

((45, 978), (45, 5))

In [7]:
# zero gene matching

zero_genes = external_unlabeled_data_df.T[external_unlabeled_data_df.sum() == 0].index.to_list()
for gene in zero_genes:
    tissue_gdsc_data_df[gene] = 0

### Load alignment model

In [8]:
dim_latent = 128
device = 'cuda:1'

In [9]:
gdsc_dataset = AlignerDataset(tissue_gdsc_data_df, 'gdsc', tissue_gdsc_info_df['tissue_label'])
external_unlabeled_dataset = AlignerDataset(external_unlabeled_data_df, 'external', external_unlabeled_info_df['tissue_label'])
external_labeled_dataset = AlignerDataset(external_labeled_data_df, 'external', external_labeled_info_df['tissue_label'])

In [10]:
source = 'GDSC'
target = 'External'

best_model_name = f'../src/ckpts/THERAPI_aligner_{source}_{target}.pt'

best_gdsc_AE = SOURCE_AE(n_genes=gdsc_dataset.n_genes, n_classes=len(gdsc_info_df['tissue_label'].unique()), n_latent=dim_latent)
best_gdsc_AE.load_state_dict(torch.load(best_model_name, map_location=device)['source_AE'])
best_gdsc_AE.to(device)

best_external_weightencoder = TARGET_weightencoder(n_genes=external_unlabeled_dataset.n_genes, n_latent=dim_latent, n_celines=673)
best_external_weightencoder.load_state_dict(torch.load(best_model_name, map_location=device)['target_weightencoder'])
best_external_weightencoder.to(device)

TARGET_weightencoder(
  (Q): Sequential(
    (0): Linear(in_features=978, out_features=128, bias=True)
    (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=128, bias=True)
  )
  (K): Linear(in_features=128, out_features=128, bias=False)
  (decoder): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Linear(in_features=256, out_features=978, bias=True)
  )
)

### Alignment result save

In [11]:
alignment_result_folder = '../data/Alignment_results_External/'
os.makedirs(alignment_result_folder, exist_ok=True)

In [12]:
best_gdsc_AE.eval()
best_external_weightencoder.eval()

TARGET_weightencoder(
  (Q): Sequential(
    (0): Linear(in_features=978, out_features=128, bias=True)
    (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=128, bias=True)
  )
  (K): Linear(in_features=128, out_features=128, bias=False)
  (decoder): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Linear(in_features=256, out_features=978, bias=True)
  )
)

In [13]:
gdsc_dataset_input = gdsc_dataset.data.to(device)
gdsc_latent, _ = best_gdsc_AE(gdsc_dataset_input)
gdsc_latent_df = pd.DataFrame(gdsc_latent.cpu().detach().numpy(), index=tissue_gdsc_data_df.index)
gdsc_latent_df.shape

(45, 128)

In [14]:
gdsc_latent_df.to_csv(alignment_result_folder + 'GDSC_alignment_latent.csv')

In [15]:
external_unlabeled_weights, external_unlabeled_latent, external_unlabeled_wgex, _ = best_external_weightencoder(external_unlabeled_dataset.data.to(device), gdsc_latent, gdsc_dataset_input)

external_unlabeled_latent_df = pd.DataFrame(external_unlabeled_latent.cpu().detach().numpy(), index=external_unlabeled_data_df.index)

external_unlabeled_weights_df = pd.DataFrame(external_unlabeled_weights.cpu().detach().numpy(), index=external_unlabeled_data_df.index)
external_unlabeled_weights_df.columns = tissue_gdsc_data_df.index

external_unlabeled_wgex_df = pd.DataFrame(external_unlabeled_wgex.cpu().detach().numpy(), index=external_unlabeled_data_df.index)
external_unlabeled_wgex_df.columns = gdsc_data_df.columns

external_unlabeled_latent_df.shape, external_unlabeled_weights_df.shape, external_unlabeled_wgex_df.shape

((807, 128), (807, 45), (807, 978))

In [16]:
external_unlabeled_latent_df.to_csv(alignment_result_folder + 'External_unlabeled_alignment_latent.csv')
external_unlabeled_weights_df.to_csv(alignment_result_folder + 'External_unlabeled_alignment_weights.csv')
external_unlabeled_wgex_df.to_csv(alignment_result_folder + 'External_unlabeled_alignment_wgex.csv')

In [17]:
external_labeled_weights, external_labeled_latent, external_labeled_wgex, _ = best_external_weightencoder(external_labeled_dataset.data.to(device), gdsc_latent, gdsc_dataset_input)

external_labeled_latent_df = pd.DataFrame(external_labeled_latent.cpu().detach().numpy(), index=external_labeled_data_df.index)

external_labeled_weights_df = pd.DataFrame(external_labeled_weights.cpu().detach().numpy(), index=external_labeled_data_df.index)
external_labeled_weights_df.columns = tissue_gdsc_data_df.index

external_labeled_wgex_df = pd.DataFrame(external_labeled_wgex.cpu().detach().numpy(), index=external_labeled_data_df.index)
external_labeled_wgex_df.columns = gdsc_data_df.columns

external_labeled_latent_df.shape, external_labeled_weights_df.shape, external_labeled_wgex_df.shape

((179, 128), (179, 45), (179, 978))

In [18]:
external_labeled_latent_df.to_csv(alignment_result_folder + 'External_labeled_alignment_latent.csv')
external_labeled_weights_df.to_csv(alignment_result_folder + 'External_labeled_alignment_weights.csv')
external_labeled_wgex_df.to_csv(alignment_result_folder + 'External_labeled_alignment_wgex.csv')