In [None]:
import sys, os, warnings
sys.path.append(os.path.dirname(os.getcwd()))
warnings.filterwarnings("ignore")

from tqdm import tqdm_notebook
from torchpgm.model import *
from torchpgm.layers import *

from cld.postprocessing import *
from cld.criterion import *
from cld.walker import *

from utils import *
from config import *

from itertools import product

import seaborn as sns

sns.set_style("whitegrid")

plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300

In [None]:
device = "cuda"
folder = f"{DATA}/vink"
Nh, Npam = 200, 5
best_epoch = 90
q_pi, N_pi = 21, 736
model_full_name = f"rbmssl_pid_h{Nh}_npam{Npam}_gamma5.306595410288844"


def lit_to_pam(s):
    pam = []
    s += "N" * max(0, (Npam - len(s)))
    for x in s:
        pam += NAd_idx[x]
    return torch.tensor(pam).float()[None].to(device)

In [None]:
selected_gammas = sorted(gammas)[260:542]

## Demo of the Constrained Langevin Dynamics

We first the model

In [None]:
pi = OneHotLayer(None, N=N_pi, q=q_pi, name="pi")
h = GaussianLayer(N=Nh, name="hidden")
classifier = PAM_classifier(Nh, Npam * 4)
E = [(pi.name, h.name)]
E.sort()

model_rbm = PI_RBM_SSL(classifier, layers={pi.name: pi, h.name: h}, edges=E, name=model_full_name)
model_rbm = model_rbm.to(device)
model_rbm.load(f"{folder}/weights/{model_full_name}_{best_epoch}.h5")
model_rbm.eval()
model_rbm = model_rbm.to("cpu")
model_rbm.ais()

In [None]:
x_cas9 = torch.load(f"{DATA}/x_cas9.pt")
zero_idx = torch.load(f"{DATA}/zero_idx.pt")
kept_idx = list(range(736))
target = lit_to_pam("NGG")
x_cas9.view(-1,21).shape

We then define the constraints and the objective of the walk and set up the walker

In [None]:
with torch.no_grad():
    e0 = (model_rbm({"pi":x_cas9[None]})/736 - model_rbm.Z)[0].item()
    
e = (e0-0.01, e0+0.01)
s = (30, 35)
T = 0.1*torch.ones(1,1,len(kept_idx))

objective = RbmCriterion(model_rbm, postprocessing=ConstantPostprocessor(0))
constraints = [
            SimCriterion(x_cas9, list(range(736)), postprocessing = LinearPostprocessor(100, 0, s[0], 0, s[1])),
            RbmCriterion(model_rbm, postprocessing = LinearPostprocessor(100, e0, e[0], e0, e[1])),
]
weight_constraints = [10,1000]

walker = Walker(x_cas9.view(21, -1).clone(), model_rbm, objective, constraints, zero_idx, gamma=.1, n=1, a=1,
        c=1e-2, eps=1, target=target.cpu(), T=T, weight_constraints = weight_constraints)

We plot some walks

In [None]:
with torch.no_grad():
    e0 = (model_rbm({"pi":x_cas9[None]})/736 - model_rbm.Z)[0].item()
e_plage = [(e0-0.01, e0+0.01), (e0-0.03, e0-0.02)]    
sim_plage = [(30, 35),(50, 55)]
x = []
TRACKS = []
for s, e in zip(sim_plage,e_plage):
    objective = RbmCriterion(model_rbm, postprocessing=ConstantPostprocessor(0))
    constraints = [
                SimCriterion(x_cas9, list(range(736)), postprocessing = LinearPostprocessor(100, 0,s[0],0,s[1])),
                RbmCriterion(model_rbm, postprocessing = LinearPostprocessor(100, e0 ,e[0],e0,e[1])),
    ]
    weight_constraints = [10,1000]

    walker = Walker(x_cas9.view(21, -1).clone(), model_rbm, objective, constraints, zero_idx, gamma=.1, n=1, a=1,
            c=1e-2, eps=1, target=target.cpu(), T=T, weight_constraints = weight_constraints)

    x.append(walker.run(16, n_epochs = 200, verbose=False))
    e = np.concatenate([track["e"][None] for track in walker.TRACKS])
    for e_ in e.T[:10]:
        plt.plot(e_)
    plt.show()
    sim = np.concatenate([track["abs_diff"][None] for track in walker.TRACKS])
    for sim_ in sim.T[:10]:
        plt.plot(sim_)
    plt.show()
    TRACKS.append((deepcopy(e.T[:10]), deepcopy(sim.T[:10])))

## Generative capacities given $\gamma$

We take a look at different values of $\gamma$ and plot the generated sequences

In [None]:
sim_plage = [(5*i,5*i+5) for i in range(2,10)]
for gamma in selected_gammas:
    x = []
    TRACKS = []
    model_full_name = f"rbmssl_pid_h{Nh}_npam{Npam}_gamma{gamma}"

    model_rbm = PI_RBM_SSL(classifier, layers= {pi.name: pi, h.name: h}, edges=E, name = model_full_name)
    model_rbm = model_rbm.to(device)
    model_rbm.load(f"{folder}/weights/{model_full_name}_{best_epoch}.h5")
    model_rbm.eval()
    model_rbm.ais()
    model_rbm = model_rbm.to("cpu")
    model_rbm.Z = model_rbm.Z.cpu()
    edge = model_rbm.edges["pi -> hidden"]
    
    idx0 = 0
    idx = (idx0*torch.ones(512).int()).to(device)

    x__ = torch.cat([x_cas9.view(21, -1).clone()[None] for i in idx],0)

    x_cas9 = torch.clone(x__[0].view(-1))
    h_cas9 = edge(x_cas9[None], False)


    with torch.no_grad():
        e0 = (model_rbm({"pi":x_cas9[None]})/736 - model_rbm.Z)[0].item()
    e_plage = [(e0-0.02, e0+0.02)]

    for s, e in product(sim_plage,e_plage):
        T = 0.1*torch.ones(1,1,len(kept_idx))
        objective = RbmCriterion(model_rbm, postprocessing=ConstantPostprocessor(0))
        constraints = [
                    SimCriterion(x_cas9, list(range(736)), postprocessing = LinearPostprocessor(100, 0,s[0],0,s[1])),
                    RbmCriterion(model_rbm, postprocessing = LinearPostprocessor(100, e0 ,e[0],e0,e[1])),
        ]
        
        walker = Walker(x_cas9.view(21, -1).clone(), model_rbm, objective, constraints, zero_idx, gamma=1, n=1, a=1,
                c=1e-2, eps=1, target=target.cpu(), T=T)

        x.append(walker.run(16, n_epochs = 200, verbose=False))
        e = np.concatenate([track["e"][None] for track in walker.TRACKS])
        sim = np.concatenate([track["abs_diff"][None] for track in walker.TRACKS])
        TRACKS.append((deepcopy(e.T[:10]), deepcopy(sim.T[:10])))
    torch.save((x,TRACKS), f"{folder}/tracks/tracks_{gamma}.pt")

In [None]:
for (track_e, track_s), (s, e) in zip(TRACKS,product(sim_plage,e_plage)):
    plt.figure(figsize = (10,5))
    plt.subplot(121)
    print(e,s)
    for e_ in track_e[:5]:
        plt.plot(-e_)
    plt.plot([0,100],[.15,-e[0]],color="black")
    plt.plot([100,200],[-e[0],-e[0]],color="black")
    plt.plot([0,100],[.15,-e[1]],color="black")
    plt.plot([100,200],[-e[1],-e[1]],color="black")
    plt.ylabel("E_RBM")
    plt.xlabel("Step")

    #plt.show()
    plt.subplot(122)
    for sim_ in track_s[:5]:
        plt.plot(sim_)
    plt.plot([0,100],[0,s[0]],color="black")
    plt.plot([100,200],[s[0],s[0]],color="black")
    plt.plot([0,100],[0,s[1]],color="black")
    plt.plot([100,200],[s[1],s[1]],color="black")
    
    plt.ylabel("Distance to SpyCas9")
    plt.xlabel("Step")
    plt.show()

## Experimentally tested data and generated data

In [None]:
df = pd.read_excel(f"{DATA}/ML-designed PID.xlsx")
batch = 0
X_labelled = []
nnz_idx = [i for i in range(736) if i not in zero_idx]
for seq in df.seq[df.batch>=batch]:
    seq_onehot = torch.tensor(to_onehot([AA_IDS[x__]+1 if x__ in AA else 0 for x__ in seq],(None,21)).T)
    x_ = torch.zeros(21,736)
    x_[0] = 1
    x_[0, nnz_idx] = 0
    x_[:,nnz_idx] = seq_onehot.float()
    X_labelled.append(x_)
X_labelled = torch.stack(X_labelled,0)

In [None]:
with torch.no_grad():
    e0 = (model_rbm({"pi":x_cas9[None]})/736 - model_rbm.Z)[0].item()
e_plage = [(e0-0.01-0.02*i, e0+0.01-0.02*i) for i in range(6)]    
sim_plage = [(10+5*i, 10+5*i+5) for i in range(10)]
x = []
TRACKS = []
for s, e in product(sim_plage,e_plage):
    objective = RbmCriterion(model_rbm, postprocessing=ConstantPostprocessor(0))
    constraints = [
                SimCriterion(x_cas9, list(range(736)), postprocessing = LinearPostprocessor(100, 0,s[0],0,s[1])),
                RbmCriterion(model_rbm, postprocessing = LinearPostprocessor(100, e0 ,e[0],e0,e[1])),
    ]
    weight_constraints = [10,1000]

    walker = Walker(x_cas9.view(21, -1).clone(), model_rbm, objective, constraints, zero_idx, gamma=.1, n=1, a=1,
            c=1e-2, eps=1, target=target.cpu(), T=T, weight_constraints = weight_constraints)

    x.append(walker.run(16, n_epochs = 200, verbose=False))
    e = np.concatenate([track["e"][None] for track in walker.TRACKS])
    plt.show()
    sim = np.concatenate([track["abs_diff"][None] for track in walker.TRACKS])
    plt.show()
    TRACKS.append((deepcopy(e.T[:10]), deepcopy(sim.T[:10])))

In [None]:
xs = torch.cat(x,0).reshape(-1, 21, 736).argmax(1).numpy()
aligned_seqs = []
unaligned_seqs = []
for x_ in xs:
    aligned_seqs.append("".join([AA[x__-1] if x__ > 0 else "-" for x__ in x_]))
    unaligned_seqs.append("".join([AA[x__-1] for x__ in x_ if x__ > 0]))