This notebook takes as input the set of mutations present in each dataset and uses the common vocabulary of mutations across datasets to process and generate data usable by the PREDICT-AI model.

In [1]:
import pandas as pd
import numpy as np

In [2]:
import pickle
import json
from collections import defaultdict

In [3]:
from torch.utils.data import Dataset,DataLoader,TensorDataset
import torch

In [4]:
from typing import Optional,Union
from transformers import BatchEncoding,TensorType

In [5]:
# update file path later
with open(f"/data/ajayago/papers_data/systematic_assessment/raw/mutation_files/annotation_vocab_ccle_cbio_icgc_moores_tcga_genie_nci60_nuh_union.pickle", "rb") as f:
    vocab_annovar = pickle.load(f)

In [6]:
len(vocab_annovar)

2324534

In [7]:
vocab_annovar_tokenized = {j:i for i, j in enumerate(sorted(vocab_annovar.keys()))} # id mapping for each mutation across survival and other datasets

In [8]:
len(vocab_annovar_tokenized) # Eg: {'A1BG@A191T': 0,'A1BG@A268V': 1,...}

2324534

In [9]:
vocab_df = pd.DataFrame(vocab_annovar).transpose().sort_index()
vocab_df.columns = ['sift_pred', 'sift4g_pred', 'lrt_pred', 'mutationtaster_pred',
       'mutationassessor_pred', 'fathmm_pred', 'provean_pred', 'metasvm_pred',
       'm_cap_pred', 'primateai_pred', 'deogen2_pred', 'bayesdel_addaf_pred',
       'bayesdel_noaf_pred', 'clinpred_pred', 'list_s2_pred',
       'fathmm_mkl_coding_pred', 'fathmm_xf_coding_pred', 'clinvar_Pathogenic',
       'clinvar_Benign', 'clinvar_Unknown', 'gpd_LU', 'gpd_NCU', 'gpd_PIU']
vocab_df
# vocab_df.to_csv("/data/ajayago/papers_data/systematic_assessment/processed/vocab_predict_ai_ccle_cbio_icgc_moores_tcga_genie_nci60_nuh_union.csv")

Unnamed: 0,sift_pred,sift4g_pred,lrt_pred,mutationtaster_pred,mutationassessor_pred,fathmm_pred,provean_pred,metasvm_pred,m_cap_pred,primateai_pred,...,clinpred_pred,list_s2_pred,fathmm_mkl_coding_pred,fathmm_xf_coding_pred,clinvar_Pathogenic,clinvar_Benign,clinvar_Unknown,gpd_LU,gpd_NCU,gpd_PIU
A1BG@A191T,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,1,1,0,0
A1BG@A268V,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,1,0,0,1
A1BG@A295T,0,0,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,1,0,0,1
A1BG@A332E,1,0,0,0,0,0,0,0,0,1,...,0,0,0,0,0,0,1,1,0,0
A1BG@A353V,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,1,1,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ZZEF1@Y255H,0,0,0,0,0,0,0,0,0,0,...,0,0,1,0,0,0,1,0,0,1
ZZEF1@Y2618S,1,1,1,1,1,0,1,0,1,0,...,1,1,1,1,0,0,1,1,0,0
ZZEF1@Y399H,1,1,1,1,1,0,0,0,0,1,...,1,1,1,1,0,0,1,1,0,0
ZZEF1@Y702D,1,1,1,1,1,0,1,0,1,0,...,1,0,1,1,0,0,1,0,0,1


In [10]:
vocab_df

Unnamed: 0,sift_pred,sift4g_pred,lrt_pred,mutationtaster_pred,mutationassessor_pred,fathmm_pred,provean_pred,metasvm_pred,m_cap_pred,primateai_pred,...,clinpred_pred,list_s2_pred,fathmm_mkl_coding_pred,fathmm_xf_coding_pred,clinvar_Pathogenic,clinvar_Benign,clinvar_Unknown,gpd_LU,gpd_NCU,gpd_PIU
A1BG@A191T,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,1,1,0,0
A1BG@A268V,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,1,0,0,1
A1BG@A295T,0,0,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,1,0,0,1
A1BG@A332E,1,0,0,0,0,0,0,0,0,1,...,0,0,0,0,0,0,1,1,0,0
A1BG@A353V,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,1,1,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ZZEF1@Y255H,0,0,0,0,0,0,0,0,0,0,...,0,0,1,0,0,0,1,0,0,1
ZZEF1@Y2618S,1,1,1,1,1,0,1,0,1,0,...,1,1,1,1,0,0,1,1,0,0
ZZEF1@Y399H,1,1,1,1,1,0,0,0,0,1,...,1,1,1,1,0,0,1,1,0,0
ZZEF1@Y702D,1,1,1,1,1,0,1,0,1,0,...,1,0,1,1,0,0,1,0,0,1


In [11]:
druid_data_dir = "/data/druid_data/"

In [12]:
raw_data_dir = "/data/ajayago/papers_data/systematic_assessment/raw/mutation_files/"
data_splits_dir = "/data/ajayago/papers_data/systematic_assessment/processed/"
expt1A_dir = data_splits_dir + "Experiment1/SettingA/"
expt1B_dir = data_splits_dir + "Experiment1/SettingB/"
expt2A_dir = data_splits_dir + "Experiment2/SettingA/"
expt2B_dir = data_splits_dir + "Experiment2/SettingB/"

In [13]:
genes324 = list(pd.read_csv("/data/druid_data/raw_data/gene2ind.txt", header=None)[0])
len(genes324)

324

In [14]:
def get_alias_to_canonical_name_map():
    with open("/data/druid_data/raw_data/gene_aliases.json", "r") as fp:
        aliases_on_disk = json.load(fp)

    alias_to_canonical_name_map = {}
    for canonical_name, aliases in aliases_on_disk.items():
        # Some canonical names have only one alias - convert those as list for consistency
        if type(aliases) != list:
            aliases = [aliases]

        for alias in aliases:

            # If an alias is one of the canonical names in GENES_324, do not add it to the map
            # Else, we'd be renaming a canonical named column into something else
            if alias in genes324:
                print(f"Alias {alias} is a canonical_name, skipping")
                continue

            if alias in alias_to_canonical_name_map:
                print(
                    f"Found multiple canonical names for alias - {alias} = {[canonical_name, alias_to_canonical_name_map[alias]]}"
                )
                # Drop aliases with conflicting canonical names per recommendation from clinicians
                alias_to_canonical_name_map.pop(alias)

            # Convert all aliases to be upper case for consistency
            alias_to_canonical_name_map[alias.upper()] = canonical_name.upper()

    return alias_to_canonical_name_map

In [15]:
alias2canonicalmap = get_alias_to_canonical_name_map()

Alias RAD54L is a canonical_name, skipping
Found multiple canonical names for alias - CDK4I = ['CDKN2B', 'CDKN2A']
Found multiple canonical names for alias - IDH = ['IDH2', 'IDH1']
Found multiple canonical names for alias - IDP = ['IDH2', 'IDH1']
Found multiple canonical names for alias - HDMX = ['MDM4', 'MDM2']
Found multiple canonical names for alias - HNPCC = ['MSH2', 'MLH1']
Found multiple canonical names for alias - MRP1 = ['MSH3', 'MDM4']
Alias KRAS is a canonical_name, skipping
Found multiple canonical names for alias - ADPRTL2 = ['PARP3', 'PARP2']
Found multiple canonical names for alias - ADPRTL3 = ['PARP3', 'PARP2']
Found multiple canonical names for alias - MCAP = ['PIK3CA', 'BRD4']
Found multiple canonical names for alias - PI3K = ['PIK3CB', 'PIK3CA']
Found multiple canonical names for alias - R51H3 = ['RAD51D', 'RAD51C']
Found multiple canonical names for alias - PTC = ['RET', 'PTCH1']
Found multiple canonical names for alias - SDH1 = ['SDHB', 'SDHA']
Found multiple canoni

In [16]:
len(set(alias2canonicalmap.values()))

311

In [17]:
def convert2canonical(gene):
    if gene in genes324: # already a canonical name
        return gene
    if gene in alias2canonicalmap.keys(): # not canonical name => convert to canonical name
        return alias2canonicalmap[gene]
    return np.NaN # not in 324 genes
        

### Cell Lines

In [18]:
with open(expt1A_dir + "cell_lines_fold0.pkl", "rb") as f:
    exp1A_cl_fold0 = pickle.load(f)
    
with open(expt1B_dir + "cell_lines_fold0.pkl", "rb") as f:
    exp1B_cl_fold0 = pickle.load(f)

with open(expt2A_dir + "cell_lines_fold0.pkl", "rb") as f:
    exp2A_cl_fold0 = pickle.load(f)

with open(expt2B_dir + "cell_lines_fold0.pkl", "rb") as f:
    exp2B_cl_fold0 = pickle.load(f)

In [19]:
cell_line_data = druid_data_dir + "CCLE_23Q4" # to load raw mutations from

In [20]:
cl_mutations = pd.read_csv(cell_line_data + "/patient_gene_alteration(mutation).csv")
cl_mutations

Unnamed: 0,depmap_id,gene,alteration
0,PR-sxFiuq,SAMD11,L76V
1,PR-DNEoiz,SAMD11,P107S
2,PR-2ei6MD,SAMD11,E160K
3,PR-CYz5sB,SAMD11,A218V
4,PR-xcsbEI,SAMD11,N285S
...,...,...,...
885436,PR-MX9ndc,KDM5D,R68H
885437,PR-AiAKPa,EIF1AY,D83Y
885438,PR-MX9ndc,RPS4Y2,T115A
885439,PR-Bs4EcD,RPS4Y2,P152S


In [21]:
cl_mutations["canonical_gene_name"] = cl_mutations["gene"].apply(lambda x: convert2canonical(x))

In [22]:
len(cl_mutations[cl_mutations.canonical_gene_name.astype(str) != "nan"].depmap_id.unique())

2313

In [23]:
cl_mutations_reduced = cl_mutations[cl_mutations.canonical_gene_name.astype(str) != "nan"].reset_index(drop=True)
cl_mutations_reduced

Unnamed: 0,depmap_id,gene,alteration,canonical_gene_name
0,PR-rsoNmY,TNFRSF14,G68R,TNFRSF14
1,PR-yDgpga,TNFRSF14,G68R,TNFRSF14
2,PR-sgXEkc,TNFRSF14,G89C,TNFRSF14
3,PR-kRqGcx,TNFRSF14,A140S,TNFRSF14
4,PR-ZhEuUF,TNFRSF14,R149M,TNFRSF14
...,...,...,...,...
29444,PR-6Ybf3z,BCORL1,P1755QfsTer20,BCORL1
29445,PR-81oclJ,BCORL1,P1755QfsTer20,BCORL1
29446,PR-Qvs2q6,BCORL1,P1755QfsTer20,BCORL1
29447,PR-EaZDJD,BCORL1,E1767K,BCORL1


### Patients

In [24]:
# Fold 0, 1, 2
for i in range(0, 3):
    with open(expt1A_dir + f"patients_fold{i}.pkl", "rb") as f:
        exec(f"exp1A_patient_fold{i} = pickle.load(f)")
        
    with open(expt1B_dir + f"patients_fold{i}.pkl", "rb") as f:
        exec(f"exp1B_patient_fold{i} = pickle.load(f)")
    
    with open(expt2A_dir + f"patients_fold{i}.pkl", "rb") as f:
        exec(f"exp2A_patient_fold{i} = pickle.load(f)")
    
    with open(expt2B_dir + f"patients_fold{i}.pkl", "rb") as f:
        exec(f"exp2B_patient_fold{i} = pickle.load(f)")

In [25]:
tcga_data = druid_data_dir + "Tcga" # to load raw mutations from
moores_data = druid_data_dir + "Moores"
cbio_hcc_mskimpact_2018_data = druid_data_dir + "CBIO/hcc_mskimpact_2018"
cbio_brca_mskcc_2019_data = druid_data_dir + "CBIO/brca_mskcc_2019"
genie_crc_data = druid_data_dir + "GenieCRC"
genie_nsclc_data = druid_data_dir + "NSCLC"

In [26]:
tcga_mutations = pd.read_csv(tcga_data + "/patient_gene_alteration(mutation).csv")
tcga_mutations

Unnamed: 0,patient_id,gene,alteration
0,TCGA-50-5931,CAMTA1,V870E
1,TCGA-50-5931,CATSPER4,P365=
2,TCGA-50-5931,KDF1,I243T
3,TCGA-50-5931,CSMD2,T417S
4,TCGA-50-5931,SFPQ,G647C
...,...,...,...
3093849,TCGA-YD-A9TA,CNGA2,G303G
3093850,TCGA-YD-A9TA,MAGEA12,R243R
3093851,TCGA-YD-A9TA,ZNF275,L224L
3093852,TCGA-YD-A9TA,L1CAM,P279P


In [27]:
moores_mutations = pd.read_csv(moores_data + "/patient_gene_alteration(mutation).csv")
moores_mutations

Unnamed: 0,patient_id,gene,alteration
0,1,PTEN,splice site 493-1 G>A
1,2,TP53,P151A
2,3,ESR1,Y537S
3,4,PTEN,I67K
4,4,CTNNB1,T257I
...,...,...,...
220,84,GATA3,G335fs*18
221,85,TP53,H168R
222,85,GATA3,N332fs*21
223,86,MLL2,A4571T


In [28]:
cbio_hcc_mutations = pd.read_csv(cbio_hcc_mskimpact_2018_data + "/patient_gene_alteration(mutation).csv")
cbio_hcc_mutations

Unnamed: 0,patient_id,gene,alteration
0,P-0005038-T02-IM6,TNFRSF14,Q242R
1,P-0005038-T02-IM6,JAK1,S729C
2,P-0005038-T02-IM6,MEN1,X224_splice
3,P-0005038-T02-IM6,ALK,E717K
4,P-0015203-T01-IM6,ZRSR2,C172S
...,...,...,...
531,P-0012182-T01-IM5,NEGR1,Q8L
532,P-0012182-T01-IM5,SETD2,S2479A
533,P-0012182-T01-IM5,POLE,V544M
534,P-0012182-T01-IM5,AXIN1,E291*


In [29]:
cbio_brca_mutations = pd.read_csv(cbio_brca_mskcc_2019_data + "/patient_gene_alteration(mutation).csv")
cbio_brca_mutations

Unnamed: 0,patient_id,gene,alteration
0,s_DS_bkm_077_T,VTCN1,S192L
1,s_DS_bkm_078_T2,NOTCH2,D1582N
2,s_DS_bkm_078_T1,NOTCH2,D1582N
3,s_DS_bkm_074_T,NOTCH2,T1303P
4,s_DS_bkm_064_T2,NOTCH2,P6Rfs*27
...,...,...,...
653,s_DS_bkm_058_T,NCOR1,A750V
654,s_DS_bkm_058_T,BCOR,N193T
655,s_DS_bkm_059_T,SF3B1,I641V
656,s_DS_bkm_059_T,ESR1,L536R


In [30]:
genie_crc_mutations = pd.read_csv(genie_crc_data + "/patient_gene_alteration(mutation).csv")
genie_crc_mutations

Unnamed: 0,patient_id,gene,alteration
0,GENIE-DFCI-002643-6598,PALB2,*16*
1,GENIE-DFCI-002643-6598,FBXW7,E287Q
2,GENIE-DFCI-002643-6598,EGFR,R252C
3,GENIE-DFCI-002643-6598,PSMD13,D175H
4,GENIE-DFCI-002643-6598,NRAS,G12D
...,...,...,...
23060,GENIE-VICC-182499-unk-1,ZNF703,D403_P404insAPRRLQLLHLQRAD
23061,GENIE-VICC-397091-unk-1,ZNF703,D403_P404insAPRRLQLLHLQRAD
23062,GENIE-VICC-669338-unk-1,ZNF703,D403_P404insAPRRLQLLHLQRAD
23063,GENIE-VICC-669338-unk-1,ZNF703,A507delinsERP


In [31]:
genie_nsclc_mutations = pd.read_csv(genie_nsclc_data + "/patient_gene_alteration(mutation).csv")
genie_nsclc_mutations

Unnamed: 0,patient_id,gene,alteration
0,GENIE-DFCI-003908-234520,RAD50,*31*
1,GENIE-DFCI-003908-234520,ARID2,Q1227R
2,GENIE-DFCI-003908-234520,FANCB,L43I
3,GENIE-DFCI-003908-234520,SETD2,E1971Kfs*35
4,GENIE-DFCI-003908-234520,POLD1,V553I
...,...,...,...
17341,GENIE-VICC-780278-unk-1,WISP3,S352Y
17342,GENIE-VICC-120723-unk-1,WT1,V73M
17343,GENIE-VICC-287735-unk-2,WT1,P132L
17344,GENIE-VICC-780278-unk-1,WT1,X447_splice


In [32]:
# with RECIST response and survival data
patients_combined = pd.concat([tcga_mutations, moores_mutations, cbio_hcc_mutations, cbio_brca_mutations, genie_crc_mutations, genie_nsclc_mutations], ignore_index=True)
patients_combined

Unnamed: 0,patient_id,gene,alteration
0,TCGA-50-5931,CAMTA1,V870E
1,TCGA-50-5931,CATSPER4,P365=
2,TCGA-50-5931,KDF1,I243T
3,TCGA-50-5931,CSMD2,T417S
4,TCGA-50-5931,SFPQ,G647C
...,...,...,...
3135679,GENIE-VICC-780278-unk-1,WISP3,S352Y
3135680,GENIE-VICC-120723-unk-1,WT1,V73M
3135681,GENIE-VICC-287735-unk-2,WT1,P132L
3135682,GENIE-VICC-780278-unk-1,WT1,X447_splice


In [33]:
patients_combined["canonical_gene_name"] = patients_combined["gene"].apply(lambda x: convert2canonical(x))

In [34]:
len(patients_combined.patient_id.unique())

13920

In [35]:
len(patients_combined[patients_combined.canonical_gene_name.astype(str) != "nan"].patient_id.unique())

12907

In [36]:
patients_combined_reduced = patients_combined[patients_combined.canonical_gene_name.astype(str)!="nan"].reset_index(drop=True)
patients_combined_reduced

Unnamed: 0,patient_id,gene,alteration,canonical_gene_name
0,TCGA-50-5931,DNMT3A,V258M,DNMT3A
1,TCGA-50-5931,MSH2,I356V,MSH2
2,TCGA-50-5931,RICTOR,L42F,RICTOR
3,TCGA-50-5931,NOTCH1,D297G,NOTCH1
4,TCGA-50-5931,MLL2,E668*,MLL2
...,...,...,...,...
130656,GENIE-VICC-780278-unk-1,VHL,E204G,VHL
130657,GENIE-VICC-120723-unk-1,WT1,V73M,WT1
130658,GENIE-VICC-287735-unk-2,WT1,P132L,WT1
130659,GENIE-VICC-780278-unk-1,WT1,X447_splice,WT1


#### Processing Survival datasets

In [37]:
# used to calculate max number of mutations per sample and max number of genes mutated
survival_patients_combined = pd.concat([genie_crc_mutations, genie_nsclc_mutations], ignore_index=True)
survival_patients_combined["canonical_gene_name"] = survival_patients_combined["gene"].apply(lambda x: convert2canonical(x))
survival_patients_combined_reduced = survival_patients_combined[survival_patients_combined.canonical_gene_name.astype(str)!="nan"].reset_index(drop=True)
survival_patients_combined_reduced

Unnamed: 0,patient_id,gene,alteration,canonical_gene_name
0,GENIE-DFCI-002643-6598,PALB2,*16*,PALB2
1,GENIE-DFCI-002643-6598,FBXW7,E287Q,FBXW7
2,GENIE-DFCI-002643-6598,EGFR,R252C,EGFR
3,GENIE-DFCI-002643-6598,NRAS,G12D,NRAS
4,GENIE-DFCI-002643-6598,CDKN2A,R80*,CDKN2A
...,...,...,...,...
28614,GENIE-VICC-780278-unk-1,VHL,E204G,VHL
28615,GENIE-VICC-120723-unk-1,WT1,V73M,WT1
28616,GENIE-VICC-287735-unk-2,WT1,P132L,WT1
28617,GENIE-VICC-780278-unk-1,WT1,X447_splice,WT1


In [38]:
# only consider unique genes mutated in each patient
survival_patients_combined_reduced.drop_duplicates(["patient_id", "canonical_gene_name"]).groupby(["patient_id"]).agg("count")

Unnamed: 0_level_0,gene,alteration,canonical_gene_name
patient_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
GENIE-DFCI-000013-8840,5,5,5
GENIE-DFCI-000016-74822,3,3,3
GENIE-DFCI-000036-9268,2,2,2
GENIE-DFCI-000036-9269,3,3,3
GENIE-DFCI-000036-9278,8,8,8
...,...,...,...
GENIE-VICC-994641-unk-1,2,2,2
GENIE-VICC-994755-unk-1,1,1,1
GENIE-VICC-995221-unk-1,10,10,10
GENIE-VICC-996370-unk-1,1,1,1


In [39]:
# max number of genes mutated across patients - theoretically 324
max_gene_count = survival_patients_combined_reduced.drop_duplicates(["patient_id", "canonical_gene_name"]).groupby(["patient_id"]).agg("count").max()["gene"]
max_gene_count # can even be set to 324 since in the worst case all 324 genes may be mutated

179

In [40]:
# max number of mutations per cell line/patient with an annotation available in the vocab
survival_patients_combined_reduced.patient_id.value_counts(ascending=True)

patient_id
GENIE-DFCI-011512-10189          1
GENIE-VICC-983614-unk-1          1
GENIE-VICC-410616-unk-1          1
GENIE-VICC-181550-unk-1          1
GENIE-VICC-676204-unk-1          1
                              ... 
GENIE-DFCI-091546-350438       144
GENIE-DFCI-039131-203438       200
GENIE-MSK-P-0005824-T01-IM5    218
GENIE-MSK-P-0006612-T01-IM5    242
GENIE-VICC-976300-unk-1        453
Name: count, Length: 3461, dtype: int64

In [41]:
survival_patients_combined_reduced[survival_patients_combined_reduced.patient_id == "GENIE-VICC-113182-unk-2"] # sanity check

Unnamed: 0,patient_id,gene,alteration,canonical_gene_name
27907,GENIE-VICC-113182-unk-2,EGFR,L858R,EGFR


In [42]:
survival_patients_combined_reduced["point_mutations"] = survival_patients_combined_reduced["canonical_gene_name"] + "@" + survival_patients_combined_reduced["alteration"]
survival_patients_combined_reduced

Unnamed: 0,patient_id,gene,alteration,canonical_gene_name,point_mutations
0,GENIE-DFCI-002643-6598,PALB2,*16*,PALB2,PALB2@*16*
1,GENIE-DFCI-002643-6598,FBXW7,E287Q,FBXW7,FBXW7@E287Q
2,GENIE-DFCI-002643-6598,EGFR,R252C,EGFR,EGFR@R252C
3,GENIE-DFCI-002643-6598,NRAS,G12D,NRAS,NRAS@G12D
4,GENIE-DFCI-002643-6598,CDKN2A,R80*,CDKN2A,CDKN2A@R80*
...,...,...,...,...,...
28614,GENIE-VICC-780278-unk-1,VHL,E204G,VHL,VHL@E204G
28615,GENIE-VICC-120723-unk-1,WT1,V73M,WT1,WT1@V73M
28616,GENIE-VICC-287735-unk-2,WT1,P132L,WT1,WT1@P132L
28617,GENIE-VICC-780278-unk-1,WT1,X447_splice,WT1,WT1@X447_splice


In [43]:
len(survival_patients_combined_reduced.patient_id.unique())

3461

In [44]:
# from PREDICT_AI
def get_texts(samples_df):
    texts = (samples_df.groupby(["patient_id", "canonical_gene_name"])["point_mutations"].agg(lambda x: " ".join(list(x))).reset_index(level="canonical_gene_name").apply(lambda x: x["canonical_gene_name"] + " <mutsep> " + x["point_mutations"], axis=1).groupby(["patient_id"]).agg(lambda x: " <gensep> ".join(list(x))))
    return texts

In [45]:
def yield_tokens(texts, genes):
    geneset = genes.union({"<gensep>", "<mutsep>"})
    for text in texts:
        tokens = ["<unk>" if not "@" in tok and not tok in geneset else tok for tok in text.split(" ")]
        univariate_mutation_tokens = ["<mut>" if "@" in tok else tok for tok in tokens]  # Treat all mutations to have the same token <mut>
        yield univariate_mutation_tokens

In [46]:
texts = get_texts(survival_patients_combined_reduced)
texts

patient_id
GENIE-DFCI-000013-8840     BRAF <mutsep> BRAF@G466L <gensep> CDKN2A <muts...
GENIE-DFCI-000016-74822    ARID1A <mutsep> ARID1A@Y551Lfs*72 <gensep> NRA...
GENIE-DFCI-000036-9268     ATRX <mutsep> ATRX@G2123E <gensep> KEAP1 <muts...
GENIE-DFCI-000036-9269     MLL <mutsep> MLL@A3489S <gensep> STK11 <mutsep...
GENIE-DFCI-000036-9278     BCORL1 <mutsep> BCORL1@P566H <gensep> ERBB4 <m...
                                                 ...                        
GENIE-VICC-994641-unk-1    BRAF <mutsep> BRAF@G469A <gensep> TP53 <mutsep...
GENIE-VICC-994755-unk-1                              NRAS <mutsep> NRAS@G12D
GENIE-VICC-995221-unk-1    CDKN2A <mutsep> CDKN2A@*50* <gensep> FLT1 <mut...
GENIE-VICC-996370-unk-1                             TP53 <mutsep> TP53@M237I
GENIE-VICC-996370-unk-2    APC <mutsep> APC@R876* <gensep> ATRX <mutsep> ...
Length: 3461, dtype: object

In [47]:
len(texts) == len(survival_patients_combined_reduced.patient_id.unique())

True

#### Build vocab of genes and mutations for later use

In [48]:
from torchtext.vocab import build_vocab_from_iterator

In [49]:
token_generator = yield_tokens(texts, set(genes324))
vocab = build_vocab_from_iterator(token_generator, specials=["<s>","<pad>","</s>","<unk>"])

In [50]:
gene2id = {vocab.lookup_token(idx):idx for idx in range(len(vocab))}
len(gene2id)

303

In [51]:
unknown = list(set(genes324).difference(set([k for k,v in gene2id.items()])))
len(unknown)

28

In [52]:
# add the missing genes
idx = len(vocab)
for gene in unknown:
    print(f"Adding {gene}")
    gene2id[gene] = idx
    idx += 1

Adding TIPARP
Adding EZR
Adding FGF12
Adding MEF2B
Adding SGK1
Adding CD74
Adding BCR
Adding CD70
Adding KRAS
Adding LTK
Adding BTG2
Adding CD22
Adding PARP2
Adding PPP2R2A
Adding P2RY8
Adding CYP17A1
Adding MKNK1
Adding MERTK
Adding HDAC1
Adding DDR1
Adding EPHB4
Adding SDC4
Adding RAD54L
Adding PTPRO
Adding FGF10
Adding PARP3
Adding TYRO3
Adding CUL4A


In [53]:
gene2id # gene vocabulary

{'<s>': 0,
 '<pad>': 1,
 '</s>': 2,
 '<unk>': 3,
 '<mut>': 4,
 '<mutsep>': 5,
 '<gensep>': 6,
 'TP53': 7,
 'NRAS': 8,
 'APC': 9,
 'EGFR': 10,
 'PIK3CA': 11,
 'MLL2': 12,
 'ATM': 13,
 'ARID1A': 14,
 'BRAF': 15,
 'SMAD4': 16,
 'STK11': 17,
 'KEAP1': 18,
 'SMARCA4': 19,
 'FBXW7': 20,
 'NF1': 21,
 'ATRX': 22,
 'SETD2': 23,
 'NOTCH1': 24,
 'CREBBP': 25,
 'SOX9': 26,
 'BRCA2': 27,
 'ALK': 28,
 'MLL': 29,
 'ERBB4': 30,
 'RBM10': 31,
 'CARD11': 32,
 'MET': 33,
 'EPHA3': 34,
 'MTOR': 35,
 'CTNNB1': 36,
 'PTEN': 37,
 'NTRK3': 38,
 'ROS1': 39,
 'NOTCH3': 40,
 'EP300': 41,
 'BCOR': 42,
 'ERBB2': 43,
 'FLT1': 44,
 'TSC2': 45,
 'MED12': 46,
 'TET2': 47,
 'ASXL1': 48,
 'PDGFRA': 49,
 'AR': 50,
 'KDR': 51,
 'RB1': 52,
 'PTCH1': 53,
 'ERBB3': 54,
 'PBRM1': 55,
 'POLE': 56,
 'RNF43': 57,
 'FANCA': 58,
 'IKZF1': 59,
 'NOTCH2': 60,
 'PIK3R1': 61,
 'RET': 62,
 'BRCA1': 63,
 'NTRK1': 64,
 'FAM123B': 65,
 'DNMT3A': 66,
 'CDKN2A': 67,
 'MSH6': 68,
 'ATR': 69,
 'BRD4': 70,
 'CIC': 71,
 'PDGFRB': 72,
 'PALB2': 

In [54]:

class Tokenizer:
	# vocab: vocabulary instance; annovar: annovar dictionary {gene_indicator: gene_idx}
	def __init__(self, vocab, vocab_annovar_tokenized) -> None:
		self.vocab = vocab
		self.annovar = vocab_annovar_tokenized #{'ZRSR2@V250M': 12175} #12942

	def __call__(self,texts,padding=False,max_length:Optional[int]=None,return_tensors:Optional[Union[str,TensorType]]=None,return_attention_mask=True,):
		return self.batch_encode_plus(texts=texts,padding=padding,max_length=max_length,return_tensors=return_tensors,return_attention_mask=return_attention_mask,)

	def batch_encode_plus(self,texts,padding=False,max_length:Optional[int]=None,return_tensors:Optional[Union[str,TensorType]]=None,return_attention_mask=True,):
		if type(texts) == str:
			texts = [texts]
		masked_texts = [" ".join(["<mut>" if "@" in tok else tok for tok in text.split(" ")]) for text in texts] #['TP53 <mutsep> <mut>']
		batch_tokens = [["<s>"] + text.split(" ") + ["</s>"] for text in masked_texts] #['<s>','TP53','<mutsep>','<mut>','</s>']
		batch_tokens = [[token if token in self.vocab else "<unk>" for token in tokens] for tokens in batch_tokens] #['<s>','<unk>','<mutsep>','<mut>','<gensep>','TP53','<mutsep>','<mut>','</s>']
		batch_numerical_tokens = [self.vocab.lookup_indices(tokens) for tokens in batch_tokens] #[0,8,5,4,2]
		batch_attention = [[1] * len(tokens) for tokens in batch_tokens] 	#! 1=non-padding
		batch_annovar = [[0 if not tok in self.annovar else self.annovar[tok] for tok in tokens] for tokens in [["<s>"]+text.split(" ")+["</s>"] for text in texts]]
		# print(batch_annovar) #[0, 0, 0, 1579, 0, 0, 0, 11577, 0] 

		if padding:
			if max_length is None:
				max_length = max([len(tokens) for tokens in batch_tokens]) 
			for numerical_tokens,attention,annovar in zip(batch_numerical_tokens,batch_attention,batch_annovar):
				padding_num = max_length-len(numerical_tokens)
				if padding_num > 0:
					numerical_tokens += [1]*padding_num
					attention += [0]*padding_num
					annovar += [0]*padding_num

		if return_tensors == "pt":
			try:
				batch_numerical_tokens = torch.tensor(batch_numerical_tokens)
				batch_attention = torch.tensor(batch_attention)
				batch_annovar = torch.tensor(batch_annovar)
			except:
				raise ValueError("Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length.")
			assert batch_annovar.max() <= len(self.annovar), "Check annovar dict consistency in indices"

		output = {"input_ids": batch_numerical_tokens,"attention_mask": batch_attention,"annovar_mask": batch_annovar,}
		if not return_attention_mask:
			output.pop("attention_mask")
		return BatchEncoding(output)
	

In [55]:
tok = Tokenizer(vocab, vocab_annovar_tokenized) # vocab here is for genes324, vocab annovar tokenized is for all mutation

In [56]:
annotated_data = tok(texts, return_tensors="pt", padding=True)

In [57]:
annotated_data

{'input_ids': tensor([[ 0, 15,  5,  ...,  1,  1,  1],
        [ 0, 14,  5,  ...,  1,  1,  1],
        [ 0, 22,  5,  ...,  1,  1,  1],
        ...,
        [ 0, 67,  5,  ...,  1,  1,  1],
        [ 0,  7,  5,  ...,  1,  1,  1],
        [ 0,  9,  5,  ...,  1,  1,  1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'annovar_mask': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])}

In [58]:
annotated_data["input_ids"].shape, annotated_data["attention_mask"].shape, annotated_data["annovar_mask"].shape

(torch.Size([3461, 991]), torch.Size([3461, 991]), torch.Size([3461, 991]))

In [59]:
torch.nonzero(annotated_data["annovar_mask"]).shape

torch.Size([19595, 2])

In [60]:
# Sanity check
texts

patient_id
GENIE-DFCI-000013-8840     BRAF <mutsep> BRAF@G466L <gensep> CDKN2A <muts...
GENIE-DFCI-000016-74822    ARID1A <mutsep> ARID1A@Y551Lfs*72 <gensep> NRA...
GENIE-DFCI-000036-9268     ATRX <mutsep> ATRX@G2123E <gensep> KEAP1 <muts...
GENIE-DFCI-000036-9269     MLL <mutsep> MLL@A3489S <gensep> STK11 <mutsep...
GENIE-DFCI-000036-9278     BCORL1 <mutsep> BCORL1@P566H <gensep> ERBB4 <m...
                                                 ...                        
GENIE-VICC-994641-unk-1    BRAF <mutsep> BRAF@G469A <gensep> TP53 <mutsep...
GENIE-VICC-994755-unk-1                              NRAS <mutsep> NRAS@G12D
GENIE-VICC-995221-unk-1    CDKN2A <mutsep> CDKN2A@*50* <gensep> FLT1 <mut...
GENIE-VICC-996370-unk-1                             TP53 <mutsep> TP53@M237I
GENIE-VICC-996370-unk-2    APC <mutsep> APC@R876* <gensep> ATRX <mutsep> ...
Length: 3461, dtype: object

In [61]:
texts.reset_index().iloc[10]

patient_id                             GENIE-DFCI-000147-350443
0             APC <mutsep> APC@E1309Dfs*4 <gensep> ARID1A <m...
Name: 10, dtype: object

In [62]:
texts.iloc[10]

'APC <mutsep> APC@E1309Dfs*4 <gensep> ARID1A <mutsep> ARID1A@Q611Rfs*8 <gensep> AXL <mutsep> AXL@T328M <gensep> EP300 <mutsep> EP300@Q1904P <gensep> ESR1 <mutsep> ESR1@R28C <gensep> MED12 <mutsep> MED12@T1130P <gensep> NOTCH3 <mutsep> NOTCH3@R2207Q <gensep> NRAS <mutsep> NRAS@Q61L <gensep> POLD1 <mutsep> POLD1@D644N <gensep> TP53 <mutsep> TP53@S241_G245del'

In [63]:
annotated_data["input_ids"][10][:25]

tensor([  0,   9,   5,   4,   6,  14,   5,   4,   6, 103,   5,   4,   6,  41,
          5,   4,   6,  98,   5,   4,   6,  46,   5,   4,   6])

In [64]:
for i in annotated_data["input_ids"][10][:25]:
    print(vocab.lookup_token(i.item()), end= " ")

<s> APC <mutsep> <mut> <gensep> ARID1A <mutsep> <mut> <gensep> AXL <mutsep> <mut> <gensep> EP300 <mutsep> <mut> <gensep> ESR1 <mutsep> <mut> <gensep> MED12 <mutsep> <mut> <gensep> 

In [65]:
torch.nonzero(annotated_data["annovar_mask"][10])

tensor([[11],
        [15],
        [19],
        [23],
        [27],
        [31],
        [35]])

In [66]:
annotated_data["annovar_mask"][10][11], annotated_data["annovar_mask"][10][15], annotated_data["annovar_mask"][10][19], annotated_data["annovar_mask"][10][23]

(tensor(200961), tensor(621983), tensor(657450), tensor(1167051))

In [67]:
for k, v in vocab_annovar_tokenized.items():
    if v in [200961, 621983, 657450, 1167051]:
        print(k, v)

AXL@T328M 200961
EP300@Q1904P 621983
ESR1@R28C 657450
MED12@T1130P 1167051


In [68]:
annotated_data["attention_mask"][10][:20]

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [69]:
# save as dataframe for later use
survival_data_input_ids_pd = pd.DataFrame(annotated_data["input_ids"], index=texts.index)
survival_data_input_ids_pd.columns = [f"input_ids_{i}" for i in range(survival_data_input_ids_pd.shape[1])]
survival_data_annovar_ids_pd = pd.DataFrame(annotated_data["annovar_mask"], index=texts.index)
survival_data_annovar_ids_pd.columns = [f"annovar_ids_{i}" for i in range(survival_data_annovar_ids_pd.shape[1])]
survival_data_attention_mask_pd = pd.DataFrame(annotated_data["attention_mask"], index=texts.index)
survival_data_attention_mask_pd.columns = [f"mask_{i}" for i in range(survival_data_attention_mask_pd.shape[1])]

In [70]:
texts

patient_id
GENIE-DFCI-000013-8840     BRAF <mutsep> BRAF@G466L <gensep> CDKN2A <muts...
GENIE-DFCI-000016-74822    ARID1A <mutsep> ARID1A@Y551Lfs*72 <gensep> NRA...
GENIE-DFCI-000036-9268     ATRX <mutsep> ATRX@G2123E <gensep> KEAP1 <muts...
GENIE-DFCI-000036-9269     MLL <mutsep> MLL@A3489S <gensep> STK11 <mutsep...
GENIE-DFCI-000036-9278     BCORL1 <mutsep> BCORL1@P566H <gensep> ERBB4 <m...
                                                 ...                        
GENIE-VICC-994641-unk-1    BRAF <mutsep> BRAF@G469A <gensep> TP53 <mutsep...
GENIE-VICC-994755-unk-1                              NRAS <mutsep> NRAS@G12D
GENIE-VICC-995221-unk-1    CDKN2A <mutsep> CDKN2A@*50* <gensep> FLT1 <mut...
GENIE-VICC-996370-unk-1                             TP53 <mutsep> TP53@M237I
GENIE-VICC-996370-unk-2    APC <mutsep> APC@R876* <gensep> ATRX <mutsep> ...
Length: 3461, dtype: object

In [71]:
len(survival_patients_combined_reduced.patient_id.unique())

3461

In [72]:
# Note: survival data and other datasets follow different processing for use with transformer. Diff max token length
# Survival data: 
# Input: 'ATM <mutsep> ATM@R337C <gensep> DNMT3A <mutsep> DNMT3A@R882H <gensep> NRAS <mutsep> NRAS@G13D <gensep> TP53 <mutsep> TP53@G266R'
# After processing: max length - max number of tokens
# input_ids: [ 0, 18,  5,  4,  6, 85,  5,  4,  6,  8,  5,  4,  6,  7,  5,  4,  2,  1, 1,  1,  1,  1,  1,  1,  1 ...] 
# This is equivalent to <s> ATM <mutsep> <mut> <gensep> DNMT3A <mutsep> <mut> <gensep> NRAS <mutsep> <mut> <gensep> TP53 <mutsep> <mut> </s> <pad> <pad> <pad> <pad> ...
# annovar_mask: [0, 0, 0, 154350, 0, 0, 0, 493064, 0, 0, 0, 1191554, 0, 0, 0, 1837529, 0, 0, ...]
# The non-zero numbers are equivalent to mutation ids in vocab
# attention_mask: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, ...]
# 1 corresponds to non-padding.

# RECIST/AUDRC data:
# Input: 'ATM <mutsep> ATM@R337C <gensep> DNMT3A <mutsep> DNMT3A@R882H <gensep> NRAS <mutsep> NRAS@G13D <gensep> TP53 <mutsep> TP53@G266R'
# After processing:
# gens_mat_pd: [18,  7,  8, 85,  0,  0,  0, ...] # max length - max number of genes mutated
# The numbers here correspond to ids in genes_324 vocab, 0 indicates no mutations
# muts_mat_pd: [154350,  493064, 1837529, 1191554, 0, 0, 0, ...] # max length - max number of mutations
# The numbers here correspond to ids in vocab_annovar_tokenized, 0 indicates no mutations
# mask_mat_pd: [1,1,1,1,0,0,0,...,0,1,1,1,1,0,0,0...] # max length - sum of max genes and max mutations
# 1 indicates presence of genes or mutations, 0 otherwise

#### Processing RECIST and AUC datasets

In [73]:
recist_dataset = pd.concat([tcga_mutations, moores_mutations, cbio_hcc_mutations, cbio_brca_mutations])
recist_dataset["canonical_gene_name"] = recist_dataset["gene"].apply(lambda x: convert2canonical(x))

In [74]:
recist_dataset_reduced = recist_dataset[~recist_dataset.canonical_gene_name.isna()].reset_index(drop=True)
recist_dataset_reduced["point_mutations"] = recist_dataset_reduced["canonical_gene_name"] + "@" + recist_dataset_reduced["alteration"]
recist_dataset_reduced

Unnamed: 0,patient_id,gene,alteration,canonical_gene_name,point_mutations
0,TCGA-50-5931,DNMT3A,V258M,DNMT3A,DNMT3A@V258M
1,TCGA-50-5931,MSH2,I356V,MSH2,MSH2@I356V
2,TCGA-50-5931,RICTOR,L42F,RICTOR,RICTOR@L42F
3,TCGA-50-5931,NOTCH1,D297G,NOTCH1,NOTCH1@D297G
4,TCGA-50-5931,MLL2,E668*,MLL2,MLL2@E668*
...,...,...,...,...,...
102037,s_DS_bkm_058_T,PIK3CA,H1047L,PIK3CA,PIK3CA@H1047L
102038,s_DS_bkm_058_T,BCOR,N193T,BCOR,BCOR@N193T
102039,s_DS_bkm_059_T,SF3B1,I641V,SF3B1,SF3B1@I641V
102040,s_DS_bkm_059_T,ESR1,L536R,ESR1,ESR1@L536R


In [75]:
cl_mutations_reduced["point_mutations"] = cl_mutations_reduced["canonical_gene_name"] + "@" + cl_mutations_reduced["alteration"]
cl_mutations_reduced

Unnamed: 0,depmap_id,gene,alteration,canonical_gene_name,point_mutations
0,PR-rsoNmY,TNFRSF14,G68R,TNFRSF14,TNFRSF14@G68R
1,PR-yDgpga,TNFRSF14,G68R,TNFRSF14,TNFRSF14@G68R
2,PR-sgXEkc,TNFRSF14,G89C,TNFRSF14,TNFRSF14@G89C
3,PR-kRqGcx,TNFRSF14,A140S,TNFRSF14,TNFRSF14@A140S
4,PR-ZhEuUF,TNFRSF14,R149M,TNFRSF14,TNFRSF14@R149M
...,...,...,...,...,...
29444,PR-6Ybf3z,BCORL1,P1755QfsTer20,BCORL1,BCORL1@P1755QfsTer20
29445,PR-81oclJ,BCORL1,P1755QfsTer20,BCORL1,BCORL1@P1755QfsTer20
29446,PR-Qvs2q6,BCORL1,P1755QfsTer20,BCORL1,BCORL1@P1755QfsTer20
29447,PR-EaZDJD,BCORL1,E1767K,BCORL1,BCORL1@E1767K


In [76]:
# cell line data processing
cl2gens,cl2muts = defaultdict(set),defaultdict(set)
for idx,row in cl_mutations_reduced.iterrows():
    depmap_id,gene,mutation = row['depmap_id'],row['canonical_gene_name'],row['point_mutations']
    cl2gens[depmap_id].add(gene)
    if mutation in vocab_annovar_tokenized.keys():
        cl2muts[depmap_id].add(mutation)
    else:
        cl2muts[depmap_id].add('')

In [77]:
len(cl2gens), len(cl2muts)

(2313, 2313)

In [78]:
### patients data processing
patients2gens,patients2muts = defaultdict(set),defaultdict(set)
for idx,row in recist_dataset_reduced.iterrows():
    patient_id,gene,mutation = row['patient_id'],row['canonical_gene_name'],row['point_mutations']
    patients2gens[patient_id].add(gene)
    if mutation in vocab_annovar_tokenized.keys():
        patients2muts[patient_id].add(mutation)
    else:
        patients2muts[patient_id].add('')

In [79]:
# sanity check for index 10
patients2gens[13], patients2muts[13]

({'ATM', 'DNMT3A', 'NRAS', 'TP53'},
 {'ATM@R337C', 'DNMT3A@R882H', 'NRAS@G13D', 'TP53@G266R'})

In [80]:
len(patients2gens), len(patients2muts)

(9446, 9446)

In [81]:
max_gens = max([max([len(val) for val in cl2gens.values()]), max([len(val) for val in patients2gens.values()])])
max_muts = max([max([len(val) for val in cl2muts.values()]), max([len(val) for val in patients2muts.values()])])
max_length = max_gens + max_muts

In [82]:
max_gens, max_muts, max_length

(254, 543, 797)

In [83]:
ccle2gens_mat,ccle2muts_mat,ccle_mask = np.zeros((len(cl2gens),max_gens),int),np.zeros((len(cl2muts),max_muts),int),np.zeros((len(cl2gens),max_length),int)
patients2gens_mat,patients2muts_mat,patients_mask = np.zeros((len(patients2gens),max_gens),int),np.zeros((len(patients2muts),max_muts),int),np.zeros((len(patients2gens),max_length),int)
print(patients2gens_mat.shape,patients2muts_mat.shape,patients_mask.shape) 
for i,genes in enumerate(cl2gens.values()):
    for j,gene in enumerate(genes):
        ccle2gens_mat[i,j] = gene2id[gene]
        ccle_mask[i,j] = 1
for i,genes in enumerate(patients2gens.values()):
    for j,gene in enumerate(genes):
        patients2gens_mat[i,j] = gene2id[gene]
        patients_mask[i,j] = 1
for i,muts in enumerate(cl2muts.values()):
    for j,mut in enumerate(muts):
        if mut in vocab_annovar_tokenized.keys():
            ccle2muts_mat[i,j] = vocab_annovar_tokenized[mut]
            ccle_mask[i,max_gens+j] = 1
for i,muts in enumerate(patients2muts.values()):
    for j,mut in enumerate(muts):
        if mut in vocab_annovar_tokenized.keys():
            patients2muts_mat[i,j] = vocab_annovar_tokenized[mut]
            patients_mask[i,max_gens+j] = 1

(9446, 254) (9446, 543) (9446, 797)


In [84]:
ccle2gens_mat_pd = pd.DataFrame(ccle2gens_mat,index=cl2gens.keys())
ccle2muts_mat_pd = pd.DataFrame(ccle2muts_mat,index=cl2muts.keys()) 
ccle_mask_pd = pd.DataFrame(ccle_mask,index=cl2gens.keys())
patients2gens_mat_pd = pd.DataFrame(patients2gens_mat,index=patients2gens.keys())
patients2muts_mat_pd = pd.DataFrame(patients2muts_mat,index=patients2muts.keys()) 
patients_mask_pd = pd.DataFrame(patients_mask,index=patients2gens.keys())

In [85]:
ccle2gens_mat_pd.shape, ccle2muts_mat_pd.shape, ccle_mask.shape

((2313, 254), (2313, 543), (2313, 797))

In [86]:
patients2gens_mat_pd.shape, patients2muts_mat_pd.shape, patients_mask_pd.shape

((9446, 254), (9446, 543), (9446, 797))

In [87]:
# sanity check for patient above
patients2gens_mat_pd.loc[13].values

array([ 8, 13, 66,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0

In [88]:
gene2id["ATM"], gene2id["DNMT3A"], gene2id['NRAS'], gene2id['TP53']

(13, 66, 8, 7)

In [89]:
patients2muts_mat_pd.loc[13].values

array([ 174750, 2084268, 1349617,  560636,       0,       0,       0,
             0,       0,       0,       0,       0,       0,       0,
             0,       0,       0,       0,       0,       0,       0,
             0,       0,       0,       0,       0,       0,       0,
             0,       0,       0,       0,       0,       0,       0,
             0,       0,       0,       0,       0,       0,       0,
             0,       0,       0,       0,       0,       0,       0,
             0,       0,       0,       0,       0,       0,       0,
             0,       0,       0,       0,       0,       0,       0,
             0,       0,       0,       0,       0,       0,       0,
             0,       0,       0,       0,       0,       0,       0,
             0,       0,       0,       0,       0,       0,       0,
             0,       0,       0,       0,       0,       0,       0,
             0,       0,       0,       0,       0,       0,       0,
             0,     

In [90]:
for k, v in vocab_annovar_tokenized.items():
    if v in [174750, 2084268, 1349617,  560636,]:
        print(k, v)

ATM@R337C 174750
DNMT3A@R882H 560636
NRAS@G13D 1349617
TP53@G266R 2084268


In [91]:
patients_mask_pd.loc[13].values

array([1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [None]:
# potentially reuse the tokenizer(tok) from meddata in PREDICT-AI eval `annodata = tok(texts, return_tensors="pt", padding=True)` for survival data

### Combine with train-test splits for AUDRC and RECIST datasets

In [92]:
ccle2gens_mat_pd.columns = [f"input_ids_{i}" for i in range(ccle2gens_mat_pd.shape[1])]
ccle2muts_mat_pd.columns = [f"annovar_ids_{i}" for i in range(ccle2muts_mat_pd.shape[1])]
ccle_mask_pd.columns = [f"mask_{i}" for i in range(ccle_mask_pd.shape[1])]

patients2gens_mat_pd.columns = [f"input_ids_{i}" for i in range(patients2gens_mat_pd.shape[1])]
patients2muts_mat_pd.columns = [f"annovar_ids_{i}" for i in range(patients2muts_mat_pd.shape[1])]
patients_mask_pd.columns = [f"mask_{i}" for i in range(patients_mask_pd.shape[1])]

#### Cell Lines (AUDRC dataset)

In [93]:
# Experiment 1A
# needs data (input mutation ids), annovar (annovar mutation ids), attention (attention mask), drug (fingerprint), label (auc)
exp1A_cl_fold0_processed = {}
for div in ["train", "val", "test"]:
    exp1A_cl_fold0_processed[div] = {}
    for k, v in exp1A_cl_fold0[div].items():
        print(div + ": " + k, end=" -- ")
        exp1A_cl_fold0_processed[div][k] = {}
        # data
        merged_data = v.merge(ccle2gens_mat_pd, how="inner", left_on="sample_id", right_on=ccle2gens_mat_pd.index)
        print(merged_data.shape)
        # mutations2use = cl_mutations.loc[v.sample_id]
        if not (merged_data.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_data.sample_id))
        exp1A_cl_fold0_processed[div][k]["input_ids"] = merged_data

        # annovar
        merged_annovar = v.merge(ccle2muts_mat_pd, how="inner", left_on="sample_id", right_on=ccle2muts_mat_pd.index)
        print(merged_annovar.shape)
        if not (merged_annovar.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_annovar.sample_id))
        exp1A_cl_fold0_processed[div][k]["annovar_ids"] = merged_annovar

        # attention
        merged_attention = v.merge(ccle_mask_pd, how="inner", left_on="sample_id", right_on=ccle_mask_pd.index)
        print(merged_attention.shape)
        if not (merged_attention.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_attention.sample_id))
        exp1A_cl_fold0_processed[div][k]["attention_mask"] = merged_attention

train: BUPARLISIB -- (860, 260)
(860, 549)
(860, 803)
train: CISPLATIN -- (692, 260)
(692, 549)
(692, 803)
train: FLUOROURACIL -- (868, 260)
(868, 549)
(868, 803)
train: GEMCITABINE -- (861, 260)
(861, 549)
(861, 803)
train: PACLITAXEL -- (860, 260)
(860, 549)
(860, 803)
train: SORAFENIB -- (861, 260)
(861, 549)
(861, 803)
train: TEMOZOLOMIDE -- (864, 260)
(864, 549)
(864, 803)
val: BUPARLISIB -- (96, 260)
(96, 549)
(96, 803)
val: CISPLATIN -- (78, 260)
(78, 549)
(78, 803)
val: FLUOROURACIL -- (96, 260)
(96, 549)
(96, 803)
val: GEMCITABINE -- (96, 260)
(96, 549)
(96, 803)
val: PACLITAXEL -- (96, 260)
(96, 549)
(96, 803)
val: SORAFENIB -- (96, 260)
(96, 549)
(96, 803)
val: TEMOZOLOMIDE -- (97, 260)
(97, 549)
(97, 803)
test: BUPARLISIB -- (119, 260)
(119, 549)
(119, 803)
test: CISPLATIN -- (95, 260)
(95, 549)
(95, 803)
test: FLUOROURACIL -- (120, 260)
(120, 549)
(120, 803)
test: GEMCITABINE -- (118, 260)
(118, 549)
(118, 803)
test: PACLITAXEL -- (119, 260)
(119, 549)
(119, 803)
test: SOR

In [94]:
merged_attention

Unnamed: 0,sample_id,drug_name,auc,ic50,drug_category,response_label,mask_0,mask_1,mask_2,mask_3,...,mask_787,mask_788,mask_789,mask_790,mask_791,mask_792,mask_793,mask_794,mask_795,mask_796
0,PR-6134Do,TEMOZOLOMIDE,0.973411,6.793903,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
1,PR-I3vWjE,TEMOZOLOMIDE,0.970762,9.183456,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
2,PR-HlfNG0,TEMOZOLOMIDE,0.974962,8.157104,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
3,PR-MvGdkR,TEMOZOLOMIDE,0.977305,7.104347,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
4,PR-kERVrz,TEMOZOLOMIDE,0.938574,7.450541,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
115,PR-97JLE2,TEMOZOLOMIDE,0.991676,6.531208,1,0,1,1,1,0,...,0,0,0,0,0,0,0,0,0,0
116,PR-ko8bDr,TEMOZOLOMIDE,0.977194,6.834325,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
117,PR-y3wle9,TEMOZOLOMIDE,0.983291,6.807553,1,0,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
118,PR-qtSxBH,TEMOZOLOMIDE,0.977976,6.375077,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0


In [95]:
# Experiment 1B
# needs data (input mutation ids), annovar (annovar mutation ids), attention (attention mask), drug (fingerprint), label (auc)
exp1B_cl_fold0_processed = {}
for div in ["train", "val", "test"]:
    exp1B_cl_fold0_processed[div] = {}
    for k, v in exp1B_cl_fold0[div].items():
        print(div + ": " + str(k), end=" -- ")
        exp1B_cl_fold0_processed[div][k] = {}
        # data
        merged_data = v.merge(ccle2gens_mat_pd, how="inner", left_on="sample_id", right_on=ccle2gens_mat_pd.index)
        print(merged_data.shape)
        # mutations2use = cl_mutations.loc[v.sample_id]
        if not (merged_data.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_data.sample_id))
        exp1B_cl_fold0_processed[div][k]["input_ids"] = merged_data

        # annovar
        merged_annovar = v.merge(ccle2muts_mat_pd, how="inner", left_on="sample_id", right_on=ccle2muts_mat_pd.index)
        print(merged_annovar.shape)
        if not (merged_annovar.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_annovar.sample_id))
        exp1B_cl_fold0_processed[div][k]["annovar_ids"] = merged_annovar

        # attention
        merged_attention = v.merge(ccle_mask_pd, how="inner", left_on="sample_id", right_on=ccle_mask_pd.index)
        print(merged_attention.shape)
        if not (merged_attention.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_attention.sample_id))
        exp1B_cl_fold0_processed[div][k]["attention_mask"] = merged_attention

train: ('CISPLATIN', 'TCGA-CESC', 'TCGA') -- (692, 260)
(692, 549)
(692, 803)
train: ('CISPLATIN', 'TCGA-HNSC', 'TCGA') -- (692, 260)
(692, 549)
(692, 803)
train: ('FLUOROURACIL', 'TCGA-STAD', 'TCGA') -- (868, 260)
(868, 549)
(868, 803)
train: ('GEMCITABINE', 'TCGA-PAAD', 'TCGA') -- (861, 260)
(861, 549)
(861, 803)
train: ('PACLITAXEL', 'TCGA-BRCA', 'TCGA') -- (860, 260)
(860, 549)
(860, 803)
train: ('TEMOZOLOMIDE', 'TCGA-LGG', 'TCGA') -- (864, 260)
(864, 549)
(864, 803)
val: ('CISPLATIN', 'TCGA-CESC', 'TCGA') -- (78, 260)
(78, 549)
(78, 803)
val: ('CISPLATIN', 'TCGA-HNSC', 'TCGA') -- (78, 260)
(78, 549)
(78, 803)
val: ('FLUOROURACIL', 'TCGA-STAD', 'TCGA') -- (96, 260)
(96, 549)
(96, 803)
val: ('GEMCITABINE', 'TCGA-PAAD', 'TCGA') -- (96, 260)
(96, 549)
(96, 803)
val: ('PACLITAXEL', 'TCGA-BRCA', 'TCGA') -- (96, 260)
(96, 549)
(96, 803)
val: ('TEMOZOLOMIDE', 'TCGA-LGG', 'TCGA') -- (97, 260)
(97, 549)
(97, 803)
test: ('CISPLATIN', 'TCGA-CESC', 'TCGA') -- (95, 260)
(95, 549)
(95, 803)
test

In [96]:
exp1B_cl_fold0_processed["train"][('TEMOZOLOMIDE', 'TCGA-LGG', 'TCGA')]["input_ids"]

Unnamed: 0,sample_id,drug_name,auc,ic50,drug_category,response_label,input_ids_0,input_ids_1,input_ids_2,input_ids_3,...,input_ids_244,input_ids_245,input_ids_246,input_ids_247,input_ids_248,input_ids_249,input_ids_250,input_ids_251,input_ids_252,input_ids_253
0,PR-yqcDUP,TEMOZOLOMIDE,0.901061,3.432533,1,1,322,7,0,0,...,0,0,0,0,0,0,0,0,0,0
1,PR-eJKW7S,TEMOZOLOMIDE,0.974787,7.666837,1,1,22,7,103,11,...,0,0,0,0,0,0,0,0,0,0
2,PR-7wfZWS,TEMOZOLOMIDE,0.992637,7.150543,1,0,7,303,46,125,...,0,0,0,0,0,0,0,0,0,0
3,PR-4ysADv,TEMOZOLOMIDE,0.954605,3.987438,1,1,231,7,115,164,...,0,0,0,0,0,0,0,0,0,0
4,PR-iTkQMf,TEMOZOLOMIDE,0.972264,5.150002,1,1,8,216,53,153,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
859,PR-s9Eho8,TEMOZOLOMIDE,0.986594,6.672166,1,0,127,7,0,0,...,0,0,0,0,0,0,0,0,0,0
860,PR-5sCxrE,TEMOZOLOMIDE,0.907780,4.312283,1,1,231,127,7,40,...,0,0,0,0,0,0,0,0,0,0
861,PR-tcGtRs,TEMOZOLOMIDE,0.975047,5.617542,1,1,7,204,297,103,...,0,0,0,0,0,0,0,0,0,0
862,PR-JSUJXv,TEMOZOLOMIDE,0.973179,7.607630,1,1,12,167,10,19,...,0,0,0,0,0,0,0,0,0,0


In [97]:
# Experiment 2A
exp2A_cl_fold0_processed = {}
for div in ["train", "val", "test"]:
    exp2A_cl_fold0_processed[div] = {}
    v = exp2A_cl_fold0[div]
    print(div, end=" -- ")

    # data
    merged_data = v.merge(ccle2gens_mat_pd, how="inner", left_on="sample_id", right_on=ccle2gens_mat_pd.index)
    print(merged_data.shape)
    # mutations2use = cl_mutations.loc[v.sample_id]
    if not (merged_data.shape[0] == v.shape[0]):
        print("Missing some mutations!")
        print(set(v.sample_id) - set(merged_data.sample_id))
    exp2A_cl_fold0_processed[div]["input_ids"] = merged_data

    # annovar
    merged_annovar = v.merge(ccle2muts_mat_pd, how="inner", left_on="sample_id", right_on=ccle2muts_mat_pd.index)
    print(merged_annovar.shape)
    # mutations2use = cl_mutations.loc[v.sample_id]
    if not (merged_annovar.shape[0] == v.shape[0]):
        print("Missing some mutations!")
        print(set(v.sample_id) - set(merged_annovar.sample_id))
    exp2A_cl_fold0_processed[div]["annovar_ids"] = merged_annovar

    # attention
    merged_attention = v.merge(ccle_mask_pd, how="inner", left_on="sample_id", right_on=ccle_mask_pd.index)
    print(merged_attention.shape)
    if not (merged_attention.shape[0] == v.shape[0]):
        print("Missing some mutations!")
        print(set(v.sample_id) - set(merged_attention.sample_id))
    exp2A_cl_fold0_processed[div]["attention_mask"] = merged_attention

train -- (156441, 260)
(156441, 549)
(156441, 803)
val -- (17371, 260)
(17371, 549)
(17371, 803)
test -- (21589, 260)
(21589, 549)
(21589, 803)


In [98]:
exp2A_cl_fold0_processed["train"]["input_ids"]

Unnamed: 0,sample_id,drug_name,auc,ic50,drug_category,response_label,input_ids_0,input_ids_1,input_ids_2,input_ids_3,...,input_ids_244,input_ids_245,input_ids_246,input_ids_247,input_ids_248,input_ids_249,input_ids_250,input_ids_251,input_ids_252,input_ids_253
0,PR-132fPs,DOCETAXEL,0.191876,-4.662091,1,1,36,303,226,186,...,0,0,0,0,0,0,0,0,0,0
1,PR-L3QLdq,ELEPHANTIN,0.940458,5.730421,3,0,7,179,327,168,...,0,0,0,0,0,0,0,0,0,0
2,PR-NxSV8u,MITOXANTRONE,0.921925,4.070582,1,0,89,7,11,277,...,0,0,0,0,0,0,0,0,0,0
3,PR-oLPbwB,DACTINOMYCIN,0.179515,-6.588337,1,1,146,11,121,53,...,0,0,0,0,0,0,0,0,0,0
4,PR-4ngqZx,CCT007093,0.989986,3.724712,3,0,162,52,23,50,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
156436,PR-M4505H,PFI-1,0.919051,3.534174,3,1,7,90,8,224,...,0,0,0,0,0,0,0,0,0,0
156437,PR-Bz57NU,NILOTINIB,0.995489,4.073733,1,0,231,220,313,12,...,0,0,0,0,0,0,0,0,0,0
156438,PR-6SyWYo,SAPITINIB,0.492491,-1.567439,2,1,89,7,49,189,...,0,0,0,0,0,0,0,0,0,0
156439,PR-wGySam,TASELISIB,0.901939,2.716776,2,0,22,7,105,55,...,0,0,0,0,0,0,0,0,0,0


In [99]:
# Experiment 2B
# needs data (input mutation ids), annovar (annovar mutation ids), attention (attention mask), drug (fingerprint), label (auc)
exp2B_cl_fold0_processed = {}
for div in ["train", "val", "test"]:
    exp2B_cl_fold0_processed[div] = {}
    for k, v in exp2B_cl_fold0[div].items():
        print(div + ": " + str(k), end=" -- ")
        exp2B_cl_fold0_processed[div][k] = {}
        # data
        merged_data = v.merge(ccle2gens_mat_pd, how="inner", left_on="sample_id", right_on=ccle2gens_mat_pd.index)
        print(merged_data.shape)
        # mutations2use = cl_mutations.loc[v.sample_id]
        if not (merged_data.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_data.sample_id))
        exp2B_cl_fold0_processed[div][k]["input_ids"] = merged_data

        # annovar
        merged_annovar = v.merge(ccle2muts_mat_pd, how="inner", left_on="sample_id", right_on=ccle2muts_mat_pd.index)
        print(merged_annovar.shape)
        if not (merged_annovar.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_annovar.sample_id))
        exp2B_cl_fold0_processed[div][k]["annovar_ids"] = merged_annovar

        # attention
        merged_attention = v.merge(ccle_mask_pd, how="inner", left_on="sample_id", right_on=ccle_mask_pd.index)
        print(merged_attention.shape)
        if not (merged_attention.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_attention.sample_id))
        exp2B_cl_fold0_processed[div][k]["attention_mask"] = merged_attention

train: TCGA-BRCA -- (156441, 260)
(156441, 549)
(156441, 803)
train: TCGA-CESC -- (156441, 260)
(156441, 549)
(156441, 803)
train: TCGA-HNSC -- (156441, 260)
(156441, 549)
(156441, 803)
train: TCGA-STAD -- (156441, 260)
(156441, 549)
(156441, 803)
train: TCGA-PAAD -- (156441, 260)
(156441, 549)
(156441, 803)
train: TCGA-LGG -- (156441, 260)
(156441, 549)
(156441, 803)
val: TCGA-BRCA -- (17371, 260)
(17371, 549)
(17371, 803)
val: TCGA-CESC -- (17371, 260)
(17371, 549)
(17371, 803)
val: TCGA-HNSC -- (17371, 260)
(17371, 549)
(17371, 803)
val: TCGA-STAD -- (17371, 260)
(17371, 549)
(17371, 803)
val: TCGA-PAAD -- (17371, 260)
(17371, 549)
(17371, 803)
val: TCGA-LGG -- (17371, 260)
(17371, 549)
(17371, 803)
test: TCGA-BRCA -- (21589, 260)
(21589, 549)
(21589, 803)
test: TCGA-CESC -- (21589, 260)
(21589, 549)
(21589, 803)
test: TCGA-HNSC -- (21589, 260)
(21589, 549)
(21589, 803)
test: TCGA-STAD -- (21589, 260)
(21589, 549)
(21589, 803)
test: TCGA-PAAD -- (21589, 260)
(21589, 549)
(21589, 803

#### Patients (RECIST datasets)

In [100]:
# Experiment 1A
# needs data (input mutation ids), annovar (annovar mutation ids), attention (attention mask), drug (fingerprint), label (RECIST)
# fold 0
exp1A_patient_fold0_processed = {}
for div in ["train", "val", "test"]:
    exp1A_patient_fold0_processed[div] = {}
    for k, v in exp1A_patient_fold0[div].items():
        print(div + ": " + k, end=" -- ")
        exp1A_patient_fold0_processed[div][k] = {}
        # data
        merged_data = v.merge(patients2gens_mat_pd, how="inner", left_on="sample_id", right_on=patients2gens_mat_pd.index)
        print(merged_data.shape)
        # mutations2use = cl_mutations.loc[v.sample_id]
        if not (merged_data.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_data.sample_id))
        exp1A_patient_fold0_processed[div][k]["input_ids"] = merged_data

        # annovar
        merged_annovar = v.merge(patients2muts_mat_pd, how="inner", left_on="sample_id", right_on=patients2muts_mat_pd.index)
        print(merged_annovar.shape)
        if not (merged_annovar.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_annovar.sample_id))
        exp1A_patient_fold0_processed[div][k]["annovar_ids"] = merged_annovar

        # attention
        merged_attention = v.merge(patients_mask_pd, how="inner", left_on="sample_id", right_on=patients_mask_pd.index)
        print(merged_attention.shape)
        if not (merged_attention.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_attention.sample_id))
        exp1A_patient_fold0_processed[div][k]["attention_mask"] = merged_attention

# fold 1
exp1A_patient_fold1_processed = {}
for div in ["train", "val", "test"]:
    exp1A_patient_fold1_processed[div] = {}
    for k, v in exp1A_patient_fold1[div].items():
        print(div + ": " + k, end=" -- ")
        exp1A_patient_fold1_processed[div][k] = {}
        # data
        merged_data = v.merge(patients2gens_mat_pd, how="inner", left_on="sample_id", right_on=patients2gens_mat_pd.index)
        print(merged_data.shape)
        # mutations2use = cl_mutations.loc[v.sample_id]
        if not (merged_data.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_data.sample_id))
        exp1A_patient_fold1_processed[div][k]["input_ids"] = merged_data

        # annovar
        merged_annovar = v.merge(patients2muts_mat_pd, how="inner", left_on="sample_id", right_on=patients2muts_mat_pd.index)
        print(merged_annovar.shape)
        if not (merged_annovar.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_annovar.sample_id))
        exp1A_patient_fold1_processed[div][k]["annovar_ids"] = merged_annovar

        # attention
        merged_attention = v.merge(patients_mask_pd, how="inner", left_on="sample_id", right_on=patients_mask_pd.index)
        print(merged_attention.shape)
        if not (merged_attention.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_attention.sample_id))
        exp1A_patient_fold1_processed[div][k]["attention_mask"] = merged_attention

# fold 2
exp1A_patient_fold2_processed = {}
for div in ["train", "val", "test"]:
    exp1A_patient_fold2_processed[div] = {}
    for k, v in exp1A_patient_fold2[div].items():
        print(div + ": " + k, end=" -- ")
        exp1A_patient_fold2_processed[div][k] = {}
        # data
        merged_data = v.merge(patients2gens_mat_pd, how="inner", left_on="sample_id", right_on=patients2gens_mat_pd.index)
        print(merged_data.shape)
        # mutations2use = cl_mutations.loc[v.sample_id]
        if not (merged_data.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_data.sample_id))
        exp1A_patient_fold2_processed[div][k]["input_ids"] = merged_data

        # annovar
        merged_annovar = v.merge(patients2muts_mat_pd, how="inner", left_on="sample_id", right_on=patients2muts_mat_pd.index)
        print(merged_annovar.shape)
        if not (merged_annovar.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_annovar.sample_id))
        exp1A_patient_fold2_processed[div][k]["annovar_ids"] = merged_annovar

        # attention
        merged_attention = v.merge(patients_mask_pd, how="inner", left_on="sample_id", right_on=patients_mask_pd.index)
        print(merged_attention.shape)
        if not (merged_attention.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_attention.sample_id))
        exp1A_patient_fold2_processed[div][k]["attention_mask"] = merged_attention

train: BUPARLISIB -- (16, 259)
(16, 548)
(16, 802)
train: CISPLATIN -- (64, 259)
(64, 548)
(64, 802)
train: FLUOROURACIL -- (33, 259)
(33, 548)
(33, 802)
train: GEMCITABINE -- (37, 259)
(37, 548)
(37, 802)
train: PACLITAXEL -- (26, 259)
(26, 548)
(26, 802)
train: SORAFENIB -- (38, 259)
(38, 548)
(38, 802)
train: TEMOZOLOMIDE -- (58, 259)
(58, 548)
(58, 802)
val: BUPARLISIB -- (2, 259)
(2, 548)
(2, 802)
val: CISPLATIN -- (8, 259)
(8, 548)
(8, 802)
val: FLUOROURACIL -- (3, 259)
(3, 548)
(3, 802)
val: GEMCITABINE -- (4, 259)
(4, 548)
(4, 802)
val: PACLITAXEL -- (3, 259)
(3, 548)
(3, 802)
val: SORAFENIB -- (5, 259)
(5, 548)
(5, 802)
val: TEMOZOLOMIDE -- (8, 259)
(8, 548)
(8, 802)
test: BUPARLISIB -- (9, 259)
(9, 548)
(9, 802)
test: CISPLATIN -- (27, 259)
(27, 548)
(27, 802)
test: FLUOROURACIL -- (12, 259)
(12, 548)
(12, 802)
test: GEMCITABINE -- (14, 259)
(14, 548)
(14, 802)
test: PACLITAXEL -- (8, 259)
(8, 548)
(8, 802)
test: SORAFENIB -- (15, 259)
(15, 548)
(15, 802)
test: TEMOZOLOMIDE -

In [101]:
exp1A_patient_fold0_processed["train"]["CISPLATIN"]["input_ids"]

Unnamed: 0,sample_id,drug_name,recist,mappedProject,dataset_name,input_ids_0,input_ids_1,input_ids_2,input_ids_3,input_ids_4,...,input_ids_244,input_ids_245,input_ids_246,input_ids_247,input_ids_248,input_ids_249,input_ids_250,input_ids_251,input_ids_252,input_ids_253
0,TCGA-VS-A94Y,CISPLATIN,0,TCGA-CESC,TCGA,23,12,188,0,0,...,0,0,0,0,0,0,0,0,0,0
1,TCGA-IQ-A61G,CISPLATIN,0,TCGA-HNSC,TCGA,7,28,43,178,317,...,0,0,0,0,0,0,0,0,0,0
2,TCGA-D3-A1Q3,CISPLATIN,1,TCGA-SKCM,TCGA,131,74,44,0,0,...,0,0,0,0,0,0,0,0,0,0
3,TCGA-VS-A9UH,CISPLATIN,1,TCGA-CESC,TCGA,40,52,44,115,12,...,0,0,0,0,0,0,0,0,0,0
4,TCGA-DS-A7WF,CISPLATIN,1,TCGA-CESC,TCGA,11,46,136,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
59,TCGA-VS-A9UA,CISPLATIN,1,TCGA-CESC,TCGA,7,52,8,144,12,...,0,0,0,0,0,0,0,0,0,0
60,TCGA-LK-A4NZ,CISPLATIN,0,TCGA-MESO,TCGA,254,135,75,0,0,...,0,0,0,0,0,0,0,0,0,0
61,TCGA-QS-A5YQ,CISPLATIN,1,TCGA-UCEC,TCGA,284,85,56,108,86,...,0,0,0,0,0,0,0,0,0,0
62,TCGA-IR-A3LL,CISPLATIN,1,TCGA-CESC,TCGA,109,58,163,38,175,...,0,0,0,0,0,0,0,0,0,0


In [102]:
# Experiment 1B
# needs data (input mutation ids), annovar (annovar mutation ids), attention (attention mask), drug (fingerprint), label (RECIST)
# fold 0
exp1B_patient_fold0_processed = {}
for div in ["train", "val", "test"]:
    exp1B_patient_fold0_processed[div] = {}
    for k, v in exp1B_patient_fold0[div].items():
        print(div + ": " + str(k), end=" -- ")
        exp1B_patient_fold0_processed[div][k] = {}
        # data
        merged_data = v.merge(patients2gens_mat_pd, how="inner", left_on="sample_id", right_on=patients2gens_mat_pd.index)
        print(merged_data.shape)
        # mutations2use = cl_mutations.loc[v.sample_id]
        if not (merged_data.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_data.sample_id))
        exp1B_patient_fold0_processed[div][k]["input_ids"] = merged_data

        # annovar
        merged_annovar = v.merge(patients2muts_mat_pd, how="inner", left_on="sample_id", right_on=patients2muts_mat_pd.index)
        print(merged_annovar.shape)
        if not (merged_annovar.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_annovar.sample_id))
        exp1B_patient_fold0_processed[div][k]["annovar_ids"] = merged_annovar

        # attention
        merged_attention = v.merge(patients_mask_pd, how="inner", left_on="sample_id", right_on=patients_mask_pd.index)
        print(merged_attention.shape)
        if not (merged_attention.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_attention.sample_id))
        exp1B_patient_fold0_processed[div][k]["attention_mask"] = merged_attention

# fold 1
exp1B_patient_fold1_processed = {}
for div in ["train", "val", "test"]:
    exp1B_patient_fold1_processed[div] = {}
    for k, v in exp1B_patient_fold1[div].items():
        print(div + ": " + str(k), end=" -- ")
        exp1B_patient_fold1_processed[div][k] = {}
        # data
        merged_data = v.merge(patients2gens_mat_pd, how="inner", left_on="sample_id", right_on=patients2gens_mat_pd.index)
        print(merged_data.shape)
        # mutations2use = cl_mutations.loc[v.sample_id]
        if not (merged_data.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_data.sample_id))
        exp1B_patient_fold1_processed[div][k]["input_ids"] = merged_data

        # annovar
        merged_annovar = v.merge(patients2muts_mat_pd, how="inner", left_on="sample_id", right_on=patients2muts_mat_pd.index)
        print(merged_annovar.shape)
        if not (merged_annovar.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_annovar.sample_id))
        exp1B_patient_fold1_processed[div][k]["annovar_ids"] = merged_annovar

        # attention
        merged_attention = v.merge(patients_mask_pd, how="inner", left_on="sample_id", right_on=patients_mask_pd.index)
        print(merged_attention.shape)
        if not (merged_attention.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_attention.sample_id))
        exp1B_patient_fold1_processed[div][k]["attention_mask"] = merged_attention

# fold 2
exp1B_patient_fold2_processed = {}
for div in ["train", "val", "test"]:
    exp1B_patient_fold2_processed[div] = {}
    for k, v in exp1B_patient_fold2[div].items():
        print(div + ": " + str(k), end=" -- ")
        exp1B_patient_fold2_processed[div][k] = {}
        # data
        merged_data = v.merge(patients2gens_mat_pd, how="inner", left_on="sample_id", right_on=patients2gens_mat_pd.index)
        print(merged_data.shape)
        # mutations2use = cl_mutations.loc[v.sample_id]
        if not (merged_data.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_data.sample_id))
        exp1B_patient_fold2_processed[div][k]["input_ids"] = merged_data

        # annovar
        merged_annovar = v.merge(patients2muts_mat_pd, how="inner", left_on="sample_id", right_on=patients2muts_mat_pd.index)
        print(merged_annovar.shape)
        if not (merged_annovar.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_annovar.sample_id))
        exp1B_patient_fold2_processed[div][k]["annovar_ids"] = merged_annovar

        # attention
        merged_attention = v.merge(patients_mask_pd, how="inner", left_on="sample_id", right_on=patients_mask_pd.index)
        print(merged_attention.shape)
        if not (merged_attention.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_attention.sample_id))
        exp1B_patient_fold2_processed[div][k]["attention_mask"] = merged_attention

train: ('CISPLATIN', 'TCGA-CESC', 'TCGA') -- (29, 259)
(29, 548)
(29, 802)
train: ('CISPLATIN', 'TCGA-HNSC', 'TCGA') -- (22, 259)
(22, 548)
(22, 802)
train: ('FLUOROURACIL', 'TCGA-STAD', 'TCGA') -- (19, 259)
(19, 548)
(19, 802)
train: ('GEMCITABINE', 'TCGA-PAAD', 'TCGA') -- (23, 259)
(23, 548)
(23, 802)
train: ('PACLITAXEL', 'TCGA-BRCA', 'TCGA') -- (15, 259)
(15, 548)
(15, 802)
train: ('TEMOZOLOMIDE', 'TCGA-LGG', 'TCGA') -- (52, 259)
(52, 548)
(52, 802)
val: ('CISPLATIN', 'TCGA-CESC', 'TCGA') -- (4, 259)
(4, 548)
(4, 802)
val: ('CISPLATIN', 'TCGA-HNSC', 'TCGA') -- (3, 259)
(3, 548)
(3, 802)
val: ('FLUOROURACIL', 'TCGA-STAD', 'TCGA') -- (3, 259)
(3, 548)
(3, 802)
val: ('GEMCITABINE', 'TCGA-PAAD', 'TCGA') -- (3, 259)
(3, 548)
(3, 802)
val: ('PACLITAXEL', 'TCGA-BRCA', 'TCGA') -- (3, 259)
(3, 548)
(3, 802)
val: ('TEMOZOLOMIDE', 'TCGA-LGG', 'TCGA') -- (6, 259)
(6, 548)
(6, 802)
test: ('CISPLATIN', 'TCGA-CESC', 'TCGA') -- (15, 259)
(15, 548)
(15, 802)
test: ('CISPLATIN', 'TCGA-HNSC', 'TCGA')

In [103]:
exp1B_patient_fold1_processed["train"][("CISPLATIN", "TCGA-HNSC", "TCGA")]["annovar_ids"]

Unnamed: 0,sample_id,drug_name,recist,mappedProject,dataset_name,annovar_ids_0,annovar_ids_1,annovar_ids_2,annovar_ids_3,annovar_ids_4,...,annovar_ids_533,annovar_ids_534,annovar_ids_535,annovar_ids_536,annovar_ids_537,annovar_ids_538,annovar_ids_539,annovar_ids_540,annovar_ids_541,annovar_ids_542
0,TCGA-P3-A6T8,CISPLATIN,1,TCGA-HNSC,TCGA,0,688216,2085161,0,0,...,0,0,0,0,0,0,0,0,0,0
1,TCGA-CQ-A4CH,CISPLATIN,1,TCGA-HNSC,TCGA,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,TCGA-UF-A7JK,CISPLATIN,1,TCGA-HNSC,TCGA,0,240199,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,TCGA-BA-A6DJ,CISPLATIN,1,TCGA-HNSC,TCGA,0,788619,1665663,1007290,1692774,...,0,0,0,0,0,0,0,0,0,0
4,TCGA-BA-A4IF,CISPLATIN,1,TCGA-HNSC,TCGA,0,2084190,1516018,349648,195978,...,0,0,0,0,0,0,0,0,0,0
5,TCGA-CN-A6V6,CISPLATIN,1,TCGA-HNSC,TCGA,0,1135170,126795,620425,507638,...,0,0,0,0,0,0,0,0,0,0
6,TCGA-CN-A641,CISPLATIN,1,TCGA-HNSC,TCGA,0,347797,2084023,0,0,...,0,0,0,0,0,0,0,0,0,0
7,TCGA-D6-A6EK,CISPLATIN,1,TCGA-HNSC,TCGA,0,1306343,295771,1216011,560645,...,0,0,0,0,0,0,0,0,0,0
8,TCGA-UF-A7JJ,CISPLATIN,1,TCGA-HNSC,TCGA,0,1155311,1364464,450530,116778,...,0,0,0,0,0,0,0,0,0,0
9,TCGA-KU-A6H7,CISPLATIN,1,TCGA-HNSC,TCGA,0,1973398,1680628,389853,188448,...,0,0,0,0,0,0,0,0,0,0


In [104]:
# Experiment 2A
# fold 0
exp2A_patient_fold0_processed = {}
for div in ["train", "val", "test"]:
    exp2A_patient_fold0_processed[div] = {}
    v = exp2A_patient_fold0[div]
    print(div, end=" -- ")

    # data
    merged_data = v.merge(patients2gens_mat_pd, how="inner", left_on="sample_id", right_on=patients2gens_mat_pd.index)
    print(merged_data.shape)
    # mutations2use = cl_mutations.loc[v.sample_id]
    if not (merged_data.shape[0] == v.shape[0]):
        print("Missing some mutations!")
        print(set(v.sample_id) - set(merged_data.sample_id))
    exp2A_patient_fold0_processed[div]["input_ids"] = merged_data

    # annovar
    merged_annovar = v.merge(patients2muts_mat_pd, how="inner", left_on="sample_id", right_on=patients2muts_mat_pd.index)
    print(merged_annovar.shape)
    # mutations2use = cl_mutations.loc[v.sample_id]
    if not (merged_annovar.shape[0] == v.shape[0]):
        print("Missing some mutations!")
        print(set(v.sample_id) - set(merged_annovar.sample_id))
    exp2A_patient_fold0_processed[div]["annovar_ids"] = merged_annovar

    # attention
    merged_attention = v.merge(patients_mask_pd, how="inner", left_on="sample_id", right_on=patients_mask_pd.index)
    print(merged_attention.shape)
    if not (merged_attention.shape[0] == v.shape[0]):
        print("Missing some mutations!")
        print(set(v.sample_id) - set(merged_attention.sample_id))
    exp2A_patient_fold0_processed[div]["attention_mask"] = merged_attention

# fold 1
exp2A_patient_fold1_processed = {}
for div in ["train", "val", "test"]:
    exp2A_patient_fold1_processed[div] = {}
    v = exp2A_patient_fold1[div]
    print(div, end=" -- ")

    # data
    merged_data = v.merge(patients2gens_mat_pd, how="inner", left_on="sample_id", right_on=patients2gens_mat_pd.index)
    print(merged_data.shape)
    # mutations2use = cl_mutations.loc[v.sample_id]
    if not (merged_data.shape[0] == v.shape[0]):
        print("Missing some mutations!")
        print(set(v.sample_id) - set(merged_data.sample_id))
    exp2A_patient_fold1_processed[div]["input_ids"] = merged_data

    # annovar
    merged_annovar = v.merge(patients2muts_mat_pd, how="inner", left_on="sample_id", right_on=patients2muts_mat_pd.index)
    print(merged_annovar.shape)
    # mutations2use = cl_mutations.loc[v.sample_id]
    if not (merged_annovar.shape[0] == v.shape[0]):
        print("Missing some mutations!")
        print(set(v.sample_id) - set(merged_annovar.sample_id))
    exp2A_patient_fold1_processed[div]["annovar_ids"] = merged_annovar

    # attention
    merged_attention = v.merge(patients_mask_pd, how="inner", left_on="sample_id", right_on=patients_mask_pd.index)
    print(merged_attention.shape)
    if not (merged_attention.shape[0] == v.shape[0]):
        print("Missing some mutations!")
        print(set(v.sample_id) - set(merged_attention.sample_id))
    exp2A_patient_fold1_processed[div]["attention_mask"] = merged_attention

# fold 2
exp2A_patient_fold2_processed = {}
for div in ["train", "val", "test"]:
    exp2A_patient_fold2_processed[div] = {}
    v = exp2A_patient_fold2[div]
    print(div, end=" -- ")

    # data
    merged_data = v.merge(patients2gens_mat_pd, how="inner", left_on="sample_id", right_on=patients2gens_mat_pd.index)
    print(merged_data.shape)
    # mutations2use = cl_mutations.loc[v.sample_id]
    if not (merged_data.shape[0] == v.shape[0]):
        print("Missing some mutations!")
        print(set(v.sample_id) - set(merged_data.sample_id))
    exp2A_patient_fold2_processed[div]["input_ids"] = merged_data

    # annovar
    merged_annovar = v.merge(patients2muts_mat_pd, how="inner", left_on="sample_id", right_on=patients2muts_mat_pd.index)
    print(merged_annovar.shape)
    # mutations2use = cl_mutations.loc[v.sample_id]
    if not (merged_annovar.shape[0] == v.shape[0]):
        print("Missing some mutations!")
        print(set(v.sample_id) - set(merged_annovar.sample_id))
    exp2A_patient_fold2_processed[div]["annovar_ids"] = merged_annovar

    # attention
    merged_attention = v.merge(patients_mask_pd, how="inner", left_on="sample_id", right_on=patients_mask_pd.index)
    print(merged_attention.shape)
    if not (merged_attention.shape[0] == v.shape[0]):
        print("Missing some mutations!")
        print(set(v.sample_id) - set(merged_attention.sample_id))
    exp2A_patient_fold2_processed[div]["attention_mask"] = merged_attention

train -- (488, 259)
(488, 548)
(488, 802)
val -- (53, 259)
(53, 548)
(53, 802)
test -- (115, 259)
(115, 548)
(115, 802)
train -- (488, 259)
(488, 548)
(488, 802)
val -- (54, 259)
(54, 548)
(54, 802)
test -- (114, 259)
(114, 548)
(114, 802)
train -- (487, 259)
(487, 548)
(487, 802)
val -- (56, 259)
(56, 548)
(56, 802)
test -- (113, 259)
(113, 548)
(113, 802)


In [105]:
exp2A_patient_fold0_processed["train"]["attention_mask"]

Unnamed: 0,sample_id,drug_name,recist,mappedProject,dataset_name,mask_0,mask_1,mask_2,mask_3,mask_4,...,mask_787,mask_788,mask_789,mask_790,mask_791,mask_792,mask_793,mask_794,mask_795,mask_796
0,TCGA-DB-A64P,TEMOZOLOMIDE,0,TCGA-LGG,TCGA,1,1,1,0,0,...,0,0,0,0,0,0,0,0,0,0
1,TCGA-S9-A89V,TEMOZOLOMIDE,0,TCGA-LGG,TCGA,1,1,1,0,0,...,0,0,0,0,0,0,0,0,0,0
2,P-0001324-T01-IM3,SORAFENIB,0,TCGA-LIHC,CBIO_hcc_mskimpact_2018,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
3,TCGA-S9-A6U8,CARMUSTINE,0,TCGA-LGG,TCGA,1,1,1,0,0,...,0,0,0,0,0,0,0,0,0,0
4,TCGA-CN-4731,CETUXIMAB,0,TCGA-HNSC,TCGA,1,1,1,1,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
483,s_DS_bkm_008_T,BUPARLISIB,0,TCGA-BRCA,CBIO_brca_mskcc_2019,1,1,1,1,0,...,0,0,0,0,0,0,0,0,0,0
484,TCGA-GN-A8LK,CARBOPLATIN,0,TCGA-SKCM,TCGA,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
485,TCGA-VS-A8EJ,CISPLATIN,0,TCGA-CESC,TCGA,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
486,P-0002719-T01-IM3,SORAFENIB,0,TCGA-LIHC,CBIO_hcc_mskimpact_2018,1,1,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [106]:
# Experiment 2B
# needs data (input mutation ids), annovar (annovar mutation ids), attention (attention mask), drug (fingerprint), label (RECIST)
# fold 0
exp2B_patient_fold0_processed = {}
for div in ["train", "val", "test"]:
    exp2B_patient_fold0_processed[div] = {}
    for k, v in exp2B_patient_fold0[div].items():
        print(div + ": " + str(k), end=" -- ")
        exp2B_patient_fold0_processed[div][k] = {}
        # data
        merged_data = v.merge(patients2gens_mat_pd, how="inner", left_on="sample_id", right_on=patients2gens_mat_pd.index)
        print(merged_data.shape)
        # mutations2use = cl_mutations.loc[v.sample_id]
        if not (merged_data.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_data.sample_id))
        exp2B_patient_fold0_processed[div][k]["input_ids"] = merged_data

        # annovar
        merged_annovar = v.merge(patients2muts_mat_pd, how="inner", left_on="sample_id", right_on=patients2muts_mat_pd.index)
        print(merged_annovar.shape)
        if not (merged_annovar.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_annovar.sample_id))
        exp2B_patient_fold0_processed[div][k]["annovar_ids"] = merged_annovar

        # attention
        merged_attention = v.merge(patients_mask_pd, how="inner", left_on="sample_id", right_on=patients_mask_pd.index)
        print(merged_attention.shape)
        if not (merged_attention.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_attention.sample_id))
        exp2B_patient_fold0_processed[div][k]["attention_mask"] = merged_attention

# fold 1
exp2B_patient_fold1_processed = {}
for div in ["train", "val", "test"]:
    exp2B_patient_fold1_processed[div] = {}
    for k, v in exp2B_patient_fold1[div].items():
        print(div + ": " + str(k), end=" -- ")
        exp2B_patient_fold1_processed[div][k] = {}
        # data
        merged_data = v.merge(patients2gens_mat_pd, how="inner", left_on="sample_id", right_on=patients2gens_mat_pd.index)
        print(merged_data.shape)
        # mutations2use = cl_mutations.loc[v.sample_id]
        if not (merged_data.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_data.sample_id))
        exp2B_patient_fold1_processed[div][k]["input_ids"] = merged_data

        # annovar
        merged_annovar = v.merge(patients2muts_mat_pd, how="inner", left_on="sample_id", right_on=patients2muts_mat_pd.index)
        print(merged_annovar.shape)
        if not (merged_annovar.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_annovar.sample_id))
        exp2B_patient_fold1_processed[div][k]["annovar_ids"] = merged_annovar

        # attention
        merged_attention = v.merge(patients_mask_pd, how="inner", left_on="sample_id", right_on=patients_mask_pd.index)
        print(merged_attention.shape)
        if not (merged_attention.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_attention.sample_id))
        exp2B_patient_fold1_processed[div][k]["attention_mask"] = merged_attention

# fold 2
exp2B_patient_fold2_processed = {}
for div in ["train", "val", "test"]:
    exp2B_patient_fold2_processed[div] = {}
    for k, v in exp2B_patient_fold2[div].items():
        print(div + ": " + str(k), end=" -- ")
        exp2B_patient_fold2_processed[div][k] = {}
        # data
        merged_data = v.merge(patients2gens_mat_pd, how="inner", left_on="sample_id", right_on=patients2gens_mat_pd.index)
        print(merged_data.shape)
        # mutations2use = cl_mutations.loc[v.sample_id]
        if not (merged_data.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_data.sample_id))
        exp2B_patient_fold2_processed[div][k]["input_ids"] = merged_data

        # annovar
        merged_annovar = v.merge(patients2muts_mat_pd, how="inner", left_on="sample_id", right_on=patients2muts_mat_pd.index)
        print(merged_annovar.shape)
        if not (merged_annovar.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_annovar.sample_id))
        exp2B_patient_fold2_processed[div][k]["annovar_ids"] = merged_annovar

        # attention
        merged_attention = v.merge(patients_mask_pd, how="inner", left_on="sample_id", right_on=patients_mask_pd.index)
        print(merged_attention.shape)
        if not (merged_attention.shape[0] == v.shape[0]):
            print("Missing some mutations!")
            print(set(v.sample_id) - set(merged_attention.sample_id))
        exp2B_patient_fold2_processed[div][k]["attention_mask"] = merged_attention

train: TCGA-BRCA -- (74, 259)
(74, 548)
(74, 802)
train: TCGA-CESC -- (33, 259)
(33, 548)
(33, 802)
train: TCGA-HNSC -- (39, 259)
(39, 548)
(39, 802)
train: TCGA-STAD -- (38, 259)
(38, 548)
(38, 802)
train: TCGA-PAAD -- (32, 259)
(32, 548)
(32, 802)
train: TCGA-LGG -- (60, 259)
(60, 548)
(60, 802)
val: TCGA-BRCA -- (8, 259)
(8, 548)
(8, 802)
val: TCGA-CESC -- (4, 259)
(4, 548)
(4, 802)
val: TCGA-HNSC -- (5, 259)
(5, 548)
(5, 802)
val: TCGA-STAD -- (5, 259)
(5, 548)
(5, 802)
val: TCGA-PAAD -- (4, 259)
(4, 548)
(4, 802)
val: TCGA-LGG -- (7, 259)
(7, 548)
(7, 802)
test: TCGA-BRCA -- (17, 259)
(17, 548)
(17, 802)
test: TCGA-CESC -- (15, 259)
(15, 548)
(15, 802)
test: TCGA-HNSC -- (12, 259)
(12, 548)
(12, 802)
test: TCGA-STAD -- (12, 259)
(12, 548)
(12, 802)
test: TCGA-PAAD -- (14, 259)
(14, 548)
(14, 802)
test: TCGA-LGG -- (30, 259)
(30, 548)
(30, 802)
train: TCGA-BRCA -- (72, 259)
(72, 548)
(72, 802)
train: TCGA-CESC -- (31, 259)
(31, 548)
(31, 802)
train: TCGA-HNSC -- (37, 259)
(37, 548)

### Save files

In [107]:
save_dir = "/data/ajayago/papers_data/systematic_assessment/input_types/transformer_inputs/"
save_dir_expt1A_dir = save_dir + "Experiment1/SettingA/"
save_dir_expt1B_dir = save_dir + "Experiment1/SettingB/"
save_dir_expt2A_dir = save_dir + "Experiment2/SettingA/"
save_dir_expt2B_dir = save_dir + "Experiment2/SettingB/"

In [108]:
# Experiment 1A
with open(f"{save_dir_expt1A_dir}/cell_lines_fold0_processed.pkl", "wb") as f:
    pickle.dump(exp1A_cl_fold0_processed, f)

with open(f"{save_dir_expt1A_dir}/patients_fold0_processed.pkl", "wb") as f:
    pickle.dump(exp1A_patient_fold0_processed, f)
    
with open(f"{save_dir_expt1A_dir}/patients_fold1_processed.pkl", "wb") as f:
    pickle.dump(exp1A_patient_fold1_processed, f)
    
with open(f"{save_dir_expt1A_dir}/patients_fold2_processed.pkl", "wb") as f:
    pickle.dump(exp1A_patient_fold2_processed, f)

In [109]:
# Experiment 1B
with open(f"{save_dir_expt1B_dir}/cell_lines_fold0_processed.pkl", "wb") as f:
    pickle.dump(exp1B_cl_fold0_processed, f)

with open(f"{save_dir_expt1B_dir}/patients_fold0_processed.pkl", "wb") as f:
    pickle.dump(exp1B_patient_fold0_processed, f)
    
with open(f"{save_dir_expt1B_dir}/patients_fold1_processed.pkl", "wb") as f:
    pickle.dump(exp1B_patient_fold1_processed, f)
    
with open(f"{save_dir_expt1B_dir}/patients_fold2_processed.pkl", "wb") as f:
    pickle.dump(exp1B_patient_fold2_processed, f)

In [110]:
# Experiment 2A
with open(f"{save_dir_expt2A_dir}/cell_lines_fold0_processed.pkl", "wb") as f:
    pickle.dump(exp2A_cl_fold0_processed, f)

with open(f"{save_dir_expt2A_dir}/patients_fold0_processed.pkl", "wb") as f:
    pickle.dump(exp2A_patient_fold0_processed, f)
    
with open(f"{save_dir_expt2A_dir}/patients_fold1_processed.pkl", "wb") as f:
    pickle.dump(exp2A_patient_fold1_processed, f)
    
with open(f"{save_dir_expt2A_dir}/patients_fold2_processed.pkl", "wb") as f:
    pickle.dump(exp2A_patient_fold2_processed, f)

In [111]:
# Experiment 2B
with open(f"{save_dir_expt2B_dir}/cell_lines_fold0_processed.pkl", "wb") as f:
    pickle.dump(exp2B_cl_fold0_processed, f)

with open(f"{save_dir_expt2B_dir}/patients_fold0_processed.pkl", "wb") as f:
    pickle.dump(exp2B_patient_fold0_processed, f)
    
with open(f"{save_dir_expt2B_dir}/patients_fold1_processed.pkl", "wb") as f:
    pickle.dump(exp2B_patient_fold1_processed, f)
    
with open(f"{save_dir_expt2B_dir}/patients_fold2_processed.pkl", "wb") as f:
    pickle.dump(exp2B_patient_fold2_processed, f)

### Survival datasets

In [112]:
### Open files like experiment 2A setting but uniform over all experiments in PREDICT-AI
with open(f"{data_splits_dir}/survival_splits/patients_fold0_survival.pkl", "rb") as f:
    survival_setting_split0 = pickle.load(f)
with open(f"{data_splits_dir}/survival_splits/patients_fold1_survival.pkl", "rb") as f:
    survival_setting_split1 = pickle.load(f)
with open(f"{data_splits_dir}/survival_splits/patients_fold2_survival.pkl", "rb") as f:
    survival_setting_split2 = pickle.load(f)

In [113]:
survival_setting_split0["train"]

Unnamed: 0,sample_id,REGIMEN_NUMBER,drug_name,drug_start_date,drug_end_date,pfs_status,pfs_days,category,dataset_name,mappedProject
128,GENIE-UHN-019873-ARC1,1.0,GEFITINIB,301.0,637.0,1.0,347.0,1.0,GENIE_NSCLC,TCGA-LUAD
58,GENIE-MSK-P-0018585-T01-IM6,3.0,MITOMYCIN,606.0,606.0,1.0,188.0,1.0,GENIE_CRC,TCGA-COAD
137,GENIE-VICC-199259-unk-1,1.0,CRIZOTINIB,43.0,210.0,0.0,167.0,1.0,GENIE_NSCLC,TCGA-LUAD
95,GENIE-MSK-P-0003854-T01-IM5,3.0,DOCETAXEL,1115.0,1115.0,0.0,5.0,1.0,GENIE_NSCLC,TCGA-LUAD
22,GENIE-DFCI-107986-481743,1.0,CAPECITABINE,69.0,95.0,1.0,603.0,1.0,GENIE_CRC,TCGA-COAD
...,...,...,...,...,...,...,...,...,...,...
92,GENIE-MSK-P-0009660-T01-IM5,1.0,BEVACIZUMAB,68.0,68.0,1.0,84.0,1.0,GENIE_NSCLC,TCGA-LUAD
106,GENIE-MSK-P-0015434-T01-IM6,1.0,CRIZOTINIB,22.0,92.0,1.0,63.0,1.0,GENIE_NSCLC,TCGA-LUAD
24,GENIE-DFCI-111555-617724,2.0,BEVACIZUMAB,343.0,343.0,1.0,49.0,1.0,GENIE_CRC,TCGA-COAD
139,GENIE-VICC-647513-unk-1,5.0,OSIMERTINIB,3895.0,4573.0,1.0,1000.0,1.0,GENIE_NSCLC,TCGA-LUAD


In [114]:
# needs data, annovar, attention, pfs times, event status, drug, label
# Survival
# fold 0
survival_patient_fold0_processed = {}
for div in ["train", "val", "test"]:
    survival_patient_fold0_processed[div] = {}
    v = survival_setting_split0[div]
    print(div, end=" -- ")

    # data
    merged_data = v.merge(survival_data_input_ids_pd, how="inner", left_on="sample_id", right_on=survival_data_input_ids_pd.index)
    print(merged_data.shape)
    # mutations2use = cl_mutations.loc[v.sample_id]
    if not (merged_data.shape[0] == v.shape[0]):
        print("Missing some mutations!")
        print(set(v.sample_id) - set(merged_data.sample_id))
    survival_patient_fold0_processed[div]["input_ids"] = merged_data

    # annovar
    merged_annovar = v.merge(survival_data_annovar_ids_pd, how="inner", left_on="sample_id", right_on=survival_data_annovar_ids_pd.index)
    print(merged_annovar.shape)
    # mutations2use = cl_mutations.loc[v.sample_id]
    if not (merged_annovar.shape[0] == v.shape[0]):
        print("Missing some mutations!")
        print(set(v.sample_id) - set(merged_annovar.sample_id))
    survival_patient_fold0_processed[div]["annovar_ids"] = merged_annovar

    # attention
    merged_attention = v.merge(survival_data_attention_mask_pd, how="inner", left_on="sample_id", right_on=survival_data_attention_mask_pd.index)
    print(merged_attention.shape)
    if not (merged_attention.shape[0] == v.shape[0]):
        print("Missing some mutations!")
        print(set(v.sample_id) - set(merged_attention.sample_id))
    survival_patient_fold0_processed[div]["attention_mask"] = merged_attention

# fold 1
survival_patient_fold1_processed = {}
for div in ["train", "val", "test"]:
    survival_patient_fold1_processed[div] = {}
    v = survival_setting_split1[div]
    print(div, end=" -- ")

    # data
    merged_data = v.merge(survival_data_input_ids_pd, how="inner", left_on="sample_id", right_on=survival_data_input_ids_pd.index)
    print(merged_data.shape)
    # mutations2use = cl_mutations.loc[v.sample_id]
    if not (merged_data.shape[0] == v.shape[0]):
        print("Missing some mutations!")
        print(set(v.sample_id) - set(merged_data.sample_id))
    survival_patient_fold1_processed[div]["input_ids"] = merged_data

    # annovar
    merged_annovar = v.merge(survival_data_annovar_ids_pd, how="inner", left_on="sample_id", right_on=survival_data_annovar_ids_pd.index)
    print(merged_annovar.shape)
    # mutations2use = cl_mutations.loc[v.sample_id]
    if not (merged_annovar.shape[0] == v.shape[0]):
        print("Missing some mutations!")
        print(set(v.sample_id) - set(merged_annovar.sample_id))
    survival_patient_fold1_processed[div]["annovar_ids"] = merged_annovar

    # attention
    merged_attention = v.merge(survival_data_attention_mask_pd, how="inner", left_on="sample_id", right_on=survival_data_attention_mask_pd.index)
    print(merged_attention.shape)
    if not (merged_attention.shape[0] == v.shape[0]):
        print("Missing some mutations!")
        print(set(v.sample_id) - set(merged_attention.sample_id))
    survival_patient_fold1_processed[div]["attention_mask"] = merged_attention

# fold 2
survival_patient_fold2_processed = {}
for div in ["train", "val", "test"]:
    survival_patient_fold2_processed[div] = {}
    v = survival_setting_split2[div]
    print(div, end=" -- ")

    # data
    merged_data = v.merge(survival_data_input_ids_pd, how="inner", left_on="sample_id", right_on=survival_data_input_ids_pd.index)
    print(merged_data.shape)
    # mutations2use = cl_mutations.loc[v.sample_id]
    if not (merged_data.shape[0] == v.shape[0]):
        print("Missing some mutations!")
        print(set(v.sample_id) - set(merged_data.sample_id))
    survival_patient_fold2_processed[div]["input_ids"] = merged_data

    # annovar
    merged_annovar = v.merge(survival_data_annovar_ids_pd, how="inner", left_on="sample_id", right_on=survival_data_annovar_ids_pd.index)
    print(merged_annovar.shape)
    # mutations2use = cl_mutations.loc[v.sample_id]
    if not (merged_annovar.shape[0] == v.shape[0]):
        print("Missing some mutations!")
        print(set(v.sample_id) - set(merged_annovar.sample_id))
    survival_patient_fold2_processed[div]["annovar_ids"] = merged_annovar

    # attention
    merged_attention = v.merge(survival_data_attention_mask_pd, how="inner", left_on="sample_id", right_on=survival_data_attention_mask_pd.index)
    print(merged_attention.shape)
    if not (merged_attention.shape[0] == v.shape[0]):
        print("Missing some mutations!")
        print(set(v.sample_id) - set(merged_attention.sample_id))
    survival_patient_fold2_processed[div]["attention_mask"] = merged_attention

train -- (84, 1001)
(84, 1001)
(84, 1001)
val -- (10, 1001)
(10, 1001)
(10, 1001)
test -- (48, 1001)
(48, 1001)
(48, 1001)
train -- (85, 1001)
(85, 1001)
(85, 1001)
val -- (10, 1001)
(10, 1001)
(10, 1001)
test -- (47, 1001)
(47, 1001)
(47, 1001)
train -- (85, 1001)
(85, 1001)
(85, 1001)
val -- (10, 1001)
(10, 1001)
(10, 1001)
test -- (47, 1001)
(47, 1001)
(47, 1001)


In [115]:
survival_patient_fold0_processed["train"]["input_ids"]

Unnamed: 0,sample_id,REGIMEN_NUMBER,drug_name,drug_start_date,drug_end_date,pfs_status,pfs_days,category,dataset_name,mappedProject,...,input_ids_981,input_ids_982,input_ids_983,input_ids_984,input_ids_985,input_ids_986,input_ids_987,input_ids_988,input_ids_989,input_ids_990
0,GENIE-UHN-019873-ARC1,1.0,GEFITINIB,301.0,637.0,1.0,347.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,1,1,1,1,1,1,1,1,1,1
1,GENIE-MSK-P-0018585-T01-IM6,3.0,MITOMYCIN,606.0,606.0,1.0,188.0,1.0,GENIE_CRC,TCGA-COAD,...,1,1,1,1,1,1,1,1,1,1
2,GENIE-VICC-199259-unk-1,1.0,CRIZOTINIB,43.0,210.0,0.0,167.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,1,1,1,1,1,1,1,1,1,1
3,GENIE-MSK-P-0003854-T01-IM5,3.0,DOCETAXEL,1115.0,1115.0,0.0,5.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,1,1,1,1,1,1,1,1,1,1
4,GENIE-DFCI-107986-481743,1.0,CAPECITABINE,69.0,95.0,1.0,603.0,1.0,GENIE_CRC,TCGA-COAD,...,1,1,1,1,1,1,1,1,1,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
79,GENIE-MSK-P-0009660-T01-IM5,1.0,BEVACIZUMAB,68.0,68.0,1.0,84.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,1,1,1,1,1,1,1,1,1,1
80,GENIE-MSK-P-0015434-T01-IM6,1.0,CRIZOTINIB,22.0,92.0,1.0,63.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,1,1,1,1,1,1,1,1,1,1
81,GENIE-DFCI-111555-617724,2.0,BEVACIZUMAB,343.0,343.0,1.0,49.0,1.0,GENIE_CRC,TCGA-COAD,...,1,1,1,1,1,1,1,1,1,1
82,GENIE-VICC-647513-unk-1,5.0,OSIMERTINIB,3895.0,4573.0,1.0,1000.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,1,1,1,1,1,1,1,1,1,1


In [116]:
survival_patient_fold0_processed["train"]["annovar_ids"]

Unnamed: 0,sample_id,REGIMEN_NUMBER,drug_name,drug_start_date,drug_end_date,pfs_status,pfs_days,category,dataset_name,mappedProject,...,annovar_ids_981,annovar_ids_982,annovar_ids_983,annovar_ids_984,annovar_ids_985,annovar_ids_986,annovar_ids_987,annovar_ids_988,annovar_ids_989,annovar_ids_990
0,GENIE-UHN-019873-ARC1,1.0,GEFITINIB,301.0,637.0,1.0,347.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,0,0,0,0,0,0,0,0,0,0
1,GENIE-MSK-P-0018585-T01-IM6,3.0,MITOMYCIN,606.0,606.0,1.0,188.0,1.0,GENIE_CRC,TCGA-COAD,...,0,0,0,0,0,0,0,0,0,0
2,GENIE-VICC-199259-unk-1,1.0,CRIZOTINIB,43.0,210.0,0.0,167.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,0,0,0,0,0,0,0,0,0,0
3,GENIE-MSK-P-0003854-T01-IM5,3.0,DOCETAXEL,1115.0,1115.0,0.0,5.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,0,0,0,0,0,0,0,0,0,0
4,GENIE-DFCI-107986-481743,1.0,CAPECITABINE,69.0,95.0,1.0,603.0,1.0,GENIE_CRC,TCGA-COAD,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
79,GENIE-MSK-P-0009660-T01-IM5,1.0,BEVACIZUMAB,68.0,68.0,1.0,84.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,0,0,0,0,0,0,0,0,0,0
80,GENIE-MSK-P-0015434-T01-IM6,1.0,CRIZOTINIB,22.0,92.0,1.0,63.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,0,0,0,0,0,0,0,0,0,0
81,GENIE-DFCI-111555-617724,2.0,BEVACIZUMAB,343.0,343.0,1.0,49.0,1.0,GENIE_CRC,TCGA-COAD,...,0,0,0,0,0,0,0,0,0,0
82,GENIE-VICC-647513-unk-1,5.0,OSIMERTINIB,3895.0,4573.0,1.0,1000.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,0,0,0,0,0,0,0,0,0,0


In [117]:
drug_fp = pd.read_csv("/data/ajayago/papers_data/systematic_assessment/raw/metadata/drug_morgan_fingerprints.csv")
drug_fp

Unnamed: 0,drug_name,0,1,2,3,4,5,6,7,8,...,2038,2039,2040,2041,2042,2043,2044,2045,2046,2047
0,JW-7-24-1,0,1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,KIN001-260,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,NSC-87877,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,0,0
3,GNE-317,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,NAVITOCLAX,0,1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
474,SB590885,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
475,STAUROSPORINE,0,0,0,0,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
476,TW 37,0,1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
477,ULIXERTINIB,0,1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [118]:
survival_patient_fold0_processed["train"]["attention_mask"]

Unnamed: 0,sample_id,REGIMEN_NUMBER,drug_name,drug_start_date,drug_end_date,pfs_status,pfs_days,category,dataset_name,mappedProject,...,mask_981,mask_982,mask_983,mask_984,mask_985,mask_986,mask_987,mask_988,mask_989,mask_990
0,GENIE-UHN-019873-ARC1,1.0,GEFITINIB,301.0,637.0,1.0,347.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,0,0,0,0,0,0,0,0,0,0
1,GENIE-MSK-P-0018585-T01-IM6,3.0,MITOMYCIN,606.0,606.0,1.0,188.0,1.0,GENIE_CRC,TCGA-COAD,...,0,0,0,0,0,0,0,0,0,0
2,GENIE-VICC-199259-unk-1,1.0,CRIZOTINIB,43.0,210.0,0.0,167.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,0,0,0,0,0,0,0,0,0,0
3,GENIE-MSK-P-0003854-T01-IM5,3.0,DOCETAXEL,1115.0,1115.0,0.0,5.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,0,0,0,0,0,0,0,0,0,0
4,GENIE-DFCI-107986-481743,1.0,CAPECITABINE,69.0,95.0,1.0,603.0,1.0,GENIE_CRC,TCGA-COAD,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
79,GENIE-MSK-P-0009660-T01-IM5,1.0,BEVACIZUMAB,68.0,68.0,1.0,84.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,0,0,0,0,0,0,0,0,0,0
80,GENIE-MSK-P-0015434-T01-IM6,1.0,CRIZOTINIB,22.0,92.0,1.0,63.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,0,0,0,0,0,0,0,0,0,0
81,GENIE-DFCI-111555-617724,2.0,BEVACIZUMAB,343.0,343.0,1.0,49.0,1.0,GENIE_CRC,TCGA-COAD,...,0,0,0,0,0,0,0,0,0,0
82,GENIE-VICC-647513-unk-1,5.0,OSIMERTINIB,3895.0,4573.0,1.0,1000.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,0,0,0,0,0,0,0,0,0,0


In [119]:
survival_patient_fold0_processed["train"]["attention_mask"].merge(drug_fp, left_on="drug_name", right_on="drug_name")

Unnamed: 0,sample_id,REGIMEN_NUMBER,drug_name,drug_start_date,drug_end_date,pfs_status,pfs_days,category,dataset_name,mappedProject,...,2038,2039,2040,2041,2042,2043,2044,2045,2046,2047
0,GENIE-UHN-019873-ARC1,1.0,GEFITINIB,301.0,637.0,1.0,347.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,0,0,0,0,0,0,0,0,0,0
1,GENIE-MSK-P-0018585-T01-IM6,3.0,MITOMYCIN,606.0,606.0,1.0,188.0,1.0,GENIE_CRC,TCGA-COAD,...,0,0,0,0,0,0,0,0,0,0
2,GENIE-VICC-199259-unk-1,1.0,CRIZOTINIB,43.0,210.0,0.0,167.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,0,0,0,0,0,0,0,0,0,0
3,GENIE-MSK-P-0003854-T01-IM5,3.0,DOCETAXEL,1115.0,1115.0,0.0,5.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,0,0,0,0,0,0,0,0,0,0
4,GENIE-DFCI-107986-481743,1.0,CAPECITABINE,69.0,95.0,1.0,603.0,1.0,GENIE_CRC,TCGA-COAD,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
73,GENIE-UHN-982281-ARC1,1.0,GEFITINIB,918.0,1518.0,1.0,231.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,0,0,0,0,0,0,0,0,0,0
74,GENIE-MSK-P-0010041-T01-IM5,1.0,CRIZOTINIB,50.0,184.0,0.0,135.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,0,0,0,0,0,0,0,0,0,0
75,GENIE-MSK-P-0015434-T01-IM6,1.0,CRIZOTINIB,22.0,92.0,1.0,63.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,0,0,0,0,0,0,0,0,0,0
76,GENIE-VICC-647513-unk-1,5.0,OSIMERTINIB,3895.0,4573.0,1.0,1000.0,1.0,GENIE_NSCLC,TCGA-LUAD,...,0,0,0,0,0,0,0,0,0,0


In [120]:
save_dir

'/data/ajayago/papers_data/systematic_assessment/input_types/transformer_inputs/'

In [121]:
# survival data save
with open(f"{save_dir}/survival_processed/survival_fold0_processed.pkl", "wb") as f:
    pickle.dump(survival_patient_fold0_processed, f)

with open(f"{save_dir}/survival_processed/survival_fold1_processed.pkl", "wb") as f:
    pickle.dump(survival_patient_fold1_processed, f)

with open(f"{save_dir}/survival_processed/survival_fold2_processed.pkl", "wb") as f:
    pickle.dump(survival_patient_fold2_processed, f)