In [1]:
"""
The purpose of this Jupyter notebook is to familiarize oneself with ESM
C (ESM Cambrian).
"""

'\nThe purpose of this Jupyter notebook is to familiarize oneself with ESM\nC (ESM Cambrian).\n'

In [2]:
import os

import torch
from esm.models.esmc import ESMC
from esm.sdk import batch_executor
from esm.sdk.api import ESMProtein, LogitsConfig

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
DEVICE = "mps"
model = ESMC.from_pretrained("esmc_600m").to(DEVICE).eval()

Fetching 4 files: 100%|██████████| 4/4 [00:00<00:00, 81840.08it/s]


In [4]:
def embed_sequence(model, sequence):
    with torch.no_grad():
        protein = ESMProtein(sequence=sequence)
        protein_tensor = model.encode(protein)
        output = model.logits(
            protein_tensor, LogitsConfig(sequence=False, return_embeddings=True)
        )
    return output

In [29]:
some_seqs = ["A" * 3] * 100

with batch_executor() as executor:
    outputs = executor.execute_batch(
        user_func=embed_sequence, model=model, sequence=some_seqs
    )

Processing  100%|████████████████████████| 100/100 [Elapsed: 00:10 | Remaining: 00:00] , Success=100 Fail=0 Retry=0


In [30]:
print(len(outputs))

100


In [31]:
print(outputs[0])

LogitsOutput(logits=ForwardTrackData(sequence=None, structure=None, secondary_structure=None, sasa=None, function=None), embeddings=tensor([[[ 0.0042, -0.0045,  0.0032,  ...,  0.0155,  0.0032, -0.0076],
         [ 0.0288, -0.0025,  0.0080,  ...,  0.0262,  0.0543,  0.0122],
         [ 0.0279, -0.0257, -0.0047,  ...,  0.0200,  0.0187,  0.0086],
         [ 0.0280, -0.0160, -0.0238,  ..., -0.0051,  0.0120, -0.0143],
         [-0.0065, -0.0131, -0.0015,  ...,  0.0060,  0.0086, -0.0147]]],
       device='mps:0'), residue_annotation_logits=None, hidden_states=None)


In [28]:
for i in range(3):
    current_embedding = outputs[i].embeddings
    print(current_embedding.shape)
    print(f"Sequence length: {len(some_seqs[i])}")
    print()

torch.Size([1, 5, 960])
Sequence length: 3

torch.Size([1, 8, 960])
Sequence length: 6

torch.Size([1, 4, 960])
Sequence length: 2



In [5]:
humongous_seq = "G" * 10000
humongous_output = embed_sequence(model, humongous_seq)

In [7]:
print(humongous_output.embeddings.shape)

torch.Size([1, 10002, 1152])


In [8]:
tiny_seq = "A"
tiny_output = embed_sequence(model, tiny_seq)

In [9]:
print(tiny_output.embeddings.shape)

torch.Size([1, 3, 1152])
