This notebook converts the transformer inputs of patients and cell lines using the transformer encoder trained in PREDICT-AI (in the downstream DRP setting) and generates fixed length inputs which will be consumed by the diffDRP model.

In [1]:
import pandas as pd
import numpy as np
import pickle

In [2]:
import torch.nn as nn
import torch
from torch.utils.data import TensorDataset, DataLoader

In [3]:
device = torch.device("cuda:0")

In [4]:
predict_ai_trained_folder = "/data//papers_data/systematic_assessment/run_files/PREDICT-AI/saved_model_annotated_mutations/"

In [10]:
class TransformerEmbedder(nn.Module):# Based on PREDICT-AI
    def __init__(self):
        super(TransformerEmbedder, self).__init__()
        self.config = {}
        self.config["hidden_dim"] = 64
        self.config["annotation_emb_dim"] = 23
        vocab_df = pd.read_csv("/data//papers_data/systematic_assessment/processed/vocab_predict_ai_ccle_cbio_icgc_moores_tcga_genie_nci60_nuh_union.csv", index_col=0)
        self.annotation_tensor = torch.Tensor(vocab_df.values)
        self.config["n_vocab"] = vocab_df.shape[0]  
        self.device = torch.device(f"cuda:0")
        # pretrained gene embedding from survival prediction model
        self.embs_gene = nn.Embedding(num_embeddings=self.config["n_vocab"], embedding_dim=self.config["hidden_dim"], padding_idx=1)
        self.fc = nn.Linear(in_features=self.config["annotation_emb_dim"], out_features=self.config["hidden_dim"])
    
        self.transformer_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.config["hidden_dim"],nhead=4,dropout=0.1,batch_first=True,),num_layers=8,)
    
    def patient_predictor(self, patient_mut_input,patient_anno_input,patient_mut_input_mask):
        patient_mut_input = patient_mut_input.to(self.device)
        # patient_anno_input = patient_anno_input.to(self.device)
        patient_mut_input_mask = patient_mut_input_mask.to(self.device)
        # annovar = annovar.to(self.device)
        patient_mut_emb = self.embs_gene(patient_mut_input) #torch.Size([256, 198, 64])
        patient_anno_emb = self.fc(self.annotation_tensor[patient_anno_input].to(self.device)) #torch.Size([256, 378, 64])
        patient_mut_emb = torch.cat((patient_mut_emb, patient_anno_emb), dim=1) #torch.Size([256, 576, 64])
        patient_mut_emb = self.transformer_encoder(patient_mut_emb,src_key_padding_mask=patient_mut_input_mask)
        patient_mut_emb = patient_mut_emb.mean(dim=2)
        return patient_mut_emb

    def forward(self, patient_mut_input,patient_anno_input,patient_mut_input_mask):
        return self.patient_predictor(patient_mut_input,patient_anno_input,patient_mut_input_mask)

In [7]:
features2select_inputs_ids = [f"input_ids_{i}" for i in range(254)] 
features2select_annovar_ids = [f"annovar_ids_{i}" for i in range(543)]
features2select_mask_ids = [f"mask_{i}" for i in range(797)] # 797 features in mask for survival data

In [8]:
exp2A_dir = "/data//papers_data/systematic_assessment/input_types/transformer_inputs/Experiment2/SettingA/"

In [9]:
# Experiment 2A fold 0
# with open(f"{exp2A_dir}/cell_lines_fold0_processed.pkl", "rb") as f:
#     exp2A_cl_fold0_processed = pickle.load(f)

# with open(f"{exp2A_dir}/patients_fold0_processed.pkl", "rb") as f:
#     exp2A_patient_fold0_processed = pickle.load(f)
    
# with open(f"{exp2A_dir}/patients_fold1_processed.pkl", "rb") as f:
#     exp2A_patient_fold1_processed = pickle.load(f)
    
# with open(f"{exp2A_dir}/patients_fold2_processed.pkl", "rb") as f:
#     exp2A_patient_fold2_processed = pickle.load(f)

In [10]:
# unlabelled data
with open(f"{exp2A_dir}/patients_fold0_processed_unlabelled.pkl", "rb") as f:
    exp2A_patient_fold0_unlabelled_processed = pickle.load(f)
    
with open(f"{exp2A_dir}/patients_fold1_processed_unlabelled.pkl", "rb") as f:
    exp2A_patient_fold1_unlabelled_processed = pickle.load(f)
    
with open(f"{exp2A_dir}/patients_fold2_processed_unlabelled.pkl", "rb") as f:
    exp2A_patient_fold2_unlabelled_processed = pickle.load(f)

#### Experiment 2A

**Fold 0 Patients**

In [5]:
# load pretrained PREDICT-AI
predict_ai_fold0 = torch.load(predict_ai_trained_folder + "best_pretrained_validation_CI_val_corr_2A_ALL_fold0.pth")

In [8]:
predict_ai_fold0["model"].keys()

odict_keys(['embs_gene.weight', 'embs_drug.0.weight', 'embs_drug.0.bias', 'embs_drug.1.weight', 'embs_drug.1.bias', 'embs_drug.1.running_mean', 'embs_drug.1.running_var', 'embs_drug.1.num_batches_tracked', 'embs_drug.3.weight', 'embs_drug.3.bias', 'embs_drug.4.weight', 'embs_drug.4.bias', 'embs_drug.4.running_mean', 'embs_drug.4.running_var', 'embs_drug.4.num_batches_tracked', 'embs_drug.6.weight', 'embs_drug.6.bias', 'fc.weight', 'fc.bias', 'transformer_encoder.layers.0.self_attn.in_proj_weight', 'transformer_encoder.layers.0.self_attn.in_proj_bias', 'transformer_encoder.layers.0.self_attn.out_proj.weight', 'transformer_encoder.layers.0.self_attn.out_proj.bias', 'transformer_encoder.layers.0.linear1.weight', 'transformer_encoder.layers.0.linear1.bias', 'transformer_encoder.layers.0.linear2.weight', 'transformer_encoder.layers.0.linear2.bias', 'transformer_encoder.layers.0.norm1.weight', 'transformer_encoder.layers.0.norm1.bias', 'transformer_encoder.layers.0.norm2.weight', 'transforme

In [11]:
trf_model0 = TransformerEmbedder().to(device)

In [13]:
# Load pretrained model params from PREDICT_AI patient predictor
trf_model0_dict = trf_model0.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in predict_ai_fold0["model"].items() if k in trf_model0_dict}
# 2. overwrite entries in the existing state dict
trf_model0_dict.update(pretrained_dict) 
# 3. load the new state dict
trf_model0.load_state_dict(pretrained_dict)

<All keys matched successfully>

In [14]:
# Freeze params
for param in trf_model0.parameters():
    param.requires_grad = False

In [15]:
# Set to eval mode for inference
trf_model0.eval()

TransformerEmbedder(
  (embs_gene): Embedding(2324534, 64, padding_idx=1)
  (fc): Linear(in_features=23, out_features=64, bias=True)
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
        )
        (linear1): Linear(in_features=64, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=64, bias=True)
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
)

In [21]:
exp2A_patient_fold0_processed["train"]["input_ids"]

Unnamed: 0,sample_id,drug_name,recist,mappedProject,dataset_name,input_ids_0,input_ids_1,input_ids_2,input_ids_3,input_ids_4,...,input_ids_244,input_ids_245,input_ids_246,input_ids_247,input_ids_248,input_ids_249,input_ids_250,input_ids_251,input_ids_252,input_ids_253
0,TCGA-DB-A64P,TEMOZOLOMIDE,0,TCGA-LGG,TCGA,71,252,176,0,0,...,0,0,0,0,0,0,0,0,0,0
1,TCGA-S9-A89V,TEMOZOLOMIDE,0,TCGA-LGG,TCGA,22,66,21,0,0,...,0,0,0,0,0,0,0,0,0,0
2,P-0001324-T01-IM3,SORAFENIB,0,TCGA-LIHC,CBIO_hcc_mskimpact_2018,206,9,68,210,249,...,0,0,0,0,0,0,0,0,0,0
3,TCGA-S9-A6U8,CARMUSTINE,0,TCGA-LGG,TCGA,22,7,176,0,0,...,0,0,0,0,0,0,0,0,0,0
4,TCGA-CN-4731,CETUXIMAB,0,TCGA-HNSC,TCGA,278,141,113,7,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
483,s_DS_bkm_008_T,BUPARLISIB,0,TCGA-BRCA,CBIO_brca_mskcc_2019,68,11,82,7,0,...,0,0,0,0,0,0,0,0,0,0
484,TCGA-GN-A8LK,CARBOPLATIN,0,TCGA-SKCM,TCGA,122,303,49,7,115,...,0,0,0,0,0,0,0,0,0,0
485,TCGA-VS-A8EJ,CISPLATIN,0,TCGA-CESC,TCGA,122,7,102,44,41,...,0,0,0,0,0,0,0,0,0,0
486,P-0002719-T01-IM3,SORAFENIB,0,TCGA-LIHC,CBIO_hcc_mskimpact_2018,23,34,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [22]:
# fold 0
exp2A_patient_fold0_transformed = {}
for type in ["train", "val", "test"]:
    print(type)
    train_target_data = exp2A_patient_fold0_processed[type]["input_ids"][features2select_inputs_ids].values
    train_target_attention = exp2A_patient_fold0_processed[type]["attention_mask"][features2select_mask_ids].values
    train_target_annovar = exp2A_patient_fold0_processed[type]["annovar_ids"][features2select_annovar_ids].values
    train_patient_dataset = TensorDataset(torch.LongTensor(train_target_data), torch.LongTensor(train_target_annovar), torch.tensor(train_target_attention).bool())
    target_dataloader_train = DataLoader(train_patient_dataset, batch_size = 512, shuffle = False)
    print(len(train_patient_dataset))
    train_target_embedded = []
    for mat, anno, mask in target_dataloader_train:
        mat, mask = mat.to(device,torch.int32), torch.Tensor(mask).to(device)
        train_target_embedded.append(trf_model0(mat, anno, mask))
    df = pd.DataFrame(torch.cat(train_target_embedded).cpu().numpy(), columns = [f"transformer_embedded_{i}" for i in range(797)])
    df = pd.concat([df, exp2A_patient_fold0_processed[type]["input_ids"][["sample_id", "drug_name", "recist", "mappedProject", "dataset_name"]]], axis = 1)
    exp2A_patient_fold0_transformed[type] = df

train
488
val
53
test
115


In [25]:
exp2A_patient_fold0_transformed["test"]

Unnamed: 0,transformer_embedded_0,transformer_embedded_1,transformer_embedded_2,transformer_embedded_3,transformer_embedded_4,transformer_embedded_5,transformer_embedded_6,transformer_embedded_7,transformer_embedded_8,transformer_embedded_9,...,transformer_embedded_792,transformer_embedded_793,transformer_embedded_794,transformer_embedded_795,transformer_embedded_796,sample_id,drug_name,recist,mappedProject,dataset_name
0,0.000040,0.000027,0.000054,0.000043,0.000031,0.000033,0.000031,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,s_DS_bkm_001_T,BUPARLISIB,0,TCGA-BRCA,CBIO_brca_mskcc_2019
1,0.000026,0.000044,0.000017,0.000024,0.000031,0.000017,0.000031,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,s_DS_bkm_006_T,BUPARLISIB,0,TCGA-BRCA,CBIO_brca_mskcc_2019
2,0.000035,0.000044,0.000014,0.000001,0.000033,0.000031,0.000031,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,s_DS_bkm_013_T,BUPARLISIB,0,TCGA-BRCA,CBIO_brca_mskcc_2019
3,0.000031,0.000035,0.000003,0.000031,0.000048,0.000016,0.000026,0.000026,0.000033,0.000004,...,0.000024,0.000024,0.000024,0.000024,0.000024,s_DS_bkm_020_T,BUPARLISIB,0,TCGA-BRCA,CBIO_brca_mskcc_2019
4,0.000047,0.000026,0.000044,0.000020,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,s_DS_bkm_021_T,BUPARLISIB,0,TCGA-BRCA,CBIO_brca_mskcc_2019
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
110,0.000026,0.000005,0.000014,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,TCGA-S9-A6WM,TEMOZOLOMIDE,0,TCGA-LGG,TCGA
111,0.000026,0.000015,0.000048,-0.000017,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,TCGA-S9-A6WN,TEMOZOLOMIDE,0,TCGA-LGG,TCGA
112,0.000025,0.000035,-0.000017,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,TCGA-DB-A4XD,TEMOZOLOMIDE,1,TCGA-LGG,TCGA
113,0.000020,0.000048,0.000020,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,TCGA-FG-A4MW,TEMOZOLOMIDE,0,TCGA-LGG,TCGA


**Fold 1 Patients**

In [22]:
# load pretrained PREDICT-AI
predict_ai_fold1 = torch.load(predict_ai_trained_folder + "best_pretrained_validation_CI_val_corr_2A_ALL_fold1.pth")

In [23]:
trf_model1 = TransformerEmbedder().to(device)

In [26]:
# Load pretrained model params from PREDICT_AI patient predictor
trf_model1_dict = trf_model1.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in predict_ai_fold1["model"].items() if k in trf_model1_dict}
# 2. overwrite entries in the existing state dict
trf_model1_dict.update(pretrained_dict) 
# 3. load the new state dict
trf_model1.load_state_dict(pretrained_dict)

<All keys matched successfully>

In [27]:
# Freeze params
for param in trf_model1.parameters():
    param.requires_grad = False

In [28]:
# Set to eval mode for inference
trf_model1.eval()

TransformerEmbedder(
  (embs_gene): Embedding(2324534, 64, padding_idx=1)
  (fc): Linear(in_features=23, out_features=64, bias=True)
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
        )
        (linear1): Linear(in_features=64, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=64, bias=True)
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
)

In [31]:
exp2A_patient_fold1_processed["train"]["input_ids"]

Unnamed: 0,sample_id,drug_name,recist,mappedProject,dataset_name,input_ids_0,input_ids_1,input_ids_2,input_ids_3,input_ids_4,...,input_ids_244,input_ids_245,input_ids_246,input_ids_247,input_ids_248,input_ids_249,input_ids_250,input_ids_251,input_ids_252,input_ids_253
0,TCGA-FD-A6TC,GEMCITABINE,1,TCGA-BLCA,TCGA,7,85,77,178,89,...,0,0,0,0,0,0,0,0,0,0
1,TCGA-S9-A6TS,CARMUSTINE,0,TCGA-LGG,TCGA,22,7,20,176,50,...,0,0,0,0,0,0,0,0,0,0
2,TCGA-VR-A8EQ,FLUOROURACIL,1,TCGA-ESCA,TCGA,7,255,160,16,24,...,0,0,0,0,0,0,0,0,0,0
3,s_DS_bkm_034_T,BUPARLISIB,0,TCGA-BRCA,CBIO_brca_mskcc_2019,129,37,97,54,20,...,0,0,0,0,0,0,0,0,0,0
4,TCGA-YU-A90Q,CARBOPLATIN,1,TCGA-TGCT,TCGA,8,74,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
483,s_DS_bkm_013_T,BUPARLISIB,0,TCGA-BRCA,CBIO_brca_mskcc_2019,7,82,154,44,28,...,0,0,0,0,0,0,0,0,0,0
484,TCGA-GN-A8LK,CARBOPLATIN,0,TCGA-SKCM,TCGA,122,303,49,7,115,...,0,0,0,0,0,0,0,0,0,0
485,TCGA-VS-A8EJ,CISPLATIN,0,TCGA-CESC,TCGA,122,7,102,44,41,...,0,0,0,0,0,0,0,0,0,0
486,P-0021780-T01-IM6,SORAFENIB,0,TCGA-LIHC,CBIO_hcc_mskimpact_2018,36,7,124,232,43,...,0,0,0,0,0,0,0,0,0,0


In [32]:
# fold 1
exp2A_patient_fold1_transformed = {}
for type in ["train", "val", "test"]:
    print(type)
    train_target_data = exp2A_patient_fold1_processed[type]["input_ids"][features2select_inputs_ids].values
    train_target_attention = exp2A_patient_fold1_processed[type]["attention_mask"][features2select_mask_ids].values
    train_target_annovar = exp2A_patient_fold1_processed[type]["annovar_ids"][features2select_annovar_ids].values
    train_patient_dataset = TensorDataset(torch.LongTensor(train_target_data), torch.LongTensor(train_target_annovar), torch.tensor(train_target_attention).bool())
    target_dataloader_train = DataLoader(train_patient_dataset, batch_size = 512, shuffle = False)
    print(len(train_patient_dataset))
    train_target_embedded = []
    for mat, anno, mask in target_dataloader_train:
        mat, mask = mat.to(device,torch.int32), torch.Tensor(mask).to(device)
        train_target_embedded.append(trf_model1(mat, anno, mask))
    df = pd.DataFrame(torch.cat(train_target_embedded).cpu().numpy(), columns = [f"transformer_embedded_{i}" for i in range(797)])
    df = pd.concat([df, exp2A_patient_fold1_processed[type]["input_ids"][["sample_id", "drug_name", "recist", "mappedProject", "dataset_name"]]], axis = 1)
    exp2A_patient_fold1_transformed[type] = df

train
488
val
54
test
114


In [33]:
exp2A_patient_fold1_transformed["test"]

Unnamed: 0,transformer_embedded_0,transformer_embedded_1,transformer_embedded_2,transformer_embedded_3,transformer_embedded_4,transformer_embedded_5,transformer_embedded_6,transformer_embedded_7,transformer_embedded_8,transformer_embedded_9,...,transformer_embedded_792,transformer_embedded_793,transformer_embedded_794,transformer_embedded_795,transformer_embedded_796,sample_id,drug_name,recist,mappedProject,dataset_name
0,0.000048,0.000029,0.000038,0.000030,0.000035,0.000041,0.000043,0.000042,0.000042,0.000042,...,0.00005,0.00005,0.00005,0.00005,0.00005,s_DS_bkm_002_T,BUPARLISIB,0,TCGA-BRCA,CBIO_brca_mskcc_2019
1,0.000042,0.000049,0.000028,0.000036,0.000028,0.000042,0.000042,0.000042,0.000042,0.000042,...,0.00005,0.00005,0.00005,0.00005,0.00005,s_DS_bkm_003_T,BUPARLISIB,1,TCGA-BRCA,CBIO_brca_mskcc_2019
2,0.000038,0.000036,0.000057,0.000042,0.000042,0.000042,0.000042,0.000042,0.000042,0.000042,...,0.00005,0.00005,0.00005,0.00005,0.00005,s_DS_bkm_008_T,BUPARLISIB,0,TCGA-BRCA,CBIO_brca_mskcc_2019
3,0.000028,0.000045,0.000037,0.000051,0.000026,0.000042,0.000042,0.000042,0.000042,0.000042,...,0.00005,0.00005,0.00005,0.00005,0.00005,s_DS_bkm_009_T,BUPARLISIB,0,TCGA-BRCA,CBIO_brca_mskcc_2019
4,0.000047,0.000011,0.000049,0.000054,0.000051,0.000037,0.000039,0.000042,0.000042,0.000042,...,0.00005,0.00005,0.00005,0.00005,0.00005,s_DS_bkm_010_T,BUPARLISIB,0,TCGA-BRCA,CBIO_brca_mskcc_2019
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
109,0.000026,0.000038,0.000042,0.000042,0.000042,0.000042,0.000042,0.000042,0.000042,0.000042,...,0.00005,0.00005,0.00005,0.00005,0.00005,TCGA-HT-A61C,TEMOZOLOMIDE,1,TCGA-LGG,TCGA
110,0.000038,0.000038,0.000017,0.000006,0.000042,0.000042,0.000042,0.000042,0.000042,0.000042,...,0.00005,0.00005,0.00005,0.00005,0.00005,TCGA-HW-A5KJ,TEMOZOLOMIDE,0,TCGA-LGG,TCGA
111,0.000007,0.000042,0.000042,0.000042,0.000042,0.000042,0.000042,0.000042,0.000042,0.000042,...,0.00005,0.00005,0.00005,0.00005,0.00005,TCGA-DU-A6S6,TEMOZOLOMIDE,0,TCGA-LGG,TCGA
112,0.000035,0.000042,0.000039,0.000006,0.000054,0.000042,0.000042,0.000042,0.000042,0.000042,...,0.00005,0.00005,0.00005,0.00005,0.00005,TCGA-S9-A6TS,TEMOZOLOMIDE,0,TCGA-LGG,TCGA


**Fold 2 Patients**

In [24]:
# load pretrained PREDICT-AI
predict_ai_fold2 = torch.load(predict_ai_trained_folder + "best_pretrained_validation_CI_val_corr_2A_ALL_fold2.pth")

In [25]:
trf_model2 = TransformerEmbedder().to(device)

In [29]:
# Load pretrained model params from PREDICT_AI patient predictor
trf_model2_dict = trf_model2.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in predict_ai_fold2["model"].items() if k in trf_model2_dict}
# 2. overwrite entries in the existing state dict
trf_model2_dict.update(pretrained_dict) 
# 3. load the new state dict
trf_model2.load_state_dict(pretrained_dict)

<All keys matched successfully>

In [30]:
# Freeze params
for param in trf_model2.parameters():
    param.requires_grad = False

In [31]:
# Set to eval mode for inference
trf_model2.eval()

TransformerEmbedder(
  (embs_gene): Embedding(2324534, 64, padding_idx=1)
  (fc): Linear(in_features=23, out_features=64, bias=True)
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
        )
        (linear1): Linear(in_features=64, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=64, bias=True)
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
)

In [39]:
exp2A_patient_fold2_processed["train"]["input_ids"]

Unnamed: 0,sample_id,drug_name,recist,mappedProject,dataset_name,input_ids_0,input_ids_1,input_ids_2,input_ids_3,input_ids_4,...,input_ids_244,input_ids_245,input_ids_246,input_ids_247,input_ids_248,input_ids_249,input_ids_250,input_ids_251,input_ids_252,input_ids_253
0,P-0021780-T01-IM6,SORAFENIB,0,TCGA-LIHC,CBIO_hcc_mskimpact_2018,36,7,124,232,43,...,0,0,0,0,0,0,0,0,0,0
1,TCGA-A5-A1OH,CARBOPLATIN,1,TCGA-UCEC,TCGA,115,257,232,152,103,...,0,0,0,0,0,0,0,0,0,0
2,TCGA-DX-A7EQ,DOXORUBICIN,0,TCGA-SARC,TCGA,43,105,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,TCGA-FB-A5VM,GEMCITABINE,0,TCGA-PAAD,TCGA,8,7,71,0,0,...,0,0,0,0,0,0,0,0,0,0
4,s_DS_bkm_035_T,BUPARLISIB,0,TCGA-BRCA,CBIO_brca_mskcc_2019,13,181,126,17,78,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
482,TCGA-GN-A8LK,CARBOPLATIN,0,TCGA-SKCM,TCGA,122,303,49,7,115,...,0,0,0,0,0,0,0,0,0,0
483,TCGA-EX-A3L1,CISPLATIN,1,TCGA-CESC,TCGA,40,16,192,0,0,...,0,0,0,0,0,0,0,0,0,0
484,TCGA-3A-A9IC,FLUOROURACIL,0,TCGA-PAAD,TCGA,8,7,113,0,0,...,0,0,0,0,0,0,0,0,0,0
485,P-0020359-T01-IM6,SORAFENIB,0,TCGA-LIHC,CBIO_hcc_mskimpact_2018,7,183,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [40]:
# fold 0
exp2A_patient_fold2_transformed = {}
for type in ["train", "val", "test"]:
    print(type)
    train_target_data = exp2A_patient_fold2_processed[type]["input_ids"][features2select_inputs_ids].values
    train_target_attention = exp2A_patient_fold2_processed[type]["attention_mask"][features2select_mask_ids].values
    train_target_annovar = exp2A_patient_fold2_processed[type]["annovar_ids"][features2select_annovar_ids].values
    train_patient_dataset = TensorDataset(torch.LongTensor(train_target_data), torch.LongTensor(train_target_annovar), torch.tensor(train_target_attention).bool())
    target_dataloader_train = DataLoader(train_patient_dataset, batch_size = 512, shuffle = False)
    print(len(train_patient_dataset))
    train_target_embedded = []
    for mat, anno, mask in target_dataloader_train:
        mat, mask = mat.to(device,torch.int32), torch.Tensor(mask).to(device)
        train_target_embedded.append(trf_model2(mat, anno, mask))
    df = pd.DataFrame(torch.cat(train_target_embedded).cpu().numpy(), columns = [f"transformer_embedded_{i}" for i in range(797)])
    df = pd.concat([df, exp2A_patient_fold2_processed[type]["input_ids"][["sample_id", "drug_name", "recist", "mappedProject", "dataset_name"]]], axis = 1)
    exp2A_patient_fold2_transformed[type] = df

train
487
val
56
test
113


In [41]:
exp2A_patient_fold2_transformed["test"]

Unnamed: 0,transformer_embedded_0,transformer_embedded_1,transformer_embedded_2,transformer_embedded_3,transformer_embedded_4,transformer_embedded_5,transformer_embedded_6,transformer_embedded_7,transformer_embedded_8,transformer_embedded_9,...,transformer_embedded_792,transformer_embedded_793,transformer_embedded_794,transformer_embedded_795,transformer_embedded_796,sample_id,drug_name,recist,mappedProject,dataset_name
0,0.000014,0.000063,0.000046,0.000035,0.000052,0.000052,0.000052,0.000052,0.000052,0.000052,...,0.000046,0.000046,0.000046,0.000046,0.000046,s_DS_bkm_005_T,BUPARLISIB,1,TCGA-BRCA,CBIO_brca_mskcc_2019
1,0.000057,0.000039,0.000040,0.000020,0.000052,0.000052,0.000052,0.000052,0.000052,0.000052,...,0.000046,0.000046,0.000046,0.000046,0.000046,s_DS_bkm_007_T,BUPARLISIB,0,TCGA-BRCA,CBIO_brca_mskcc_2019
2,0.000046,0.000038,0.000037,0.000036,0.000052,0.000052,0.000052,0.000052,0.000052,0.000052,...,0.000046,0.000046,0.000046,0.000046,0.000046,s_DS_bkm_018_T,BUPARLISIB,0,TCGA-BRCA,CBIO_brca_mskcc_2019
3,0.000040,0.000047,0.000057,0.000040,0.000042,0.000052,0.000052,0.000052,0.000052,0.000052,...,0.000046,0.000046,0.000046,0.000046,0.000046,s_DS_bkm_029_T,BUPARLISIB,0,TCGA-BRCA,CBIO_brca_mskcc_2019
4,0.000040,0.000054,0.000036,0.000052,0.000052,0.000052,0.000052,0.000052,0.000052,0.000052,...,0.000046,0.000046,0.000046,0.000046,0.000046,s_DS_bkm_030_T,BUPARLISIB,0,TCGA-BRCA,CBIO_brca_mskcc_2019
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
108,0.000046,0.000011,0.000052,0.000052,0.000052,0.000052,0.000052,0.000052,0.000052,0.000052,...,0.000046,0.000046,0.000046,0.000046,0.000046,TCGA-DH-A7UT,TEMOZOLOMIDE,0,TCGA-LGG,TCGA
109,0.000041,0.000029,0.000037,0.000038,0.000011,0.000052,0.000052,0.000052,0.000052,0.000052,...,0.000046,0.000046,0.000046,0.000046,0.000046,TCGA-IK-8125,TEMOZOLOMIDE,1,TCGA-LGG,TCGA
110,0.000045,0.000046,0.000050,0.000038,0.000037,0.000011,0.000052,0.000052,0.000052,0.000052,...,0.000045,0.000045,0.000045,0.000045,0.000045,TCGA-DH-A66B,TEMOZOLOMIDE,0,TCGA-LGG,TCGA
111,0.000042,0.000046,0.000011,0.000052,0.000052,0.000052,0.000052,0.000052,0.000052,0.000052,...,0.000046,0.000046,0.000046,0.000046,0.000046,TCGA-DH-A7UU,TEMOZOLOMIDE,0,TCGA-LGG,TCGA


In [42]:
# Experiment 2A patients
save_dir_expt2A_dir = "/data//papers_data/systematic_assessment/input_types/transformer_inputs_transformed_797/Experiment2/SettingA/"
with open(f"{save_dir_expt2A_dir}/patients_fold0_processed.pkl", "wb") as f:
    pickle.dump(exp2A_patient_fold0_transformed, f)
    
with open(f"{save_dir_expt2A_dir}/patients_fold1_processed.pkl", "wb") as f:
    pickle.dump(exp2A_patient_fold1_transformed, f)
    
with open(f"{save_dir_expt2A_dir}/patients_fold2_processed.pkl", "wb") as f:
    pickle.dump(exp2A_patient_fold2_transformed, f)

#### Cell Lines

**Fold 0**

In [45]:
# Freeze params
for param in trf_model0.parameters():
    param.requires_grad = False

In [46]:
# Set to eval mode for inference
trf_model0.eval()

TransformerEmbedder(
  (embs_gene): Embedding(2324534, 64, padding_idx=1)
  (fc): Linear(in_features=23, out_features=64, bias=True)
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
        )
        (linear1): Linear(in_features=64, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=64, bias=True)
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
)

In [14]:
exp2A_cl_fold0_processed["train"]["input_ids"]

Unnamed: 0,sample_id,drug_name,auc,ic50,drug_category,response_label,input_ids_0,input_ids_1,input_ids_2,input_ids_3,...,input_ids_244,input_ids_245,input_ids_246,input_ids_247,input_ids_248,input_ids_249,input_ids_250,input_ids_251,input_ids_252,input_ids_253
0,PR-132fPs,DOCETAXEL,0.191876,-4.662091,1,1,36,303,226,186,...,0,0,0,0,0,0,0,0,0,0
1,PR-L3QLdq,ELEPHANTIN,0.940458,5.730421,3,0,7,179,327,168,...,0,0,0,0,0,0,0,0,0,0
2,PR-NxSV8u,MITOXANTRONE,0.921925,4.070582,1,0,89,7,11,277,...,0,0,0,0,0,0,0,0,0,0
3,PR-oLPbwB,DACTINOMYCIN,0.179515,-6.588337,1,1,146,11,121,53,...,0,0,0,0,0,0,0,0,0,0
4,PR-4ngqZx,CCT007093,0.989986,3.724712,3,0,162,52,23,50,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
156436,PR-M4505H,PFI-1,0.919051,3.534174,3,1,7,90,8,224,...,0,0,0,0,0,0,0,0,0,0
156437,PR-Bz57NU,NILOTINIB,0.995489,4.073733,1,0,231,220,313,12,...,0,0,0,0,0,0,0,0,0,0
156438,PR-6SyWYo,SAPITINIB,0.492491,-1.567439,2,1,89,7,49,189,...,0,0,0,0,0,0,0,0,0,0
156439,PR-wGySam,TASELISIB,0.901939,2.716776,2,0,22,7,105,55,...,0,0,0,0,0,0,0,0,0,0


In [15]:
# fold 0
exp2A_cl_fold0_transformed = {}
for type in ["train", "val", "test"]:
    print(type)
    train_target_data = exp2A_cl_fold0_processed[type]["input_ids"][features2select_inputs_ids].values
    train_target_attention = exp2A_cl_fold0_processed[type]["attention_mask"][features2select_mask_ids].values
    train_target_annovar = exp2A_cl_fold0_processed[type]["annovar_ids"][features2select_annovar_ids].values
    train_patient_dataset = TensorDataset(torch.LongTensor(train_target_data), torch.LongTensor(train_target_annovar), torch.tensor(train_target_attention).bool())
    target_dataloader_train = DataLoader(train_patient_dataset, batch_size = 512, shuffle = False)
    print(len(train_patient_dataset))
    train_target_embedded = []
    for mat, anno, mask in target_dataloader_train:
        mat, mask = mat.to(device,torch.int32), torch.Tensor(mask).to(device)
        train_target_embedded.append(trf_model0(mat, anno, mask))
    df = pd.DataFrame(torch.cat(train_target_embedded).cpu().numpy(), columns = [f"transformer_embedded_{i}" for i in range(797)])
    df = pd.concat([df, exp2A_cl_fold0_processed[type]["input_ids"][["sample_id", "drug_name", "auc",]]], axis = 1)
    exp2A_cl_fold0_transformed[type] = df

train
156441
val
17371
test
21589


In [16]:
exp2A_cl_fold0_transformed["test"]

Unnamed: 0,transformer_embedded_0,transformer_embedded_1,transformer_embedded_2,transformer_embedded_3,transformer_embedded_4,transformer_embedded_5,transformer_embedded_6,transformer_embedded_7,transformer_embedded_8,transformer_embedded_9,...,transformer_embedded_790,transformer_embedded_791,transformer_embedded_792,transformer_embedded_793,transformer_embedded_794,transformer_embedded_795,transformer_embedded_796,sample_id,drug_name,auc
0,0.000031,0.000028,0.000040,0.000043,0.000031,0.000031,3.080256e-05,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,PR-7wTlSr,FLUOROURACIL,0.779342
1,0.000021,0.000034,0.000030,0.000016,0.000031,0.000022,-1.832098e-05,0.000007,0.000030,0.000043,...,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,PR-1T6Ogc,FLUOROURACIL,0.717074
2,0.000019,0.000035,0.000019,0.000049,0.000060,0.000016,1.867302e-05,0.000035,0.000020,0.000011,...,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,PR-rWrc4g,FLUOROURACIL,0.901599
3,0.000019,0.000035,0.000044,0.000037,0.000026,0.000018,2.413988e-06,-0.000018,0.000053,0.000058,...,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,PR-Q0HrBx,FLUOROURACIL,0.904370
4,0.000027,0.000042,0.000012,0.000029,0.000018,-0.000020,4.939735e-06,0.000017,0.000032,0.000024,...,0.000020,0.000020,0.000020,0.000020,0.000020,0.000020,0.000020,PR-ZU3ok7,FLUOROURACIL,0.913728
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21584,0.000024,0.000035,0.000014,0.000003,0.000019,0.000008,-9.685755e-07,0.000029,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,PR-bTdCzB,TRETINOIN,0.981736
21585,0.000035,0.000051,0.000025,0.000006,0.000030,0.000010,6.258488e-06,0.000019,0.000042,0.000021,...,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,PR-0sb1bx,TRETINOIN,0.916793
21586,0.000013,-0.000011,0.000020,0.000022,0.000030,-0.000018,5.016476e-05,0.000031,0.000027,0.000008,...,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,PR-dTN7So,TRETINOIN,0.988135
21587,0.000003,0.000026,0.000012,0.000035,0.000031,0.000031,3.079139e-05,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,PR-kRtxhU,TRETINOIN,0.972332


In [17]:
save_dir_expt2A_dir = "/data//papers_data/systematic_assessment/input_types/transformer_inputs_transformed_797/Experiment2/SettingA/"
with open(f"{save_dir_expt2A_dir}/cell_lines_fold0_processed.pkl", "wb") as f:
    pickle.dump(exp2A_cl_fold0_transformed, f)

#### Unlabelled Patients

In [18]:
exp2A_patient_fold0_unlabelled_processed["train"]["input_ids"]

Unnamed: 0,input_ids_0,input_ids_1,input_ids_2,input_ids_3,input_ids_4,input_ids_5,input_ids_6,input_ids_7,input_ids_8,input_ids_9,...,input_ids_244,input_ids_245,input_ids_246,input_ids_247,input_ids_248,input_ids_249,input_ids_250,input_ids_251,input_ids_252,input_ids_253
TCGA-50-5931,12,70,105,66,7,76,24,89,0,0,...,0,0,0,0,0,0,0,0,0,0
TCGA-LN-A7HV,29,17,149,7,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
TCGA-EK-A3GM,34,43,54,272,11,124,249,141,63,101,...,0,0,0,0,0,0,0,0,0,0
TCGA-P6-A5OF,146,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
TCGA-B8-5550,115,224,106,325,41,40,42,296,39,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
s_DS_bkm_055_T,13,11,48,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
s_DS_bkm_056_T,43,123,11,60,133,173,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
s_DS_bkm_057_T,150,107,152,22,11,136,62,167,126,66,...,0,0,0,0,0,0,0,0,0,0
s_DS_bkm_058_T,42,11,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [20]:
# fold 0
exp2A_patient_fold0_unlabelled_transformed = {}
for type in ["train"]:
    print(type)
    train_target_data = exp2A_patient_fold0_unlabelled_processed[type]["input_ids"][features2select_inputs_ids].values
    train_target_attention = exp2A_patient_fold0_unlabelled_processed[type]["attention_mask"][features2select_mask_ids].values
    train_target_annovar = exp2A_patient_fold0_unlabelled_processed[type]["annovar_ids"][features2select_annovar_ids].values
    train_patient_dataset = TensorDataset(torch.LongTensor(train_target_data), torch.LongTensor(train_target_annovar), torch.tensor(train_target_attention).bool())
    target_dataloader_train = DataLoader(train_patient_dataset, batch_size = 512, shuffle = False)
    print(len(train_patient_dataset))
    train_target_embedded = []
    for mat, anno, mask in target_dataloader_train:
        mat, mask = mat.to(device,torch.int32), torch.Tensor(mask).to(device)
        train_target_embedded.append(trf_model0(mat, anno, mask))
    df = pd.DataFrame(torch.cat(train_target_embedded).cpu().numpy(), columns = [f"transformer_embedded_{i}" for i in range(797)])
    exp2A_patient_fold0_unlabelled_transformed[type] = df

train
9331


In [32]:
# fold 1
exp2A_patient_fold1_unlabelled_transformed = {}
for type in ["train"]:
    print(type)
    train_target_data = exp2A_patient_fold1_unlabelled_processed[type]["input_ids"][features2select_inputs_ids].values
    train_target_attention = exp2A_patient_fold1_unlabelled_processed[type]["attention_mask"][features2select_mask_ids].values
    train_target_annovar = exp2A_patient_fold1_unlabelled_processed[type]["annovar_ids"][features2select_annovar_ids].values
    train_patient_dataset = TensorDataset(torch.LongTensor(train_target_data), torch.LongTensor(train_target_annovar), torch.tensor(train_target_attention).bool())
    target_dataloader_train = DataLoader(train_patient_dataset, batch_size = 512, shuffle = False)
    print(len(train_patient_dataset))
    train_target_embedded = []
    for mat, anno, mask in target_dataloader_train:
        mat, mask = mat.to(device,torch.int32), torch.Tensor(mask).to(device)
        train_target_embedded.append(trf_model1(mat, anno, mask))
    df = pd.DataFrame(torch.cat(train_target_embedded).cpu().numpy(), columns = [f"transformer_embedded_{i}" for i in range(797)])
    exp2A_patient_fold1_unlabelled_transformed[type] = df

train
9332


In [33]:
# fold 2
exp2A_patient_fold2_unlabelled_transformed = {}
for type in ["train"]:
    print(type)
    train_target_data = exp2A_patient_fold2_unlabelled_processed[type]["input_ids"][features2select_inputs_ids].values
    train_target_attention = exp2A_patient_fold2_unlabelled_processed[type]["attention_mask"][features2select_mask_ids].values
    train_target_annovar = exp2A_patient_fold2_unlabelled_processed[type]["annovar_ids"][features2select_annovar_ids].values
    train_patient_dataset = TensorDataset(torch.LongTensor(train_target_data), torch.LongTensor(train_target_annovar), torch.tensor(train_target_attention).bool())
    target_dataloader_train = DataLoader(train_patient_dataset, batch_size = 512, shuffle = False)
    print(len(train_patient_dataset))
    train_target_embedded = []
    for mat, anno, mask in target_dataloader_train:
        mat, mask = mat.to(device,torch.int32), torch.Tensor(mask).to(device)
        train_target_embedded.append(trf_model2(mat, anno, mask))
    df = pd.DataFrame(torch.cat(train_target_embedded).cpu().numpy(), columns = [f"transformer_embedded_{i}" for i in range(797)])
    exp2A_patient_fold2_unlabelled_transformed[type] = df

train
9333


In [35]:
exp2A_patient_fold0_unlabelled_transformed["train"]

Unnamed: 0,transformer_embedded_0,transformer_embedded_1,transformer_embedded_2,transformer_embedded_3,transformer_embedded_4,transformer_embedded_5,transformer_embedded_6,transformer_embedded_7,transformer_embedded_8,transformer_embedded_9,...,transformer_embedded_787,transformer_embedded_788,transformer_embedded_789,transformer_embedded_790,transformer_embedded_791,transformer_embedded_792,transformer_embedded_793,transformer_embedded_794,transformer_embedded_795,transformer_embedded_796
0,0.000016,0.000033,0.000046,0.000014,0.000035,0.000017,0.000022,0.000010,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024
1,0.000031,0.000018,0.000020,0.000035,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024
2,0.000032,0.000015,0.000026,0.000019,0.000026,0.000009,0.000044,-0.000018,0.000024,0.000028,...,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023
3,0.000010,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024
4,0.000032,0.000030,0.000027,0.000022,0.000002,0.000019,0.000019,0.000044,0.000035,0.000031,...,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9326,0.000015,0.000026,0.000014,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024
9327,0.000015,0.000029,0.000026,0.000007,0.000017,0.000005,0.000031,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024
9328,0.000010,0.000031,0.000035,0.000024,0.000026,0.000018,-0.000013,0.000018,0.000020,0.000013,...,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023,0.000023
9329,0.000019,0.000026,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024


In [36]:
# Experiment 2A patients unlabelled
save_dir_expt2A_dir = "/data//papers_data/systematic_assessment/input_types/transformer_inputs_transformed_797/Experiment2/SettingA/"
with open(f"{save_dir_expt2A_dir}/patients_fold0_processed_unlabelled.pkl", "wb") as f:
    pickle.dump(exp2A_patient_fold0_unlabelled_transformed, f)
    
with open(f"{save_dir_expt2A_dir}/patients_fold1_processed_unlabelled.pkl", "wb") as f:
    pickle.dump(exp2A_patient_fold1_unlabelled_transformed, f)
    
with open(f"{save_dir_expt2A_dir}/patients_fold2_processed_unlabelled.pkl", "wb") as f:
    pickle.dump(exp2A_patient_fold2_unlabelled_transformed, f)