In [None]:
import os
os.system('pip install -q glob2==0.7 requests pytest-shutil==1.7.0 pyBigWig==0.3.18 urllib3==1.26.14 tqdm==4.64.1 joblib==1.2.0 ipywidgets==8.0.4 biopython')

In [None]:
!rm -r TECSAS/
!git clone https://github.com/ed29rice/TECSAS.git

In [None]:
import TECSAS.TECSAS as TECSAS

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import matplotlib.colors as colors
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data import dataset
from torch.nn import functional as F
from sklearn.metrics import confusion_matrix

In [None]:
#Path to training data and model parameters
dpath='./'

In [None]:
n_neigbors = 14
n_predict = 3
NEXP = 124
nbatches = 8000

emsize = 128 # embedding dimension
d_hid = 64 # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2 # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 8  # number of heads in nn.MultiheadAttention
dropout = 0.01  # dropout probability
nfeatures = NEXP*(2*n_neigbors+1)
ostates = 5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = TECSAS.TECSAS(n_predict, emsize, nhead, d_hid, nlayers, nfeatures, ostates, dropout).to(device)

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print('Number of params:',params)

In [None]:
n_neigbors = 14
n_predict = 3
NEXP = 124
nbatches = 8000

In [None]:
dict_p=torch.load(dpath+'/bv_K562_124.pt')
tmp_dict={}
for k in dict_p.keys():
    tmp_dict['.'.join(k.split('.')[1:])]=dict_p[k]

model.load_state_dict(tmp_dict)
model.eval()

In [None]:
NEXP=124
checkpoint = torch.load(dpath+'/training_info_set_124.pt')
epoch = checkpoint['epoch']
loss = checkpoint['best_val_loss']
train_data = checkpoint['train_data']
test_data = checkpoint['test_data']
ntest_loci = checkpoint['ntest_loci']
loci_indx = checkpoint['loci_indx']

In [None]:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])

In [None]:
def get_batch_test(source: Tensor, i: int, n_predict: int, ndxs=None ):
    data = source[i*bptt:(i+1)*bptt,2*(n_predict-1)+1:][:,:,np.newaxis]
    target = source[i*bptt:(i+1)*bptt,:2*(n_predict-1)+1]
    indexes = ndxs[i*bptt:(i+1)*bptt]
    return data.to(device), target.to(device), indexes

In [None]:
#Predicting subcompartments for K562 test set
bptt = len(train_data)//nbatches
nbatches_eval=len(test_data)//bptt
l=[]
lt=[]
failed_inputs=[]
failed_targets=[]
failed_pred=[]
failed_loci=[]
suc_inputs=[]
suc_targets=[]
suc_pred=[]
suc_loci=[]
with torch.no_grad():
    for batch in range(nbatches_eval):
        data, targets, batch_loci = get_batch_test(test_data, batch,n_predict=n_predict, ndxs=ntest_loci)
        if batch%10==0: print(batch, nbatches_eval, len(targets))
        prediction=model(data,None)[0].argmax(dim=-1)[:,n_predict-1].cpu()
        idx=prediction!=targets[:,n_predict-1].cpu()
        failed_inputs.append(targets[idx,n_predict-1].cpu())
        failed_targets.append(data[idx].cpu())
        failed_pred.append(prediction[idx])
        failed_loci.append(batch_loci[idx])
        idx=prediction==targets[:,n_predict-1].cpu()
        suc_inputs.append(targets[idx,n_predict-1].cpu())
        suc_targets.append(data[idx].cpu())
        suc_pred.append(prediction[idx])
        suc_loci.append(batch_loci[idx])
        l.append(prediction)
        lt.append(targets[:,n_predict-1].cpu())

In [None]:
failed_inputs=np.concatenate(failed_inputs)
failed_pred=np.concatenate(failed_pred)
failed_targets=np.concatenate(failed_targets)
failed_loci=np.concatenate(failed_loci)
suc_inputs=np.concatenate(suc_inputs)
suc_pred=np.concatenate(suc_pred)
suc_targets=np.concatenate(suc_targets)
suc_loci=np.concatenate(suc_loci)
l=np.concatenate(l)
lt=np.concatenate(lt)

In [None]:
print('BT Accuracy:')
print('test:',np.round(np.sum(l==lt)/len(l),4))

In [None]:
conf_matrix_P=np.round(confusion_matrix(l,lt,normalize='true'),2)
print('BT Confusion matrix:')
print(conf_matrix_P)