In [1]:
import os
from tqdm import tqdm
from itertools import chain
import pandas as pd
import numpy as np
import random, torch
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style = 'white', font_scale=1.5)

import sys
sys.path.insert(0, '/home/was966/Research/mims-compass/')
from compass.utils import plot_embed_with_label
from compass import PreTrainer, FineTuner, loadcompass
from compass.utils import plot_embed_with_label, score
from compass.tokenizer import CANCER_CODE


def onehot(S):
    assert type(S) == pd.Series, 'Input type should be pd.Series'
    dfd = pd.get_dummies(S, dummy_na=True)
    nanidx = dfd[dfd[np.nan]].index
    dfd.loc[nanidx, :] = np.nan
    dfd = dfd.drop(columns=[np.nan])*1.
    cols = dfd.sum().sort_values(ascending=False).index.tolist()
    dfd = dfd[cols]
    return dfd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_path = '../00_data/'
df_label = pd.read_pickle(os.path.join(data_path, 'ITRP.PATIENT.TABLE'))
df_tpm = pd.read_pickle(os.path.join(data_path, 'ITRP.TPM.TABLE'))
df_tpm.shape, df_label.shape

((1133, 15672), (1133, 113))

In [3]:
# load the pretrained model as a feature extractor
pretrainer = loadcompass('../checkpoint/latest/pretrainer.pt', map_location = 'cpu')
genesetprojector = pretrainer.model.latentprojector.genesetprojector
cellpathwayprojector = pretrainer.model.latentprojector.cellpathwayprojector
df_tpm = df_tpm[pretrainer.feature_name]
pretrainer.count_parameters()

1019421

In [4]:
dfcx = df_label.cancer_type.map(CANCER_CODE).to_frame('cancer_code').join(df_tpm)
dfgeneset, dfcelltype = pretrainer.extract(dfcx,  batch_size= 128)

100%|#####################################################################################| 9/9 [00:49<00:00,  5.46s/it]


In [5]:
df_label.to_csv('./ITRP/00_clinical_label.csv')
dfgeneset.to_csv('./ITRP/01_readouts_geneset.csv')
dfcelltype.to_csv('./ITRP/02_readouts_celltype.csv')

In [6]:
dfg, dfc = pretrainer.project(dfcx,  batch_size= 128)

100%|#####################################################################################| 9/9 [00:56<00:00,  6.23s/it]


In [7]:
pid = dfg.index.map(lambda x:x.split('$$')[0])
fid = dfg.index.map(lambda x:x.split('$$')[1])
df = pd.DataFrame(index=dfg.index)
df['Index'] = pid
df['feature_name'] = fid
df = df.join(dfg)
df = df.sort_values(['feature_name', 'Index'])
df.to_csv('./ITRP/03_features_geneset.csv.gzip', compression = 'gzip')

In [8]:
pid = dfc.index.map(lambda x:x.split('$$')[0])
fid = dfc.index.map(lambda x:x.split('$$')[1])
df = pd.DataFrame(index=dfc.index)
df['Index'] = pid
df['feature_name'] = fid
df = df.join(dfc)
df = df.sort_values(['feature_name', 'Index'])
df.to_csv('./ITRP/04_features_celltype.gzip', compression = 'gzip')