In [None]:
import glob
import numpy as np
import matplotlib.pyplot as plt
import torch
import optuna
import os
import pytorch_lightning as pl
import yaml

from Data.Drosophilla.FlyDataMod import FlyDataModule
from IPython.core.debugger import set_trace
from Models import Transformer as tr
from torch import nn as nn
from Utils import callbacks as cb
from Utils import evaluations as ev
from Utils import HyperParams as hp
from Utils import loggers as lg

np.random.seed(0)

In [None]:
cell_lines  = ['S2','KC','BG']
label_types = ['gamma','insulation', 'directionality']
label_vals  = [0, 3, 10]
ppaths      = ["Experiments/Table_1_Transformer_Tunning_Gamma/BEST_HPARAMS.yaml",
               "Experiments/Table_2_Transformer_Tunning_Insulation/BEST_HPARAMS.yaml",
               "Experiments/Table_3_Transformer_Tunning_Directionality/BEST_HPARAMS.yaml"]

In [None]:


for label_type, label_val, ppath in zip(label_types, label_vals, ppaths):
    root_dir = "Experiments/Table_Cell_Lines_"+str(label_type)
    if not os.path.isdir(root_dir):
        os.mkdir(root_dir)

    pfile = open(ppath,'r')
    params = yaml.load(pfile)
    
    
    for i, train_line in enumerate(cell_lines):
        logger = lg.DictLogger(i,
                            root_dir)
        trainer = pl.Trainer(
                    gpus=1,
                    logger=logger,
                    max_epochs=50,
                    callbacks=[cb.getcb()],
                    default_root_dir=root_dir)
        dm      = FlyDataModule(cell_line=train_line,
                            data_win_radius=params['data_win_radius'],
                            batch_size=params['hparams']['batch_size'],
                            label_type=label_type,
                            label_val=label_val)
            
        dm.setup()
        hparams={'train_line':train_line,
                'label_type':label_type,
                'label_val':label_val,
                'batch_size':params['hparams']['batch_size']}
        
        model_trans = tr.TransformerModule(
                ntoken=params['ntoken'],
                ninp=params['ninp'],
                nhid=params['nhid'],
                nhead=params['nhead'],
                nlayers=params['nlayers'],
                dropout=params['dropout'],
                loss_type=params['loss_type'],
                lr=params['lr'],
                hparams=hparams
            )
        trainer.fit(model_trans, dm)


In [None]:
metrics = ['mse','mae','r2','spearman','pearson']
data    = {}

for label_type, label_val in zip(label_types, label_vals):
    data[label_type]={}
    for met in metrics:
        data[label_type][met] = np.zeros((len(cell_lines), len(cell_lines)))

    for i, train_line in enumerate(cell_lines):
        for j, test_line in enumerate(cell_lines):
            dm = FlyDataModule(cell_line=test_line,
                  data_win_radius=5,
                  batch_size=1,
                  label_type=label_type,
                  label_val=label_val)
            dm.setup()
            
            layer_weights = glob.glob("Experiments/Table_Cell_Lines_"+str(label_type)+"/optuna/version_"+str(i)+"/checkpoints/*")[0]
            model         = tr.TransformerModule.load_from_checkpoint(layer_weights).to("cuda:0")
            vec           = ev.getModelMetrics(model, dm, 'test')
            for m, met in enumerate(metrics):
                data[label_type][met][i,j] = vec[met]
            

In [None]:
data


In [None]:
import matplotlib.gridspec as gridspec
def formall(ax, 
            data,
           fm,
           sz=10):
    for (j,i), label in np.ndenumerate(data):
        strr ="{:."+str(fm)+"f}"
        ax.text(i,j, 
                strr.format(label),
                ha='center', 
                va='center',
               size=sz)
    positions = ['right','top','left','bottom']
    for p in positions:
        ax.spines[p].set_visible(False)
    ax.set_xticks([])
    ax.set_yticks([])
    return ax

fig, ax = plt.subplots(5, 3,figsize=(5,10))
gs1     = gridspec.GridSpec(5,3)
gs1.update(wspace=0.05)

colors  = ["Oranges","Greens","Blues"]
for l, (label_type, color) in enumerate(zip(label_types, colors)):
    for m, met in enumerate(metrics):
        disp    = data[label_type][met]
        ax[m,l] = plt.subplot(gs1[m,l])
        ax[m,l].imshow(disp, cmap=color, vmax=1.75*np.max(disp))
        if met=='mse' and label_type=="directionality":
            formall(ax[m,l], disp, fm=0)
        else:
            formall(ax[m,l], disp, fm=2)
plt.show()