In [None]:
import pandas as pd
import numpy as np
import torch
import esm

In [36]:
import gc
import os

In [21]:
data = np.load("data/cullpdb+profile_5926_filtered.npy.gz", allow_pickle=True)

In [22]:
N = data.shape[0]
L, D = 700, 57
arr = data.reshape(N, L, D)

In [23]:
print("arr shape:", arr.shape)

arr shape: (5365, 700, 57)


In [24]:
arr.shape

(5365, 700, 57)

In [25]:
arr[0][500]

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 1.])

In [26]:

# Amino acid mapping (20 + X + padding)
AA_CODES = list("ACDEFGHIKLMNPQRSTVWY") + ['X'] + ['-']
# Q8 code order from dataset description
Q8_ORDER = ['L','B','E','G','I','H','S','T','-']
# Q8 → Q3 mapping: 0=Helix, 1=Strand, 2=Coil, 3=Padding
Q8_TO_Q3 = np.array([2, 1, 1, 0, 0, 0, 2, 2, 3])

def extract_features_labels(data):
    """
    Extract AA one-hot, N/C terminal flags, and Q3 labels from CullPDB array.
    Returns:
        features: list of np.arrays (n_residues, 24)  # 22 AA + 2 terminal flags
        labels_q3: list of np.arrays (n_residues,)    # Q3 labels (0,1,2)
        sequences: list of str                       # decoded AA sequences
    """
    N = data.shape[0]
    L, D = 700, 57
    arr = data.reshape(N, L, D)

    features = []
    labels_q3 = []
    sequences = []

    for prot in arr:
        # Blocks from dataset description
        aa_onehot = prot[:, 0:22]
        q8_onehot = prot[:, 22:31]
        nc_flags = prot[:, 31:33]  # N-term / C-term

        # Decode Q8 index and mask padding
        q8_idx = q8_onehot.argmax(axis=1)
        mask = q8_idx != 8  # not "NoSeq"

        # Q3 labels
        q3_idx = Q8_TO_Q3[q8_idx]
        q3_idx = q3_idx[mask]  # remove padding

        # Features = AA one-hot + N/C flags
        feat = np.concatenate([aa_onehot[mask], nc_flags[mask]], axis=1)

        # Decode AA sequence (optional, for inspection)
        aa_idx = aa_onehot.argmax(axis=1)
        seq = "".join(AA_CODES[i] for i in aa_idx[mask])

        features.append(feat)
        labels_q3.append(q3_idx)
        sequences.append(seq)

    return features, labels_q3, sequences


# Example usage
data = np.load("data/cullpdb+profile_5926_filtered.npy.gz", allow_pickle=True)
features, q3_labels, sequences = extract_features_labels(data)

print("Protein 0:")
print("Sequence:", sequences[0])
print("Features shape:", features[0].shape)  # (len_protein, 24)
print("Q3 labels:", q3_labels[0])



Protein 0:
Sequence: GEYPTWYGANPYFMSTHDMFDRDGWENTMENPIKXWHKAAVFFFYTNSNNWWHNGKWEDRMCENMYGKETEPQMWQXQARYYTMARESHAHQKFPHXAFWDWPMTEEGGAEDKRHRDNYWHQMMXTNWPFAERHMPFKQWWDNQWMTNAFEVRHMQPNGWMYAWKYWNQVIFDYMSSFHEIATWAFTRKEHHSIDPGWDNEDWWNHRTKRXKENMMTMKTKEDRFKEHRYTWSMRGADEFRCTWIRFRPNWWWRFRFKGWRKFDKNRMFFKESNAHYMEYTWNMTENNHPMATKETMWMSTSNWYFFKRMDWWSK
Features shape: (315, 24)
Q3 labels: [2 2 2 2 1 1 1 2 2 2 2 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 0 0 0 0 2 2 1 1 1 1
 1 1 1 1 1 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 2 0 0 0 0 2 2 2 1 1 2 2 2 2 2 1 2
 1 1 2 2 0 0 0 1 1 1 2 2 2 1 1 1 1 2 2 2 2 2 1 1 1 1 1 1 1 1 2 0 0 0 0 0 2
 0 0 0 0 2 2 2 1 1 1 1 1 1 1 1 1 1 1 2 2 2 1 1 2 2 2 1 1 2 2 2 2 2 2 2 2 2
 0 0 0 1 1 1 2 2 2 2 1 1 1 1 1 1 1 1 1 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 2
 2 1 1 1 1 1 1 1 2 2 2 2 0 0 0 2 1 1 1 1 1 1 1 2 2 2 2 1 1 1 1 1 1 1 1 1 2
 2 2 2 2 1 1 1 1 1 1 1 1 1 1 2 2 2 2 1 1 1 1 1 1 2 2 2 2 2 1 1 1 1 1 1 1 1
 1 1 1 1 1 2 2 2 0 0 0 2 2 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 1 1 1 1 1 1 1 1
 1 1 1 1 1 2 2 2 2 2 1 1 1

In [27]:
model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
batch_converter = alphabet.get_batch_converter()

In [28]:
model.eval()
for p in model.parameters():
    p.requires_grad = False

In [34]:
AA_seqs_data = [(f"protein_{i}", seq) for i, seq in enumerate(sequences)]

batch_labels, batch_strs, batch_tokens = batch_converter(AA_seqs_data[0:100])

In [35]:
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[6])
token_embeddings = results["representations"][6]

In [82]:
os.makedirs("embeddings", exist_ok=True)

model.eval()
for p in model.parameters():
    p.requires_grad = False

BATCH_SIZE = 100       # how many to process in memory at once
MAX_SEQUENCES = 500    # total sequences to process
LAYER = 6              # ESM layer to extract

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

total_to_process = min(MAX_SEQUENCES, len(sequences))

for start in range(0, total_to_process, BATCH_SIZE):
    end = min(start + BATCH_SIZE, total_to_process)
    print(f"Processing sequences {start} to {end-1}")

    AA_seqs_data = [(f"protein_{i}", seq) for i, seq in enumerate(sequences[start:end], start=start)]
    batch_labels, batch_strs, batch_tokens = batch_converter(AA_seqs_data)
    batch_tokens = batch_tokens.to(device)

    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[LAYER])
        token_embeddings = results["representations"][LAYER].cpu().numpy()  # (batch_size, max_length, emb_dim)

    # Save each protein embedding separately, removing padding and special tokens
    for i, (label, seq) in enumerate(zip(batch_labels, batch_strs)):
        seq_len = len(seq)
        protein_emb = token_embeddings[i, 1:seq_len+1]  # remove BOS/EOS tokens
        np.save(f"embeddings/{label}_emb.npy", protein_emb)

    del batch_tokens, results, token_embeddings
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

print(f"✅ Processed and saved embeddings for {total_to_process} sequences.")

Processing sequences 0 to 99
Processing sequences 100 to 199
Processing sequences 200 to 299
Processing sequences 300 to 399
Processing sequences 400 to 499
✅ Processed and saved embeddings for 500 sequences.


In [103]:
def load_embeddings_to_array(folder="embeddings", max_count=None, sort=True):
    """
    Load per-protein .npy embeddings from a folder and combine into a list.
    Handles variable-length protein embeddings (no padding).

    Args:
        folder (str): Path to the folder containing .npy embedding files.
        max_count (int or None): Number of embeddings to load. 
                                 If None, load all files.
        sort (bool): Sort files by name before loading.

    Returns:
        names (list[str]): Protein names in the same order as list entries.
        embeddings (list[np.ndarray]): Each entry is (protein_length, embedding_dim)
    """
    files = [f for f in os.listdir(folder) if f.endswith(".npy")]
    if sort:
        files.sort()
    if max_count is not None:
        files = files[:max_count]

    embeddings = []
    names = []
    for fname in files:
        path = os.path.join(folder, fname)
        emb = np.load(path)  # shape: (protein_length, embedding_dim)
        embeddings.append(emb)
        names.append(os.path.splitext(fname)[0])  # unique name per protein

    return names, embeddings

In [None]:
names, emb_array = load_embeddings_to_array("embeddings", max_count=500)
print("Shape:", len(emb_array))
print("First 5 names:", names[:5])
print(emb_array[0].shape)

Shape: 500
First 5 names: ['protein_0_emb', 'protein_100_emb', 'protein_101_emb', 'protein_102_emb', 'protein_103_emb']
(315, 320)
[ 0.08165033  0.10039306  0.02018797 -0.00969464 -0.48676437  0.59172815
  0.11124381 -0.16277462 -0.4684964  -0.53604203 -0.3474227  -0.07909001
  0.3967625   0.00250078  0.37723446  0.11119861 -0.18185318  0.20834017
 -0.50457394 -0.20081164 -0.07767012  0.37427914  0.01468289  0.4572351
  0.29783446 -0.2635855  -0.22743875  0.1735735  -0.38215002 -0.35645536
  0.3352471  -0.30648166 -0.4908226   0.20672217  0.68076754  0.27712142
  0.34008697 -0.4964391  -0.24952833  0.14109582  0.5502073  -0.03907342
  0.24146566 -0.01808552  0.23523512  0.07904153 -0.14599124 -0.20839697
 -0.10136926  0.47693837 -0.12188553  0.17358887  0.0193473  -0.2559413
 -0.03461693 -0.45603454  0.14921339  0.40606982  0.08036682  0.13388783
  0.60761714 -0.76782936 -0.16673733 -0.42674634 -0.3337297  -0.07469623
 -0.27109152  0.32472694 -0.4889805  -0.03554168 -0.5993802  -0.0384

In [113]:
import re

def extract_index(name):
    match = re.search(r'protein_(\d+)_emb', name)
    return int(match.group(1)) if match else -1

sorted_pairs = sorted(zip(names, emb_array), key=lambda x: extract_index(x[0]))
names_sorted, emb_array_sorted = zip(*sorted_pairs)
names = list(names_sorted)
emb_array = list(emb_array_sorted)

In [116]:

protein_inputs = []
for i, name in enumerate(names):
    emb = emb_array[i]  # shape: (protein_length, embedding_dim)
    labels = q3_labels[i]  # shape: (protein_length,)
    # Ensure lengths match
    if emb.shape[0] == len(labels):
        protein_inputs.append((torch.tensor(emb, dtype=torch.float32), torch.tensor(labels, dtype=torch.long)))
    else:
        print(f"Warning: Length mismatch for {name}")

# Now protein_inputs is a list of (embedding_matrix, label_vector) tuples
print(protein_inputs)

[(tensor([[ 0.1302,  0.1632,  0.2888,  ...,  0.1957, -0.1322, -0.5956],
        [ 0.2530, -0.3770,  0.1877,  ...,  0.0174, -0.0739, -0.1022],
        [-0.1744, -0.0730,  0.2751,  ..., -0.1269, -0.0565, -0.3023],
        ...,
        [ 0.3078,  0.2568, -0.2026,  ..., -0.0545, -0.0945, -0.1158],
        [-0.2438, -0.1002, -0.4044,  ...,  0.1521,  0.2968, -0.3561],
        [ 0.1615,  0.3238,  0.1479,  ..., -0.3572, -0.0415, -0.3630]]), tensor([2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 0, 0, 0, 0, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2,
        2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 0, 0, 0, 0, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2,
        1, 2, 1, 1, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 2,
        1, 1, 1, 1, 1, 1, 1, 1, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 2, 2, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2,
   