In [None]:
import glob
import numpy as np
import matplotlib.pyplot as plt
import torch
import optuna
import os
import pandas as pd
import pytorch_lightning as pl
import seaborn as sns
import yaml

from Data.Drosophilla.FlyDataMod import FlyDataModule
from IPython.core.debugger import set_trace
from Models import Transformer as tr
from torch import nn as nn
from Utils import callbacks as cb
from Utils import evaluations as ev
from Utils import HyperParams as hp
from Utils import loggers as lg

np.random.seed(0)

In [None]:
#plot hic
hic_df = pd.read_csv("Data/Drosophilla/HiC_Maps/GSE69013_S2_merged_IC-heatmap-20K.txt",
    sep='\t')

dm_gamma      = FlyDataModule(cell_line="S2",
                          data_win_radius=5,
                          batch_size=1,
                          label_type="gamma",
                          label_val=0)
dm_gamma.setup()
dm_insul      = FlyDataModule(cell_line="S2",
                          data_win_radius=5,
                          batch_size=1,
                          label_type="insulation",
                          label_val=3)
dm_insul.setup()
dm_direc      = FlyDataModule(cell_line="S2",
                          data_win_radius=5,
                          batch_size=1,
                          label_type="directionality",
                          label_val=10)
dm_direc.setup()
features      = dm_gamma.train_dataloader().dataset.features


gamma_weights = glob.glob("Experiments/Table_1_Transformer_Tunning_Gamma/optuna/version_8/checkpoints/*")[0]
gamma_model   = tr.TransformerModule.load_from_checkpoint(gamma_weights).to("cuda:0")

insulation_weights = glob.glob("Experiments/Table_2_Transformer_Tunning_Insulation/optuna/version_6/checkpoints/*")[0]
insulation_model   = tr.TransformerModule.load_from_checkpoint(insulation_weights).to("cuda:0")

direction_weights = glob.glob("Experiments/Table_3_Transformer_Tunning_Directionality/optuna/version_5/checkpoints/*")[0]
direction_model   = tr.TransformerModule.load_from_checkpoint(direction_weights).to("cuda:0")



#Input Data
hic_bias    = 0
s_idx = 100
e_idx = 300

hic_data    = hic_df.iloc[s_idx+hic_bias:e_idx+hic_bias,s_idx+hic_bias:e_idx+hic_bias].to_numpy()
fea_data    = features[s_idx:e_idx,:].transpose()

gamma_lab   = dm_gamma.train_dataloader().dataset.labels[s_idx:e_idx]
insul_lab   = dm_insul.train_dataloader().dataset.labels[s_idx:e_idx]
direc_lab   = dm_direc.train_dataloader().dataset.labels[s_idx:e_idx]

gamma_feat  = torch.from_numpy(dm_gamma.train_dataloader().dataset.features[s_idx-5:e_idx+5]).float()
insul_feat  = torch.from_numpy(dm_insul.train_dataloader().dataset.features[s_idx-5:e_idx+5]).float()
direc_feat  = torch.from_numpy(dm_direc.train_dataloader().dataset.features[s_idx-5:e_idx+5]).float()

gamma_feat  = torch.unsqueeze(gamma_feat,dim=0).to("cuda:0")
gamma_pred   = []
for i in range(0, gamma_feat.shape[1]-10):
    seg_vec = gamma_feat[:,i:i+11,:]
    cur_vec = gamma_model(seg_vec).squeeze()
    gamma_pred.append(cur_vec[5].item())

insul_feat  = torch.unsqueeze(insul_feat,dim=0).to("cuda:0")
insul_pred   = []
for i in range(0, insul_feat.shape[1]-10):
    seg_vec = insul_feat[:,i:i+11,:]
    cur_vec = insulation_model(seg_vec).squeeze()
    insul_pred.append(cur_vec[5].item())
    
direc_feat  = torch.unsqueeze(direc_feat,dim=0).to("cuda:0")
direc_pred   = []
for i in range(0, direc_feat.shape[1]-10):
    seg_vec = direc_feat[:,i:i+11,:]
    cur_vec = direction_model(seg_vec).squeeze()
    direc_pred.append(cur_vec[5].item())

In [None]:
# Functions to help with plotting
def pcolormesh_45deg(ax, matrix_c, start=0, resolution=1, *args, **kwargs):
    start_pos_vector = [start+resolution*i for i in range(len(matrix_c)+1)]
    import itertools
    n = matrix_c.shape[0]
    t = np.array([[1, 0.5], [-1, 0.5]])
    matrix_a = np.dot(np.array([(i[1], i[0])
                                for i in itertools.product(start_pos_vector[::-1],
                                                           start_pos_vector)]), t)
    x = matrix_a[:, 1].reshape(n + 1, n + 1)
    y = matrix_a[:, 0].reshape(n + 1, n + 1)
    im = ax.pcolormesh(x, y, np.flipud(matrix_c), *args, **kwargs)
    im.set_rasterized(True)
    return im

In [None]:
import matplotlib.gridspec as grd
NUM_SUBS = 8
gs = grd.GridSpec(ncols=1, 
                  nrows=NUM_SUBS, 
                  height_ratios=[4,1,1,1,1,1,1,15],
                  width_ratios=[1],
                  wspace=0.1,
                  hspace=0.1)
fig, ax  = plt.subplots(NUM_SUBS, figsize=(8,12))

#HiC Data
ax[0] = plt.subplot(gs[0])
pcolormesh_45deg(ax[0], hic_data, cmap="Reds")
ax[0].set_aspect(0.5)
ax[0].set_ylim(0,50)
ax[0].yaxis.tick_right()

#TAD data

ylabs     = ["L", "P", "L", "P","L", "P"]
plot_data = [gamma_lab, gamma_pred, insul_lab, insul_pred, direc_lab, direc_pred]
cols      = ['goldenrod','wheat','forestgreen','palegreen','cornflowerblue','lightblue']
for i, (plda, ylab) in enumerate(zip(plot_data, ylabs)):
    ax[i+1] = plt.subplot(gs[i+1]) 
    ax[i+1].plot(plda, c=cols[i])
    ax[i+1].set_xlim(0,200)
    ax[i+1].set_ylabel(ylab)
    ax[i+1].axes.get_xaxis().set_visible(False)

ax[0].spines['top'].set_visible(False)
ax[0].spines['left'].set_visible(False)
ax[0].spines['right'].set_visible(False)
ax[0].set_yticks([0,50])

for i in range(1, NUM_SUBS):
    ax[i].spines['top'].set_visible(False)
    ax[i].spines['right'].set_visible(False)
    ax[i].spines['left'].set_visible(False)
    if i>2:
        ax[i].spines['bottom'].set_visible(False)
        ax[i].plot(list(range(0, len(plda))),
                   np.repeat(0, len(plda)),
                  linewidth=0.5)
    ax[i].set_xticklabels([])
    ax[i].yaxis.tick_right()
    

labels = open('Data/Drosophilla/s2_kc_bg_scaled_18_features_2901.csv').readline().split(",")[6:]

ax[7] = plt.subplot(gs[-1])
ax[7].imshow(fea_data, 'PuOr', interpolation='none', aspect="auto")
ax[7].set_yticks(range(0, len(labels)))
ax[7].set_yticklabels(labels)
ax[7].spines['top'].set_visible(False)
ax[7].spines['right'].set_visible(False)
ax[7].spines['left'].set_visible(False)
ax[7].spines['bottom'].set_visible(False)

plt.show()
