In [1]:
cd ..

/home/ubuntu/anindya/esmfold-lite


In [2]:
from esmfold.model import ESMFold
import torch
from pathlib import Path

model_name = "esmfold_3B_v1"
url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt"
model_data = torch.hub.load_state_dict_from_url(url, progress=True, map_location="cpu")

In [3]:
cfg = model_data["cfg"]["model"]
model_state = model_data["model"]

In [4]:
model = ESMFold(esmfold_config=cfg)

In [5]:
model.load_state_dict(model_state, strict=False)

_IncompatibleKeys(missing_keys=['esm.embed_tokens.weight', 'esm.layers.0.self_attn.k_proj.weight', 'esm.layers.0.self_attn.k_proj.bias', 'esm.layers.0.self_attn.v_proj.weight', 'esm.layers.0.self_attn.v_proj.bias', 'esm.layers.0.self_attn.q_proj.weight', 'esm.layers.0.self_attn.q_proj.bias', 'esm.layers.0.self_attn.out_proj.weight', 'esm.layers.0.self_attn.out_proj.bias', 'esm.layers.0.self_attn.rot_emb.inv_freq', 'esm.layers.0.self_attn_layer_norm.weight', 'esm.layers.0.self_attn_layer_norm.bias', 'esm.layers.0.fc1.weight', 'esm.layers.0.fc1.bias', 'esm.layers.0.fc2.weight', 'esm.layers.0.fc2.bias', 'esm.layers.0.final_layer_norm.weight', 'esm.layers.0.final_layer_norm.bias', 'esm.layers.1.self_attn.k_proj.weight', 'esm.layers.1.self_attn.k_proj.bias', 'esm.layers.1.self_attn.v_proj.weight', 'esm.layers.1.self_attn.v_proj.bias', 'esm.layers.1.self_attn.q_proj.weight', 'esm.layers.1.self_attn.q_proj.bias', 'esm.layers.1.self_attn.out_proj.weight', 'esm.layers.1.self_attn.out_proj.bias'

In [6]:
expected_keys = set(model.state_dict().keys())
found_keys = set(model_state.keys())

missing_essential_keys = []
for missing_key in expected_keys - found_keys:
    if not missing_key.startswith("esm."):
        missing_essential_keys.append(missing_key)

In [7]:
missing_essential_keys

['trunk.structure_module.ipa.linear_kv_points.linear.weight',
 'trunk.structure_module.ipa.linear_q_points.linear.bias',
 'trunk.structure_module.ipa.linear_q_points.linear.weight',
 'trunk.structure_module.ipa.linear_kv_points.linear.bias']

In [None]:
model.eval().to("cuda:5")

In [None]:
sequence = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
# Multimer prediction can be done with chains separated by ':'

with torch.no_grad():
    output = model.infer_pdb(sequence)

with open("result.pdb", "w") as f:
    f.write(output)

In [None]:
!pip install biotite

In [None]:
import biotite.structure.io as bsio
struct = bsio.load_structure("result.pdb", extra_fields=["b_factor"])
print(struct.b_factor.mean())  # this will be the pLDDT

In [None]:
import torch
import esm

# Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein3",  "K A <mask> I S Q"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]

# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):
    sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))

# Look at the unsupervised self-attention map contact predictions
import matplotlib.pyplot as plt
for (_, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]):
    plt.matshow(attention_contacts[: tokens_len, : tokens_len])
    plt.title(seq)
    plt.show()

In [None]:
import torch
from esmfold.model import ESMFold

In [None]:
import torch

model = esm.pretrained.esmfold_v1()
model = model.eval().cuda()

# Optionally, uncomment to set a chunk size for axial attention. This can help reduce memory.
# Lower sizes will have lower memory requirements at the cost of increased speed.
# model.set_chunk_size(128)

sequence = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
# Multimer prediction can be done with chains separated by ':'

with torch.no_grad():
    output = model.infer_pdb(sequence)

with open("result.pdb", "w") as f:
    f.write(output)

import biotite.structure.io as bsio
struct = bsio.load_structure("result.pdb", extra_fields=["b_factor"])
print(struct.b_factor.mean())  # this will be the pLDDT
# 88.3