In [None]:
import json

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

In [None]:
pseudo_extra = [62, 69, 76, 116, 118, 150, 158]

pockets = [[5, 7, 59, 63, 66, 159, 163, 167, 171],
           [7, 9, 24, 34, 45, 63, 66, 67, 70, 99],
           [9, 70, 73, 74, 97],
           [99, 114, 155, 156, 159, 160],
           [97, 114, 147, 152, 156],
           [77, 80, 81, 84, 95, 123, 143, 146, 147]]

pseudo_indices = set([elem for pocket in pockets for elem in pocket] + pseudo_extra)
pseudo_indices = np.array(list(pseudo_indices)) - 1
pseudo_indices = np.sort(pseudo_indices)

model, alphabet = torch.hub.load("facebookresearch/esm:main", "esm2_t6_8M_UR50D")
batch_converter = alphabet.get_batch_converter()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

### Download HLA sequences from IPD-IMGT/HLA (https://www.ebi.ac.uk/ipd/imgt/hla/)

In [None]:
batch_size = 8

data = pd.read_csv("hla_seq_data.csv")

sequence_representations = {}

for start_idx in tqdm(range(0, len(data), batch_size), total=len(data) // batch_size):
    batch_data = data[start_idx:start_idx + batch_size]
    batch_labels, batch_strs, batch_tokens = batch_converter(batch_data)
    batch_tokens = batch_tokens.to(device)
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[6], return_contacts=True)
    token_representations = results["representations"][6]

    for i, tokens_len in enumerate(batch_lens):
        hla = data[start_idx + i][0]
        seq = data[start_idx + i][1]
        embeddings = token_representations[i, 1: tokens_len - 1]
        if len(seq) < 171:
            print(f'{hla} has less than 171 residues')
            continue

        if seq.startswith('M'):
            embeddings = embeddings[24:, :][pseudo_indices, :]
        elif seq.startswith('SH'):
            embeddings = embeddings[pseudo_indices-1, :]
        elif seq.startswith('HS'):
            embeddings = embeddings[pseudo_indices-2, :]
        elif 'GSHS' in seq:
            seq = seq[seq.find('GSHS'):]
            embeddings = embeddings[pseudo_indices, :]
        elif 'CSHS' in seq:
            seq = seq[seq.find('CSHS'):]
            embeddings = embeddings[pseudo_indices, :]
        elif 'GCHS' in seq:
            seq = seq[seq.find('GCHS'):]
            embeddings = embeddings[21:, :][pseudo_indices, :]
        else:
            try:
                embeddings = embeddings[pseudo_indices-1, :]
            except:
                print(f'{hla} does not contain full residues')
                continue

        sequence_representations[hla] = embeddings.cpu().numpy().tolist()


In [None]:
allele_to_index = {}
index_to_allele = {}
sorted_alleles = sorted(sequence_representations.keys())  # Sort for consistency

num_alleles = len(sequence_representations)
embedding_shape = (40, 320)
full_shape = (num_alleles, *embedding_shape)

mmap_array = np.memmap('hla_esm.mmap',
                       dtype=np.float64,
                       mode='w+',
                       shape=full_shape)

for idx, allele in enumerate(sorted_alleles):
    allele_to_index[allele] = idx
    mmap_array[idx] = np.array(sequence_representations[allele]).astype(np.float64)

del mmap_array # Flush to disk

index_data = {
    'hla_to_idx': allele_to_index,
    'shape': list(full_shape), 
    'dtype': 'float64',
}

with open('hla_esm_index.json', 'w') as f:
    json.dump(index_data, f, indent=2)

print(f"Saved embeddings to memory-mapped file: hla_esm.mmap")
print(f"Shape: {full_shape}")
print(f"Size: {num_alleles * embedding_shape[0] * embedding_shape[1] * 4 / (1024**2):.2f} MB")