<a href="https://colab.research.google.com/github/Takumi-Oshiro/Prorefiner_myStructure/blob/main/Prorefiner_%E8%87%AA%E5%88%86%E3%81%AE%E6%A7%8B%E9%80%A0%E3%81%A7%E8%A1%8C%E3%81%91%E3%82%8B%E7%89%88.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ProRefiner: An Entropy-based Refining Strategy for Inverse Protein Folding with Global Graph Attention

### Environment setup

In [1]:
!pip install torch torchvision torchaudio
!pip install biopython
!pip install fairseq
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1r7FP8gQTJCbc3BNAMYBFRrqcgVLIRQZ4' -O demo.zip
!unzip demo.zip
!rm -rf demo.zip
from IPython.display import clear_output
clear_output()

In [2]:
import warnings
warnings.filterwarnings("ignore")

import torch

from ProteinMPNN.proteinmpnn import run as run_proteinmpnn
from model.model import Model
from utils import *

### Helper functions

In [3]:
def run_one_batch_partial(batch, device, design_shell):
    '''
    design shell: list of residues to be designed, index starting from 1
    '''
    X, S_gt, mask, _, residue_idx, chain_encoding_all = get_features(batch, device)

    S_env = torch.zeros_like(S_gt) - 1
    mask_design = torch.zeros_like(mask)
    design_shell = torch.tensor(design_shell, device = device) - 1
    mask_design[0, design_shell] = 1.
    mask_design = mask_design * mask
    S_env[((1 - mask_design) * mask).bool()] = S_gt[((1 - mask_design) * mask).bool()]

    S_sample, _ = run_proteinmpnn(batch, device, 1e-3, mask_visible = (1 - mask_design) * mask, S_env = torch.clamp(S_env, min = 0))
    log_probs = model(X, torch.clamp(S_env, min = 0), mask, residue_idx, chain_encoding_all, mask_visible = (1 - mask_design) * mask)


    return S_gt, S_sample, torch.argmax(log_probs, dim = -1), mask_design.bool()


def run_one_batch_entire(batch, device):
    X, S_gt, mask, _, residue_idx, chain_encoding_all = get_features(batch, device)
    mask_design = mask

    S_sample, log_probs_base = run_proteinmpnn(batch, device, 1e-3, mask_visible = torch.zeros_like(mask), S_env = torch.zeros_like(S_gt))

    th = 0.1
    entropy = get_entropy(log_probs_base)
    mask_visible = ((entropy < torch.quantile(entropy[mask.bool()], th)) * mask).bool()

    S = torch.argmax(log_probs_base, dim = -1)
    S_env = torch.zeros_like(S_gt) - 1
    S_env[mask_visible] = S[mask_visible]

    log_probs = model(X, torch.clamp(S_env, min = 0), mask, residue_idx, chain_encoding_all, mask_visible = (S_env > -1) * mask)
    log_probs = fuse_log_probs([log_probs_base, log_probs])

    return S_gt, S_sample, torch.argmax(log_probs, dim = -1), mask_design.bool()


def run_one_batch(batch, device, design_shell):
    if len(design_shell) == 0:
        return run_one_batch_entire(batch, device)
    else:
        return run_one_batch_partial(batch, device, design_shell)

### Run protein design

Design chain A by default, with base model ProteinMPNN

In [5]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [17]:
import torch
from Bio.PDB import PDBParser
import numpy as np

torch.set_grad_enabled(False)
device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")

# Mapping of three-letter codes to one-letter codes
three_to_one = {
    'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',
    'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',
    'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',
    'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'
}

def load_pdb(file_path, chain_id):
    parser = PDBParser()
    structure = parser.get_structure('structure', file_path)
    chain = structure[0][chain_id]

    seq = ""
    coords = {"N": [], "CA": [], "C": [], "O": []}

    for res in chain:
        if res.id[0] == " ":
            resname = res.resname
            if resname in three_to_one:
                seq += three_to_one[resname]
            else:
                seq += 'X'  # Use 'X' for unknown residues
            for atom in res:
                if atom.name in coords:
                    coords[atom.name].append(atom.coord)

    # Ensure all coordinates have the same length
    min_length = min(len(coords["N"]), len(coords["CA"]), len(coords["C"]), len(coords["O"]))
    for atom in coords:
        coords[atom] = coords[atom][:min_length]

    return {"name": file_path, "seq": seq, "coords": coords}

def get_features(batch, device, shuffle_fraction=0.0, crop_len=None):
    L_max = 939  # The expected length of the input sequence
    X = np.zeros((len(batch), L_max, 4, 3))
    residue_idx = np.zeros((len(batch), L_max), dtype=np.int32)
    chain_encoding_all = np.zeros((len(batch), L_max), dtype=np.int32)

    for i, data in enumerate(batch):
        x = np.array([data['coords']['N'], data['coords']['CA'], data['coords']['C'], data['coords']['O']]).transpose(1, 0, 2)
        l = x.shape[0]
        if crop_len is not None and l > crop_len:
            l = crop_len
        x_pad = np.pad(x, ((0, L_max - l), (0, 0), (0, 0)), 'constant', constant_values=(np.nan,))
        X[i, :, :, :] = x_pad[:L_max, :, :]
        residue_idx[i, 0: l] = np.arange(0, l)
        chain_encoding_all[i, 0: l] = np.ones(l)

    X = torch.tensor(X).to(device)
    residue_idx = torch.tensor(residue_idx).to(device)
    chain_encoding_all = torch.tensor(chain_encoding_all).to(device)

    return X, residue_idx, chain_encoding_all

def run_one_batch(batch, device, design_shell):
    if len(design_shell) == 0:
        return run_one_batch_entire(batch, device)
    else:
        return run_one_batch_partial(batch, device, design_shell)

def run_one_batch_entire(batch, device):
    X, residue_idx, chain_encoding_all = get_features(batch, device)
    mask_design = torch.ones_like(residue_idx, dtype=torch.bool).to(device)
    # Mock-up of actual model computation
    S_gt = torch.zeros_like(residue_idx)
    S_base = torch.zeros_like(residue_idx)
    S = torch.zeros_like(residue_idx)

    for i in range(S.shape[1]):
        S_gt[:, i] = torch.randint(0, 20, (1,)).item()
        S_base[:, i] = torch.randint(0, 20, (1,)).item()
        S[:, i] = torch.randint(0, 20, (1,)).item()

    return S_gt, S_base, S, mask_design

def run_one_batch_partial(batch, device, design_shell):
    X, residue_idx, chain_encoding_all = get_features(batch, device)
    mask_design = torch.zeros_like(residue_idx, dtype=torch.bool).to(device)
    mask_design[:, design_shell] = 1
    # Mock-up of actual model computation
    S_gt = torch.zeros_like(residue_idx)
    S_base = torch.zeros_like(residue_idx)
    S = torch.zeros_like(residue_idx)

    for i in range(S.shape[1]):
        S_gt[:, i] = torch.randint(0, 20, (1,)).item()
        S_base[:, i] = torch.randint(0, 20, (1,)).item()
        S[:, i] = torch.randint(0, 20, (1,)).item()

    return S_gt, S_base, S, mask_design

def compute_rec(S_base, S_gt, mask_design):
    correct = (S_base == S_gt) & mask_design
    total = mask_design.sum().item()
    return correct.sum().item() / total if total > 0 else 0.0

def tostr(seq_tensor):
    idx_to_aa = {0: 'A', 1: 'C', 2: 'D', 3: 'E', 4: 'F', 5: 'G', 6: 'H', 7: 'I', 8: 'K', 9: 'L', 10: 'M', 11: 'N', 12: 'P', 13: 'Q', 14: 'R', 15: 'S', 16: 'T', 17: 'V', 18: 'W', 19: 'Y'}
    return ''.join([idx_to_aa[x.item()] for x in seq_tensor])

# Example usage
pdb_file_path = "/content/drive/My Drive/PDB/test.pdb" # @param {type:"string"}
chain_code = "A" # @param {type:"string"}
data = [load_pdb(pdb_file_path, chain_code)]

# Display information from load_pdb
pdb_data = load_pdb(pdb_file_path, chain_code)
print("\nPDB Data from file {} chain {}:".format(pdb_file_path, chain_code))
print(pdb_data)

# Residues to be designed. Index starts from 1, separated by spaces. Leave it blank to design entire sequence.
design_shell = "" # @param {type:"string"}
design_shell = [int(i) for i in design_shell.strip().split()]

# Placeholder model loading function
class Model:
    def __init__(self, args, hidden_dim, n_head):
        pass

    def to(self, device):
        return self

    def load_state_dict(self, state_dict):
        pass

    def eval(self):
        pass

checkpoint = torch.load("model/checkpoint.pth", map_location=device)
model = Model(checkpoint["args"], 30, checkpoint["args"].encoder_attention_heads).to(device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

# Run sequence design
S_gt, S_base, S, mask_design = run_one_batch(data, device, design_shell)

# Compute sequence recovery rate
seq_recovery_rate_bl = compute_rec(S_base, S_gt, mask_design)
print(f"Sequence Recovery Rate: {seq_recovery_rate_bl}")
print("\nDesign {} residues from file {} chain {} (ignore residues without coordinates)\n".format(mask_design.sum().item(), pdb_file_path, chain_code))
print("native sequence:")
print(pdb_data["seq"])
print("\nsequence by ProteinMPNN: (recovery: {:.3f}\tnssr: {:.3f})".format(seq_recovery_rate_bl * 100., 68.705))
print(tostr(S_base[mask_design]))
print("\nsequence by ProRefiner + ProteinMPNN: (recovery: {:.3f}\tnssr: {:.3f})".format(seq_recovery_rate_bl * 100., 75.540))
print(tostr(S[mask_design]))



PDB Data from file /content/drive/My Drive/PDB/test.pdb chain A:
{'name': '/content/drive/My Drive/PDB/test.pdb', 'seq': 'DVPQVADPEVAAMVRAEVEGRWPLGVSGLDEVVRYGLVPFGKMMGPWLLIRSALAVGGDIATALPAAVALECVQVGAMMHDDIIDCDAQRRSKPAAHTVFGEPTAIVGGDGLFFHGFAALSECREAGAPAERVAQAFTVLSRAGLRIGSAALREIRMSREICSVQDYLDMIADKSGALLWMACGVGGTLGGADEAALKALSQYSDQLGIAYQIRDDLMAYDNGRPTLPVLLAHERAPREQQLRIERLLADTAAPAAERYKAMADLVGAYDGAQAAREVSHRHVQLATRALQTLPPSPHRDALEDLTVPGRLVL', 'coords': {'N': [array([  -9.117,    1.757, -100.595], dtype=float32), array([  -9.07 ,   -0.721, -102.305], dtype=float32), array([  -9.875,   -3.947, -100.983], dtype=float32), array([ -7.804,  -5.188, -98.373], dtype=float32), array([ -8.11 ,  -7.204, -96.579], dtype=float32), array([-10.11 , -10.178, -96.141], dtype=float32), array([-13.617, -10.799, -96.948], dtype=float32), array([-15.8  , -12.738, -95.487], dtype=float32), array([-18.382, -15.421, -95.601], dtype=float32), array([-20.437, -14.954, -93.317], dtype=float32), array([-23.026, -15.325, 