# Make ESM embs of entire reviewed Unirpot

In [46]:
# Enable autoreload extension
%load_ext autoreload
# Automatically reload all modules (except built-ins) before executing code
%autoreload 2

import sys
import os
os.chdir("/home/gdallagl/myworkdir/ESMSec")
print(f"Working directory successfully changed to: {os.getcwd()}")

import utils.my_functions as mf
import utils.models as my_models
import utils.dataset as my_dataset
import utils.embeddings_functions as my_embs
import utils.scanning as my_scanning

import torch
import torch.nn as nn
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tqdm import tqdm
import seaborn as sns

print(torch.__version__)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Working directory successfully changed to: /home/gdallagl/myworkdir/ESMSec
2.5.0+cu121


In [None]:
# Configuration / hyperparameters
config = {
    "SEED": 42,             # Random seed

    "PROTEIN_MAX_LENGTH": 1000, # Max protein length (for ESM2)
    "PRETRAIN_ESM_CHECKPOINT_NAME": "facebook/esm2_t33_650M_UR50D",#"facebook/esm2_t33_650M_UR50D", #"facebook/esm2_t12_35M_UR50D", #"facebook/esm2_t6_8M_UR50D", # ESM2 model name --> if nto isntalled, autocally dowlaoded
    "PRETRAIN_ESM_CACHE_DIR": "/home/gdallagl/myworkdir/data/esm2-models", # ESM2 model cache dir
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu", # Device to use (cuda or cpu)
    "BATCH_SIZE": 128,

    "UNIPROT_PATH": "/home/gdallagl/myworkdir/ESMSec/data/secreted/dataset_secreted.csv", #'/home/gdallagl/myworkdir/ESMSec/data/cell_cycle/datasets/cc_dataset_final.csv', #"/home/gdallagl/myworkdir/ESMSec/data/cell_cycle/datasets/nucleolus_final_dataset.csv", #"/home/gdallagl/myworkdir/ESMSec/data/cell_cycle/cell-cycle-dataset_2:3.csv",  #"/home/gdallagl/myworkdir/ESMSec/data/cell_cycle/only-guaranteed_cell-cycle-dataset_2:3.csv", #"/home/gdallagl/myworkdir/ESMSec/data/cell_cycle/cell-cycle-dataset_2:3.csv", 
    "TYPE_EMB_FOR_CLASSIFICATION": "cls", #"concat(agg_mean, agg_max)", #"contextualized_embs", #"concat(agg_mean, agg_max)",cls, stacked_linearized_all

    "UNIPROT_PATH": "/home/gdallagl/myworkdir/ESMSec/data/UniProt/human_proteome.tsv"
}
config["PRECOMPUTED_EMBS_PATH"] = os.path.join(
    f"/home/gdallagl/myworkdir/ESMSec/data/cell_cycle/precomputed_embs",
    f"entire_reviewed_uniprot_{config['PRETRAIN_ESM_CHECKPOINT_NAME'].replace('/', '-')}_"
    f"{config['TYPE_EMB_FOR_CLASSIFICATION']}.pt"
)
print(config["PRECOMPUTED_EMBS_PATH"])

# Initializations
random.seed(config["SEED"])
np.random.seed(config["SEED"])
torch.manual_seed(config["SEED"])
torch.backends.cudnn.benchmark = True 

/home/gdallagl/myworkdir/ESMSec/data/cell_cycle/precomputed_embs/entire_reviewed_uniprot_facebook-esm2_t6_8M_UR50D_cls.pt


# Instantiate ESM model

In [48]:
# Load pre-trained ESM model
esm_model = AutoModel.from_pretrained(config["PRETRAIN_ESM_CHECKPOINT_NAME"],  cache_dir=config["PRETRAIN_ESM_CACHE_DIR"]).to(config["DEVICE"])
# Checj whcih model has been moded by AutoModel.from_pretrained()
print("\nESM model type", type(esm_model), "\n")

# Load relative tokenizer
tokenizer = AutoTokenizer.from_pretrained(config["PRETRAIN_ESM_CHECKPOINT_NAME"])

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



ESM model type <class 'transformers.models.esm.modeling_esm.EsmModel'> 



# Load dataset and Tokenize

In [53]:
data = pd.read_csv(config["UNIPROT_PATH"], sep="\t")

# only reviewd
data = data[data.Reviewed == "reviewed"]

# Remae
data = data[["Entry", "Sequence"]].rename(columns={"Entry": "protein", "Sequence": "sequence"})

# For Later
num_samples = data.shape[0]

# ATTENTION:trucnat seqq to enforce max conetxt of EMS
data["truncated_sequence"] = data.sequence.apply(my_dataset.truncate_sequence)

# tokenize truncated seqeunces
    # ATTENTION: they a
print("Tokenizing...")
encoded = tokenizer(
    list(data.truncated_sequence),
    padding='max_length',
    max_length=config["PROTEIN_MAX_LENGTH"],
    truncation=True,
    return_tensors="pt"
)
# add to dict
input_ids_tensor = encoded["input_ids"]          # shape: (N, L)
attention_mask_tensor = encoded["attention_mask"]
# add to df
data["input_ids"] = [tensor for tensor in encoded["input_ids"]]
data["attention_mask"] = [tensor for tensor in encoded["attention_mask"]]

#####################

# save all information needed to tothe model
cache_data = {
    'protein': list(data.protein),
    'sequence': list(data.sequence),
    'truncated_sequence': list(data.truncated_sequence),
    'input_ids': input_ids_tensor,
    'attention_mask': attention_mask_tensor,
    'embedding': torch.zeros((num_samples, 1), dtype=torch.float32), # FAKE
    'label': torch.zeros((num_samples, 1), dtype=torch.float32), # FAKE
    'set': torch.zeros((num_samples, 1), dtype=torch.float32), # FAKE
}


Tokenizing...


# Create embs


In [54]:
esm_model.eval()

#Cr eate dataloader for batched processing
dataloader = my_dataset.create_dataloader(
    cache_data,
    batch_size=config.get("BATCH_SIZE", 32),
    shuffle=False
)

embeddings_list = []
protein_names = []

for batch in tqdm(dataloader, desc="Processing protein batches"):

    batch_input_ids = batch["input_ids"].to(config["DEVICE"])
    batch_attention_mask = batch["attention_mask"].to(config["DEVICE"])
    batch_proteins = batch["name"]

    with torch.no_grad():
        outputs_esm = esm_model(
            input_ids=batch_input_ids, 
            attention_mask=batch_attention_mask, 
            return_dict=True
        )

        # Get spefici embs from geena contextualised embs
        batch_embeddings = my_models.get_embs_from_context_embs( 
                                                    context_embs_esm = outputs_esm.last_hidden_state,
                                                    attention_mask = batch_attention_mask,
                                                    type_embs = config["TYPE_EMB_FOR_CLASSIFICATION"],
                                                    exclude_cls=True)

        # Store batch results
        protein_names.extend(batch_proteins)
        embeddings_list.append(batch_embeddings.detach().cpu())

print("Saving embeddings to fast PyTorch cache...")
torch.save(
    {
        "protein": protein_names,
        "embedding": torch.cat(embeddings_list, dim=0), #[e.cpu() for e in embeddings_list]  # list of tensors
    }, 
    config["PRECOMPUTED_EMBS_PATH"])



Processing protein batches:   0%|          | 0/160 [00:00<?, ?it/s]

Processing protein batches:   9%|▉         | 14/160 [00:26<04:37,  1.90s/it]


KeyboardInterrupt: 

# How to Load Embeddings in Downstream Scripts

Save embs in Torch format. 

Save a csv with Protein | index in array.

In [None]:
print(f"Reading back from file: {config['PRECOMPUTED_EMBS_PATH']}...")

# Load the precomputed embeddings
emb_dict_precomputed = torch.load(config["PRECOMPUTED_EMBS_PATH"], weights_only=False)
all_proteins = emb_dict_precomputed["protein"]  # List of protein names
all_embeddings = emb_dict_precomputed["embedding"]  # Tensor: (N_all, emb_dim)

print(f"✓ Loaded {len(all_proteins)} embeddings (dim: {all_embeddings.shape[1]})")

# Create a mapping for fast lookup: protein_name -> index
protein_to_idx = {protein: idx for idx, protein in enumerate(all_proteins)}

# Extract embeddings for proteins in cache_data in the correct order
subset_embeddings = []
missing_proteins = []

for p in cache_data["protein"]:
    if p in protein_to_idx:
        idx = protein_to_idx[p]
        subset_embeddings.append(all_embeddings[idx])
    else:
        missing_proteins.append(p)
        print(f"⚠ Protein {p} not found in precomputed embeddings.")

# Stack into a single tensor
if subset_embeddings:
    subset_embeddings_tensor = torch.stack(subset_embeddings)  # Shape: (N_subset, emb_dim)

    # Update cache_data with the embeddings
    cache_data["embedding"] = subset_embeddings_tensor

    print(f"✓ Loaded {len(subset_embeddings)} embeddings for your dataset")
    print(f"  Shape: {subset_embeddings_tensor.shape}")

if missing_proteins:
    print(f"⚠ Total missing: {len(missing_proteins)} proteins")

Reading back from file: /home/gdallagl/myworkdir/ESMSec/data/cell_cycle/precomputed_embs/entire_reviewed_uniprot_facebook-esm2_t6_8M_UR50D_cls.pt...
✓ Loaded 1000 embeddings (dim: 320)
✓ Loaded 1000 embeddings for your dataset
  Shape: torch.Size([1000, 320])
