# Calculate embeddings from ESM

Code taken from here https://github.com/facebookresearch/esm/blob/main/scripts/extract.py

### Prepare inputs

In [1]:
from datasets import Dataset, load_dataset
import pandas as pd
from tqdm import tqdm

In [2]:
dss = load_dataset('EvaKlimentova/knots_AF')
test = dss['test']
train = dss['train']

Using custom data configuration EvaKlimentova--knots_AF-265fee554925f78a
Found cached dataset parquet (/home/jovyan/.cache/huggingface/datasets/EvaKlimentova___parquet/EvaKlimentova--knots_AF-265fee554925f78a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
FASTA_TRAIN = "Knots_AF_train.fasta"
OUTPUT_TRAIN = "Knots_AF_train_embeddings.csv"
FASTA_TEST = "Knots_AF_test.fasta"
OUTPUT_TEST = "Knots_AF_test_embeddings.csv"

In [4]:
file_train = open(FASTA_TRAIN, 'w')
for row in train:
    file_train.write(f'>{row["ID"]},{row["label"]}\n{row["uniprotSequence"]}\n')

file_test = open(FASTA_TEST, 'w')
for row in test:
    file_test.write(f'>{row["ID"]},{row["label"]}\n{row["uniprotSequence"]}\n')


## Compute embeddings

In [5]:
import torch
from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained, MSATransformer
from pathlib import Path
import time

def compute_embeddings(inp, output):

    # instead of args giving default numbers

    include = ['mean']

    model, alphabet = torch.hub.load("facebookresearch/esm:main", "esm2_t36_3B_UR50D")
    model.eval()

    if torch.cuda.is_available():
        model = model.cuda()
        print("Transferred model to GPU")

    dataset = FastaBatchedDataset.from_file(inp)
    batches = dataset.get_batch_indices(4096, extra_toks_per_seq=1)
    data_loader = torch.utils.data.DataLoader(
        dataset, collate_fn=alphabet.get_batch_converter(1022), batch_sampler=batches
    )
    print(f"Read {inp} with {len(dataset)} sequences")

    return_contacts = "contacts" in include

    assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in [-1])
    repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in [-1]]

    file = open(output, 'w')

    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):
            print(
                f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
            )
            if torch.cuda.is_available():
                toks = toks.to(device="cuda", non_blocking=True)

            out = model(toks, repr_layers=repr_layers, return_contacts=return_contacts)

            logits = out["logits"].to(device="cpu")
            representations = {
                layer: t.to(device="cpu") for layer, t in out["representations"].items()
            }
            if return_contacts:
                contacts = out["contacts"].to(device="cpu")

            for i, label in enumerate(labels):
                result = {"label": label}
                truncate_len = min(1022, len(strs[i]))
                # Call clone on tensors to ensure tensors are not views into a larger representation
                # See https://github.com/pytorch/pytorch/issues/1995
                if "per_tok" in include:
                    result["representations"] = {
                        layer: t[i, 1: truncate_len + 1].clone()
                        for layer, t in representations.items()
                    }
                if "mean" in include:
                    result["mean_representations"] = {
                        layer: t[i, 1: truncate_len + 1].mean(0).clone()
                        for layer, t in representations.items()
                    }
                if "bos" in include:
                    result["bos_representations"] = {
                        layer: t[i, 0].clone() for layer, t in representations.items()
                    }
                if return_contacts:
                    result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()

                #print(result['mean_representations'][36].detach().numpy())
                #print(result['label'])

                file.write(result['label'] + ',' + ','.join(str(e) for e in result['mean_representations'][36].detach().numpy()))
                file.write('\n')



In [None]:
start = time.time()

for (inp, out) in [(FASTA_TEST, OUTPUT_TEST), (FASTA_TRAIN, OUTPUT_TRAIN)]:
    compute_embeddings(inp, out)
    
end = time.time()

Using cache found in /home/jovyan/.cache/torch/hub/facebookresearch_esm_main


Transferred model to GPU
Read Knots_AF_test.fasta with 37401 sequences
Processing 1 of 3937 batches (48 sequences)
Processing 2 of 3937 batches (45 sequences)
Processing 3 of 3937 batches (43 sequences)
Processing 4 of 3937 batches (41 sequences)
Processing 5 of 3937 batches (40 sequences)
Processing 6 of 3937 batches (39 sequences)
Processing 7 of 3937 batches (37 sequences)
Processing 8 of 3937 batches (36 sequences)
Processing 9 of 3937 batches (35 sequences)
Processing 10 of 3937 batches (35 sequences)
Processing 11 of 3937 batches (34 sequences)
Processing 12 of 3937 batches (33 sequences)
Processing 13 of 3937 batches (32 sequences)
Processing 14 of 3937 batches (32 sequences)
Processing 15 of 3937 batches (31 sequences)
Processing 16 of 3937 batches (31 sequences)
Processing 17 of 3937 batches (30 sequences)
Processing 18 of 3937 batches (30 sequences)
Processing 19 of 3937 batches (30 sequences)
Processing 20 of 3937 batches (29 sequences)
Processing 21 of 3937 batches (29 sequ

In [5]:
print("Calculating embedding time: ", end - start)

Calculating embedding time:  76031.86687970161


---------------------------------------