This notebook demonstrates how to use [ESM-2](https://github.com/facebookresearch/esm) transformer protein language model to extract embeddings from provided protein sequences. The following code calculates the embeddings for each protein sequence in the training and test FASTA files and saves them as individual `.pt` files. 

The resulting embedding can be loaded using the following code:
```
import torch

embedding = torch.load('[EntryID].pt')
embedding = embedding['mean_representations'][33].numpy()
```

Computing the embeddings and subsequently reading in the `.pt` files can take a while. The resulting numpy arrays can be found [here](https://www.kaggle.com/datasets/viktorfairuschin/cafa-5-ems-2-embeddings-numpy).

**Note** that the test FASTA file contains duplicate entries. For this reason, this notebook uses cleaned FASTA files, which can be found [here](https://www.kaggle.com/datasets/viktorfairuschin/cafa-5-fasta-files).

In [None]:
!pip install -q fair-esm

In [None]:
import pathlib
import torch

from esm import FastaBatchedDataset, pretrained

In [None]:
def extract_embeddings(model_name, fasta_file, output_dir, tokens_per_batch=4096, seq_length=1022, repr_layers=[33], num_gpus=2):
    
    model, alphabet = pretrained.load_model_and_alphabet(model_name)
    model.eval()
    
    if torch.cuda.is_available():
        # Wrap model with DataParallel for multi-GPU usage
        if num_gpus > 1 and torch.cuda.device_count() >= num_gpus:
            print(f"Using {num_gpus} GPUs: {list(range(num_gpus))}")
            model = torch.nn.DataParallel(model, device_ids=list(range(num_gpus)))
        model = model.cuda()
    
    dataset = FastaBatchedDataset.from_file(fasta_file)
    batches = dataset.get_batch_indices(tokens_per_batch, extra_toks_per_seq=1)
    data_loader = torch.utils.data.DataLoader(
        dataset, 
        collate_fn=alphabet.get_batch_converter(seq_length), 
        batch_sampler=batches,
        num_workers=0,  # Keep 0 to avoid issues with CUDA and multiprocessing
        pin_memory=True  # Enable for faster data transfer to GPU
    )
    output_dir.mkdir(parents=True, exist_ok=True)
    
    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):
            if (batch_idx % 1000) == 0:
                print(f'Processing batch {batch_idx + 1} of {len(batches)}')
            
            if torch.cuda.is_available():
                toks = toks.to(device="cuda", non_blocking=True)
            
            out = model(toks, repr_layers=repr_layers, return_contacts=False)
            
            logits = out["logits"].to(device="cpu")
            representations = {layer: t.to(device="cpu") for layer, t in out["representations"].items()}
            
            for i, label in enumerate(labels):
                entry_id = label.split()[0]
                
                filename = output_dir / f"{entry_id}.pt"
                truncate_len = min(seq_length, len(strs[i]))
                result = {"entry_id": entry_id}
                result["mean_representations"] = {
                    layer: t[i, 1 : truncate_len + 1].mean(0).clone()
                    for layer, t in representations.items()
                }
                torch.save(result, filename)

## Process train file

In [None]:
model_name = 'esm2_t33_650M_UR50D'
fasta_file = pathlib.Path('/kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta')
output_dir = pathlib.Path('train_embeddings')

extract_embeddings(model_name, fasta_file, output_dir)

## Process test file

In [None]:
model_name = 'esm2_t33_650M_UR50D'
fasta_file = pathlib.Path('/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset.fasta')
output_dir = pathlib.Path('test_embeddings')

extract_embeddings(model_name, fasta_file, output_dir)