In [None]:
import torch
import esm
from Bio import SeqIO
import numpy as np

## ESM2 variant embeddings

In [None]:
import torch
import esm
from Bio import SeqIO
import os


# Define models and FASTA files

models = {
    "esm2_t6_8M_UR50D": 6,
    "esm2_t12_35M_UR50D": 12,
    "esm2_t30_150M_UR50D": 30,
    "esm2_t33_650M_UR50D": 33,
    "esm2_t36_3B_UR50D": 36,
}

fasta_files = [
    "TF_Training.txt",
    "NTF_training.txt",
    "TF_Ind.txt",
    "NTF_Ind.txt",
]

In [None]:
%%time

# Function to generate embeddings

def generate_embeddings(model_name, layer, fasta_file):
    print(f"\n🔹 Processing {fasta_file} with {model_name} ...")

    # Load model
    model, alphabet = esm.pretrained.load_model_and_alphabet(model_name)
    batch_converter = alphabet.get_batch_converter()
    model.eval()

    # Read sequences
    sequences = [(record.id, str(record.seq)) for record in SeqIO.parse(fasta_file, "fasta")]
    #mean_embeddings = {}
    cls_embeddings = {}

    batch_size = 8
    with torch.no_grad():
        for i in range(0, len(sequences), batch_size):
            batch_seqs = sequences[i : i + batch_size]
            batch_labels, batch_strs, batch_tokens = batch_converter(batch_seqs)

            if torch.cuda.is_available():
                batch_tokens = batch_tokens.cuda()
                model = model.cuda()

            # Extract embeddings
            results = model(batch_tokens, repr_layers=[layer])
            token_embeddings = results["representations"][layer]

            # Save both mean-pooled and CLS
            for j, (seq_id, seq) in enumerate(batch_seqs):
                seq_len = (batch_tokens[j] != alphabet.padding_idx).sum()

                cls_emb = token_embeddings[j, 0].cpu()  # CLS
                mean_emb = token_embeddings[j, 1:seq_len-1].mean(0).cpu()  # mean over residues

                mean_embeddings[seq_id] = mean_emb
                cls_embeddings[seq_id] = cls_emb

    # Save files separately
    base_name = os.path.splitext(os.path.basename(fasta_file))[0]

    mean_file = f"{model_name}_{base_name}_mean.pt"
    cls_file = f"{model_name}_{base_name}_cls.pt"

    torch.save(mean_embeddings, mean_file)
    torch.save(cls_embeddings, cls_file)

    
    # Verify saved files
    
    loaded_mean = torch.load(mean_file)
    loaded_cls = torch.load(cls_file)
    first_id = list(loaded_mean.keys())[0]

    print(f"✅ Saved {len(loaded_mean)} mean embeddings to {mean_file} (shape: {loaded_mean[first_id].shape})")
    print(f"✅ Saved {len(loaded_cls)} CLS embeddings to {cls_file} (shape: {loaded_cls[first_id].shape})")



# Run for all models and FASTAs

for model_name, layer in models.items():
    for fasta_file in fasta_files:
        generate_embeddings(model_name, layer, fasta_file)


## Ankh-base model embeddings

In [None]:
# import torch
# from transformers import AutoTokenizer, AutoModel, BertTokenizer, BertModel
# from Bio import SeqIO

# 
# # Config
# 
# fasta_files = ["TF_training.txt", "NTF_training.txt", "TF_ind.txt", "NTF_ind.txt"]

# models = {
#     "protbert": {
#         "name": "Rostlab/prot_bert",
#         "tokenizer": lambda: BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False),
#         "model": lambda: BertModel.from_pretrained("Rostlab/prot_bert"),
#         "hidden_size": 1024,
#     },
#     "ankh_base": {
#         "name": "ankh-models/ankh-base",
#         "tokenizer": lambda: AutoTokenizer.from_pretrained("ElnaggarLab/ankh-base"),
#         "model": lambda: AutoModel.from_pretrained("ElnaggarLab/ankh-base"),
#         "hidden_size": 768,
#     }
# }

# 
# # Embedding extraction function
# 
# def extract_embeddings(model_key, fasta_file):
#     print(f"▶ Processing {model_key} on {fasta_file}")

#     tokenizer = models[model_key]["tokenizer"]()
#     model = models[model_key]["model"]()
#     model.eval()
#     if torch.cuda.is_available():
#         model = model.cuda()

#     sequences = [(record.id, str(record.seq)) for record in SeqIO.parse(fasta_file, "fasta")]

#     mean_embeddings = {}
#     #cls_embeddings = {}

#     with torch.no_grad():
#         for seq_id, seq in sequences:
#             seq_spaced = " ".join(list(seq))  # space-separate amino acids
#             inputs = tokenizer(seq_spaced, return_tensors="pt", add_special_tokens=True)
#             if torch.cuda.is_available():
#                 inputs = {k: v.cuda() for k, v in inputs.items()}

#             outputs = model(**inputs)
#             last_hidden = outputs.last_hidden_state  # (1, L, hidden_size)

#             # [CLS] embedding = first token
#             #cls_emb = last_hidden[:, 0, :].squeeze().cpu()
#             # Mean embedding (excluding CLS + SEP)
#             mean_emb = last_hidden[:, 1:-1, :].mean(1).squeeze().cpu()

#             #cls_embeddings[seq_id] = cls_emb
#             mean_embeddings[seq_id] = mean_emb

#     # Save
#     base_name = fasta_file.replace(".txt", "")
#     torch.save(mean_embeddings, f"{model_key}_{base_name}_mean.pt")
#     #torch.save(cls_embeddings, f"{model_key}_{base_name}_cls.pt")

#     print(f"✅ Saved: {model_key}_{base_name}_mean.pt")


# 
# # Run extraction
# 
# for model_key in models.keys():
#     for fasta_file in fasta_files:
#         extract_embeddings(model_key, fasta_file)


In [None]:
import torch
from transformers import AutoTokenizer, AutoModel
from Bio import SeqIO


# Config

fasta_files = ["TF_training.txt", "NTF_training.txt", "TF_ind.txt", "NTF_ind.txt"]

model_name = "ElnaggarLab/ankh-base"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
encoder = model.encoder  # use encoder only (avoid decoder requirement)
encoder.eval()
if torch.cuda.is_available():
    encoder = encoder.cuda()


# Embedding extraction function

def extract_embeddings(fasta_file):
    print(f"▶ Processing {fasta_file}")

    sequences = [(record.id, str(record.seq)) for record in SeqIO.parse(fasta_file, "fasta")]

    mean_embeddings = {}
    #cls_embeddings = {}

    with torch.no_grad():
        for seq_id, seq in sequences:
            seq_spaced = " ".join(list(seq))  # space-separate amino acids
            inputs = tokenizer(seq_spaced, return_tensors="pt", add_special_tokens=True)
            if torch.cuda.is_available():
                inputs = {k: v.cuda() for k, v in inputs.items()}

            outputs = encoder(**inputs)
            last_hidden = outputs.last_hidden_state  # (1, L, hidden_size)

            # [CLS] embedding = first token
            cls_emb = last_hidden[:, 0, :].squeeze().cpu()
            # Mean embedding (excluding CLS + SEP)
            mean_emb = last_hidden[:, 1:-1, :].mean(1).squeeze().cpu()

            cls_embeddings[seq_id] = cls_emb
            mean_embeddings[seq_id] = mean_emb

    # Save
    base_name = fasta_file.replace(".txt", "")
    torch.save(mean_embeddings, f"ankh_base_{base_name}_mean.pt")
    torch.save(cls_embeddings, f"ankh_base_{base_name}_cls.pt")

    print(f"✅ Saved: ankh_base_{base_name}_mean.pt, ankh_base_{base_name}_cls.pt")



# Run for all fasta files

for fasta_file in fasta_files:
    extract_embeddings(fasta_file)


## Protbert

In [None]:
import torch
from transformers import AutoTokenizer, AutoModel
from Bio import SeqIO


# Config

fasta_files = ["TF_training.txt", "NTF_training.txt", "TF_ind.txt", "NTF_ind.txt"]

# ProtBert model
model_name = "Rostlab/prot_bert"
tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
model = AutoModel.from_pretrained(model_name)
model.eval()
if torch.cuda.is_available():
    model = model.cuda()


# Embedding extraction function

def extract_embeddings(fasta_file):
    print(f"▶ Processing {fasta_file}")

    sequences = [(record.id, str(record.seq)) for record in SeqIO.parse(fasta_file, "fasta")]

    mean_embeddings = {}
    cls_embeddings = {}

    with torch.no_grad():
        for seq_id, seq in sequences:
            seq_spaced = " ".join(list(seq))  # ProtBert expects space-separated AAs
            inputs = tokenizer(seq_spaced, return_tensors="pt", add_special_tokens=True)
            if torch.cuda.is_available():
                inputs = {k: v.cuda() for k, v in inputs.items()}

            outputs = model(**inputs)
            last_hidden = outputs.last_hidden_state  # (1, L, hidden_size)

            # [CLS] embedding = first token
            cls_emb = last_hidden[:, 0, :].squeeze().cpu()
            # Mean embedding (excluding CLS + SEP)
            mean_emb = last_hidden[:, 1:-1, :].mean(1).squeeze().cpu()

            cls_embeddings[seq_id] = cls_emb
            mean_embeddings[seq_id] = mean_emb

    # Save
    base_name = fasta_file.replace(".txt", "")
    torch.save(mean_embeddings, f"protbert_{base_name}_mean.pt")
    torch.save(cls_embeddings, f"protbert_{base_name}_cls.pt")

    print(f"✅ Saved: protbert_{base_name}_mean.pt, protbert_{base_name}_cls.pt")



# Run for all fasta files

for fasta_file in fasta_files:
    extract_embeddings(fasta_file)