# Data Visualization

The purpose of this notebook is to generate figures which provide insight into the distribution of our training data.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Image
from IPython.core.debugger import set_trace
from torch import nn as nn
import torch
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning import LightningDataModule
from torch.utils.data import Dataset, DataLoader
from sklearn import preprocessing
from torch.nn import TransformerEncoderLayer, TransformerEncoder
from torch.nn import functional as F
from Data.Drosophilla import FlyDataMod as dm

In [None]:
cell_line_hics = [
    "Data/Drosophilla/HiC_Maps/GSE69013_S2_merged_IC-heatmap-20K.txt",
    "Data/Drosophilla/HiC_Maps/GSE69013_KC_merged_IC-heatmap-20K.txt",
    "Data/Drosophilla/HiC_Maps/GSE69013_BG3_merged_IC-heatmap-20K.txt",
]

cell_lines = ['S2','KC','BG']
chro_mats = {}

for cell_line, fn in zip(cell_lines, cell_line_hics):
    hic  = np.loadtxt(fn,
          dtype=str,
             skiprows=1)
    mat   = hic[:,1:].astype(float)
    bins  = hic[:,0]
    np.save("bins.npy", bins)
    def getChro(x):
        return x.split(":")[0]
    chro_at_bins = np.array(list(map(getChro, bins)))
    chros     = np.unique(chro_at_bins)
    chros     = np.delete(chros, 4)# the epigentic data is not used for chro 4
    for chro in chros:
        d1_mask = chro_at_bins != chro
        chro_mats[cell_line, chro] = np.delete(np.delete(mat,
                            d1_mask,
                           axis=0), d1_mask,
                           axis=1)  


In [None]:
# chart explaining Insulation score
class computeInsulation(torch.nn.Module):
    def __init__(self, window_radius=10, deriv_size=10):
        super(computeInsulation, self).__init__()
        self.window_radius = window_radius
        self.deriv_size  = deriv_size
        self.di_pool     = torch.nn.AvgPool2d(kernel_size=(2*window_radius+1), stride=1) #51
        self.top_pool    = torch.nn.AvgPool1d(kernel_size=deriv_size, stride=1)
        self.bottom_pool = torch.nn.AvgPool1d(kernel_size=deriv_size, stride=1)

    def forward(self, x):
        iv     = self.di_pool(x)
        iv     = torch.diagonal(iv, dim1=2, dim2=3)*((self.window_radius*2+1)**2)
        iv_no  = iv
        iv     = torch.log2(iv/torch.mean(iv))
        top    = self.top_pool(iv[:,:,self.deriv_size:])
        bottom = self.bottom_pool(iv[:,:,:-self.deriv_size])
        dv     = (top-bottom)
        left   = torch.cat([torch.zeros(dv.shape[0], dv.shape[1],2), dv], dim=2)
        right  = torch.cat([dv, torch.zeros(dv.shape[0], dv.shape[1],2)], dim=2)
        band   = ((left<0) == torch.ones_like(left)) * ((right>0) == torch.ones_like(right))
        band   = band[:,:,2:-2]
        boundaries = []
        for i in range(0, band.shape[0]):
            cur_bound = torch.where(band[i,0])[0]+self.window_radius+self.deriv_size
            boundaries.append(cur_bound)
        return iv_no, iv, dv, boundaries


In [None]:
import torch
from torch.nn import functional as F
class computeDirectionality(torch.nn.Module):
    def __init__(self,
                 radius=2
                ):
        self.up   = torch.zeros((2*radius+1, 2*radius+1))
        self.down = torch.zeros((2*radius+1, 2*radius+1))
        self.down[radius+1:,radius]    = 1
        self.up[:radius, radius]       = 1
        self.up   = torch.unsqueeze(self.up, 0)
        self.up   = torch.unsqueeze(self.up, 0)
        self.down = torch.unsqueeze(self.down, 0)
        self.down = torch.unsqueeze(self.down, 0)
        
    def forward(self, x):
        a       = F.conv2d(x, self.up)
        b       = F.conv2d(x, self.down)
        e       = (a+b)/2
        sign    =  torch.sign(b-a)
        term    = ((((a-e)**2)/e)+(((b-e)**2)/e))
        di      = sign * term
        di      = di.squeeze()
        di_vec  = torch.diagonal(di)
        return  a, b, e, term, di_vec

In [None]:
cell_line_hic = "Data/Drosophilla/HiC_Maps/GSE69013_S2_merged_IC-heatmap-20K.txt"
cell_line     ='S2'
win_radius    = 2
deriv_size    = 2 
insul_no_norm = {}
insul_vecs    = {}
dv_vecs       = {}
chro          = '2L'
radius=4
insulationComputer = computeInsulation(window_radius=win_radius,
                                      deriv_size=deriv_size)
directionComputer  = computeDirectionality(radius=radius)
mat_torch = torch.unsqueeze(torch.unsqueeze(
    torch.from_numpy(chro_mats[cell_line, chro]),
    dim=0), dim=1)/10
mat_torch = torch.round(mat_torch)
iv_no, iv, dv, boundaries     =  insulationComputer(mat_torch.to(dtype=torch.float32))
mat_torch = torch.unsqueeze(torch.unsqueeze(
    torch.from_numpy(chro_mats[cell_line, chro]),
    dim=0), dim=1)/10
a, b, e, term, di_vec = directionComputer.forward(mat_torch.to(dtype=torch.float32))


def formall(ax, 
            data,
           fm,
           sz=10):
    for (j,i), label in np.ndenumerate(data):
        strr ="{:."+str(fm)+"f}"
        ax.text(i,j, 
                strr.format(label),
                ha='center', 
                va='center',
               size=sz)
    positions = ['right','top','left','bottom']
    for p in positions:
        ax.spines[p].set_visible(False)
    ax.set_xticks([])
    ax.set_yticks([])
    return ax


In [None]:
print(mat_torch.shape)
print(iv.shape)

fig, ax = plt.subplots(figsize=(5,5))
data    = mat_torch[0,0,0:12,0:12]
ax.imshow(data, cmap="Reds", vmax=40)
formall(ax, data, fm=0)
plt.show()

#insulation
fig, ax = plt.subplots(figsize=(4,1))
data    = iv_no[0,:,0:8]
ax.imshow(data, cmap="Greens", vmin=0, vmax=400)
formall(ax,data, fm=0)
plt.show()

#dv
fig, ax = plt.subplots(figsize=(3.8,1))
data    = iv[0,:,0:8]
ax.imshow(data, cmap="Greens", vmax=2)
formall(ax,data, fm=2)
plt.show()

fig, ax = plt.subplots(figsize=(3.5,1))
data    = 2*dv[0,:,0:6]
ax.imshow(data, cmap="Greens",vmax=3)
formall(ax,data, fm=2)
plt.show()


In [None]:
#Direction
print(a.shape)
print(b.shape)


fig, ax = plt.subplots(figsize=(1,3.5))
data=torch.transpose(torch.unsqueeze(torch.diagonal(a[0,0,0:6,0:6]),0),0,1)
ax.imshow(data,cmap="Blues", vmin=20, vmax=80)
formall(ax, data, fm=0)
plt.show()

fig, ax = plt.subplots(figsize=(1,3.5))
data=torch.transpose(torch.unsqueeze(torch.diagonal(b[0,0,0:6,0:6]),0),0,1)
ax.imshow(data,cmap="Blues", vmin=20, vmax=80)
formall(ax, data, fm=0)
plt.show()

fig, ax = plt.subplots(figsize=(1,3.5))
data=torch.transpose(torch.unsqueeze(torch.diagonal(e[0,0,0:6,0:6]),0),0,1)
ax.imshow(data,cmap="Blues", vmin=0, vmax=150)
formall(ax, data, fm=0)
plt.show()

fig, ax = plt.subplots(figsize=(3.5,1))
data=torch.unsqueeze(torch.diagonal(term[0,0,0:6,0:6]),0)
ax.imshow(data,cmap="Blues", vmin=0, vmax=25)
formall(ax, data, fm=0)
plt.show()

fig, ax = plt.subplots(figsize=(3.5,1))
ax.imshow(torch.unsqueeze(di_vec[0:6],0), cmap="Blues", vmin=-4, vmax=25)
formall(ax, torch.unsqueeze(di_vec[0:6],0), fm=0)
plt.show()

In [None]:
#Gamma
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
gamma_df = np.loadtxt("Data/Drosophilla/clean_labels.csv",
                     dtype=str,
                     skiprows=1,
                     delimiter=',')
gammas = gamma_df[:,7].astype(float)[0:13]


exps = 30 
tad_gammas = np.zeros((len(gammas), exps))
for i in range(0, exps):
    tad_gammas[:,i] = gammas>(.1*i)


fig, ax = plt.subplots(1, figsize=(2.5,5))
ax.imshow(tad_gammas, 
          cmap="OrRd",
          aspect='auto',
          origin='upper',
          interpolation='None')

ax.grid(color='black', linewidth=.1, which='minor')
ax.set_xticklabels(np.arange(0,4.1,1))
ax.xaxis.set_major_locator(MultipleLocator(10))
ax.xaxis.set_minor_locator(MultipleLocator(1))


ax.spines['left'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.show()


print(gammas.shape)
fig, ax = plt.subplots(1, figsize=(3.5,1))
x = list(range(0, len(gammas)))
y = np.round(gammas ,decimals=2)
print(x)
ax.plot(x,y,color='orange')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_ylim(0,4)



fig, ax = plt.subplots(figsize=(3.5,1))
data=np.expand_dims(y,0)[:,0:6]
ax.imshow(data,cmap="Oranges", vmin=0, vmax=5)
formall(ax, data, fm=2)
plt.show()
print(data.shape)

In [None]:
def buildValueHistogram(cell_line,
                       label_type,
                       label_val):
    data_win_radius = 1
    batch_size      = 1
    dmod = dm.FlyDataModule(cell_line=cell_line,
                      data_win_radius=data_win_radius,
                      batch_size=batch_size,
                      label_type=label_type,
                      label_val=label_val)
    dataset = dmod.FlyDataset(cell_line=cell_line,
                           tvt="full",
                           data_win_radius=data_win_radius,
                           label_type=label_type,
                           label_val=label_val)
    return dataset.label_vecs[:,1,0]

In [None]:
l_types  = ['gamma', 'insulation', 'directionality']
colors   = ['goldenrod', 'forestgreen','cornflowerblue']
fig, ax = plt.subplots(1,3, figsize=(10,3))
ax[0].set_ylabel("Num of Bins", fontfamily='sans-serif', fontsize=14)
for i, l_type in enumerate(l_types):
    x = buildValueHistogram(cell_line="S2",
                            label_type=l_type,
                            label_val=10)                     
    ax[i].hist(x, rwidth=0.95, color=colors[i], bins=20)
    ax[i].spines['top'].set_visible(False)
    ax[i].spines['right'].set_visible(False)
    ax[i].set_xlabel(l_type.capitalize(), fontfamily='sans-serif', fontsize=14)
plt.show()