In [1]:
# For Development and debugging:
# Reload modul without restarting the kernel
%load_ext autoreload
%autoreload 2

In [2]:
import tensorflow as tf
import numpy as np
import pandas as pd
pd.options.display.max_columns = None
import os
import sys
import matplotlib.pyplot as plt
import json
import copy
import time

# Add EXTERNAL_LIBS_PATH to sys paths (for loading libraries)
EXTERNAL_LIBS_PATH = '/home/hhughes/Documents/Master_Thesis/Project/workspace/libs'
sys.path.insert(1, EXTERNAL_LIBS_PATH)

# Load cortum libs
import NN_interpretability as nn_inter
import Data_augmentation as data_aug

# Disable GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [3]:
# Set parameters
params = {}
params['base_path'] = '/home/hhughes/Documents/Master_Thesis/Project/workspace/Interpretability'
params['model_dir'] = 'BL_t1'
params['CMA'] = 'CMA_0'
params['seed'] = 123
params['cells'] = ['340547', '307720', '321021', '232615', '205760', '379184']

In [4]:
# Set script vars
output_path = os.path.join(params['base_path'], 'Score_maps', params['model_dir'])
os.makedirs(output_path, exist_ok=True)
print('Output path:\n{}'.format(output_path))

Output path:
/home/hhughes/Documents/Master_Thesis/Project/workspace/Interpretability/Score_maps/BL_t1


# 1.- Load Data

## 1.1.- Load general data (independent to the model)

In [5]:
# Load metadata
with open(os.path.join(params['base_path'], 'Metadata', 'filtered_metadata.csv'), 'r') as file:
    metadata_df = pd.read_csv(file)
metadata_df.head()

Unnamed: 0,mapobject_id_cell,mapobject_id,plate_name,well_name,well_pos_y,well_pos_x,tpoint,zplane,label,is_border,plate_name_cell,well_name_cell,well_pos_y_cell,well_pos_x_cell,tpoint_cell,zplane_cell,label_cell,is_border_cell,is_mitotic,is_mitotic_labels,is_polynuclei_HeLa,is_polynuclei_HeLa_labels,is_polynuclei_184A1,is_polynuclei_184A1_labels,cell_cycle,cell_type,perturbation,duration,cell_size,00_DAPI_avg,07_H2B_avg,01_CDK9_pT186_avg,03_CDK9_avg,05_GTF2B_avg,07_SETD1A_avg,08_H3K4me3_avg,09_SRRM2_avg,10_H3K27ac_avg,11_KPNA2_MAX_avg,12_RB1_pS807_S811_avg,13_PABPN1_avg,14_PCNA_avg,15_SON_avg,16_H3_avg,17_HDAC3_avg,19_KPNA1_MAX_avg,20_SP100_avg,21_NCL_avg,01_PABPC1_avg,02_CDK7_avg,03_RPS6_avg,05_Sm_avg,07_POLR2A_avg,09_CCNT1_avg,10_POL2RA_pS2_avg,11_PML_avg,12_YAP1_avg,13_POL2RA_pS5_avg,15_U2SNRNPB_avg,18_NONO_avg,20_ALYREF_avg,21_COIL_avg,00_BG488_avg,00_BG568_avg,00_EU_avg,09_SRRM2_ILASTIK_avg,15_SON_ILASTIK_avg,set
0,263042,263055,plate01,J16,0,0,0,0,4,0,plate01,J16,0,0,0,0,4,0,0.0,,0.0,,0.0,,S,184A1,DMSO,720.0,8401.0,39.610201,282.212391,14.709787,34.481345,122.127518,29.408288,48.079878,56.508154,71.84475,93.999276,108.771761,157.8085,85.331438,43.616172,323.670767,128.386167,134.658031,18.805303,57.414965,26.659837,122.008728,35.669525,62.715594,133.643428,77.230008,283.199741,6.970504,92.891114,280.267874,41.649294,350.631178,356.940289,10.909557,8.600135,1.771257,431.543626,7754.927152,9285.622307,train
1,263043,263056,plate01,J16,0,0,0,0,5,0,plate01,J16,0,0,0,0,5,0,0.0,,0.0,,0.0,,S,184A1,DMSO,720.0,12207.0,34.884746,249.743762,14.587963,33.729957,104.78706,25.446742,35.599046,54.977671,48.061287,90.354757,97.722536,145.751775,95.141995,43.284735,308.379265,118.856662,125.728553,14.72206,52.481306,21.451844,122.747143,30.401393,53.993867,124.670275,81.619971,272.290742,9.256049,120.165421,267.31488,39.224038,322.807908,387.219828,10.134157,7.880916,1.749512,461.474236,8818.934136,11041.621938,test
2,263044,263057,plate01,J16,0,0,0,0,6,0,plate01,J16,0,0,0,0,6,0,0.0,,0.0,,0.0,,G1,184A1,DMSO,720.0,15734.0,31.429217,184.722779,7.892179,19.263491,81.864454,23.929299,29.5369,39.606383,39.047856,77.348039,12.910413,149.859577,42.855445,25.530594,219.118552,102.575527,125.740665,5.12014,27.724444,24.18737,49.106275,29.276163,35.680201,85.295744,60.60249,232.305672,8.39551,124.383677,193.414682,36.115297,273.154826,252.115717,7.854372,8.116748,1.803605,372.570739,5740.956972,7330.80844,train
3,263045,263058,plate01,J16,0,0,0,0,7,0,plate01,J16,0,0,0,0,7,0,0.0,,0.0,,0.0,,S,184A1,DMSO,720.0,15767.0,43.34909,241.460906,18.110119,44.601291,122.976267,40.604674,48.76011,58.076779,77.012867,121.283321,104.755645,159.79255,84.807594,43.961097,319.584115,120.314566,120.470444,25.555966,60.24333,21.220701,140.246102,27.944981,60.354207,169.718381,102.371663,264.648847,10.370258,116.424263,268.396552,43.573547,351.399875,357.7626,10.338172,7.913029,1.773024,369.910382,8085.528636,9678.38181,test
4,263047,263060,plate01,J16,0,0,0,0,9,0,plate01,J16,0,0,0,0,9,0,0.0,,0.0,,0.0,,G1,184A1,DMSO,720.0,11930.0,28.164459,149.241408,27.710272,72.874109,148.002903,51.391368,56.976488,116.255474,56.907744,66.943473,97.452245,187.995952,53.530015,60.91758,134.426297,121.980833,121.369045,59.501243,54.852952,26.389251,162.722979,36.026263,79.981471,237.625004,150.247531,255.964144,18.029298,188.560486,192.44675,66.494555,242.252349,174.614508,9.784177,8.899475,1.97013,552.887343,15128.64694,13469.590947,train


In [6]:
# Load parameters
with open(os.path.join(params['base_path'], 'Metadata', 'parameters.json'), 'r') as file:
    model_params = json.load(file)
#model_params.keys()

In [7]:
# Load Channels
with open(os.path.join(params['base_path'], 'Metadata', 'channels.csv'), 'r') as file:
    channels_df = pd.read_csv(file)
# Get input channel ids
mask = channels_df.name.isin(model_params['input_channels'])
input_ids = channels_df[mask].channel_id.values
# Get output channel id
mask = channels_df.name == '00_EU'
output_id = channels_df[mask].channel_id.values[0]
# Get normalization values
norm_vals = channels_df.sort_values(by=['channel_id']).normalization_vals.values

## 1.2.- Load Model Data

In [8]:
# Load models
models = {}
models_path = os.path.join(params['base_path'], 'Models', params['model_dir'])
for model in os.listdir(models_path):
    print('Loading model: ', model)
    models[model] = tf.keras.models.load_model(os.path.join(models_path, model, params['CMA']))
print('')
models[model].summary()

Loading model:  Run_2
Loading model:  Run_1

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 224, 224, 64)      19072     
_________________________________________________________________
batch_normalization (BatchNo (None, 224, 224, 64)      256       
_________________________________________________________________
re_lu (ReLU)                 (None, 224, 224, 64)      0         
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 112, 112, 64)      0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 112, 112, 128)     73856     
_________________________________________________________________
batch_normalization_1 (Batch (None, 112, 112, 128)     512       
_________________________________________________________________
re_lu_1 (Re

In [9]:
# Load model_predictions
models_path = os.path.join(params['base_path'], 'Models', params['model_dir'])
targets_df = pd.DataFrame()
for i, model in enumerate(os.listdir(models_path)):
    print('Reading predicted values for model: ', model)
    temp_path = os.path.join(models_path, model, 'targets_'+params['CMA']+'.csv')
    with open(temp_path, 'r') as file:
        temp_df = pd.read_csv(file)
    temp_df = temp_df.drop(['y - y_hat'], axis=1)
    prediction_name = 'y_hat'+'_'+model
    temp_df[prediction_name] = temp_df.y_hat
    temp_df = temp_df.drop(['y_hat'], axis=1)
    if i == 0:
        targets_df = temp_df.copy()
    else:
        temp_df = temp_df[['mapobject_id_cell', prediction_name]]
        targets_df = targets_df.merge(temp_df, 
                                      left_on='mapobject_id_cell',
                                      right_on='mapobject_id_cell',
                                      how='left')
targets_df.head()

Reading predicted values for model:  Run_2
Reading predicted values for model:  Run_1


Unnamed: 0,y,mapobject_id_cell,set,perturbation,cell_cycle,y_hat_Run_2,y_hat_Run_1
0,337.60318,305670,train,DMSO,G1,375.793793,346.304291
1,436.179108,248989,train,DMSO,G1,448.843842,413.998138
2,344.612427,212764,train,normal,G1,332.148438,339.042633
3,397.731262,383254,train,normal,G1,322.935028,361.397125
4,251.880371,287059,train,DMSO,G1,360.65802,282.111298


## 1.3.- Load random cells

In [None]:
mask = targets_df.mapobject_id_cell.isin(np.array(params['cells'], dtype=np.int64))
targets_df[mask]

In [None]:
cells = {}
n_cells = len(params['cells'])
plt.figure(figsize=(n_cells*11,10))

for i, cell in enumerate(params['cells'], 1):
    temp_path = os.path.join(params['base_path'], 'Cells', cell+'.npz')
    temp_cell = np.load(temp_path)
    # Normalize cell
    cells[cell+'_img'] = copy.deepcopy(temp_cell['img'] / norm_vals)
    # filter accordingly to the input channels
    cells[cell+'_img'] = cells[cell+'_img'][:,:,input_ids].astype(np.float32)
    cells[cell+'_mask'] = copy.deepcopy(temp_cell['mask'])
    
    # Plot cells
    temp_img = (cells[cell+'_img'] / np.max(cells[cell+'_img'], axis=(0,1)))[:,:,10:13]
    plt.subplot(1, n_cells, i)
    nn_inter.plot_cell(img=temp_img, title=cell)

In [None]:
# Sanity check: compute y_hat given the image and using loaded models
for model in os.listdir(models_path):
    print('Model: ', model)
    for cell in params['cells']:
        print('\tmapobject_id_cell: '+cell)

        train_tensor = tf.expand_dims(cells[cell+'_img'], axis=0)
        y_true = round(targets_df.y[targets_df.mapobject_id_cell == int(cell)].values[0], 2)
        y_hat = round(targets_df['y_hat_'+model][targets_df.mapobject_id_cell == int(cell)].values[0], 2)
        y_hat_sanity = round(float(models[model].predict(train_tensor)[0][0]), 2)

        print('\t\ty_true: {}, y_hat: {}, y_hat_sanity: {}'.format(y_true, y_hat, y_hat_sanity))

# 2.- Get Score Matrix for each cell (VarGrad IG)

# 2.1.- VarGrad IG without mask

### Get Score Maps

In [None]:
# get the score maps for each cell and model
tic = time.time()
VarGrad_IG_no_mask = {}
for model in os.listdir(models_path):
    print('\nModel: ', model)
    for cell in params['cells']:
        print('\tProcessing mapobject_id_cell: '+cell)
        
        key = cell + '_' + model
        VarGrad_IG_no_mask[key] = nn_inter.get_VarGrad(img=cells[cell+'_img'],
                                                       baseline='black',
                                                       model=models[model],
                                                       n_images=15)
tac = time.time()
print('Score maps finished in (mins): {}'.format(round((tac-tic)/60),2))

### Save Score Maps

In [None]:
# save score maps
prefix = 'no_mask_'

for model in os.listdir(models_path):
    for cell in params['cells']:
        print('Saving Score map for cell {} and model {}'.format(cell, model))
        key = cell + '_' + model
        file_path = os.path.join(output_path, prefix+key+'.npy')
        np.save(file_path, VarGrad_IG_no_mask[key].numpy())

### Plote Score maps for both models and compare

In [None]:
# Plot score maps for both models
for cell in params['cells']:
    print('Plotting cell: '+cell)
    nn_inter.plot_VarGrad_IG_with_control(img=cells[cell+'_img'],
                                          score_map_1=VarGrad_IG_no_mask[cell+'_Run_1'],
                                          score_map_2=VarGrad_IG_no_mask[cell+'_Run_2'],
                                          top_percent=1,
                                          channels_df=channels_df,
                                          img_size=(7,7),
                                          score_map_same_sacale=True)

In [None]:
# Clean memory
del(VarGrad_IG_no_mask)

# 2.2.- VarGrad IG mask

### Get Score Maps

In [None]:
VarGrad_IG = {}
tic = time.time()
for model in os.listdir(models_path):
    print('\nModel: ', model)
    for cell in params['cells']:
        print('\tProcessing mapobject_id_cell: '+cell)
        
        # create temp cell mask that matchs the cell image shape
        n_channels = cells[cell+'_img'].shape[-1]
        temp_mask = np.repeat(cells[cell+'_mask'][:,:,None], n_channels, axis=2)
        
        key = cell + '_' + model
        VarGrad_IG[key] = nn_inter.get_VarGrad(img=cells[cell+'_img'],
                                               img_mask=temp_mask,
                                               baseline='black',
                                               model=models[model],
                                               n_images=15)
tac = time.time()
print('Score maps finished in: {} mins'.format(round((tac-tic)/60),2))

### Save Score Maps

In [None]:
# save score maps
prefix = ''

for model in os.listdir(models_path):
    for cell in params['cells']:
        print('Saving Score map for cell {} and model {}'.format(cell, model))
        key = cell + '_' + model
        file_path = os.path.join(output_path, prefix+key+'.npy')
        np.save(file_path, VarGrad_IG[key].numpy())

### Plote Score maps for both models and compare

In [None]:
# Plot score maps for both models
for cell in params['cells']:
    print('Plotting cell: '+cell)
    nn_inter.plot_VarGrad_IG_with_control(img=cells[cell+'_img'],
                                          img_mask=cells[cell+'_mask'],
                                          score_map_1=VarGrad_IG[cell+'_Run_1'],
                                          score_map_2=VarGrad_IG[cell+'_Run_2'],
                                          top_percent=1,
                                          channels_df=channels_df,
                                          img_size=(7,7),
                                          score_map_same_sacale=True)

### Plot Top 5%, 10%, 20%, 50% pixels

### 5%

In [None]:
for cell in params['cells']:
    print('Plotting cell: '+cell)
    nn_inter.plot_VarGrad_IG_with_control(img=cells[cell+'_img'],
                                          img_mask=cells[cell+'_mask'],
                                          score_map_1=VarGrad_IG[cell+'_Run_1'],
                                          score_map_2=VarGrad_IG[cell+'_Run_2'],
                                          top_percent=0.05,
                                          channels_df=channels_df,
                                          img_size=(7,7),
                                          score_map_same_sacale=False)

### 10%

In [None]:
for cell in params['cells']:
    print('Plotting cell: '+cell)
    nn_inter.plot_VarGrad_IG_with_control(img=cells[cell+'_img'],
                                          img_mask=cells[cell+'_mask'],
                                          score_map_1=VarGrad_IG[cell+'_Run_1'],
                                          score_map_2=VarGrad_IG[cell+'_Run_2'],
                                          top_percent=0.1,
                                          channels_df=channels_df,
                                          img_size=(7,7),
                                          score_map_same_sacale=False)

### 20%

In [None]:
for cell in params['cells']:
    print('Plotting cell: '+cell)
    nn_inter.plot_VarGrad_IG_with_control(img=cells[cell+'_img'],
                                          img_mask=cells[cell+'_mask'],
                                          score_map_1=VarGrad_IG[cell+'_Run_1'],
                                          score_map_2=VarGrad_IG[cell+'_Run_2'],
                                          top_percent=0.2,
                                          channels_df=channels_df,
                                          img_size=(7,7),
                                          score_map_same_sacale=False)

### 50%

In [None]:
for cell in params['cells']:
    print('Plotting cell: '+cell)
    nn_inter.plot_VarGrad_IG_with_control(img=cells[cell+'_img'],
                                          img_mask=cells[cell+'_mask'],
                                          score_map_1=VarGrad_IG[cell+'_Run_1'],
                                          score_map_2=VarGrad_IG[cell+'_Run_2'],
                                          top_percent=0.5,
                                          channels_df=channels_df,
                                          img_size=(7,7),
                                          score_map_same_sacale=False)