# Equifold Demo

>How to run inference and visualize model outputs

In [1]:
import gzip
import json
import os
import subprocess

import numpy as np
import pandas as pd
import py3Dmol
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from data.equifold_process_input import process_one
from models import NN
from openfold_light.residue_constants import restype_3to1
from refine import refine
from sequence_checks import number_sequences
from utils_data import collate_fn, x_to_pdb

### Refinement

Model outputs are passed through a refinement pipeline defined in `refine.py`. We'll use multiple predictions from different seeds to generate a final averaged prediction.

In [2]:
from utils import compute_prediction_error, to_atom37

n_seeds = 3

### Sample data from an internal dataset

Sequences are preprocessed with helper functions from the orignal Equifold repo.

In [3]:
df = pd.read_csv("benchmarking/equifold_int_nb_test_input_59_7_3.csv")

In [4]:
batch_size = 5
sample = df.head(batch_size)
uids = sample["uid"].tolist()
seqs_1 = sample["seq"].tolist()
seqs_2 = [None] * len(uids)
chain_ids = sample["chain_id"].tolist()
sample

Unnamed: 0,uid,seq,chain_id
0,XDBv2_exp-23219-IL13_hs-4b02_4b06,DVQLVESGGGVVQPGGSLRLSCAASGRTFSSYRMGWFRQAPGKERE...,B
1,RANKL-R13h5,EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYPMGWFRQAPGKGRE...,A
2,XDBv2_exp-23219-IL13_hs-4b02_4b06,DVQLVESGGGVVQPGGSLRLSCAASGFTFNNYAMKWVRQAPGKGLE...,K
3,XDBv2_exp-22647-Human4-1BB24-160_C121SN138DN14...,QVQLQESGGGLVQPGGSLRLSCAASGGLFSINTGGWYRQAPGKQRE...,A
4,XDBv2_exp-24200-40C01-Nb40C01,EVQLVESGGGLVQPGGSLSLSCAASGDTFGTKAMGWFRQAPGKGRE...,N


In [5]:
from data.equifold_process_input import process_one


In [6]:
# NOTE: If you face an error with ANARCI here, remove the value for differential_weight_fwr_cdr_FAPE in the config.ini file
dataset = [process_one(x) for x in zip(uids, seqs_1, seqs_2, chain_ids)]

In [7]:
loader = DataLoader(
    dataset,
    batch_size=1,
    drop_last=False,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn,
    pin_memory=True,
)

## Load saved model

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config_path = "models/models_with_recycling/config.json"
model_path = "models/models_with_recycling/7_11_run_3_last.ckpt"
output_dir = "benchmarking"
with open(config_path, "r") as f:
    config = json.load(f)
model = NN(**config)
checkpoint = torch.load(model_path, map_location="cpu")
model.load_state_dict(checkpoint["state_dict"])
model = model.to(device)
model.eval();

### Run inference

In [9]:
with torch.no_grad():
    for data in tqdm(loader):
        data = data.to(device)
        x_preds_atom37 = []
        for seed in range(n_seeds):
            # Seeding
            pl.seed_everything(seed)
            results_dict = model(
                data, compute_loss=False, return_struct=True, set_RT_to_ground_truth=False
            )

            # get pred
            x_pred = results_dict["x_pred"][0][-1]
            x_preds_atom37.append(
                torch.squeeze(
                    to_atom37(
                        x_pred,
                        data[0]["dst_resnum"],
                        data[0]["dst_atom"],
                        data[0]["dst_resname"],
                    )[0]
                )
            )

        # Get sequence names
        res, ind = np.unique(data[0]["dst_resnum"].cpu().numpy(), return_index=True)
        seq_full_name = data[0]["dst_resname"][np.sort(ind)]
        seq_short_name = [
            restype_3to1.get(seq_full_name[i], "X") for i in range(len(seq_full_name))
        ]
        seq_short_name = "".join(seq_short_name)
        numbered_sequences = number_sequences({"H": seq_short_name}, allowed_species=None)
        numbered_sequences["L"] = []

        obj = compute_prediction_error(
            numbered_sequences,
            x_preds_atom37,
            refine=True,
        )
        obj.save_all(uid=data[0]["uid"], dirname="test_inference")

100%|██████████| 5/5 [00:15<00:00,  3.03s/it]


## Visualize resulting structures

In [10]:
filename = "test_inference/XDBv2_exp-23219-IL13_hs-4b02_4b06_final_model.pdb"
colour_by = "predicted_error"  # @param ["predicted_error", "chain", "rainbow"]
show_sidechains = False  # @param {type:"boolean"}
show_mainchains = False  # @param {type:"boolean"}
# First we assign the py3Dmol.view as view
view = py3Dmol.view()
# The following lines are used to add the addModel class
# to read the PDB files of chain B and C
with open(filename, "r") as f:
    mol = f.read()
view.addModel(mol, "pdb")
# Zooming into all visualized structures
view.zoomTo()
# Here we set the background color as white
view.setBackgroundColor("white")
if colour_by == "chain":
    # Here we set the visualization style for chain B and C
    view.setStyle({"chain": "H"}, {"cartoon": {"color": "purple"}})
    view.setStyle({"chain": "L"}, {"cartoon": {"color": "green"}})
elif colour_by == "rainbow":
    view.setStyle({"cartoon": {"arrows": True, "color": "spectrum"}})
elif colour_by == "predicted_error":
    # Here we set visualization by b factor
    print("The error is calulated by comparing how much different models agree or disagree on the placement of each residue")
    view.setStyle({"cartoon": { "arrows": True, "colorscheme": {"prop": "b", "gradient": "roygb", "min": 2, "max": 0}}})
if show_sidechains:
    BB = ["C", "O", "N"]
    view.addStyle(
        {"and": [{"resn": ["GLY", "PRO"], "invert": True}, {"atom": BB, "invert": True}]},
        {"stick": {"colorscheme": f"WhiteCarbon", "radius": 0.3}},
    )
    view.addStyle(
        {"and": [{"resn": "GLY"}, {"atom": "CA"}]},
        {"sphere": {"colorscheme": f"WhiteCarbon", "radius": 0.3}},
    )
    view.addStyle(
        {"and": [{"resn": "PRO"}, {"atom": ["C", "O"], "invert": True}]},
        {"stick": {"colorscheme": f"WhiteCarbon", "radius": 0.3}},
    )
if show_mainchains:
    BB = ["C", "O", "N", "CA"]
    view.addStyle({"atom": BB}, {"stick": {"colorscheme": f"WhiteCarbon", "radius": 0.3}})

# And we finally visualize the structures using the command below
view.zoomTo()
view.show()

The error is calulated by comparing how much different models agree or disagree on the placement of each residue
