In [None]:
import sys, os, re, random, warnings, subprocess, time
sys.path.append(os.path.dirname(os.getcwd()))
warnings.filterwarnings("ignore")

from tqdm import tqdm_notebook
from itertools import product
from random import shuffle

import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

from torchpgm.model import *
from torchpgm.edge import Edge
from torchpgm.layers import *
from torchpgm.graphic import draw_G
from torchpgm.data import RBMDataWithPAM, RBMData

from sklearn.metrics import * 
from config import *
from utils import *
from utils.pam import *

from walker import * 

In [None]:
device = "cuda"
folder = f"{DATA}/cas9/vink"

batch_size = 400
Nh = 200
Npam = 5
n_epochs = 4000
start_supervision_epoch = 10

l1b = 0.25
l2 = 0.5
lambda_pa = 0.00
lr = 0.0001

visible_layers = ["pi"]
hidden_layers = ["hidden"]

In [None]:
import pandas as pd
from scipy.stats import *

import seaborn as sns
from sklearn.cluster import AgglomerativeClustering
from scipy.stats import *
from sklearn.linear_model import *
from sklearn.ensemble import *
from sklearn.naive_bayes import *
from sklearn.svm import *
from sklearn.neighbors import *
from sklearn.preprocessing import *
from sklearn.model_selection import *


def test_accept(X, pam):
    accept = []
    for x in X:
        accept_ = 1
        for i, (x_, n) in enumerate(zip(x, pam)):
            if x_[NUC_IDS[n]] == 0:
                accept_ = 0
        accept.append(accept_)
    return accept

def test_position(X, i):
    return [int(x[i].sum() == 4) if len(x)>i else 1 for x in X]

def test_position_aa(X, i, j):
    return [int(x[i,j] == 1) if len(x)>i else 1 for x in X]

In [None]:
gammas = gammas = [1]+[2*1.05**i for i in range(50)]+[0]+[1e-7*1.05**i for i in range(100)]+[1.4**i/1000 for i in range(50)]+[0.01*1.05**i for i in range(150)]+[0.0001*1.05**i for i in range(100)]+[1e-5*1.05**i for i in range(50)]+[10000*1.1**i for i in range(50)]+[20*1.08**i for i in range(50)]
gammas = sorted(gammas)
best_epoch = 90

In [None]:
vink_data = torch.load(f"/home/malbranke/data/cas9/vink/data.pt")
data = torch.stack(list(vink_data["x"]),0)
X_train = []
for x in data:
    x_ = torch.zeros(21,736)
    x_[1:] = x
    x_[0] = 1-x.sum(0)
    X_train.append(x_.flatten())
X_train = torch.stack(X_train,0)[-154:]

Xs_train_pi_ssl = []
Xs_train_predict = []
for gamma in tqdm_notebook(gammas):
    model_full_name = f"rbmssl2_pid_h{Nh}_npam{Npam}_gamma{gamma}"
    model_rbm_ssl.load(f"{DATA}/cas9/vink/weights/{model_full_name}_{best_epoch}.h5")
    #model_rbm_ssl.ais()
    model_rbm_ssl = model_rbm_ssl.to("cpu")
    X_pi_ssl = model_rbm_ssl.edges["pi -> hidden"](X_train, False)
    Xs_train_pi_ssl.append(X_pi_ssl)
    Xs_train_predict.append(model_rbm_ssl.classifier(X_pi_ssl))


In [None]:
Y_ = []
for pam in vink_data["y"][-154:]: 
    Y_.append(pam[:4])
Y_ = torch.stack(Y_,0)
Y_train = []
for pam in PAM4:
    Y_train.append(test_accept(Y_, pam))
Y_train = torch.tensor(Y_train).int()

In [None]:
df_pivot = pd.read_csv(f"/home/malbranke/data//cas9//walton/df_pivot.csv", index_col = 0)
X_seq = torch.zeros(len(df_pivot), 40, len(nnz_idx))
for i, seq in enumerate(df_pivot.pi_seq):
    X_seq[i, :20] = torch.tensor(to_onehot([AA_IDS[x] for x in seq], (None, 20)).T)
    
X_test = torch.zeros(len(df_pivot), 21, collias_data["pi_msa"][0].size(-1))
for i, seq in enumerate(df_pivot.pi_msa):
    X_test[i] = torch.tensor(to_onehot([AA_IDS[x]+1 if x in AA else 0 for x in seq], (None, 21)).T).float()
walton_data = {"pam":[torch.ones(4,4) for _ in range(len(df_pivot))], "pi_seq":list(X_seq)}

Xs_pi_ssl = []
Xs_predict = []
preds2 = []
with torch.no_grad():
    for gamma in tqdm_notebook(gammas):
        model_full_name = f"rbmssl2_pid_h{Nh}_npam{Npam}_gamma{gamma}"
        model_rbm_ssl.load(f"{DATA}/cas9/vink/weights/{model_full_name}_{best_epoch}.h5")
        #model_rbm_ssl.ais()
        model_rbm_ssl = model_rbm_ssl.to("cpu")
        X_pi_ssl = model_rbm_ssl.edges["pi -> hidden"](X_test, False)
        Xs_pi_ssl.append(X_pi_ssl)
        Xs_predict.append(model_rbm_ssl.classifier(X_pi_ssl))
        preds2.append(model_rbm_ssl.val_classifier(val_loader_labelled, visible_layers, hidden_layers, 0))


In [None]:
df_pivot = pd.read_csv(f"/home/malbranke/data//cas9//walton/df_pivot.csv", index_col = 0)
X_seq = torch.zeros(len(df_pivot), 40, len(nnz_idx))
for i, seq in enumerate(df_pivot.pi_seq):
    X_seq[i, :20] = torch.tensor(to_onehot([AA_IDS[x] for x in seq], (None, 20)).T)
    
X_test = torch.zeros(len(df_pivot), 21, collias_data["pi_msa"][0].size(-1))
for i, seq in enumerate(df_pivot.pi_msa):
    X_test[i] = torch.tensor(to_onehot([AA_IDS[x]+1 if x in AA else 0 for x in seq], (None, 21)).T).float()
walton_data = {"pam":[torch.ones(4,4) for _ in range(len(df_pivot))], "pi_seq":list(X_seq)}

Xs_pi_ssl = []
Xs_predict = []
preds3 = []
with torch.no_grad():
    for gamma in tqdm_notebook(gammas):
        model_full_name = f"classifier_rbmssl2_pid_h{Nh}_npam{Npam}_gamma{gamma}"
        model_rbm_ssl.load(f"{DATA}/cas9/vink/weights/{model_full_name}_{best_epoch}.h5")
        #model_rbm_ssl.ais()
        model_rbm_ssl = model_rbm_ssl.to("cpu")
        X_pi_ssl = model_rbm_ssl.edges["pi -> hidden"](X_test, False)
        Xs_pi_ssl.append(X_pi_ssl)
        Xs_predict.append(model_rbm_ssl.classifier(X_pi_ssl))
        preds3.append(model_rbm_ssl.val_classifier(val_loader_labelled, visible_layers, hidden_layers, 0))

In [None]:
class PAM_classifier2(nn.Module):
    def __init__(self, in_features, out_features, dropout = 0.8):
        super(PAM_classifier2, self).__init__()
        self.bn1 = nn.BatchNorm1d(in_features)
        self.bn2 = nn.BatchNorm1d(50)

        self.linear1 = nn.Linear(in_features, 50)
        self.linear2 = nn.Linear(50, out_features)

    def forward(self, x):
        return self.linear2(F.relu(self.bn2(self.linear1(F.relu(self.bn1(x))))))

    def l1b_reg(self):
        weights1 = self.linear1.weight.data
        return weights1.pow(2).sum(0).sum(0)


df_pivot = pd.read_csv(f"/home/malbranke/data//cas9//walton/df_pivot.csv", index_col = 0)
X_seq = torch.zeros(len(df_pivot), 40, len(nnz_idx))
for i, seq in enumerate(df_pivot.pi_seq):
    X_seq[i, :20] = torch.tensor(to_onehot([AA_IDS[x] for x in seq], (None, 20)).T)
    
X_test = torch.zeros(len(df_pivot), 21, collias_data["pi_msa"][0].size(-1))
for i, seq in enumerate(df_pivot.pi_msa):
    X_test[i] = torch.tensor(to_onehot([AA_IDS[x]+1 if x in AA else 0 for x in seq], (None, 21)).T).float()
walton_data = {"pam":[torch.ones(4,4) for _ in range(len(df_pivot))], "pi_seq":list(X_seq)}

Xs_pi_ssl = []
Xs_predict = []
preds4 = []
classifier = PAM_classifier2(Nh, Npam * 4, dropout = 0.)
model_rbm_ssl.classifier = classifier
with torch.no_grad():
    for gamma in tqdm_notebook(gammas):
        model_full_name = f"classifier2_rbmssl_drelu_pid_h{Nh}_npam{Npam}_gamma{gamma}"
        model_rbm_ssl.load(f"{DATA}/cas9/vink/weights/{model_full_name}_{best_epoch}.h5")
        #model_rbm_ssl.ais()
        model_rbm_ssl = model_rbm_ssl.to("cpu")
        X_pi_ssl = model_rbm_ssl.edges["pi -> hidden"](X_test, False)
        Xs_pi_ssl.append(X_pi_ssl)
        Xs_predict.append(model_rbm_ssl.classifier(X_pi_ssl))
        preds4.append(model_rbm_ssl.val_classifier(val_loader_labelled, visible_layers, hidden_layers, 0))

In [None]:
from scipy.ndimage import *

In [None]:
n = 30
values = np.array([x[1] for x in preds2])
plt.plot(gammas[1:], gaussian_filter1d(values[1:],n, mode="nearest"), c="red")

fvalues = gaussian_filter1d(values[1:],n, mode="nearest")
errors = np.array([np.std(values[i:i+n]) for i in range(1,len(values)-n)])
plt.fill_between(gammas[n//2+1:-n//2], fvalues[n//2:-n//2]-errors, fvalues[n//2:-n//2]+errors, color="red", alpha = 0.2)

plt.xscale("log")
plt.ylim(0.7,0.95)
plt.xlim(1e-6,1e5)
plt.xlabel("Gamma (strength of the classifier)")
plt.ylabel("Balanced Accuracy for PAM prediction")

In [None]:
n = 20
Y = np.array((np.log10(df_pivot[PAM4].values.T) > -3.5), dtype=int)
Y = torch.tensor(np.log10(df_pivot[PAM4].values.T))

In [None]:
def safe_roc_auc_score(y, pred):
    spear = roc_auc_score(y, pred) #spearmanr(y, pred)[0]
    return spear

In [None]:
pears = []
spears = []
p_vals = []

clfs = []
all_preds = []
all_y = []
all_accepted = []
nb_data = []
with torch.no_grad():
    for X_train_pi_ssl, X_pi_ssl, X_pred in tqdm_notebook(zip(Xs_train_pi_ssl, Xs_pi_ssl, Xs_predict)):
        for i, (lit_pam, pam) in enumerate(zip(PAM4,PAM4_tensor)):
            preds = []
            y_train = Y_train[i]
            y_test = Y[i]
            y_test_discrete  = np.array((y_test > -3.5), dtype=int)
            y_test = np.array((y_test > -3.5), dtype=int)

            if np.abs(0.5-np.mean(y_train.float().numpy()))>=0.45 or np.abs(0.5-np.mean(y_test_discrete))>=0.45:
                continue

                all_y.append(y_test)
            all_accepted.append(lit_pam)

            pred = X_pred.view(-1,5,4)[:,torch.arange(4),pam.argmax(0)].mean(-1).detach()
            preds.append(pred.numpy())
            pear_, spear_, p_val_ = [],[],[]

            pear = pearsonr(y_test, pred)[0]
            #spear = spearmanr(y_test, pred)[0] 
            spear = safe_roc_auc_score(y_test, pred)
            p_val = spearmanr(y_test, pred,alternative="greater")[1]

            pears.append(pear)
            spears.append(spear)
            p_vals.append(p_val)

            #print(f"{lit_pam} || CLASSIFIER || Pearson : {pear:.3f} || Spearman : {spear:.3f} || Significance : {-np.log10(p_val):.2f}")
            pear_, spear_, p_val_ = [],[],[]
            for zero, name, x_train, x_test in zip([], ["LogRegCV"],[X_train_pi_ssl], [X_pi_ssl]):
                clf = LogisticRegressionCV() 
                scaler = StandardScaler().fit(x_train)
                clf.fit(scaler.transform(x_train), y_train)
                pred = clf.predict_proba(scaler.transform(x_test))[:,1]
                pear = pearsonr(y_test, pred)[0]
                spear = spearmanr(y_test, pred, alternative="greater")[0] #
                spear = safe_roc_auc_score(y_test, pred)
                p_val = spearmanr(y_test, pred, alternative="greater")[1]

                preds.append(pred)
                pears.append(pear)
                spears.append(spear)
                p_vals.append(p_val)
                #print(f"{lit_pam} || {name} || Pearson : {pear:.3f} || Spearman : {spear:.3f} || Significance : {-np.log10(p_val):.2f}")
            for name, x_train, x_test in zip(["LogReg"],[X_train_pi_ssl], [X_pi_ssl]):
                clf = LogisticRegression() 
                scaler = StandardScaler().fit(x_train)
                clf.fit(scaler.transform(x_train), y_train)
                pred = clf.predict_proba(scaler.transform(x_test))[:,1]
                pear = pearsonr(y_test, pred)[0]
                spear = spearmanr(y_test, pred, alternative="greater")[0] #
                spear = safe_roc_auc_score(y_test, pred)
                p_val = spearmanr(y_test, pred, alternative="greater")[1]

                preds.append(pred)
                pears.append(pear)
                spears.append(spear)
                p_vals.append(p_val)
                #print(f"{lit_pam} || {name} || Pearson : {pear:.3f} || Spearman : {spear:.3f} || Significance : {-np.log10(p_val):.2f}")

                 #print(f"{lit_pam} || {name} || AUC : {spear:.3f}")
            all_preds.append(preds)
            #print()

In [None]:
np.array(all_preds).shape

In [None]:
np.array(all_preds).reshape(-1, len(all_accepted_idx),2,105)[300,-1,0, 105]

In [None]:
all_accepted_idx = []
with torch.no_grad():
    for X_train_pi_ssl, X_pi_ssl, X_pred in tqdm_notebook(zip(Xs_train_pi_ssl, Xs_pi_ssl, Xs_predict)):
        for i, (lit_pam, pam) in enumerate(zip(PAM4,PAM4_tensor)):
            y_train = Y_train[i]
            y_test = Y[i]
            y_test_discrete  = np.array((y_test > -3.5), dtype=int)
            y_test = np.array((y_test > -3.5), dtype=int)

            if np.abs(0.5-np.mean(y_train.float().numpy()))>=0.45 or np.abs(0.5-np.mean(y_test_discrete))>=0.45:
                continue
            all_accepted_idx.append(i)
        break

In [None]:
Y2 = torch.zeros(16,105)
for k, (j, nuc) in tqdm_notebook(enumerate(product(range(4), "ATCG"))):
    for i, (lit_pam, pam) in enumerate(zip(PAM4,PAM4_tensor)):
        if lit_pam[j] == nuc:
            Y2[k] += torch.exp(Y[i]*np.log(10))
Y2 = Y2.reshape(4,4,-1)
(Y2[:,:]/Y2.sum(1)[:,None])[:,:,-4]
Y2 = Y2.reshape(4,4,-1)
R = 2+ (Y2 * np.log2(Y2)).sum(1)[:,None]
Y2 = Y2/Y2.sum(1)[:,None]*R


In [None]:
Y2 = torch.zeros(16,105)
for k, (j, nuc) in tqdm_notebook(enumerate(product(range(4), "ATCG"))):
    for i, (lit_pam, pam) in enumerate(zip(PAM4,PAM4_tensor)):
        if lit_pam[j] == nuc:
            Y2[k] += torch.exp(Y[i]*np.log(10))
Y2 = Y2.reshape(4,4,-1)
R = 2+ (Y2 * np.log2(Y2)).sum(1)[:,None]
Y2 = Y2/Y2.sum(1)[:,None]*R


In [None]:
gammas

In [None]:
Xs_train_pi_ssl[300].shape

In [None]:
from sklearn.decomposition import *
with torch.no_grad():
    pcaer = PCA(3).fit(Xs_train_pi_ssl[300])
    X_pca = pcaer.transform(Xs_pi_ssl[300])

In [None]:
plt.figure(figsize = (30,120))
for i, pam in enumerate(PAM4):
    y = Y[i]

    plt.subplot(32,8,i+1)
    colors = np.array(["red", "green"])
    plt.scatter(X_pca[:,0],X_pca[:,1], c=colors[(y>-3.5).int().numpy()])
    plt.title(pam)
plt.show()

In [None]:

plt.figure(figsize = (30,120))
for i, pam in enumerate(PAM4):
    y = Y[i]

    plt.subplot(32,8,i+1)
    colors = np.array(["red", "green"])
    plt.scatter(X_pca[:,0],X_pca[:,1], c=colors[(y>-3.5).int().numpy()])
    plt.title(pam)
    plt.xlim(-0.6,0.5)
    plt.ylim(-0.07,0.05)
plt.show()

In [None]:
Xs_predict[300].reshape(-1, 5,4).softmax(-1)[:,:-1].permute(1,2,0).size()

In [None]:
Y2_pred = (Xs_predict[300]/2).reshape(-1, 5,4)[:,:-1].sigmoid().permute(1,2,0)
Y2_pred = Y2_pred/(Y2_pred.sum(1)[:,None])
R = 2+ (Y2_pred * np.log2(Y2_pred)).sum(1)[:,None]
Y2_pred = Y2_pred/Y2_pred.sum(1)[:,None]*R

In [None]:
import seaborn as sns

sns.set_style("whitegrid")

In [None]:
all_accepted_idx = []
with torch.no_grad():
    for X_train_pi_ssl, X_pi_ssl, X_pred in tqdm_notebook(zip(Xs_train_pi_ssl, Xs_pi_ssl, Xs_predict)):
        for i, (lit_pam, pam) in enumerate(zip(PAM4,PAM4_tensor)):
            y_train = Y_train[i]
            y_test = Y[i]
            y_test_discrete  = np.array((y_test > -3.5), dtype=int)
            y_test = np.array((y_test > -3.5), dtype=int)

            if np.abs(0.5-np.mean(y_train.float().numpy()))>=0.45 or np.abs(0.5-np.mean(y_test_discrete))>=0.45:
                continue
            all_accepted_idx.append(i)
        break

In [None]:
0

In [None]:
values = np.median(np.array(spears).reshape(len(gammas),-1,2),1)
n=50
plt.plot(gammas[1:], gaussian_filter1d(values[1:,0],n, mode="nearest"), c="blue")
plt.plot(gammas[1:], gaussian_filter1d(values[1:,1],n,  mode="nearest"), c="red")


fvalues = gaussian_filter1d(values[1:,0],n, mode="nearest")
errors = np.array([np.std(values[i:i+n,0]) for i in range(1,len(values)-n)])
plt.fill_between(gammas[n//2+1:-n//2], fvalues[n//2:-n//2]-errors, fvalues[n//2:-n//2]+errors, color="blue", alpha = 0.2)

plt.plot(gammas[1:], gaussian_filter1d(values[1:,1],n,  mode="nearest"), c="red")

fvalues = gaussian_filter1d(values[1:,1],n, mode="nearest")
errors = np.array([np.std(values[i:i+n,1]) for i in range(1,len(values)-n)])
plt.fill_between(gammas[n//2+1:-n//2], fvalues[n//2:-n//2]-errors, fvalues[n//2:-n//2]+errors, color="red", alpha = 0.2)

plt.legend(["Classifier", "LogReg + L2 over Repr. Layer", "LogReg over Repr. Layer"])
plt.xlabel("Gamma (Strength of the classifer)")
plt.ylabel("Median AUROC")
plt.xlim(1e-2,1e2)
plt.xscale("log")


In [None]:
plt.figure(figsize=(25,70))
for i, idx in zip(range(83),all_accepted_idx):
    lit = PAM4[idx]
    balance_train = int(np.array(Y_train[idx]).mean()*100)

    balance_test = int(np.array((Y[idx] > -3.5), dtype=int).mean()*100)
    
    plt.subplot(17,5,i+1)
    values = np.array(spears).reshape(len(gammas),-1,3)[:,i]
    n=20
    plt.plot(gammas[1:], gaussian_filter1d(values[1:,0],n, mode="nearest"), c="blue")
    plt.plot(gammas[1:], gaussian_filter1d(values[1:,1],n,  mode="nearest"), c="red")
    plt.plot(gammas[1:], gaussian_filter1d(values[1:,2],n,  mode="nearest"), c="green")


    fvalues = gaussian_filter1d(values[1:,0],n, mode="nearest")
    errors = np.array([np.std(values[i:i+n,0]) for i in range(1,len(values)-n)])
    plt.fill_between(gammas[n//2+1:-n//2], fvalues[n//2:-n//2]-errors, fvalues[n//2:-n//2]+errors, color="blue", alpha = 0.2)

    fvalues = gaussian_filter1d(values[1:,1],n, mode="nearest")
    errors = np.array([np.std(values[i:i+n,1]) for i in range(1,len(values)-n)])
    plt.fill_between(gammas[n//2+1:-n//2], fvalues[n//2:-n//2]-errors, fvalues[n//2:-n//2]+errors, color="red", alpha = 0.2)
    
    fvalues = gaussian_filter1d(values[1:,2],n, mode="nearest")
    errors = np.array([np.std(values[i:i+n,1]) for i in range(1,len(values)-n)])
    plt.fill_between(gammas[n//2+1:-n//2], fvalues[n//2:-n//2]-errors, fvalues[n//2:-n//2]+errors, color="green", alpha = 0.2)

    
    plt.legend(["Classifier", "LogReg + L2 over Repr. Layer", "LogReg over Repr. Layer"])
    plt.title(f"{lit} : Train {balance_train}%+ / Test {balance_test}%+")
    plt.xscale("log")


In [None]:
from weblogo import Logo

In [None]:
from logomaker import Logo

In [None]:
idx = -5

df = pd.DataFrame(columns = ["A","T","C","G"])
df.loc[1] = list(Y2[0,:,idx].numpy())
df.loc[2] = list(Y2[1,:,idx].numpy())
df.loc[3] = list(Y2[2,:,idx].numpy())
df.loc[4] = list(Y2[3,:,idx].numpy())

df = pd.DataFrame(columns = ["A","T","C","G"])



# Create subplots


# Draw weblogos for each frequency matrix
for i in range(105):
    
    fig, axs = plt.subplots(1, 2, figsize=(10, 2))
    df.loc[1] = list(Y2[0,:,i].numpy())
    df.loc[2] = list(Y2[1,:,i].numpy())
    df.loc[3] = list(Y2[2,:,i].numpy())
    df.loc[4] = list(Y2[3,:,i].numpy())

    Logo(df,shade_below=.5,ax = axs[0],
                              fade_below=.5,
                              font_name='Arial Rounded MT Bold')

    
    df.loc[1] = list(Y2_pred[0,:,i].numpy())
    df.loc[2] = list(Y2_pred[1,:,i].numpy())
    df.loc[3] = list(Y2_pred[2,:,i].numpy())
    df.loc[4] = list(Y2_pred[3,:,i].numpy())

    Logo(df,shade_below=.5,ax = axs[1],
                              fade_below=.5,
                              font_name='Arial Rounded MT Bold')
    axs[0].set_ylim(0,2)
    axs[1].set_ylim(0,2)
    plt.title(df_pivot.index[i])
    plt.show()
