In [1]:
import os
import sys

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import torch

sys.path.append("..")

from models import ESMIF, MPNN
from models.revor import ReVor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
checkpoint_path = "../model_params/ligandmpnn/ligandmpnn_v_32_020_25.pt"
mpnn = MPNN(
    checkpoint_path,
    model_type="ligand_mpnn",
    device=device,
    ligand_mpnn_use_side_chain_context=True,
)
# esmif = ESMIF(device=device)

In [None]:
pdb_path = "../pdbs/FlexID.pdb"
redesigned_residues = (
    "A9 A26 A30 A35 A83 A106 A128 A131 A173 A219 A223 A225 A254 A256 A316 A321 A328"
)
seqs_wt = "STLRLLISDSHDPWFNLAVEECIFRQMPATQRVLFLVRNADTVVIGRNQNPWKECNIRRMEEDNVRLARRSSGGGAVFHDLGNTCFTFMAGKPEYDKTISTSIVLNALNALGVSAEASGRNDLVVKTVEGDRKVSGSAYRETKDRGLHHGTLLLNADLSRLANYLNPDKKKLAAKGITSVRSRVTNLTELLPGITHEQVCEAITEAFFAHYGERVEAEIISPNETPDLPNFAETFARQSSWEWNFGQSPAFSHLLDERFTWGGVELRFDVEKGHITRAQVFTDSLNPAPLEALAGRLQGCLYRADELQQECEALLVDFPEQEKELRELSAWMAGAVR"
repeat = 4

In [None]:
revor = ReVor(
    mpnn, pdb_path, seqs_wt, repeat=repeat, redesigned_residues=redesigned_residues
)
revor.revert(
    "../results/alignments.fasta",
    cutoff=0.1,
    batch_size=8,
    max_step=2,
    n_samples=8,
    temperature=2,
    checkpoint_path="./checkpoint.pkl",
    save_checkpoint_interval=20,
)

In [None]:
revor.plot(nx.multipartite_layout, subset_key="iteration")
revor.plot(nx.multipartite_layout, subset_key="distance")
revor.plot(nx.multipartite_layout, subset_key="topology")

In [None]:
revor.get_results()

In [None]:
revor.save("../results/alignments")