# Inference 
- Perform inference on test set using the pre-trained model (use the weights saved in '../code/checkpoints').
- Test notebook for script.
- Script generates `0107_pred_script.csv`.

#### Notes
- Run pre-trained model on tile test set.
- Test set contains all WSIs except those used for training.
- The WSI - EH-15740-002.svs (associated with animal "123 XL") was renamed to EH-15740-000.svs in all files to avoid confusion with the WSI with the same name EH-15740-002.svs (associated with animal "201-223 X4") and same genotype. 

***
## Inference (test code for script)

In [None]:
import torch

In [1]:
'Load metadata'
data_fld = '../data'
meta_fld = os.path.join(data_fld, 'mouse_data')
path = os.path.join(meta_fld, 'wsi_meta.csv')
df_wsi = pd.read_csv(path, index_col=0)
path = os.path.join(meta_fld, 'tile_meta_v2.csv')
df_tile = pd.read_csv(path, index_col=0)

In [2]:
# df_tile

In [4]:
'Load info about training set'
#sp_fld = os.path.join(meta_fld, 'sp_h2g_new_trn_1')
path = os.path.join(meta_fld, 'tile_train_1_1108.csv')
df_trn = pd.read_csv(path, index_col=0)
len(df_trn)

19664

In [5]:
df_trn.to_csv('../results/Extended_Data_Table_1.csv')

In [6]:
'Build df with train/test bt finding WSI used in training (overestimates)'
wsi_trn = df_trn[df_trn.is_in_train].wsi_name.unique()
# wsi_trn = [wsi_file.split('.')[0] for wsi_file in wsi_trn]
len(wsi_trn)
all_wsi = df_tile.wsi_name.unique()
# all WSI used in training are found (even though some are duplicated)
assert len([wsi in all_wsi for wsi in wsi_trn]) == len(wsi_trn)
df = df_tile.copy()
df['is_train'] = df.wsi_name.isin(wsi_trn)
df['is_test'] = ~df.is_train
df.is_train.sum()

5982

In [7]:
len(df.groupby('wsi_id').count())

156

In [19]:
mask = df.is_train
len(df[mask].groupby('wsi_id').count())

58

In [20]:
len(df[mask])

5982

In [21]:
mask = df.is_test
len(df[mask].groupby('wsi_id').count())

98

In [24]:
58 + 98

156

In [22]:
len(df[mask])

17609

In [23]:
17609 + 5982

23591

In [8]:
g = df.groupby(['genotype', 'wsi_id'])
g.size().to_frame().reset_index()['genotype'].value_counts()

CBTP        62
CBT         27
CBTP3       21
CBT3        20
CBTA        12
CBTPA       11
Multiple     3
Name: genotype, dtype: int64

In [9]:
mask = df_trn.is_in_train == True
len(df_trn[mask].wsi_name.unique())

56

In [10]:
len(df_trn[~mask].wsi_name.unique())

92

In [11]:
trn_animals = df_trn[mask].animal_id.unique()
len(trn_animals)

17

In [12]:
df[df.animal_id.isin(trn_animals)]

pd.merge(df, trn_animals_dr, how = 'inner', on = ['animal_id', 'dissected_region'])

Unnamed: 0_level_0,wsi_id,wsi_name,genotype,x_tile_coord,y_tile_coord,file,batch,animal_id,dissected_region,is_CBTP_old,unique_wsi,connected_component,is_train,is_test
Unnamed: 0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
1452,Eran Hodis 2020-09-15-001_CBT3,Eran Hodis 2020-09-15-001.svs,CBT3,38,13,Eran Hodis 2020-09-15-001_38_13_CBT3.png,L41,125,XLR,False,Eran Hodis 2020-09-15-001.svs,4,True,False
1453,Eran Hodis 2020-09-15-001_CBT3,Eran Hodis 2020-09-15-001.svs,CBT3,3,10,Eran Hodis 2020-09-15-001_3_10_CBT3.png,L41,125,XLR,False,Eran Hodis 2020-09-15-001.svs,3,True,False
1454,Eran Hodis 2020-09-15-001_CBT3,Eran Hodis 2020-09-15-001.svs,CBT3,43,12,Eran Hodis 2020-09-15-001_43_12_CBT3.png,L41,125,XLR,False,Eran Hodis 2020-09-15-001.svs,4,True,False
1455,Eran Hodis 2020-09-15-001_CBT3,Eran Hodis 2020-09-15-001.svs,CBT3,19,2,Eran Hodis 2020-09-15-001_19_2_CBT3.png,L41,125,XLR,False,Eran Hodis 2020-09-15-001.svs,1,True,False
1456,Eran Hodis 2020-09-15-001_CBT3,Eran Hodis 2020-09-15-001.svs,CBT3,18,5,Eran Hodis 2020-09-15-001_18_5_CBT3.png,L41,125,XLR,False,Eran Hodis 2020-09-15-001.svs,1,True,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
23543,Eran Hodis 2020-09-15-030_CBTP,Eran Hodis 2020-09-15-030.svs,CBTP,4,1,Eran Hodis 2020-09-15-030_4_1_CBTP.png,L41,195,XR,False,Eran Hodis 2020-09-15-030.svs,1,True,False
23544,Eran Hodis 2020-09-15-030_CBTP,Eran Hodis 2020-09-15-030.svs,CBTP,41,8,Eran Hodis 2020-09-15-030_41_8_CBTP.png,L41,195,XR,False,Eran Hodis 2020-09-15-030.svs,2,True,False
23545,Eran Hodis 2020-09-15-030_CBTP,Eran Hodis 2020-09-15-030.svs,CBTP,5,2,Eran Hodis 2020-09-15-030_5_2_CBTP.png,L41,195,XR,False,Eran Hodis 2020-09-15-030.svs,1,True,False
23546,Eran Hodis 2020-09-15-030_CBTP,Eran Hodis 2020-09-15-030.svs,CBTP,41,5,Eran Hodis 2020-09-15-030_41_5_CBTP.png,L41,195,XR,False,Eran Hodis 2020-09-15-030.svs,2,True,False


In [13]:
len(df_trn[~mask].animal_id.unique())

27

In [14]:
len(df_trn[mask].batch.unique())

3

In [15]:
len(df_trn[~mask].batch.unique())

3

In [16]:
'Create test set'
df_tst = df[df.is_test]

In importing this model check that:
- it loads the ORIGINAL weights of the patchwise network (`/weights_pw1.pth` even though it says it fails).
- it loads the SAURABH weights for the image wise network (`'/weights_iw_h2gnew_trn1.pth`)

In [19]:
df_trn.unique_wsi.nunique()

148

In [27]:
df_tst.wsi_id.unique()

array(['Eran Hodis 2020-09-15-012_Multiple',
       'Eran Hodis 2020-09-14-012_Multiple', 'EH-15740-048_Multiple',
       'Eran Hodis 2020-09-14-049_CBT3', 'Eran Hodis 2020-09-15_CBT3',
       'EH-15740-035_CBT3', 'EH-15740-033_CBT3',
       'Eran Hodis 2020-09-14_CBT3', 'Eran Hodis 2020-09-15-022_CBTP3',
       'Eran Hodis 2020-09-14-017_CBTP3', 'EH-15740-028_CBTP3',
       'EH-15740-030_CBTP3', 'Eran Hodis 2020-09-14-021_CBTP3',
       'Eran Hodis 2020-09-15-021_CBTP3',
       'Eran Hodis 2020-09-15-020_CBTP3',
       'Eran Hodis 2020-09-14-023_CBTP3', 'EH-15740-025_CBTP3',
       'Eran Hodis 2020-09-15-019_CBTP3', 'EH-15740-029_CBTP3',
       'EH-15740-024_CBTP3', 'Eran Hodis 2020-09-14-019_CBTP3',
       'Eran Hodis 2020-09-14-020_CBTP3',
       'Eran Hodis 2020-09-15-017_CBTP3',
       'Eran Hodis 2020-09-15-044_CBTA', 'EH-15740-023_CBTA',
       'Eran Hodis 2020-09-14-045_CBTA', 'Eran Hodis 2020-09-15-045_CBTA',
       'EH-15740-022_CBTA', 'Eran Hodis 2020-09-14-046_CBTA',
      

In [22]:
df_tst.unique_wsi.nunique()

98

In [17]:
from src import *
import sys
sys.argv[1:] = []
args = ModelOptions().parse()


------------ Options -------------
batch_size: 64
beta1: 0.9
beta2: 0.999
channels: 1
checkpoints_path: ./checkpoints
cuda: True
dataset_path: /home/tbiancal/melanoma_images/2021_1_7_set
debug: False
ensemble: 1
epochs: 10
gpu_ids: 0
log_interval: 5
lr: 0.001
network: 0
no_cuda: False
patch_stride: 256
seed: 1
test_batch_size: 64
testset_path: ./home/tbiancal/melanoma_images/2021_1_7_set/test
-------------- End ----------------



In [8]:
pw_network = PatchWiseNetwork(args.channels)
iw_network = ImageWiseNetwork(args.channels)
im_model = ImageWiseModel(args, iw_network, pw_network)

Loading "image-wise" model...
Loading "patch-wise" model...
Failed to load pre-trained network from  ./checkpoints/weights_pw1.pth


In [17]:
im_model.label

AttributeError: 'ImageWiseModel' object has no attribute 'label'

In [8]:
img_fld = os.path.join(fld, 'mouse_tiles')

In [49]:
# classes 
classes = ['CBT', 'CBTA', 'CBTP', 'CBT3', 'CBTPA', 'CBTP3']

In [53]:
preds = []
#df_out = df_tst.sample(n=100)
for ix, entry in tqdm(df_tst.iterrows()):
    tile_fld = os.path.join(img_fld, entry.genotype)
    tile_path = os.path.join(tile_fld, entry.file)
    assert os.path.isfile(tile_path)
    [[pred, conf, file]], output = im_model.test(tile_path, ensemble=args.ensemble == 1, verbose=False)
    # output is a tensor of shape
    softmax_res = torch.sum(np.exp(output.cpu().data), dim=0)
    probs = np.array(softmax_res / torch.sum(softmax_res))
    preds.append(probs)
df_out['pred_prob'] = preds

100it [02:05,  1.26s/it]


In [54]:
# ne salvo uno per fare test veloci
df_out.to_csv('../results/0107_pred_script.csv')