In [None]:
import sys, os, re, random, warnings, subprocess, time
import pandas as pd
sys.path.append(os.path.dirname(os.getcwd()))
warnings.filterwarnings("ignore")

from tqdm import tqdm_notebook
from sklearn.metrics import * 

from itertools import product
from random import shuffle

import torch, torch_geometric
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from torchpgm.model import *
from torchpgm.layers import *

from cld.postprocessing import *
from cld.criterion import *
from cld.walker import *

from utils import *
from config import *

import seaborn as sns
from scipy.ndimage import *
sns.set_style("whitegrid")

plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300

In [None]:
device = "cuda"
folder = f"{DATA}/vink"
Nh, Npam = 200, 5
best_epoch = 90
q_pi, N_pi = 21, 736
model_full_name = f"rbmssl2_pid_h{Nh}_npam{Npam}_gamma1"

def lit_to_pam(s):
    pam = []
    s += "N" * max(0, (Npam - len(s)))
    for x in s:
        pam += NAd_idx[x]
    return torch.tensor(pam).float()[None].to(device)

In [None]:
pi = OneHotLayer(None, N=N_pi, q=q_pi, name="pi")
h = GaussianLayer(N=Nh, name="hidden")
classifier = PAM_classifier(Nh, Npam * 4)
E = [(pi.name, h.name)]
E.sort()

model_rbm = PI_RBM_SSL(classifier, layers={pi.name: pi, h.name: h}, edges=E, name=model_full_name)
model_rbm = model_rbm.to(device)
model_rbm.load(f"{folder}/weights/{model_full_name}_{best_epoch}.h5")
model_rbm.eval()
model_rbm = model_rbm.to("cpu")
model_rbm.ais()

In [None]:
x_cas9 = torch.load(f"{DATA}/x_cas9.pt")
zero_idx = torch.load(f"{DATA}/zero_idx.pt")
kept_idx = list(range(736))
target = lit_to_pam("NGG")

In [None]:
with torch.no_grad():
    e0 = (model_rbm({"pi":x_cas9[None]})/736 - model_rbm.Z)[0].item()
e_plage = [(e0-0.03, e0-0.02)]    
sim_plage = [(30, 35),(50, 55)]
x = []
TRACKS = []
for s, e in product(sim_plage,e_plage):
    T = 0.1*torch.ones(1,1,len(kept_idx))
    objective = RbmCriterion(model_rbm, postprocessing=ConstantPostprocessor(0))
    constraints = [
                SimCriterion(x_cas9, list(range(736)), postprocessing = LinearPostprocessor(100, 0,s[0],0,s[1])),
                RbmCriterion(model_rbm, postprocessing = LinearPostprocessor(100, e0 ,e[0],e0,e[1])),
    ]
    weight_constraints = [10,1000]

    walker = Walker(x_cas9.view(21, -1).clone(), model_rbm, objective, constraints, zero_idx, gamma=1, n=1, a=1,
            c=1e-2, eps=1, target=target.cpu(), T=T, weight_constraints = weight_constraints)

    x.append(walker.run(16, n_epochs = 200, verbose=False))
    e = np.concatenate([track["e"][None] for track in walker.TRACKS])
    for e_ in e.T[:10]:
        plt.plot(e_)
    plt.show()
    sim = np.concatenate([track["abs_diff"][None] for track in walker.TRACKS])
    for sim_ in sim.T[:10]:
        plt.plot(sim_)
    plt.show()
    TRACKS.append((deepcopy(e.T[:10]), deepcopy(sim.T[:10])))

In [None]:
n=3
for (track_e, track_s), (s, e) in zip(TRACKS,product(sim_plage,e_plage)):
    plt.figure(figsize = (10,5))
    plt.subplot(121)
    print(e,s)
    for e_ in track_e[:5]:
        plt.plot(gaussian_filter1d(-e_,n,  mode="nearest"))
    plt.plot([0,100],[.15,-e[0]],color="black")
    plt.plot([100,200],[-e[0],-e[0]],color="black")
    plt.plot([0,100],[.15,-e[1]],color="black")
    plt.plot([100,200],[-e[1],-e[1]],color="black")
    plt.ylabel("E_RBM")
    plt.xlabel("Langevin Dynamic Step")

    #plt.show()
    plt.subplot(122)
    for sim_ in track_s[:5]:
        plt.plot(gaussian_filter1d(sim_,n,  mode="nearest"))
    plt.plot([0,100],[0,s[0]],color="black")
    plt.plot([100,200],[s[0],s[0]],color="black")
    plt.plot([0,100],[0,s[1]],color="black")
    plt.plot([100,200],[s[1],s[1]],color="black")
    
    plt.ylabel("Distance to SpyCas9")
    plt.xlabel("Langevin Dynamic Step")
    plt.show()

In [None]:
from pathos.multiprocessing import ProcessingPool as Pool
import subprocess
import multiprocessing
import numpy as np
import pyfoldx as foldx
from pyfoldx.structure import Structure
from sklearn.linear_model import *

from copy import deepcopy

from torch.distributions.one_hot_categorical import OneHotCategorical
NAd_in = {"A":"A", "T":"T", "C":"C", "G":"G",
          "W":"AT", "S":"CG", "M":"AC", "K":"TG", "R":"AG", "Y":"TC",
           "B":"TCG", "D":"ATG", "H":"ATC", "V": "ACG", "N":"ATCG"}
NAd = ["O","A","T","W","C","M","Y","H","G","R","K","D","S","V","B","N"]

pl3to1 = {'CYS': 'C', 'ASP': 'D', 'SER': 'S', 'GLN': 'Q', 'LYS': 'K',
     'ILE': 'I', 'PRO': 'P', 'THR': 'T', 'PHE': 'F', 'ASN': 'N', 
     'GLY': 'G', 'HIS': 'H', 'LEU': 'L', 'ARG': 'R', 'TRP': 'W', 
     'ALA': 'A', 'VAL':'V', 'GLU': 'E', 'TYR': 'Y', 'MET': 'M'}
pl1to3 = {v:k for k,v in pl3to1.items()}

device = "cpu"
def aux(args):
    seq, struct, repair = args
    chain = "A"
    do_repair, idx = repair
    s =  struct
    s_new =  struct
    position = 1102
    for target in seq[:]:
        try:
            if position in s.data[chain].keys():
                res = s.data[chain][position]
                res.code = pl1to3[target]
                s_new.data[chain][position] = res
        except:
            ()
        position += 1
    #s_new.repair()
    return float(s_new.getTotalEnergy().loc["model"]["total"])

def foldx_energy(x_sampled, structs, repair=False, ):
    seqs = []
    repairs = []
    for i, x in enumerate(x_sampled):
        seqs.append([AA[x_] for x_ in x[1:,nnz_idx].cpu().argmax(0)])
        repairs.append((repair,i))
    with multiprocessing.Pool(processes = 32) as pool:
        energies = pool.map(aux, list(zip(seqs, structs, repairs,)))
        
    #energies = [aux(seq, struct, repair) for seq, struct, repair in zip(seqs, structs, repairs)]
    return torch.tensor(energies)

def aux_with_dna(args):
    seq, struct, repair, = args
    do_repair, idx = repair
    s =  struct
    s_new =  struct
    position = 1102
    chain = "A"
    for target in seq:
        try:
            if position in s.data[chain].keys():
                res = s.data[chain][position]
                res.code = pl1to3[target]
                s_new.data[chain][position] = res
        except:
            ()
        position += 1
    #s_new.repair()
    return float(s_new.getTotalEnergy().loc["model"]["total"])

def foldx_energy_with_dna(x_sampled, structs, repair=False, ):
    seqs = []
    repairs = []
    for i, x in enumerate(x_sampled):
        seqs.append([AA[x_] for x_ in x[1:,nnz_idx].cpu().argmax(0)])
        repairs.append((repair,i))
    with multiprocessing.Pool(processes = 32) as pool:
        energies = pool.map(aux_with_dna, list(zip(seqs, structs, repairs,)))
    return torch.tensor(energies)

def aux_with_dna_interface(args):
    seq, struct, repair, = args
    do_repair, idx = repair
    s =  struct
    s_new =  struct
    position = 1
    chain = "A"
    for target in seq[:]:
        try:
            if position in s.data[chain].keys():
                res = s.data[chain][position]
                res.code = pl1to3[target]
                s_new.data[chain][position] = res
        except:
            ()
        position += 1
    if do_repair:
        #s_new.repair()
        structures[idx] = deepcopy(s_new)
    return float(s_new.getInterfaceEnergy(verbose=False)["Interaction Energy"].loc["B"].loc["D"])

def interface_energy_with_dna(x_sampled, structs, repair=False, ):
    seqs = []
    repairs = []
    for i, x in enumerate(x_sampled):
        seqs.append([AA[x_] for x_ in x[1:,nnz_idx].cpu().argmax(0)])
        repairs.append((repair,i))
    with multiprocessing.Pool(processes = 32) as pool:
        structs = [deepcopy(s) for s in structs]
        energies = pool.map(aux_with_dna_interface, list(zip(seqs, structs, repairs,)))
    #energies = [aux(seq, struct, repair) for seq, struct, repair in zip(seqs, structs, repairs)]
    return torch.tensor(energies)

def abs_diff(x, x0):
    x = x.reshape(x.size(0),21,-1)
    x0 = x0.reshape(x0.size(0),21,-1)

    return (x[:,:,nnz_idx].argmax(1) != x0[:,:,nnz_idx].argmax(1)).int().float().sum(-1)

def mean_diff(x, x0):
    x = x.reshape(x.size(0),21,-1)
    x0 = x0.reshape(x0.size(0),21,-1)

    return (x[:,:,nnz_idx].argmax(1) != x0[:,:,nnz_idx].argmax(1)).int().float().mean(-1)

def isd(x):
    x = x.reshape(x.size(0),21,-1)
    return (x[:,0,nnz_idx].mean(-1)) + (1-(x[:,0,zero_idx].mean(-1)))

def litpam_to_pam(s):
    pam = []
    s += "N"*max(0,(Npam-len(s)))
    for x in s:
        pam += NAd_idx[x]
    return torch.tensor(pam).float()[None].to(device)

def find_closest_sequence(x):
    x = x.view(len(x), 21, -1)
    distance = (x[:,:,:].argmax(-2)[:,None] != existing_sequences[:,:,:].argmax(-2)[None]).int().sum(-1).min(1)[0]
    return distance

In [None]:
torch.save(x_cas9, f"{DATA}/x_cas9.pt")

In [None]:
gammas = gammas = [1]+[2*1.05**i for i in range(50)]+[0]+[1e-7*1.05**i for i in range(100)]+[1.4**i/1000 for i in range(50)]+[0.01*1.05**i for i in range(150)]+[0.0001*1.05**i for i in range(100)]+[1e-5*1.05**i for i in range(50)]+[10000*1.1**i for i in range(50)]+[20*1.08**i for i in range(50)]

In [None]:
objective = RbmCriterion(model_rbm, postprocessing=ConstantPostprocessor(0))
constraints = [
    SimCriterion(x_cas9, nnz_idx, postprocessing=ConstantPostprocessor(None, 80)),
    RbmCriterion(model_rbm, postprocessing=ConstantPostprocessor(-0.5, None)),
]

T = 0.3 * torch.ones(1, 1, len(kept_idx))

walker = Walker(x_cas9.view(21, -1).clone(), model_rbm, objective, constraints, zero_idx, gamma=1, n=1, a=1,
                c=1e-2, eps=1, target=target.cpu(), T=T)

In [None]:
gammas = gammas = [1]+[2*1.05**i for i in range(50)]+[0]+[1e-7*1.05**i for i in range(100)]+[1.4**i/1000 for i in range(50)]+[0.01*1.05**i for i in range(150)]+[0.0001*1.05**i for i in range(100)]+[1e-5*1.05**i for i in range(50)]+[10000*1.1**i for i in range(50)]+[20*1.08**i for i in range(50)]
selected_gammas = sorted(gammas)[260:542]
sim_plage = [(5*i,5*i+5) for i in range(2,10)]


In [None]:
from itertools import product
folder = "/home/malbranke/data/cas9/vink"
for gamma in selected_gammas:
    x = []
    TRACKS = []
    model_full_name = f"rbmssl2_pid_h{Nh}_npam{Npam}_gamma{gamma}"

    model_rbm = PI_RBM_SSL(classifier, layers= {pi.name: pi, h.name: h}, edges=E, name = model_full_name)
    model_rbm = model_rbm.to(device)
    model_rbm.load(f"{folder}/weights/{model_full_name}_{best_epoch}.h5")
    model_rbm.eval()
    model_rbm.ais()
    model_rbm = model_rbm.to("cpu")
    
    idx0 = 0
    idx = (idx0*torch.ones(512).int()).to(device)

    criterion = classifier_criterion(classifier, edge, target)
    x__ = torch.cat([X_rbm[0][None] for i in idx],0)

    x_cas9 = torch.clone(x__[0].view(-1))
    h_cas9 = edge(x_cas9[None], False)


    with torch.no_grad():
        e0 = (model_rbm({"pi":x_cas9[None]})/736 - model_rbm.Z)[0].item()
    e_plage = [(e0-0.02, e0+0.02)]

    for s, e in product(sim_plage,e_plage):
        T = 0.1*torch.ones(1,1,len(kept_idx))
        objective = RbmCriterion(model_rbm, postprocessing=ConstantPostprocessor(0))
        constraints = [
                    SimCriterion(x_cas9, list(range(736)), postprocessing = LinearPostprocessor(100, 0,s[0],0,s[1])),
                    RbmCriterion(model_rbm, postprocessing = LinearPostprocessor(100, e0 ,e[0],e0,e[1])),
        ]
        
        walker = Walker(x_cas9.view(21, -1).clone(), model_rbm, objective, constraints, zero_idx, gamma=1, n=1, a=1,
                c=1e-2, eps=1, target=target.cpu(), T=T)

        x.append(walker.run(16, n_epochs = 200, verbose=False))
        e = np.concatenate([track["e"][None] for track in walker.TRACKS])
        for e_ in e.T[:10]:
            plt.plot(e_)
        plt.show()
        sim = np.concatenate([track["abs_diff"][None] for track in walker.TRACKS])
        for sim_ in sim.T[:10]:
            plt.plot(sim_)
        plt.show()
        TRACKS.append((deepcopy(e.T[:10]), deepcopy(sim.T[:10])))
    torch.save((x,TRACKS), f"tracks_{gamma}.pt")

In [None]:
for (track_e, track_s), (s, e) in zip(TRACKS,product(sim_plage,e_plage)):
    plt.figure(figsize = (10,5))
    plt.subplot(121)
    print(e,s)
    for e_ in track_e[:5]:
        plt.plot(-e_)
    plt.plot([0,100],[.15,-e[0]],color="black")
    plt.plot([100,200],[-e[0],-e[0]],color="black")
    plt.plot([0,100],[.15,-e[1]],color="black")
    plt.plot([100,200],[-e[1],-e[1]],color="black")
    plt.ylabel("E_RBM")
    plt.xlabel("Langevin Dynamic Step")

    #plt.show()
    plt.subplot(122)
    for sim_ in track_s[:5]:
        plt.plot(sim_)
    plt.plot([0,100],[0,s[0]],color="black")
    plt.plot([100,200],[s[0],s[0]],color="black")
    plt.plot([0,100],[0,s[1]],color="black")
    plt.plot([100,200],[s[1],s[1]],color="black")
    
    plt.ylabel("Distance to SpyCas9")
    plt.xlabel("Langevin Dynamic Step")
    plt.show()

In [None]:
e = []
eb = []
sim = []
fx = []
fx_dna = []
crit = []
best_sim = []
for x_ in tqdm_notebook(x):
    sim.append((x_.reshape(len(x_),21,-1).argmax(1) != walker.x0.argmax(1)).sum(-1))
    e.append((walker.model({"pi" : x_.reshape(len(x_),21,-1)[:,:,walker.kept_idx]})/736 - walker.Z)-a * sim[-1]-b)
    fx_dna.append(foldx_energy_with_dna(x_.reshape(-1,21,736), [deepcopy(structure_complex) for _ in x_]))
    crit.append(criterion(x_))
    sim[-1] = sim[-1]+0.1 * torch.randn(len(sim[-1]))
    eb.append((walker.model({"pi" : x_.reshape(len(x_),21,-1)[:,:,walker.kept_idx]})/736 - walker.Z))
torch.save(fx_dna, "all_fx_dna.pt")

In [None]:
fx_dna = torch.load("/home/malbranke/cas9/all_fx_dna.pt")

In [None]:
gamma = 0
x = []
TRACKS = []
folder = "/home/malbranke/data/cas9/vink"
a=b=0
for gamma in sorted(gammas)[:10]:
    model_full_name = f"rbmssl2_pid_h{Nh}_npam{Npam}_gamma{gamma}"

    model_rbm = PI_RBM_SSL(classifier, layers= {pi.name: pi, h.name: h}, edges=E, name = model_full_name)
    model_rbm = model_rbm.to(device)
    model_rbm.load(f"{folder}/weights/{model_full_name}_{best_epoch}.h5")
    model_rbm.eval()
    model_rbm.ais()
    model_rbm = model_rbm.to("cpu")

    target = litpam_to_pam("NGG").to(device)
    edge = model_rbm.edges["pi -> hidden"]
    model_rbm = model_rbm.to(device)

    idx0 = 0
    idx = (idx0*torch.ones(512).int()).to(device)

    criterion = classifier_criterion(classifier, edge, target)
    x__ = torch.cat([X_rbm[0][None] for i in idx],0)

    x_cas9 = torch.clone(x__[0].view(-1))
    h_cas9 = edge(x_cas9[None], False)


    sim_plage = [(5*i,5*i+5) for i in range(2,10)]
    with torch.no_grad():
        e0 = (model_rbm({"pi":x_cas9[None]})/736 - model_rbm.Z)[0].item()
    e_plage = [(e0-0.02, e0+0.02)]

    for s, e in product(sim_plage,e_plage):
        main = None
        criterions = [
                    SimCriterion(x_cas9, list(range(736)), postprocessing = LinearPostprocessor(100, 0,s[0],0,s[1])),
                    RbmCriterion(model_rbm, postprocessing = LinearPostprocessor(100, e0 ,e[0],e0,e[1])),
        ]
        T = 0.1*torch.ones(1,1,len(kept_idx))
        walker = Walker(x_cas9.view(21,-1).clone(), model_rbm, main, criterions, zero_idx, gamma=.1,
                        n=1, a = .3, c=1e-4, eps=.03, T = T, n_samples = 8, target = target.cpu())
        x.append(walker.run(16, n_epochs = 200, verbose=False))
        e = np.concatenate([track["e"][None] for track in walker.TRACKS])
        for e_ in e.T[:10]:
            plt.plot(e_)
        plt.show()
        sim = np.concatenate([track["abs_diff"][None] for track in walker.TRACKS])
        for sim_ in sim.T[:10]:
            plt.plot(sim_)
        plt.show()
        TRACKS.append((deepcopy(e.T[:10]), deepcopy(sim.T[:10])))
torch.save((x,TRACKS), f"tracks_0.pt")

In [None]:
e = []
eb = []
sim = []
fx = []
fx_dna_0 = []
crit = []
best_sim = []
for x_ in tqdm_notebook(x):
    sim.append((x_.reshape(len(x_),21,-1).argmax(1) != walker.x0.argmax(1)).sum(-1))
    e.append((walker.model({"pi" : x_.reshape(len(x_),21,-1)[:,:,walker.kept_idx]})/736 - walker.Z)-a * sim[-1]-b)
    fx_dna_0.append(foldx_energy_with_dna(x_.reshape(-1,21,736), [deepcopy(structure_complex) for _ in x_]))
    crit.append(criterion(x_))
    sim[-1] = sim[-1]+0.1 * torch.randn(len(sim[-1]))
    eb.append((walker.model({"pi" : x_.reshape(len(x_),21,-1)[:,:,walker.kept_idx]})/736 - walker.Z))
torch.save(fx_dna_0, "all_fx_dna_0.pt")

In [None]:
fx_dna_0 = torch.load("/home/malbranke/cas9/all_fx_dna_0.pt")

In [None]:
values = (torch.stack(fx_dna,0)<82).reshape(len(selected_gammas), len(sim_plage), -1).int().float()[:,:].mean((-1,-2))

In [None]:
from scipy.ndimage import *

In [None]:
selected_gammas_inv = {v:k for k,v in enumerate(selected_gammas)}

In [None]:
selected_gammas_inv[50.36340233637961], selected_gammas_inv[8.232271190763177], selected_gammas_inv[0.05669391237529596]

In [None]:
fx_dna_stacked = torch.stack(fx_dna,0).reshape(len(selected_gammas), -1)

In [None]:
plt.figure()

plt.subplot(131)
plt.hist(torch.stack(fx_dna_0,0).flatten(), bins = np.arange(65, 100, 5), density=True)
plt.ylim(0,0.06)
#plt.plot([torch.stack(fx_dna_0,0).flatten().mean(),torch.stack(fx_dna_0,0).flatten().mean()],[0,0.1])
plt.subplot(132)
plt.hist(fx_dna_stacked[120:140].flatten(), bins = np.arange(65, 100, 5), density = True)
plt.ylim(0, 0.06)
plt.yticks(np.arange(0,0.07,0.01), ["","","","","","",""])

plt.subplot(133)
plt.hist(fx_dna_stacked[240:250].flatten(), bins = np.arange(65, 100, 5), density = True)
plt.ylim(0, 0.06)
plt.yticks(np.arange(0,0.07,0.01), ["","","","","","",""])


In [None]:
n=20
plt.figure(figsize=(5,5))

thresholds = [70, 75, 80, 85]
colors = ["blue", "red", "green", "orange"]
legends = [f"fx_dna < {thr}" for thr in thresholds]
for thr, c in zip(thresholds, colors):
    values0 = (torch.stack(fx_dna_0,0)<thr).int().float()[:,:].mean()
    values = (torch.stack(fx_dna,0)<thr).reshape(len(selected_gammas), len(sim_plage), -1).int().float()[:,:].mean((-1,-2))
    fvalues = gaussian_filter1d(values[1:],n, mode="nearest")
    errors = np.array([np.std(values[i:i+n].numpy()) for i in range(1,len(values)-n)])
    plt.plot(selected_gammas[1:], gaussian_filter1d(values[1:],n,  mode="nearest"), c=c)
    #plt.plot([0,1000], [values0,values0], linestyle='dashed', c=c)
    plt.fill_between(selected_gammas[n//2+1:-n//2], fvalues[n//2:-n//2]-errors, fvalues[n//2:-n//2]+errors, color=c, alpha = 0.2)
    plt.xscale("log")

plt.plot([0,0], [values0,values0], linestyle='dashed', c="gray")

for thr, c in zip(thresholds, colors):
    values0 = (torch.stack(fx_dna_0,0)<thr).int().float()[:,:].mean()
    values = (torch.stack(fx_dna,0)<thr).reshape(len(selected_gammas), len(sim_plage), -1).int().float()[:,:].mean((-1,-2))
    fvalues = gaussian_filter1d(values[1:],n, mode="nearest")
    errors = np.array([np.std(values[i:i+n].numpy()) for i in range(1,len(values)-n)])
    #plt.plot(selected_gammas[1:], gaussian_filter1d(values[1:],n,  mode="nearest"), c=c)
    plt.plot([0,1000], [values0,values0], linestyle='dashed', c=c)
    #plt.fill_between(selected_gammas[n//2+1:-n//2], fvalues[n//2:-n//2]-errors, fvalues[n//2:-n//2]+errors, color="blue", alpha = 0.2)
    plt.xscale("log")
legends = [f"fx_dna < {thr}" for thr in thresholds]+["with standard RBM \n(gamma = 0)"]
plt.legend(legends, loc='center left', bbox_to_anchor=(1, 0.5))

plt.xlabel("Gamma (strength of the classifier)")
plt.ylabel("Share of sequences")
plt.xlim(0.05,50)
plt.title("Share of sequences with FoldX energy below thresholds \n(wild type FoldX energy: 67.43)", fontsize=12)

plt.ylim(0.,0.8)

In [None]:
from scipy.ndimage import *

n=20
values0 = (torch.stack(fx_dna_0,0)).float()[:,:].mean()
values = (torch.stack(fx_dna,0)).reshape(len(selected_gammas), len(sim_plage), -1).int().float()[:,:].mean((-1,-2))
fvalues = gaussian_filter1d(values[1:],n, mode="nearest")
errors = np.array([np.std(values[i:i+n].numpy()) for i in range(1,len(values)-n)])
plt.plot(selected_gammas[1:], gaussian_filter1d(values[1:],n,  mode="nearest"), c="blue")
plt.plot([0,1000], [values0,values0], c="black")

plt.fill_between(selected_gammas[n//2+1:-n//2], fvalues[n//2:-n//2]-errors, fvalues[n//2:-n//2]+errors, color="blue", alpha = 0.2)
plt.xscale("log")
plt.xlabel("Gamma (strength of the classifier)")
plt.ylabel("Mean FoldX energy with DNA")
#plt.ylim(80,95)

In [None]:
plt.plot(selected_gammas, (torch.stack(fx_dna,0)<82).reshape(len(selected_gammas), len(sim_plage), -1).int().float()[:,-5:].mean((-1,-2)))
plt.xscale("log")
#plt.xlim(1e-3,10)

In [None]:
plt.plot(selected_gammas, (torch.stack(fx_dna,0)<80).reshape(len(selected_gammas), len(sim_plage), -1).int().float()[:,-5:].mean((-1,-2)))
plt.xscale("log")
plt.xlim(1e-3,10)

In [None]:
labelled_crit = criterion(X_labelled.reshape(-1,21,736))

In [None]:
labelled_fx = foldx_energy(X_labelled.reshape(-1,21,736), [deepcopy(structure_alone) for _ in X_labelled])
labelled_fx_dna = foldx_energy_with_dna(X_labelled.reshape(-1,21,736), [deepcopy(structure_complex) for _ in X_labelled])
#labelled_fx_dna_AA = foldx_energy_with_dna(X_labelled.reshape(-1,21,736), [deepcopy(structure_complex_AA) for _ in X_labelled])
colors = ["red","darkorange","gold","green"]

def mcherry_threshold(x):
    if x < 0.2:
        return 0
    if x < 0.5:
        return 1
    if x < 0.8:
        return 2
    return 3
functionality = df[df.batch  >= 0]["mcherry"].apply(lambda x:colors[mcherry_threshold(x)]).values
functionality_number = df[df.batch  >= 0]["mcherry"].apply(lambda x:mcherry_threshold(x)).values


labelled_ebrbm = (walker.model({"pi" : X_labelled.reshape(len(X_labelled),21,-1)[:,:,walker.kept_idx]})/736 - walker.Z).detach()
labelled_diff = df[df.batch>=0]["sim_cas9"]

normfl = df[df.batch>=0]["mcherry"].values
labels = df[df.batch>=0]["functional"].apply(lambda x : int(x))

labelled_colors = np.array([colors[i] for i in labels])
#labelled_erbm = (labelled_ebrbm - clf.predict(labelled_diff[:,None]))
#e = torch.cat(eb,0).detach() - clf.predict(torch.cat(sim,0).detach()[:,None])

In [None]:
from sklearn.linear_model import *
from scipy.stats import *


with torch.no_grad():
    #y = clf.predict([[0],[200]])
    print(spearmanr(normfl,labelled_ebrbm))
    plt.figure(figsize = (7,7))
    plt.scatter(torch.cat(sim,0).detach(),-torch.cat(eb,0).detach(), color = "lightgray",s=7)
    plt.scatter(labelled_diff, -labelled_ebrbm, color=labelled_colors)
   # plt.plot([0,200],y, c="black")
    plt.scatter([0], (walker.model({"pi" : walker.x0.reshape(1,21,-1)[:,:,walker.kept_idx]})/736 - walker.Z), marker="+", color = "black", s = 160)
   # plt.title(f"Spearman = {spearmanr(normfl,labelled_ebrbm)[0]:.3f}")
    plt.xlim(-1,100)
    plt.ylim(0.14,0.26)
    plt.xticks(size = 14, rotation = 45)
    plt.yticks(size = 14)
    plt.xlabel("Distance to SpyCas9 PID", size = 16, )
    plt.ylabel("RBM", size = 16, )
    #plt.yscale("log")

In [None]:
e0 = (walker.model({"pi" : walker.x0.reshape(1,21,-1)[:,:,walker.kept_idx]})/736 - walker.Z)-a * 0-b
eb0 = (walker.model({"pi" : walker.x0.reshape(1,21,-1)[:,:,walker.kept_idx]})/736 - walker.Z)

In [None]:
e0 = (walker.model({"pi" : walker.x0.reshape(1,21,-1)[:,:,walker.kept_idx]})/736 - walker.Z)-a * 0-b
eb0 = (walker.model({"pi" : walker.x0.reshape(1,21,-1)[:,:,walker.kept_idx]})/736 - walker.Z)

sim0 = 0
structure_alone = Structure(code="model", path = f"/home/malbranke/data/foldx/cas9_pid.pdb")
fx0 = foldx_energy_with_dna(walker.x0.reshape(-1,21,736), [deepcopy(structure_alone)])

In [None]:
from random import random
from tqdm._tqdm_notebook import tqdm_notebook
import pandas as pd 

tqdm_notebook.pandas()

freq_df = pd.read_excel(f"{DATA}/cas9/codon_frequency_ecoli.xlsx")
freq_dict = {k:[] for k in AA}
codon_dict = {k:[] for k in AA}

for aa, codon, freq in zip(freq_df.aa, freq_df.codon, freq_df.frequency):
    if aa not in AA:
        continue
    if freq > 0.1:
        freq_dict[aa].append(float(freq))
        codon_dict[aa].append(codon)
for aa in AA:
    freq_dict[aa] = np.array(freq_dict[aa])/sum(freq_dict[aa])
    freq_dict[aa] = np.cumsum(freq_dict[aa])

start_seq = "AAACACGTGGCAAACATTCCttGGTCTCtAAAAGACTGAGGTACAG"
end_seq = "CGTATTGATCTGAGTCAGTTGGGCGGTGACTAATAAGGAGAGACCAAGTACTGTATGGCTCCGGTTT"
forbiddens = ["GGTCTC","GAGACC","CGTCTC","GAGACG"]

def aa_to_dna(x):
    dna_seq = start_seq
    for i, aa in enumerate(x):
        n_try = 0
        while True:
            if n_try > 5:
                print("try again")
                return aa_to_dna(x)
            u = random()
            codon_idx = (u > freq_dict[aa]).sum()
            codon = codon_dict[aa][codon_idx]
            stop = True
            for forbidden in forbiddens:
                if (i+1) == len(x):
                    if forbidden in (dna_seq[-5:]+codon+end_seq[:5]):
                        stop = False
                else:
                    if forbidden in (dna_seq[-5:]+codon):
                        stop = False
            if not stop:
                n_try +=1
                continue
            dna_seq+=codon
            break
    dna_seq += end_seq
    return dna_seq

In [None]:
with torch.no_grad():
    df["seq"] = ["".join(AA[i-1] for i in x_) for x_ in torch.cat(x,0).view(-1, 21,N)[:,:,nnz_idx].argmax(1).detach().numpy()]
    df["dna_seq"] = df["seq"].apply(lambda x : aa_to_dna(x))
    df["sim_cas9"] = torch.cat(sim,0).numpy()
    df["best_sim"] = torch.cat(best_sim,0).numpy()
    df["e_rbm"] = torch.cat(eb,0).numpy()
    df["e_rbm_unbiaised"] = e.numpy()
    df["fx"] = torch.cat(fx,0).numpy()
    df["fx_dna"] = torch.cat(fx_dna,0).numpy()
    df["fx_dna_AA"] = torch.cat(fx_dna_AA,0).numpy()

In [None]:
import importlib
importlib.reload(sys)

In [None]:
df_final.to_excel("batch3.xlsx")

In [None]:
plt.figure(figsize = (7,7))
with torch.no_grad():

    idx = torch.where(torch.cat(sim,0)>20)[0]
    plt.scatter(-torch.cat(eb,0)[idx],torch.cat(fx_dna,0).detach()[idx], color = "lightgray",s=7)
    plt.scatter([-x for x in labelled_ebrbm[np.where(labelled_diff > 0)]],labelled_fx_dna[np.where(labelled_diff > 0)], color=labelled_colors[np.where(labelled_diff > 0)], s=25)
    #plt.plot([-0.5,0.1],[85,85], "--", color = "red")
    #plt.plot([-0.22,-0.22],[10,500], "--", color = "green")

    plt.xlim(0.14,0.26)
    plt.ylim(55,200)
    plt.xticks(size = 14, rotation = 45)
    plt.yticks(size = 14)
    plt.xlabel("RBM score", size = 16, )
    plt.ylabel("Fx energy with DNA", size = 16, )
    #plt.yscale("log")


In [None]:
plt.figure(figsize = (10,7))
with torch.no_grad():

    idx = torch.where(torch.cat(sim,0)>0)[0]
    plt.scatter(-torch.cat(eb,0)[idx],torch.cat(crit,0).detach()[idx].exp(), color = "lightgray",s=7)
    plt.scatter([-x for x in labelled_ebrbm[np.where(labelled_diff > 0)]],np.exp(labelled_crit[np.where(labelled_diff > 0)]), color=labelled_colors[np.where(labelled_diff > 0)], s=25)
    #plt.plot([-0.5,0.1],[85,85], "--", color = "red")
    #plt.plot([-0.22,-0.22],[10,500], "--", color = "green")

    plt.xlim(0.15,0.3)
   # plt.ylim(65,200)
    plt.xticks(size = 14, rotation = 45)
    plt.yticks(size = 14)
    plt.xlabel("RBM score", size = 16, )
    plt.ylabel("Crit", size = 16, )
    #plt.yscale("log")


In [None]:
plt.figure(figsize = (15,22))
with torch.no_grad():
    plt.subplot(321)
    #y = clf.predict([[0],[200]])
    plt.scatter(torch.cat(sim,0).detach(),torch.cat(eb,0), color = "lightgray",s=7)
    plt.scatter(labelled_diff, labelled_ebrbm, color=labelled_colors)
    #plt.scatter(df_final.sim_cas9, df_final.e_rbm, color = "blue")
    plt.scatter([0], (walker.model({"pi" : walker.x0.reshape(1,21,-1)[:,:,walker.kept_idx]})/736 - walker.Z), marker="+", color = "black", s = 160)
    plt.title(f"Spearman = {spearmanr(normfl,labelled_erbm)[0]:.3f}")
    plt.xlim(-1,100)
    plt.ylim(-0.3,0)
    plt.xticks(size = 14, rotation = 45)
    plt.yticks(size = 14)
    plt.xlabel("Sim", size = 16, )
    plt.ylabel("RBM", size = 16, )
    
    plt.subplot(322)
    idx = torch.where(torch.cat(sim,0)>0)[0]
    plt.scatter(torch.cat(fx,0).detach()[idx]-80,torch.cat(fx_dna,0).detach()[idx], color = "lightgray",s=10)
    plt.scatter(labelled_fx[np.where(labelled_diff > 0)]-80, labelled_fx_dna[np.where(labelled_diff > 0)], color=labelled_colors[np.where(labelled_diff > 0)], s=15)

    #plt.scatter(df_final.fx-80, df_final.fx_dna, color = "blue")
    plt.plot([0,500],[50,50], "--", color = "red")
    plt.plot([45,45],[0,500], "--", color = "green")

    plt.xlim(9,500)
    plt.ylim(9,500)
    plt.xticks(size = 14, rotation = 45)
    plt.yticks(size = 14)
    plt.xlabel("Fx energy", size = 16, )

    plt.ylabel("Fx energy with DNA", size = 16, )
    plt.yscale("log")
    plt.xscale("log")



    plt.subplot(323)
    idx = torch.where(torch.cat(sim,0)>0)[0]
    plt.scatter(torch.cat(eb,0)[idx],torch.cat(fx,0).detach()[idx]-80, color = "lightgray",s=10)
    plt.scatter(labelled_ebrbm[np.where(labelled_diff > 0)],labelled_fx[np.where(labelled_diff > 0)]-80, color=labelled_colors[np.where(labelled_diff > 0)], s=15)
    #plt.scatter(df_final.e_rbm_unbiaised, df_final.fx-80, color = "blue")

    #plt.scatter([eb0-clf.predict([[0]])[0]], [fx0-80], marker="+", color = "black", s = 160)
    plt.plot([-0.5,0.1],[45,45], "--", color = "red")
    plt.plot([-0.15,-0.15],[10,500], "--", color = "green")

    plt.xlim(-0.3,0.)
    plt.ylim(9,500)
    plt.xticks(size = 14, rotation = 45)
    plt.yticks(size = 14)
    plt.xlabel("Unbiased RBM", size = 16, )
    plt.ylabel("Fx energy", size = 16, )
    plt.yscale("log")
    
    plt.subplot(324)
    idx = torch.where(torch.cat(sim,0)>0)[0]
    plt.scatter(torch.cat(eb,0)[idx],torch.cat(fx_dna,0).detach()[idx], color = "lightgray",s=10)
    plt.scatter(labelled_ebrbm[np.where(labelled_diff > 0)],labelled_fx_dna[np.where(labelled_diff > 0)], color=labelled_colors[np.where(labelled_diff > 0)], s=15)
    plt.plot([-0.5,0.1],[50,50], "--", color = "red")
    plt.plot([-0.15,-0.15],[10,500], "--", color = "green")

    plt.xlim(-0.3,0.)
    plt.ylim(9,500)
    plt.xticks(size = 14, rotation = 45)
    plt.yticks(size = 14)
    plt.xlabel("Unbiased RBM", size = 16, )
    plt.ylabel("Fx energy with DNA", size = 16, )
    plt.yscale("log")
    
    plt.subplot(325)
    idx = torch.where(torch.cat(sim,0)>0)[0]
    plt.scatter(torch.cat(eb,0)[idx],torch.cat(fx_dna,0).detach()[idx], color = "lightgray",s=10)
    plt.scatter(labelled_ebrbm[np.where(labelled_diff > 0)],labelled_fx_dna[np.where(labelled_diff > 0)], color=labelled_colors[np.where(labelled_diff > 0)], s=15)

    plt.plot([-0.5,0.1],[50,50], "--", color = "red")
    plt.plot([-0.15,-0.15],[10,500], "--", color = "green")

    plt.xlim(-0.3,0.)
    plt.ylim(9,500)
    plt.xticks(size = 14, rotation = 45)
    plt.yticks(size = 14)
    plt.xlabel("Unbiased RBM", size = 16, )
    plt.ylabel("Fx energy with DNA", size = 16, )
    plt.yscale("log")


In [None]:
import subprocess
import multiprocessing
import numpy as np
import pyfoldx as foldx
from pyfoldx.structure import Structure

from copy import deepcopy


W = edge.get_weights()[None]
from torch.distributions.one_hot_categorical import OneHotCategorical
NAd_in = {"A":"A", "T":"T", "C":"C", "G":"G",
          "W":"AT", "S":"CG", "M":"AC", "K":"TG", "R":"AG", "Y":"TC",
           "B":"TCG", "D":"ATG", "H":"ATC", "V": "ACG", "N":"ATCG"}
NAd = ["O","A","T","W","C","M","Y","H","G","R","K","D","S","V","B","N"]

pl3to1 = {'CYS': 'C', 'ASP': 'D', 'SER': 'S', 'GLN': 'Q', 'LYS': 'K',
     'ILE': 'I', 'PRO': 'P', 'THR': 'T', 'PHE': 'F', 'ASN': 'N', 
     'GLY': 'G', 'HIS': 'H', 'LEU': 'L', 'ARG': 'R', 'TRP': 'W', 
     'ALA': 'A', 'VAL':'V', 'GLU': 'E', 'TYR': 'Y', 'MET': 'M'}
pl1to3 = {v:k for k,v in pl3to1.items()}


def aux(args):
    seq, struct, repair = args
    chain = "A"
    do_repair, idx = repair
    s =  struct
    s_new =  struct
    position = 1
    for target in seq[:]:
        try:
            if position in s.data[chain].keys():
                res = s.data[chain][position]
                res.code = pl1to3[target]
                s_new.data[chain][position] = res
        except:
            ()
        position += 1
    return float(s_new.getTotalEnergy().loc["model"]["total"])

def foldx_energy(x_sampled, structs, repair=False, ):
    seqs = []
    repairs = []
    for i, x in enumerate(x_sampled):
        seqs.append([AA[x_] for x_ in x[1:,nnz_idx].cpu().argmax(0)])
        repairs.append((repair,i))
    with multiprocessing.Pool(processes = 32) as pool:
        energies = pool.map(aux, list(zip(seqs, structs, repairs,)))
    #energies = [aux(seq, struct, repair) for seq, struct, repair in zip(seqs, structs, repairs)]
    return torch.tensor(energies)

def aux_with_dna(args):
    seq, struct, repair, = args
    do_repair, idx = repair
    s =  struct
    s_new =  struct
    position = 1103
    chain = "B"
    for target in seq[1:]:
        try:
            if position in s.data[chain].keys():
                res = s.data[chain][position]
                res.code = pl1to3[target]
                s_new.data[chain][position] = res
        except:
            ()
        position += 1
    return float(s_new.getTotalEnergy().loc["model"]["total"])

def foldx_energy_with_dna(x_sampled, structs, repair=False, ):
    seqs = []
    repairs = []
    for i, x in enumerate(x_sampled):
        seqs.append([AA[x_] for x_ in x[1:,nnz_idx].cpu().argmax(0)])
        repairs.append((repair,i))
    with multiprocessing.Pool(processes = 32) as pool:
        energies = pool.map(aux_with_dna, list(zip(seqs, structs, repairs,)))
    return torch.tensor(energies)

def aux_with_dna_interface(args):
    seq, struct, repair, = args
    do_repair, idx = repair
    s =  struct
    s_new =  struct
    position = 1
    chain = "A"
    for target in seq[:]:
        try:
            if position in s.data[chain].keys():
                res = s.data[chain][position]
                res.code = pl1to3[target]
                s_new.data[chain][position] = res
        except:
            ()
        position += 1
    if do_repair:
        #s_new.repair()
        structures[idx] = deepcopy(s_new)
    return float(s_new.getInterfaceEnergy(verbose=False)["Interaction Energy"].loc["B"].loc["D"])

def interface_energy_with_dna(x_sampled, structs, repair=False, ):
    seqs = []
    repairs = []
    for i, x in enumerate(x_sampled):
        seqs.append([AA[x_] for x_ in x[1:,nnz_idx].cpu().argmax(0)])
        repairs.append((repair,i))
    with multiprocessing.Pool(processes = 32) as pool:
        structs = [deepcopy(s) for s in structs]
        energies = pool.map(aux_with_dna_interface, list(zip(seqs, structs, repairs,)))
    #energies = [aux(seq, struct, repair) for seq, struct, repair in zip(seqs, structs, repairs)]
    return torch.tensor(energies)

def abs_diff(x, x0):
    x = x.reshape(x.size(0),21,-1)
    x0 = x0.reshape(x0.size(0),21,-1)

    return (x[:,:,nnz_idx].argmax(1) != x0[:,:,nnz_idx].argmax(1)).int().float().sum(-1)

def mean_diff(x, x0):
    x = x.reshape(x.size(0),21,-1)
    x0 = x0.reshape(x0.size(0),21,-1)

    return (x[:,:,nnz_idx].argmax(1) != x0[:,:,nnz_idx].argmax(1)).int().float().mean(-1)

def isd(x):
    x = x.reshape(x.size(0),21,-1)
    return (x[:,0,nnz_idx].mean(-1)) + (1-(x[:,0,zero_idx].mean(-1)))

def classifier_criterion(classifier, edge, target):
    target = target
    def crit(x):
        p = classifier(edge(x, False)).sigmoid()
        #return torch.ones(p.size(0))
        #return (target * (p+1e-7).log()).mean(-1)
        return (target * (p+1e-7).log() + (1-target) * (1-p+1e-7).log()).mean(-1)
    return crit

def litpam_to_pam(s):
    pam = []
    s += "N"*max(0,(Npam-len(s)))
    for x in s:
        pam += NAd_idx[x]
    return torch.tensor(pam).float()[None].to(device)

def get_direction(model_rbm, edge, target, x, h, a, c, T=1e-1, step = 0):
    samples = 16
    x0 = x[None].expand(samples,-1,-1).reshape(samples*x.size(0), 21, -1)
    mut = edge.reverse(h[None].expand(samples,-1,-1).reshape(-1,Nh))
    mut = mut.reshape(mut.size(0),21,-1)

    mut[:,0] = -10000
    mut[:,0,zero_idx] = 10000

    phi = (mut + pi.linear.weight.view(1, pi.q, pi.N))/(1e-1)
    distribution = OneHotCategorical(probs=F.softmax(phi, 1).permute(0,2,1))
    x_sampled = distribution.sample().permute(0, 2, 1)
    x_sampled[:,:,zero_idx] = x_sampled.reshape(x_sampled.size(0), 21, -1)[:,:,zero_idx].detach()
    h_sampled = edge(x_sampled, False).reshape(samples, -1, Nh).permute(1,0,2)
    
    Z = model_rbm.Z.cpu().item()
    N = pi.N
    hs = [h_.detach() for h_ in h_sampled]
    for h_ in hs:
        h_.requires_grad = True
    p = [classifier(h_).sigmoid() for h_ in hs]
    crits = [((target * (p_+1e-7).log() + (1-target) * (1-p_+1e-7).log()).mean(-1)).mean(0) for p_ in p]
    #crits = [((target * (p_+1e-7).log()).mean(-1)).mean(0) for p_ in p]

    TRACK["|h|"].append(h.detach().pow(2).sum(-1).sqrt()/h.size(-1))
    [crit_.backward() for crit_ in crits]
    gradient_C = torch.stack([h_.grad.mean(0) for h_ in hs],0)
    with torch.no_grad():
        structures = [deepcopy(struct_cas9) for _ in range(x.size(0)*samples)]
        structures_with_dna = [deepcopy(struct_cas9_with_dna) for _ in range(x.size(0)*samples)]
        structures_with_dna_AA = [deepcopy(struct_cas9_with_dna_AA) for i in range(x.size(0)*samples)]
        Wv = edge(x_sampled,False).reshape(samples,x.size(0),-1)
        Wv_ = Wv.mean(0)

        # RBM Energy (Not too far from natural)
        pv = (model_rbm({"pi":x_sampled}).reshape(samples,x.size(0),-1)/N - Z)
        phi1 = F.relu(min_edca - pv).abs()
        phi1_ = phi1.mean(0)
        gradient_phi1 = ((Wv-Wv_)*(phi1-phi1_)).mean(0)
        TRACK["E(log P(x))"].append(pv.mean(0))
        
        # Slow down
        sim = abs_diff(x_sampled, x0).reshape(samples,x.size(0),-1)
        phi2 = F.relu(sim - 5).pow(2)
        phi2_ = phi2.mean(0)
        gradient_phi2 = ((Wv-Wv_)*(phi2-phi2_)).mean(0)
        TRACK["Diff"].append(sim.mean(0))
        
        # Sim to cas9
        diff_cas9 = abs_diff(x_sampled, x_cas9[None].expand(x_sampled.size(0),-1)).reshape(samples,x.size(0),-1)
        target = 20
        phi3 = F.relu(diff_cas9-55) + F.relu(45-diff_cas9)
        phi3_ = phi3.mean(0)
        gradient_phi3 = ((Wv-Wv_)*(phi3-phi3_)).mean(0)
        TRACK["Diff cas9"].append(diff_cas9.mean(0))
        
        # FoldX energy
        structs = [deepcopy(structures[i//samples]) for i in range(x.size(0)*samples)]
        fx = foldx_energy(x_sampled, structures).to(device).reshape(samples,x.size(0), -1)
        min_fxe= 150-30 * 4 * (0.5-step)**2
        phi4 = (fx - 116).pow(2)
        phi4_ = phi4.mean(0)
        gradient_phi4 = ((Wv-Wv_)*(phi4-phi4_)).mean(0)
        TRACK["foldx"].append(fx.mean(0))
        
        # FoldX energy in complex AA
        fx_dna_AA = foldx_energy_with_dna(x_sampled, structures_with_dna_AA).to(device).reshape(samples,x.size(0), -1)
        phi5 = -fx_dna_AA
        phi5_ = phi5.mean(0)
        gradient_phi5 = ((Wv-Wv_)*(phi5-phi5_)).mean(0)
        
        # FoldX energy in complex
        min_fxe_dna = 40
        structs_dna = [deepcopy(structures_with_dna[i//samples]) for i in range(x.size(0)*samples)]
        fx_dna = foldx_energy_with_dna(x_sampled, structures_with_dna).to(device).reshape(samples,x.size(0), -1)
        phi6 = -fx_dna
        phi6_ = phi6.mean(0)
        gradient_phi6 = ((Wv-Wv_)*(phi6-phi6_)).mean(0)
        
        gradient_C += 3*torch.randn_like(h) - c*h + 3*gradient_phi5 - 3*gradient_phi6 
        phi = 0*phi1.mean(0)[:,0] + 0*phi3.mean(0)[:,0] + 0*phi2.mean(0)[:,0] + 0*phi4.mean(0)[:,0] + 0* phi6.mean(0)[:,0]
        gradient_phi = 0*gradient_phi1 + 0*gradient_phi3 + 0*gradient_phi2 + 0*gradient_phi4 + 0*gradient_phi6

        norm_phi2 = gradient_phi.pow(2).sum(-1).detach()
        angle_C = (gradient_phi*gradient_C).sum(-1)
        angle_h = (gradient_phi*h).sum(-1)
        norm_GC = gradient_C.pow(2).sum(-1).sqrt()
        a_prime = a
        diff = phi/a_prime
       # diff = F.relu(diff) - 0 * F.relu(-diff) * torch.rand_like(diff)
        bt = diff+angle_C
        idx = torch.where(norm_phi2 > 0)
        bt[idx] = (bt[idx]/norm_phi2[idx]).clip(-100,100)
    return bt[None], gradient_C[None], gradient_phi[None]

def step_hidden(edge, classifier, x, h, a=10, c = 0.00001, T=1e-1, n = 1, step = 0):

    bt, gradient_C, gradient_phi = get_direction(model_rbm, edge, target, x, h.detach(), a, c)
    norm_GC = gradient_C.pow(2).sqrt().detach()
    h = h[None].expand(n,*h.size())
    n1 = 1
    h = (h + a*n1*(gradient_C-c*h - bt[:,:,None]*gradient_phi)).view(h.size(0)*h.size(1), -1) 
    #print(((gradient_C+bt[:,None]*gradient_phi)-c*h).view(h.size(0), -1))
    mut = edge.reverse(h).reshape(h.size(0),21,-1)
    mut[:,0] = -10000
    mut[:,0,zero_idx] = 10000
    phi = (mut + pi.linear.weight.view(1, pi.q, pi.N))/T
    distribution = OneHotCategorical(probs=F.softmax(phi, 1).permute(0, 2, 1))
    x_emitted = distribution.sample().permute(0, 2, 1)
    return h, x_emitted, gradient_C, gradient_phi

def sampling_through_criterion_and_phi(model, x_0, criterion, classifier, target = 5, T = 1e-5, n_sampling = 500, verbose = 1):
    edge = model.edges["pi -> hidden"]
    batch_size, q, N = x_0.size()
    x_0 = x_0.view(batch_size, -1).float()
    Z = model.Z.cpu().item()
    state_e = model({"pi": x_0.float().to(model.device)}).detach().cpu()/N - Z
    state_h = edge(x_0, False).detach()
    state_x = x_0.clone()
    state_p = criterion(x_0.clone())
    state_diff = torch.ones(len(x_0))*45
    state_dynamic = torch.ones(batch_size)
    structures = [Structure(code="alone", path = f"/home/malbranke/data/foldx/cas9_pid.pdb") for _ in range(batch_size)]
    structures_with_dna_AA = [deepcopy(struct_cas9_with_dna_AA) for _ in range(batch_size)]
    structures_with_dna = [deepcopy(struct_cas9_with_dna) for _ in range(batch_size)]

    state_xs = []
    n_mut = 0
    a = .3
    TRACK["chains"] = [state_x]
    n = 4
    for i in range(1,n_sampling):
        h, x_emitted, gradient_C, gradient_phi = step_hidden(edge, classifier, state_x, state_h, a, n=n, step = i/n_sampling)
        p = criterion(x_emitted.clone())
        x_emitted = x_emitted.float().cpu().reshape(-1, q, N)
        x_emitted = x_emitted.reshape(x_emitted.size(0), -1)
        TRACK["chains"].append(x_emitted)
        TRACK["C"].append(p.exp().detach())

        e = (model({"pi" : x_emitted.float().clone()})/N - Z).detach().cpu()
        structures = [deepcopy(struct_cas9) for _ in x_emitted]
        structures_with_dna = [deepcopy(struct_cas9_with_dna) for _ in x_emitted]
        structures_with_dna_AA = [deepcopy(struct_cas9_with_dna_AA) for i in x_emitted]
        
        fx = foldx_energy(x_emitted.reshape(-1, q, N), structures, repair = True)
        fx_dna = foldx_energy_with_dna(x_emitted.reshape(-1, q, N), structures_with_dna)
        fx_dna_AA = foldx_energy_with_dna(x_emitted.reshape(-1, q, N), structures_with_dna_AA)
        
        min_fxe_dna = 45
        min_fxe = 125
        idx = torch.where(~((fx<min_fxe) & (fx_dna_AA<min_fxe_dna)))[0]
        
        diffs = fx_dna_AA-fx_dna
        diffs[idx] = 1000
        diffs = torch.cat([state_diff[:,None], diffs.reshape(-1,n)],-1)
        max_idx = diffs.argmin(1)
        
        state_dynamic = (max_idx > 0).int() + 0.8 * state_dynamic
        changed_idx = torch.where((max_idx > 0))[0]
        kept_idx = torch.where(max_idx == 0)[0]
        killed_idx = torch.where(torch.rand(batch_size) > state_dynamic)[0]
        
        probs = state_dynamic.cumsum(0)/state_dynamic.sum()
        replaced_idx = torch.tensor([(probs < random.random()).sum().int().item() for _ in killed_idx])
        if len(changed_idx):
            state_x[changed_idx] = x_emitted[n * changed_idx + max_idx[changed_idx] - 1]
            state_h[changed_idx] = h[n * changed_idx + max_idx[changed_idx] - 1]
            state_p[changed_idx] = p[n * changed_idx + max_idx[changed_idx] - 1]
            state_e[changed_idx] = e[n * changed_idx + max_idx[changed_idx] - 1]
            state_diff[changed_idx] = diffs[changed_idx, max_idx[changed_idx]]
        if len(killed_idx):
            state_x[killed_idx] = x_emitted[n * replaced_idx + max_idx[replaced_idx] - 1]
            state_h[killed_idx] = h[n * replaced_idx + max_idx[replaced_idx] - 1]
            state_p[killed_idx] = p[n * replaced_idx + max_idx[replaced_idx] - 1]
            state_e[killed_idx] = e[n * replaced_idx + max_idx[replaced_idx] - 1]
            state_diff[killed_idx] = diffs[replaced_idx, max_idx[replaced_idx]]
            state_dynamic[killed_idx] = state_dynamic[replaced_idx]

        structures = [deepcopy(struct_cas9) for _ in range(batch_size)]
        structures_with_dna = [deepcopy(struct_cas9_with_dna) for _ in range(batch_size)]
        structures_with_dna_AA = [deepcopy(struct_cas9_with_dna_AA) for _ in range(batch_size)]

        n_mut+=len(changed_idx)
        norm_GC, norm_Gphi, angle = gradient_C.pow(2).sum(-1).sqrt().detach(), gradient_phi.pow(2).sum(-1).sqrt().detach(), (gradient_C*gradient_phi).sum(-1).detach()
        
        fx = foldx_energy(state_x.reshape(-1, q, N), structures)
        fx_dna = foldx_energy_with_dna(state_x.reshape(-1, q, N), structures_with_dna, repair = True)
        fx_dna_AA = foldx_energy_with_dna(state_x.reshape(-1, q, N), structures_with_dna_AA, repair = True)

        TRACK["foldx"].append(fx)
        TRACK["foldx_dna"].append(fx_dna)
        TRACK["foldx_dna_AA"].append(fx_dna_AA)
        TRACK["x"].append(state_x.clone())

        if verbose:
            print(f"""{n_mut}/{batch_size*i} [{(100*n_mut)/(batch_size*i):.2f}%] 
            || Class : {state_p.exp().mean().cpu().item():.3f} 
            || E = {state_e.mean():.3f} 
            || Dynamic : {list(state_dynamic.detach().numpy())}
            || Diff : {list(state_diff.detach().numpy())}
            || Foldx = {fx.mean():.3f} 
            || Foldx DNA = {fx_dna.mean():.3f} 
            || Foldx DNA AA = {fx_dna_AA.mean():.3f} 
            || angle = {(angle/(norm_GC*norm_Gphi)).mean():.3f} 
            || h = {(state_h-h_cas9).pow(2).sum(-1).sqrt().mean():.3f}""")  
    return state_x, state_h, state_xs

def display_dynamic(TRACK, phi, n = 8):
    pphi1, pphi2, pphi3 = phi
    idx = list(range(n))
    plt.figure(figsize =(20,5))

    plt.subplot(141)
    plt.plot([0,len(TRACK['E(log P(x))'])],[pphi1-0.05,pphi1-0.05], c="black")
    plt.plot([0,len(TRACK['E(log P(x))'])],[pphi1+0.05,pphi1+0.05], c="black")
    for y in torch.cat(TRACK['E(log P(x))'],-1)[idx]:
        plt.plot(y)
    plt.title('log P(x)')

    plt.subplot(142)
    plt.plot([0,len(TRACK['E(log P(x))'])],[pphi2-0.1,pphi2-0.1], c="black")
    plt.plot([0,len(TRACK['E(log P(x))'])],[pphi2+0.1,pphi2+0.1], c="black")

    for y in torch.cat(TRACK['Sim'],-1)[idx]:
        plt.plot(y)
    plt.title('Hamming Similarity')

    plt.subplot(143)
    plt.plot([0,len(TRACK['E(log P(x))'])],[pphi3-0.05,pphi3-0.05], c="black")
    plt.plot([0,len(TRACK['E(log P(x))'])],[pphi3+0.05,pphi3+0.05], c="black")
    for y in torch.cat(TRACK['SQA'][1:],-1)[idx]:
        plt.plot(y)
    plt.title('SQA')

    plt.subplot(144)
    for y in torch.stack(TRACK['C'],-1)[idx]:
        plt.plot(y)
    plt.title('Classifier')

    plt.show()
    
def find_closest_sequence(x):
    x = x.view(len(x), 21, -1)
    distance = (x[:,:,nnz_idx].argmax(-2)[:,None] == existing_sequences[:,:,nnz_idx].argmax(-2)[None]).int().float().mean(-1).max(1)[0]
    return distance