In [1]:
import pandas as pd
import numpy as np
import os
from sklearn.model_selection import train_test_split
import torch
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import pickle
import gc

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Get the absolute path of the current notebook
notebook_path = os.path.abspath('')

# Navigate to the project root (CS182-Final-Project)
project_root = os.path.dirname(notebook_path)
os.chdir(project_root)

# Now you can load the data using relative path from project root
print("Loading full dataset...")
full_data = pd.read_pickle('./data/benchmarkingGS_v1-0_similarityMeasure_sequence_v3-1.pkl')
display(full_data.head())


# Keep only the required columns
columns_to_keep = ['uniprotID_A', 'uniprotID_B', 'isInteraction', 'trainTest', 'sequence_A', 'sequence_B']
data = full_data[columns_to_keep]
display(data.head())

# Calculate dataset statistics
print("\n--- Full Dataset Statistics ---")

Loading full dataset...


Unnamed: 0,uniprotID_A,uniprotID_B,isInteraction,trainTest,RNAseqHPA,tissueHPA,tissueCellHPA,subcellularLocationHPA,bioProcessUniprot,cellCompUniprot,molFuncUniprot,domainUniprot,motifUniprot,Bgee,sequence_A,sequence_B
0,P28223,P41595,1,test2,0.160188,-0.44993,-0.060381,,0.400892,0.404061,0.680414,0.0,0.790569,0.422078,MDILCEENTSLSSTTNSLMQLNDDTRLYSNDFNSGEANTSDAFNWT...,MALSYRVSELQSTIPEHILQSTFVHVISSNWSGLQTESIPEEMKQI...
1,O00161,P56962,1,train,0.825131,0.85169,0.67588,0.0,0.190693,0.2,0.353553,0.0,0.0,0.922975,MDNLSSEEIQQRAHQITDESLESTRRILGLAIESQDAGIKTITMLD...,MSEDEEKVKLRRLEPAIQKFIKIVIPTDLERLRKHQINIEKYQRCR...
2,P82979,Q01081,1,train,0.93079,0.954869,0.911887,0.0,0.547723,0.365148,0.408248,0.0,0.0,,MATETVELHKLKLAELKQECLARGLETKGIKQDLIHRLQAYLEEHA...,MAEYLASIFGTEKDKVNCSFYFKIGACRHGDRCSRLHNKPTFSQTI...
3,O60678,Q14524,1,train,0.219384,,,,0.0,0.0,0.0,0.0,0.0,,MCSLASGATGGRGAVENEEDLPELSDSGDEAAWEDEDDADLPHGKQ...,MANFLLPRGTSSFRRFTRESLAAIEKRMAEKQARGSTTLQESREGL...
4,P10275,Q15648,1,train,0.500558,-0.311704,-0.371061,0.0,0.169811,0.375,0.157895,0.0,0.0,,MEVQLGLGRVYPRPPSKTYRGAFQNLFQSVREVIQNPGPRHPEAAS...,MKAQGETEESEKLSKMSSLLERLHAKFNQNRPWSETIKLVRQVMEK...


Unnamed: 0,uniprotID_A,uniprotID_B,isInteraction,trainTest,sequence_A,sequence_B
0,P28223,P41595,1,test2,MDILCEENTSLSSTTNSLMQLNDDTRLYSNDFNSGEANTSDAFNWT...,MALSYRVSELQSTIPEHILQSTFVHVISSNWSGLQTESIPEEMKQI...
1,O00161,P56962,1,train,MDNLSSEEIQQRAHQITDESLESTRRILGLAIESQDAGIKTITMLD...,MSEDEEKVKLRRLEPAIQKFIKIVIPTDLERLRKHQINIEKYQRCR...
2,P82979,Q01081,1,train,MATETVELHKLKLAELKQECLARGLETKGIKQDLIHRLQAYLEEHA...,MAEYLASIFGTEKDKVNCSFYFKIGACRHGDRCSRLHNKPTFSQTI...
3,O60678,Q14524,1,train,MCSLASGATGGRGAVENEEDLPELSDSGDEAAWEDEDDADLPHGKQ...,MANFLLPRGTSSFRRFTRESLAAIEKRMAEKQARGSTTLQESREGL...
4,P10275,Q15648,1,train,MEVQLGLGRVYPRPPSKTYRGAFQNLFQSVREVIQNPGPRHPEAAS...,MKAQGETEESEKLSKMSSLLERLHAKFNQNRPWSETIKLVRQVMEK...



--- Full Dataset Statistics ---


In [3]:
def get_unique_proteins(df):
    unique_proteins = {}
    
    for _, row in df.iterrows():
        protein_A_id = row['uniprotID_A']
        protein_A_seq = row['sequence_A']
        if protein_A_id not in unique_proteins:
            unique_proteins[protein_A_id] = protein_A_seq
            
        protein_B_id = row['uniprotID_B']
        protein_B_seq = row['sequence_B']
        if protein_B_id not in unique_proteins:
            unique_proteins[protein_B_id] = protein_B_seq
            
    return unique_proteins


unique_proteins_dict = get_unique_proteins(data)

# Print how many unique proteins were found
print(f"Found {len(unique_proteins_dict)} unique proteins")

# Save to pickle file
os.makedirs('data/full_dataset', exist_ok=True)
with open('data/full_dataset/unique_proteins.pkl', 'wb') as f:
    pickle.dump(unique_proteins_dict, f)

print("Unique proteins saved to 'data/full_dataset/unique_proteins.pkl'")



Found 12026 unique proteins
Unique proteins saved to 'data/full_dataset/unique_proteins.pkl'


## Encode proteins with ESM C

In [4]:
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig

device = "cuda" if torch.cuda.is_available() else "cpu"

if torch.cuda.is_available():
    target_gpu_index_for_fraction = 0

    print(f"Attempting to set memory fraction to 0.7 for GPU {target_gpu_index_for_fraction}")
    torch.cuda.set_per_process_memory_fraction(0.7, device=target_gpu_index_for_fraction)
    print(f"Call to set_per_process_memory_fraction for GPU {target_gpu_index_for_fraction} completed.")

model = ESMC.from_pretrained("esmc_300m").to(device)
print(f"[ESM-C] Loaded locally on {device}")
model_type = "local"

Attempting to set memory fraction to 0.7 for GPU 0
Call to set_per_process_memory_fraction for GPU 0 completed.


Fetching 4 files: 100%|██████████| 4/4 [00:00<00:00, 101067.57it/s]


[ESM-C] Loaded locally on cuda


In [5]:
def get_protein_embedding(sequence):
    """
    Get protein embedding for a given sequence using the loaded ESM model.
    Optimized with torch.no_grad() for inference.
    """
    with torch.no_grad():  # Ensures no gradients are computed for model operations
        protein = ESMProtein(sequence=sequence)
        protein_tensor = model.encode(protein)  # Model inference step
        logits_output = model.logits(           # Model inference step
            protein_tensor,
            LogitsConfig(sequence=False, return_embeddings=True)
        )
        # Get the per-protein representation by mean-pooling across sequence length
        embedding = logits_output.embeddings
        # Remove the batch dimension (first dimension)
        embedding = embedding.squeeze(0)
    return embedding.cpu().to(dtype=torch.float16)


In [6]:
# Cell 2: Efficient protein embedding function
def embed_proteins_efficiently(proteins_dict, output_path='data/full_dataset/embeddings', batch_size=100):
    """
    Efficiently embeds proteins and saves with minimal storage requirements.
    
    Args:
        proteins_dict: Dictionary of protein IDs and sequences
        output_path: Directory to save embeddings
        batch_size: Number of proteins to process in each batch
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_path, exist_ok=True)
    
    # Track proteins processed
    total_proteins = len(proteins_dict)
    proteins_processed = 0
    
    # Process proteins in batches to manage memory
    protein_ids = list(proteins_dict.keys())
    
    # Set up storage stats tracking
    total_bytes = 0
    
    # Process in batches
    for i in range(0, total_proteins, batch_size):
        batch_ids = protein_ids[i:i+batch_size]
        batch_embeddings = {}
        
        # Process each protein in the batch
        for protein_id in tqdm(batch_ids, desc=f"Processing batch {i//batch_size + 1}/{(total_proteins-1)//batch_size + 1}"):
            sequence = proteins_dict[protein_id]
            
            # Get embedding using your function that returns float16 tensors
            embedding = get_protein_embedding(sequence)
            
            # Store in batch dictionary
            batch_embeddings[protein_id] = embedding
            proteins_processed += 1
        
        # Save this batch with compression
        batch_filename = os.path.join(output_path, f"embeddings_batch_{i//batch_size + 1}.pt")
        torch.save(batch_embeddings, batch_filename, _use_new_zipfile_serialization=True)
        
        # Get file size and add to total
        batch_size_bytes = os.path.getsize(batch_filename)
        total_bytes += batch_size_bytes
        
        print(f"Saved batch {i//batch_size + 1} ({len(batch_ids)} proteins): {batch_size_bytes/1024/1024:.2f} MB")
        
        # Free up memory
        del batch_embeddings
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
    # Save a manifest file with protein IDs and their batch locations
    manifest = {protein_id: f"embeddings_batch_{i//batch_size + 1}.pt" 
                for i, protein_id in enumerate(protein_ids)}
    
    with open(os.path.join(output_path, "embeddings_manifest.pkl"), "wb") as f:
        pickle.dump(manifest, f)
    
    print(f"\nEmbedding complete: {proteins_processed} proteins processed")
    print(f"Total storage used: {total_bytes/1024/1024/1024:.2f} GB")
    print(f"Average size per protein: {total_bytes/proteins_processed/1024:.2f} KB")

In [7]:
# Cell 3: Loading function
def load_protein_embedding(protein_id, embeddings_dir='data/full_dataset/embeddings'):
    """
    Efficiently load a single protein embedding when needed.
    """
    manifest_path = os.path.join(embeddings_dir, "embeddings_manifest.pkl")
    with open(manifest_path, "rb") as f:
        manifest = pickle.load(f)
    
    if protein_id not in manifest:
        raise KeyError(f"Protein ID {protein_id} not found in embeddings manifest")
    
    batch_file = manifest[protein_id]
    batch_path = os.path.join(embeddings_dir, batch_file)
    
    # Load only the specific batch file
    batch_data = torch.load(batch_path)
    
    # Return the specific embedding
    return batch_data[protein_id]

In [16]:
# Cell 4: Low memory implementation
def embed_proteins_low_memory(proteins_dict, output_path='data/full_dataset/embeddings', batch_size=10, 
                             start_batch=0, max_proteins_per_run=None):
    """
    Low-memory implementation to embed proteins and save them efficiently.
    Can resume from a specific batch and limits proteins processed in a single run.
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_path, exist_ok=True)
    
    # Track proteins processed
    total_proteins = len(proteins_dict)
    proteins_processed = 0
    
    # Process proteins in batches to manage memory
    protein_ids = list(proteins_dict.keys())
    
    # Load existing manifest if starting from a later batch
    manifest_path = os.path.join(output_path, "embeddings_manifest.pkl")
    if start_batch > 0 and os.path.exists(manifest_path):
        with open(manifest_path, "rb") as f:
            manifest = pickle.load(f)
    else:
        manifest = {}
    
    # Set up storage stats tracking
    total_bytes = 0
    
    # Apply limit if specified
    if max_proteins_per_run is not None:
        end_batch = min(start_batch + (max_proteins_per_run // batch_size) + 1, 
                       (total_proteins // batch_size) + 1)
    else:
        end_batch = (total_proteins // batch_size) + 1
    
    print(f"Will process from batch {start_batch+1} to {end_batch} (of {(total_proteins-1)//batch_size + 1} total)")
    
    # Process in batches
    total_batches = (total_proteins - 1) // batch_size + 1
    for batch_idx in range(start_batch, end_batch):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, total_proteins)
        batch_ids = protein_ids[start_idx:end_idx]
        
        if not batch_ids:
            continue
            
        batch_filename = os.path.join(output_path, f"embeddings_batch_{batch_idx + 1}.pt")
        
        # Skip if batch already exists (for resuming)
        if os.path.exists(batch_filename):
            print(f"Batch {batch_idx + 1} already exists, skipping...")
            # Update manifest entries for this batch
            for i, protein_id in enumerate(batch_ids):
                manifest[protein_id] = f"embeddings_batch_{batch_idx + 1}.pt"
            continue
            
        batch_embeddings = {}
        
        # Process each protein in the batch individually
        for protein_id in tqdm(batch_ids, desc=f"Processing batch {batch_idx + 1}/{total_batches}"):
            sequence = proteins_dict[protein_id]
        
        # Save this batch if we processed any proteins
        if batch_embeddings:
            torch.save(batch_embeddings, batch_filename, _use_new_zipfile_serialization=True)
            
            # Update manifest
            for pid in batch_embeddings.keys():
                manifest[pid] = f"embeddings_batch_{batch_idx + 1}.pt"
                
            # Get file size and add to total
            batch_size_bytes = os.path.getsize(batch_filename)
            total_bytes += batch_size_bytes
            
            print(f"Saved batch {batch_idx + 1} ({len(batch_embeddings)} proteins): {batch_size_bytes/1024/1024:.2f} MB")
        
        # Free up memory
        del batch_embeddings
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        # Save manifest after each batch for safety
        with open(manifest_path, "wb") as f:
            pickle.dump(manifest, f)
    
    # Save final manifest file
    with open(manifest_path, "wb") as f:
        pickle.dump(manifest, f)
    
    print(f"\nEmbedding complete: {proteins_processed} proteins processed in this run")
    print(f"Total storage used in this run: {total_bytes/1024/1024/1024:.2f} GB")
    if proteins_processed > 0:
        print(f"Average size per protein: {total_bytes/proteins_processed/1024:.2f} KB")
    print(f"Total proteins in manifest: {len(manifest)}")

In [15]:
# Cell 5: Run embedding process
# Load unique proteins from the full dataset
with open('data/full_dataset/unique_proteins.pkl', 'rb') as f:
    unique_proteins = pickle.load(f)

# Set output directory
output_dir = 'data/full_dataset/embeddings'

# Run the embedding process
embed_proteins_low_memory(
    unique_proteins, 
    output_dir, 
    batch_size=100,
    start_batch=110,
    max_proteins_per_run=1000
)

Will process from batch 111 to 121 (of 121 total)


Processing batch 111/121:  10%|█         | 10/100 [00:02<00:19,  4.70it/s]

Truncating sequence for Q96MT7 from 1854 to 1500


Processing batch 111/121:  31%|███       | 31/100 [00:06<00:14,  4.67it/s]

Truncating sequence for P13942 from 1736 to 1500


Processing batch 111/121:  51%|█████     | 51/100 [00:10<00:10,  4.68it/s]

Truncating sequence for P15924 from 2871 to 1500


Processing batch 111/121:  91%|█████████ | 91/100 [00:18<00:01,  4.54it/s]

Truncating sequence for Q9UPQ9 from 1833 to 1500


Processing batch 111/121:  95%|█████████▌| 95/100 [00:19<00:01,  4.48it/s]

Truncating sequence for Q8TEQ6 from 1508 to 1500


Processing batch 111/121: 100%|██████████| 100/100 [00:20<00:00,  4.83it/s]


Saved batch 111 (100 proteins): 100.43 MB


Processing batch 112/121:   4%|▍         | 4/100 [00:00<00:20,  4.70it/s]

Truncating sequence for Q7RTP6 from 2002 to 1500


Processing batch 112/121:  10%|█         | 10/100 [00:02<00:20,  4.36it/s]

Truncating sequence for Q9BZ29 from 2069 to 1500


Processing batch 112/121:  15%|█▌        | 15/100 [00:03<00:19,  4.33it/s]

Truncating sequence for Q02224 from 2701 to 1500


Processing batch 112/121:  27%|██▋       | 27/100 [00:05<00:16,  4.48it/s]

Truncating sequence for Q5VT06 from 3117 to 1500


Processing batch 112/121:  41%|████      | 41/100 [00:08<00:13,  4.35it/s]

Truncating sequence for Q14517 from 4588 to 1500


Processing batch 112/121:  43%|████▎     | 43/100 [00:09<00:13,  4.35it/s]

Truncating sequence for Q96RW7 from 5635 to 1500


Processing batch 112/121:  50%|█████     | 50/100 [00:10<00:11,  4.46it/s]

Truncating sequence for P51587 from 3418 to 1500


Processing batch 112/121:  64%|██████▍   | 64/100 [00:13<00:08,  4.42it/s]

Truncating sequence for Q9NU22 from 5596 to 1500


Processing batch 112/121:  66%|██████▌   | 66/100 [00:14<00:07,  4.26it/s]

Truncating sequence for Q8NG31 from 2342 to 1500


Processing batch 112/121:  71%|███████   | 71/100 [00:15<00:06,  4.48it/s]

Truncating sequence for Q9NRD8 from 1548 to 1500


Processing batch 112/121:  87%|████████▋ | 87/100 [00:18<00:02,  4.46it/s]

Truncating sequence for Q96T58 from 3664 to 1500


Processing batch 112/121: 100%|██████████| 100/100 [00:21<00:00,  4.64it/s]

Truncating sequence for Q14573 from 2671 to 1500





Saved batch 112 (100 proteins): 134.90 MB


Processing batch 113/121:   2%|▏         | 2/100 [00:00<00:24,  4.04it/s]

Truncating sequence for Q9Y6R7 from 5405 to 1500


Processing batch 113/121:   5%|▌         | 5/100 [00:01<00:21,  4.35it/s]

Truncating sequence for P35580 from 1976 to 1500


Processing batch 113/121:  15%|█▌        | 15/100 [00:03<00:18,  4.62it/s]

Truncating sequence for Q13219 from 1627 to 1500


Processing batch 113/121:  17%|█▋        | 17/100 [00:03<00:18,  4.43it/s]

Truncating sequence for P08519 from 4548 to 1500
Truncating sequence for Q86UK0 from 2595 to 1500


Processing batch 113/121:  22%|██▏       | 22/100 [00:04<00:17,  4.39it/s]

Truncating sequence for P18583 from 2426 to 1500


Processing batch 113/121:  24%|██▍       | 24/100 [00:05<00:17,  4.37it/s]

Truncating sequence for O60346 from 1717 to 1500


Processing batch 113/121:  37%|███▋      | 37/100 [00:07<00:13,  4.61it/s]

Truncating sequence for Q99250 from 2005 to 1500
Truncating sequence for Q8IWU2 from 1503 to 1500


Processing batch 113/121:  46%|████▌     | 46/100 [00:09<00:12,  4.46it/s]

Truncating sequence for O43166 from 1804 to 1500


Processing batch 113/121:  60%|██████    | 60/100 [00:12<00:08,  4.45it/s]

Truncating sequence for Q562E7 from 1941 to 1500


Processing batch 113/121:  77%|███████▋  | 77/100 [00:16<00:04,  4.68it/s]

Truncating sequence for O75417 from 2590 to 1500


Processing batch 113/121:  98%|█████████▊| 98/100 [00:20<00:00,  4.63it/s]

Truncating sequence for Q2M3C7 from 1700 to 1500


Processing batch 113/121: 100%|██████████| 100/100 [00:21<00:00,  4.75it/s]


Saved batch 113 (100 proteins): 111.72 MB


Processing batch 114/121:   7%|▋         | 7/100 [00:01<00:20,  4.53it/s]

Truncating sequence for Q14004 from 1512 to 1500


Processing batch 114/121:  12%|█▏        | 12/100 [00:02<00:19,  4.60it/s]

Truncating sequence for Q92508 from 2521 to 1500


Processing batch 114/121:  19%|█▉        | 19/100 [00:04<00:18,  4.47it/s]

Truncating sequence for P13611 from 3396 to 1500


Processing batch 114/121:  24%|██▍       | 24/100 [00:05<00:16,  4.51it/s]

Truncating sequence for Q9Y4G6 from 2542 to 1500


Processing batch 114/121:  33%|███▎      | 33/100 [00:06<00:14,  4.69it/s]

Truncating sequence for O60486 from 1568 to 1500


Processing batch 114/121:  74%|███████▍  | 74/100 [00:15<00:05,  4.37it/s]

Truncating sequence for P53675 from 1640 to 1500


Processing batch 114/121:  87%|████████▋ | 87/100 [00:18<00:02,  4.60it/s]

Truncating sequence for P52948 from 1817 to 1500


Processing batch 114/121: 100%|██████████| 100/100 [00:20<00:00,  4.81it/s]


Saved batch 114 (100 proteins): 107.16 MB


Processing batch 115/121:   0%|          | 0/100 [00:00<?, ?it/s]

Truncating sequence for Q68CP9 from 1835 to 1500


Processing batch 115/121:   1%|          | 1/100 [00:00<00:24,  4.04it/s]

Truncating sequence for P39880 from 1505 to 1500


Processing batch 115/121:  20%|██        | 20/100 [00:04<00:17,  4.49it/s]

Truncating sequence for Q8IWJ2 from 1684 to 1500


Processing batch 115/121:  25%|██▌       | 25/100 [00:05<00:16,  4.59it/s]

Truncating sequence for P39060 from 1754 to 1500


Processing batch 115/121:  28%|██▊       | 28/100 [00:06<00:16,  4.41it/s]

Truncating sequence for O00763 from 2458 to 1500


Processing batch 115/121:  60%|██████    | 60/100 [00:12<00:08,  4.51it/s]

Truncating sequence for P17927 from 2039 to 1500


Processing batch 115/121:  68%|██████▊   | 68/100 [00:14<00:07,  4.48it/s]

Truncating sequence for P49792 from 3224 to 1500


Processing batch 115/121: 100%|██████████| 100/100 [00:21<00:00,  4.71it/s]


Saved batch 115 (100 proteins): 127.70 MB


Processing batch 116/121:  10%|█         | 10/100 [00:02<00:20,  4.36it/s]

Truncating sequence for Q9Y2H9 from 1570 to 1500


Processing batch 116/121:  12%|█▏        | 12/100 [00:02<00:21,  4.12it/s]

Truncating sequence for O75691 from 2785 to 1500


Processing batch 116/121:  19%|█▉        | 19/100 [00:04<00:18,  4.41it/s]

Truncating sequence for Q9UQD0 from 1980 to 1500


Processing batch 116/121:  27%|██▋       | 27/100 [00:05<00:16,  4.49it/s]

Truncating sequence for Q15652 from 2540 to 1500


Processing batch 116/121:  30%|███       | 30/100 [00:06<00:16,  4.36it/s]

Truncating sequence for Q4AC94 from 2353 to 1500


Processing batch 116/121:  32%|███▏      | 32/100 [00:07<00:15,  4.31it/s]

Truncating sequence for O60673 from 3130 to 1500


Processing batch 116/121:  49%|████▉     | 49/100 [00:10<00:11,  4.42it/s]

Truncating sequence for Q14766 from 1721 to 1500


Processing batch 116/121:  51%|█████     | 51/100 [00:11<00:11,  4.21it/s]

Truncating sequence for P08922 from 2347 to 1500


Processing batch 116/121:  55%|█████▌    | 55/100 [00:12<00:10,  4.34it/s]

Truncating sequence for Q9Y6J0 from 2220 to 1500


Processing batch 116/121:  61%|██████    | 61/100 [00:13<00:08,  4.41it/s]

Truncating sequence for Q6ZNJ1 from 2754 to 1500


Processing batch 116/121:  75%|███████▌  | 75/100 [00:16<00:05,  4.67it/s]

Truncating sequence for Q9NQT8 from 1826 to 1500


Processing batch 116/121:  99%|█████████▉| 99/100 [00:21<00:00,  4.32it/s]

Truncating sequence for Q9UGU0 from 1960 to 1500


Processing batch 116/121: 100%|██████████| 100/100 [00:21<00:00,  4.59it/s]


Saved batch 116 (100 proteins): 127.29 MB


Processing batch 117/121:   6%|▌         | 6/100 [00:01<00:20,  4.56it/s]

Truncating sequence for Q9Y4A5 from 3859 to 1500


Processing batch 117/121:  10%|█         | 10/100 [00:02<00:19,  4.54it/s]

Truncating sequence for Q8N1I0 from 1966 to 1500


Processing batch 117/121:  31%|███       | 31/100 [00:06<00:14,  4.64it/s]

Truncating sequence for Q8IWI9 from 3065 to 1500


Processing batch 117/121:  78%|███████▊  | 78/100 [00:16<00:04,  4.60it/s]

Truncating sequence for Q96JM2 from 2506 to 1500


Processing batch 117/121: 100%|██████████| 100/100 [00:20<00:00,  4.77it/s]

Truncating sequence for P48681 from 1621 to 1500





Saved batch 117 (100 proteins): 106.49 MB


Processing batch 118/121:   3%|▎         | 3/100 [00:00<00:22,  4.30it/s]

Truncating sequence for Q7Z5J4 from 1906 to 1500


Processing batch 118/121:   7%|▋         | 7/100 [00:01<00:20,  4.46it/s]

Truncating sequence for P27708 from 2225 to 1500


Processing batch 118/121:  57%|█████▋    | 57/100 [00:11<00:09,  4.64it/s]

Truncating sequence for A6NHR9 from 2005 to 1500


Processing batch 118/121:  60%|██████    | 60/100 [00:12<00:08,  4.48it/s]

Truncating sequence for P23467 from 1997 to 1500


Processing batch 118/121:  64%|██████▍   | 64/100 [00:13<00:08,  4.45it/s]

Truncating sequence for Q93008 from 2554 to 1500


Processing batch 118/121:  69%|██████▉   | 69/100 [00:14<00:07,  4.31it/s]

Truncating sequence for P35658 from 2090 to 1500


Processing batch 118/121:  75%|███████▌  | 75/100 [00:15<00:05,  4.51it/s]

Truncating sequence for Q8N3C0 from 2202 to 1500


Processing batch 118/121:  82%|████████▏ | 82/100 [00:17<00:03,  4.77it/s]

Truncating sequence for Q5H9F3 from 1711 to 1500


Processing batch 118/121:  84%|████████▍ | 84/100 [00:17<00:03,  4.39it/s]

Truncating sequence for Q9Y2I1 from 1504 to 1500


Processing batch 118/121:  99%|█████████▉| 99/100 [00:20<00:00,  4.74it/s]

Truncating sequence for Q5VST9 from 7968 to 1500


Processing batch 118/121: 100%|██████████| 100/100 [00:20<00:00,  4.78it/s]


Saved batch 118 (100 proteins): 111.68 MB


Processing batch 119/121:  11%|█         | 11/100 [00:02<00:19,  4.61it/s]

Truncating sequence for Q8NFC6 from 3051 to 1500
Truncating sequence for Q96QB1 from 1528 to 1500


Processing batch 119/121:  14%|█▍        | 14/100 [00:03<00:20,  4.29it/s]

Truncating sequence for Q7Z4S6 from 1674 to 1500


Processing batch 119/121:  35%|███▌      | 35/100 [00:07<00:13,  4.71it/s]

Truncating sequence for Q92614 from 2054 to 1500


Processing batch 119/121:  39%|███▉      | 39/100 [00:08<00:13,  4.56it/s]

Truncating sequence for Q17RW2 from 1714 to 1500
Truncating sequence for O75096 from 1905 to 1500


Processing batch 119/121:  68%|██████▊   | 68/100 [00:13<00:06,  4.68it/s]

Truncating sequence for Q8IWT3 from 2517 to 1500


Processing batch 119/121: 100%|██████████| 100/100 [00:20<00:00,  4.79it/s]


Saved batch 119 (100 proteins): 97.46 MB


Processing batch 120/121:  11%|█         | 11/100 [00:02<00:20,  4.43it/s]

Truncating sequence for O43151 from 1795 to 1500


Processing batch 120/121:  17%|█▋        | 17/100 [00:03<00:17,  4.69it/s]

Truncating sequence for Q9Y2I7 from 2098 to 1500


Processing batch 120/121:  31%|███       | 31/100 [00:06<00:15,  4.57it/s]

Truncating sequence for Q8TEU7 from 1601 to 1500


Processing batch 120/121:  35%|███▌      | 35/100 [00:07<00:14,  4.52it/s]

Truncating sequence for Q5VT25 from 1732 to 1500


Processing batch 120/121: 100%|██████████| 100/100 [00:20<00:00,  4.90it/s]


Saved batch 120 (100 proteins): 101.04 MB


Processing batch 121/121: 100%|██████████| 26/26 [00:05<00:00,  5.04it/s]


Saved batch 121 (26 proteins): 24.61 MB

Embedding complete: 1026 proteins processed in this run
Total storage used in this run: 1.12 GB
Average size per protein: 1148.24 KB
Total proteins in manifest: 12026


In [17]:
# Cell 6: Merging functions
def merge_embeddings(embeddings_dir='data/full_dataset/embeddings', output_file=None):
    """
    Merge all protein embeddings from batch files into a single pickle file.
    
    Args:
        embeddings_dir: Directory containing the embedding batch files and manifest
        output_file: Path to save the merged embeddings (default: embeddings_merged.pkl in embeddings_dir)
    
    Returns:
        Path to the merged embeddings file
    """
    # Set default output file path if not provided
    if output_file is None:
        output_file = os.path.join(embeddings_dir, "embeddings_merged.pkl")
    
    # Load the manifest file
    manifest_path = os.path.join(embeddings_dir, "embeddings_manifest.pkl")
    with open(manifest_path, "rb") as f:
        manifest = pickle.load(f)
    
    print(f"Found {len(manifest)} proteins in manifest")
    
    # Get unique batch files
    batch_files = set(manifest.values())
    print(f"Found {len(batch_files)} batch files")
    
    # Create a dictionary to hold all embeddings
    all_embeddings = {}
    
    # Track file size for reporting
    total_loaded_bytes = 0
    
    # Process each batch file
    for batch_file in tqdm(batch_files, desc="Loading batch files"):
        batch_path = os.path.join(embeddings_dir, batch_file)
        
        # Get file size
        file_size = os.path.getsize(batch_path)
        total_loaded_bytes += file_size
        
        # Load batch embeddings
        batch_embeddings = torch.load(batch_path)
        
        # Add embeddings to master dictionary
        all_embeddings.update(batch_embeddings)
        
        # Clear memory
        del batch_embeddings
        gc.collect()
    
    print(f"Loaded {len(all_embeddings)} protein embeddings")
    print(f"Total loaded: {total_loaded_bytes/1024/1024/1024:.2f} GB")
    
    # Save the merged embeddings
    print(f"Saving merged embeddings to {output_file}...")
    with open(output_file, "wb") as f:
        pickle.dump(all_embeddings, f)
    
    # Get size of the merged file
    merged_size = os.path.getsize(output_file)
    print(f"Merged file size: {merged_size/1024/1024/1024:.2f} GB")
    
    return output_file

def merge_embeddings_compressed(embeddings_dir='data/full_dataset/embeddings', output_file=None):
    """
    Merge all protein embeddings from batch files into a single compressed pickle file.
    Uses highest compression level to minimize size.
    
    Args:
        embeddings_dir: Directory containing the embedding batch files and manifest
        output_file: Path to save the merged embeddings (default: embeddings_merged_compressed.pkl in embeddings_dir)
    
    Returns:
        Path to the merged embeddings file
    """
    import gzip
    
    # Set default output file path if not provided
    if output_file is None:
        output_file = os.path.join(embeddings_dir, "embeddings_merged_compressed.pkl")
    
    # Load the manifest file
    manifest_path = os.path.join(embeddings_dir, "embeddings_manifest.pkl")
    with open(manifest_path, "rb") as f:
        manifest = pickle.load(f)
    
    print(f"Found {len(manifest)} proteins in manifest")
    
    # Get unique batch files
    batch_files = set(manifest.values())
    print(f"Found {len(batch_files)} batch files")
    
    # Create a dictionary to hold all embeddings
    all_embeddings = {}
    
    # Track file size for reporting
    total_loaded_bytes = 0
    
    # Process each batch file
    for batch_file in tqdm(batch_files, desc="Loading batch files"):
        batch_path = os.path.join(embeddings_dir, batch_file)
        
        # Get file size
        file_size = os.path.getsize(batch_path)
        total_loaded_bytes += file_size
        
        # Load batch embeddings
        batch_embeddings = torch.load(batch_path)
        
        # Add embeddings to master dictionary
        all_embeddings.update(batch_embeddings)
        
        # Clear memory
        del batch_embeddings
        gc.collect()
    
    print(f"Loaded {len(all_embeddings)} protein embeddings")
    print(f"Total loaded: {total_loaded_bytes/1024/1024/1024:.2f} GB")
    
    # Save the merged embeddings with compression
    print(f"Saving compressed merged embeddings to {output_file}...")
    with gzip.open(output_file, "wb", compresslevel=9) as f:
        pickle.dump(all_embeddings, f)
    
    # Get size of the merged file
    merged_size = os.path.getsize(output_file)
    print(f"Merged compressed file size: {merged_size/1024/1024/1024:.2f} GB")
    print(f"Compression ratio: {merged_size/total_loaded_bytes:.2f}")
    
    return output_file

In [18]:
# Cell 7: Merge the embeddings
# Choose one of these options:

# Option 1: Basic merge - faster but larger file
merged_file = merge_embeddings()
print(f"Merged embeddings saved to: {merged_file}")

# Option 2: Merge with maximum compression - slower but smaller file
# merged_compressed_file = merge_embeddings_compressed()
# print(f"Compressed merged embeddings saved to: {merged_compressed_file}")

Found 12026 proteins in manifest
Found 121 batch files


Loading batch files: 100%|██████████| 121/121 [00:15<00:00,  7.70it/s]


Loaded 12026 protein embeddings
Total loaded: 11.23 GB
Saving merged embeddings to data/full_dataset/embeddings/embeddings_merged.pkl...
Merged file size: 11.23 GB
Merged embeddings saved to: data/full_dataset/embeddings/embeddings_merged.pkl


In [12]:
protein_embeddings = pickle.load(open('data/full_dataset/embeddings/embeddings_merged.pkl', 'rb'))


In [13]:
for key, value in protein_embeddings.items():
    print(f"Protein ID: {key}, Embedding shape: {value.shape}")

Protein ID: P28223, Embedding shape: torch.Size([473, 960])
Protein ID: P41595, Embedding shape: torch.Size([483, 960])
Protein ID: O00161, Embedding shape: torch.Size([213, 960])
Protein ID: P56962, Embedding shape: torch.Size([304, 960])
Protein ID: P82979, Embedding shape: torch.Size([212, 960])
Protein ID: Q01081, Embedding shape: torch.Size([242, 960])
Protein ID: O60678, Embedding shape: torch.Size([533, 960])
Protein ID: Q14524, Embedding shape: torch.Size([1502, 960])
Protein ID: P10275, Embedding shape: torch.Size([922, 960])
Protein ID: Q15648, Embedding shape: torch.Size([1502, 960])
Protein ID: P05413, Embedding shape: torch.Size([135, 960])
Protein ID: P05556, Embedding shape: torch.Size([800, 960])
Protein ID: Q99873, Embedding shape: torch.Size([373, 960])
Protein ID: O95573, Embedding shape: torch.Size([722, 960])
Protein ID: P08588, Embedding shape: torch.Size([479, 960])
Protein ID: P21917, Embedding shape: torch.Size([421, 960])
Protein ID: O75365, Embedding shape: t