In [1]:
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig
from esm.tokenization import EsmSequenceTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
client = ESMC.from_pretrained("esmc_300m").to("mps") # or "cpu"


Fetching 4 files: 100%|██████████| 4/4 [00:00<00:00, 25653.24it/s]


In [4]:
protein = ESMProtein(sequence="AAGTAGCATCGACTCGACBBGTBRRRPPHPPA")
protein_tensor = client.encode(protein)
print(protein_tensor)
logits_output = client.logits(
   protein_tensor, LogitsConfig(sequence=True, return_embeddings=True, return_hidden_states=True)
)

ESMProteinTensor(sequence=tensor([ 0,  5,  5,  6, 11,  5,  6, 23,  5, 11, 23,  6,  5, 23, 11, 23,  6,  5,
        23, 25, 25,  6, 11, 25, 10, 10, 10, 14, 14, 21, 14, 14,  5,  2],
       device='mps:0'), structure=None, secondary_structure=None, sasa=None, function=None, residue_annotations=None, coordinates=None, potential_sequence_of_concern=False)


In [14]:
logits_output.hidden_states[29]

tensor([[[ 51.3490, -22.6709,  -1.0412,  ...,  60.8298,  24.8718, -27.4169],
         [ 58.7620,  -2.1394,  68.7211,  ..., 108.2104,  -3.0818,   8.6591],
         [104.7704, -35.2187,  13.8251,  ...,  56.3050,  -3.3278, -26.6943],
         ...,
         [ 39.5975, -61.2269, -35.7862,  ..., -14.3060,  82.8928,  31.1213],
         [ 79.9799,   8.7390, -48.7601,  ..., -29.2824,  -0.6710, -46.8744],
         [ 60.9359,  23.8559, -40.3378,  ..., -19.1161,  -5.2826, -53.2040]]],
       device='mps:0')

In [15]:
logits_output.embeddings

tensor([[[ 8.5242e-03, -4.4012e-03, -2.1846e-04,  ...,  1.0120e-02,
           4.2381e-03, -4.4994e-03],
         [ 2.2030e-02, -7.6427e-05,  2.5488e-02,  ...,  4.0026e-02,
          -4.2167e-04,  3.7494e-03],
         [ 3.8279e-02, -1.3659e-02,  5.6543e-03,  ...,  2.0940e-02,
          -4.7275e-04, -8.5971e-03],
         ...,
         [ 1.2608e-02, -1.9860e-02, -9.6806e-03,  ..., -3.3764e-03,
           2.6267e-02,  9.8276e-03],
         [ 2.5860e-02,  3.9596e-03, -1.4557e-02,  ..., -8.5287e-03,
           5.0059e-04, -1.3679e-02],
         [ 2.0643e-02,  9.7363e-03, -1.2523e-02,  ..., -5.6407e-03,
          -1.1276e-03, -1.6376e-02]]], device='mps:0')

In [14]:
logits_output.embeddings[0][33]
# Take mean across embeddings
logits_output.embeddings.shape

torch.Size([1, 34, 960])

In [19]:
logits_output.embeddings.mean(axis=0)  # (1, 1024)

tensor([[ 8.5242e-03, -4.4012e-03, -2.1846e-04,  ...,  1.0120e-02,
          4.2381e-03, -4.4994e-03],
        [ 2.2030e-02, -7.6427e-05,  2.5488e-02,  ...,  4.0026e-02,
         -4.2167e-04,  3.7494e-03],
        [ 3.8279e-02, -1.3659e-02,  5.6543e-03,  ...,  2.0940e-02,
         -4.7275e-04, -8.5971e-03],
        ...,
        [ 1.2608e-02, -1.9860e-02, -9.6806e-03,  ..., -3.3764e-03,
          2.6267e-02,  9.8276e-03],
        [ 2.5860e-02,  3.9596e-03, -1.4557e-02,  ..., -8.5287e-03,
          5.0059e-04, -1.3679e-02],
        [ 2.0643e-02,  9.7363e-03, -1.2523e-02,  ..., -5.6407e-03,
         -1.1276e-03, -1.6376e-02]], device='mps:0')

In [70]:
def embed_kmer(kmer):
		# Translate kmer to protein sequence
		# Embed using ESM-c
		
		protein_str = ESMProtein(sequence = str(kmer))
		
		protein_tensor = client.encode(protein_str)
		logits_output = client.logits(
   		protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
		)
		return logits_output.embeddings

In [55]:
embed_kmer("ACGHA").shape

torch.Size([1, 7, 960])

In [74]:
emd = [embed_kmer(kmer) for kmer in ["ACTSAAAHAC", "TTGACGTA", "CGTACGTA"]]

import torch

torch.hstack(emd).mean(axis=1)

tensor([[-1.2591e-02, -2.0148e-02,  1.5691e-03,  1.9135e-02,  1.8808e-02,
          3.4045e-02, -6.2626e-03, -1.4441e-02,  1.0434e-02,  9.9025e-03,
         -3.1191e-02, -1.6275e-02, -4.7973e-03,  2.0133e-02, -1.5195e-02,
         -8.9194e-03,  1.3775e-02, -1.3011e-02,  4.8211e-03, -7.3268e-03,
          4.1266e-02,  1.7345e-02,  2.1795e-03,  1.5547e-01, -4.0499e-03,
         -1.0740e-02, -1.0662e-02, -7.5316e-04, -1.0799e-02, -1.1512e-02,
          1.1699e-03,  1.9038e-02, -1.9944e-03,  1.2813e-02,  1.3455e-03,
         -2.3930e-02, -5.0376e-03,  8.4668e-03, -2.5288e-02,  1.1002e-02,
         -1.8229e-02, -2.8837e-03, -2.3210e-02,  4.6417e-03,  2.8611e-03,
          6.0433e-03, -6.0791e-03, -1.5512e-02,  5.1686e-03,  9.0001e-03,
         -2.6180e-02, -2.3974e-02,  2.1469e-04,  6.0970e-03, -1.6692e-02,
          1.6390e-02,  2.7780e-03,  5.0262e-03, -4.6557e-03,  1.9523e-02,
         -4.1174e-02, -3.5342e-04,  1.5403e-03, -2.8472e-02, -7.1807e-03,
          5.2228e-03, -1.7119e-02,  1.