# Inference 
- Perform inference on TCGA test set using model trained on TCGA (use the weights saved in '../code/checkpoints_2').
- Model is trained using gene expression programs as classes - Common, MITF-Low, OxPhos, Other.
- Test notebook for script.
- Script generates `0714_pred.csv`.

#### Notes
- Run pre-trained model on tile test set.
- Test set contains all WSIs except those used for training.

***
## Inference (test code for script)

In [1]:
import torch
import pandas as pd
import numpy as np
import os
from tqdm import tqdm

In [3]:
'Load metadata'

data_fld = '../data'
sub_fld = os.path.join(data_fld, 'TCGA_data')
path = os.path.join(sub_fld, 'tcga_wsi_meta.csv')
df_wsi = pd.read_csv(path, index_col=0)
path = os.path.join(sub_fld, 'tcga_tile_meta.csv')
df_tile = pd.read_csv(path)

home = os.getenv("HOME")
disk_path = os.path.join(home, 'disks')
tiles_path = os.path.join(disk_path, 'TCGA_tiles')

In [4]:
df_wsi.head(2)

Unnamed: 0_level_0,clinical_sample_id,primary_tumor_type,CNA_data,ABSOLUTE_purity,rna_subtype,pigment_score,BetaCAT,P53,PTEN,APC
clinical_donor_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
TCGA-3N-A9WC,TCGA-3N-A9WC-06,NON-ACRAL CUTANEOUS,True,0.55,Common,,0,0,0,1
TCGA-3N-A9WD,TCGA-3N-A9WD-06,NON-ACRAL CUTANEOUS,True,0.45,Common,,0,1,0,0


In [5]:
df_tile.head(2)

Unnamed: 0_level_0,Unnamed: 0,x_tile_coord,y_tile_coord,clinical_donor_id,wsi_name,clinical_sample_id,primary_tumor_type,CNA_data,ABSOLUTE_purity,rna_subtype,...,PTEN_prob,APC_prob,APC_BC,OxPhos,APC_OxPhos,APC_BC_OxPhos,P53_MITF,P53_Co,P53_MITF_Co,MITF_Co
wsi_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
TCGA-EE-A3JD-01Z-00-DX1.svs,2,1,8,TCGA-EE-A3JD,TCGA-EE-A3JD-01Z-00-DX1.D4E5B644-C7EF-442D-91F...,TCGA-EE-A3JD-06,NON-ACRAL CUTANEOUS,True,0.26,MITF-Low,...,1.858032e-14,1.0,0,False,False,False,True,True,True,True
TCGA-D3-A3CB-06Z-00-DX1.svs,4,0,5,TCGA-D3-A3CB,TCGA-D3-A3CB-06Z-00-DX1.9862D604-C9E7-44BE-99E...,TCGA-D3-A3CB-06,NON-ACRAL CUTANEOUS,True,0.37,Common,...,0.1249998,1.0,0,False,False,False,False,True,True,True


In importing this model check that:
- it loads the ORIGINAL weights of the patchwise network (`/weights_pw1.pth` even though it says it fails).
- it loads the SAURABH weights for the image wise network (`'/weights_iw_h2gnew_trn1.pth`)

In [12]:
from src import *
import sys
sys.argv[1:] = []
args = ModelOptions().parse()


------------ Options -------------
batch_size: 64
beta1: 0.9
beta2: 0.999
channels: 1
checkpoints_path: ./checkpoints_0709
cuda: True
dataset_path: /home/sparkar/disks/2021_7_9_set
debug: False
ensemble: 1
epochs: 50
gpu_ids: 0
log_interval: 5
lr: 0.0001
network: 0
no_cuda: False
patch_stride: 256
seed: 1
test_batch_size: 64
testset_path: ./home/sparkar/disks/2021_7_9_set/test
-------------- End ----------------



In [13]:
pw_network = PatchWiseNetwork(args.channels)
iw_network = ImageWiseNetwork(args.channels)
im_model = ImageWiseModel(args, iw_network, pw_network)

Loading "image-wise" model...
Loading "patch-wise" model...


In [14]:
img_fld = tiles_path
# os.listdir(img_fld)

In [15]:
# Use for checking few rows
# df_out = df_tile.reset_index().sample(n=10)
df_out = df_tile.reset_index()

In [16]:
df_out

Unnamed: 0.1,wsi_id,Unnamed: 0,x_tile_coord,y_tile_coord,clinical_donor_id,wsi_name,clinical_sample_id,primary_tumor_type,CNA_data,ABSOLUTE_purity,...,PTEN_prob,APC_prob,APC_BC,OxPhos,APC_OxPhos,APC_BC_OxPhos,P53_MITF,P53_Co,P53_MITF_Co,MITF_Co
0,TCGA-EE-A3JD-01Z-00-DX1.svs,2,1,8,TCGA-EE-A3JD,TCGA-EE-A3JD-01Z-00-DX1.D4E5B644-C7EF-442D-91F...,TCGA-EE-A3JD-06,NON-ACRAL CUTANEOUS,True,0.26,...,1.858032e-14,1.000000,0,False,False,False,True,True,True,True
1,TCGA-D3-A3CB-06Z-00-DX1.svs,4,0,5,TCGA-D3-A3CB,TCGA-D3-A3CB-06Z-00-DX1.9862D604-C9E7-44BE-99E...,TCGA-D3-A3CB-06,NON-ACRAL CUTANEOUS,True,0.37,...,1.249998e-01,1.000000,0,False,False,False,False,True,True,True
2,TCGA-EE-A2MK-01Z-00-DX1.svs,5,8,6,TCGA-EE-A2MK,TCGA-EE-A2MK-01Z-00-DX1.3A8F8407-BA89-46E6-959...,TCGA-EE-A2MK-06,NON-ACRAL CUTANEOUS,True,0.76,...,1.000000e+00,1.000000,0,False,False,False,False,True,True,True
3,TCGA-ER-A2NF-01Z-00-DX1.svs,7,4,9,TCGA-ER-A2NF,TCGA-ER-A2NF-01Z-00-DX1.1468DD2D-6AC8-4657-A02...,TCGA-ER-A2NF-01,NON-ACRAL CUTANEOUS,True,0.58,...,4.795476e-02,0.999850,0,False,False,False,False,True,True,True
4,TCGA-EE-A2MI-01Z-00-DX1.svs,8,9,6,TCGA-EE-A2MI,TCGA-EE-A2MI-01Z-00-DX1.1C56D0A7-3FA7-49A6-BBC...,TCGA-EE-A2MI-06,NON-ACRAL CUTANEOUS,True,0.66,...,1.546927e-17,1.000000,0,False,False,False,True,False,True,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
19916,TCGA-D3-A2J9-06Z-00-DX1.svs,21989,1,5,TCGA-D3-A2J9,TCGA-D3-A2J9-06Z-00-DX1.5526CFD6-96AB-49F8-B88...,TCGA-D3-A2J9-06,NON-ACRAL CUTANEOUS,True,0.59,...,3.670689e-01,0.999996,0,False,False,False,False,True,True,True
19917,TCGA-BF-A5EO-01Z-00-DX1.svs,21990,13,7,TCGA-BF-A5EO,TCGA-BF-A5EO-01Z-00-DX1.1BA74189-485E-4ABF-831...,TCGA-BF-A5EO-01,NON-ACRAL CUTANEOUS,True,0.61,...,8.264103e-01,0.766232,0,True,True,True,False,False,False,False
19918,TCGA-EE-A20C-01Z-00-DX1.svs,21991,7,3,TCGA-EE-A20C,TCGA-EE-A20C-01Z-00-DX1.48BAD79E-DFC8-44A7-92F...,TCGA-EE-A20C-06,NON-ACRAL CUTANEOUS,True,0.89,...,7.886669e-01,0.011984,0,False,False,False,False,True,True,True
19919,TCGA-ER-A2NG-01Z-00-DX1.svs,21992,2,4,TCGA-ER-A2NG,TCGA-ER-A2NG-01Z-00-DX1.35B12E55-502A-4B87-A68...,TCGA-ER-A2NG-06,NON-ACRAL CUTANEOUS,True,0.49,...,1.194280e-01,0.029790,0,False,False,False,False,True,True,True


In [17]:
LABELS = ['Common', 'MITF-Low', 'OxPhos', 'Other']

In [18]:
preds = []
for ix, entry in tqdm(df_out.iterrows()):
    file = entry.wsi_id.split('.svs')[0] + '_' + str(entry.x_tile_coord) + '_' + str(entry.y_tile_coord) + '.png'
    tile_path = os.path.join(img_fld, file)
    assert os.path.isfile(tile_path)
    [[pred, conf, file]], output = im_model.test(tile_path, ensemble=args.ensemble == 1, verbose=False)
    softmax_res = torch.sum(np.exp(output.cpu().data), dim=0)
    probs = np.array(softmax_res / torch.sum(softmax_res))
    preds.append(probs)

19921it [7:00:07,  1.27s/it]


In [19]:
df_out['pred_prob'] = preds

In [34]:
df_out[['wsi_id','x_tile_coord','y_tile_coord','clinical_donor_id','wsi_name','clinical_sample_id','primary_tumor_type','CNA_data','ABSOLUTE_purity','rna_subtype','BetaCAT','P53','PTEN','APC', 'pred_prob']]

Unnamed: 0,wsi_id,x_tile_coord,y_tile_coord,clinical_donor_id,wsi_name,clinical_sample_id,primary_tumor_type,CNA_data,ABSOLUTE_purity,rna_subtype,BetaCAT,P53,PTEN,APC,pred_prob
15236,TCGA-D3-A2J8-06Z-00-DX1.svs,7,9,TCGA-D3-A2J8,TCGA-D3-A2J8-06Z-00-DX1.5FB8E98F-EAFF-491F-856...,TCGA-D3-A2J8-06,NON-ACRAL CUTANEOUS,True,0.29,MITF-Low,0,0,1,0,"[0.024246216, 0.37138513, 0.03001649, 0.57435215]"
13737,TCGA-D3-A1QA-07Z-00-DX1.svs,5,4,TCGA-D3-A1QA,TCGA-D3-A1QA-07Z-00-DX1.FF80DD52-540E-4378-BAC...,TCGA-D3-A1QA-06,NON-ACRAL CUTANEOUS,True,0.92,Common,0,1,0,0,"[0.10113474, 0.012738679, 0.8591208, 0.027005829]"
14691,TCGA-GF-A6C8-06Z-00-DX1.svs,7,7,TCGA-GF-A6C8,TCGA-GF-A6C8-06Z-00-DX1.9388CB1A-BF64-4CE8-AAC...,TCGA-GF-A6C8-06,NON-ACRAL CUTANEOUS,True,0.68,Common,0,1,0,0,"[0.123553514, 0.3047312, 0.24140185, 0.3303135]"
13293,TCGA-RP-A694-01Z-00-DX1.svs,1,0,TCGA-RP-A694,TCGA-RP-A694-01Z-00-DX1.31B4F597-0CDA-4A3B-B34...,TCGA-RP-A694-06,NON-ACRAL CUTANEOUS,True,0.65,Common,0,0,1,0,"[0.12726061, 0.05458201, 0.62833977, 0.1898176]"
11138,TCGA-GN-A8LK-06Z-00-DX1.svs,7,1,TCGA-GN-A8LK,TCGA-GN-A8LK-06Z-00-DX1.C529E9FC-003A-4E7D-A42...,TCGA-GN-A8LK-06,NON-ACRAL CUTANEOUS,True,0.95,OxPhos,0,1,0,1,"[0.0050443173, 1.7647683e-06, 0.9948322, 0.000..."
9500,TCGA-FS-A1ZD-06Z-00-DX2.svs,7,2,TCGA-FS-A1ZD,TCGA-FS-A1ZD-06Z-00-DX2.B2BF872D-1690-49FC-A8E...,TCGA-FS-A1ZD-06,NON-ACRAL CUTANEOUS,True,0.87,MITF-Low,0,0,1,0,"[0.21602325, 0.44300774, 0.15697816, 0.18399087]"
16438,TCGA-GN-A264-01Z-00-DX1.svs,4,3,TCGA-GN-A264,TCGA-GN-A264-01Z-00-DX1.2C02EEF5-92DF-4DAF-BDE...,TCGA-GN-A264-06,NON-ACRAL CUTANEOUS,True,0.83,Common,0,0,0,0,"[0.103794605, 0.47655484, 0.29156864, 0.12808193]"
1717,TCGA-EE-A20F-01Z-00-DX1.svs,6,4,TCGA-EE-A20F,TCGA-EE-A20F-01Z-00-DX1.78401458-3FDA-4D9D-856...,TCGA-EE-A20F-06,NON-ACRAL CUTANEOUS,True,0.7,Common,0,0,1,0,"[0.9531235, 0.04120145, 2.4798052e-05, 0.00565..."
17957,TCGA-GN-A4U8-06Z-00-DX1.svs,0,5,TCGA-GN-A4U8,TCGA-GN-A4U8-06Z-00-DX1.BD0E3D01-7FE9-40B6-9C8...,TCGA-GN-A4U8-06,NON-ACRAL CUTANEOUS,True,0.52,Common,0,0,1,0,"[0.0061760247, 0.014887752, 0.10771593, 0.8712..."
7322,TCGA-DA-A1HV-01Z-00-DX1.svs,13,3,TCGA-DA-A1HV,TCGA-DA-A1HV-01Z-00-DX1.1206FAE0-BC4B-494B-923...,TCGA-DA-A1HV-06,NON-ACRAL CUTANEOUS,True,0.65,Common,0,1,0,0,"[0.37403658, 0.36738807, 0.07040305, 0.18817234]"


In [20]:
# ne salvo uno per fare test veloci
df_out.to_csv('../results/0714_pred.csv')