## Standardize the embeddings dictionary

In [1]:
# Import necessary libraries
import pickle
import numpy as np
import torch
from tqdm import tqdm
import os
from sklearn.preprocessing import StandardScaler
import gc

# Set the working directory to the project root
# Get the current working directory and navigate to project root
current_dir = os.getcwd()
print(f"Current directory: {current_dir}")

# Find the CS182-Final-Project directory
if 'CS182-Final-Project' in current_dir:
    # If we're already in a subdirectory of the project, navigate to root
    project_root = current_dir.split('CS182-Final-Project')[0] + 'CS182-Final-Project'
else:
    # If we're not in the project directory, assume we need to navigate
    project_root = '/home/richard/projects/CS182-Final-Project'

os.chdir(project_root)
print(f"Changed to project root: {os.getcwd()}")

Current directory: /home/richard/projects/CS182-Final-Project/src/data_process
Changed to project root: /home/richard/projects/CS182-Final-Project


In [2]:
# Define file paths (relative to project root)
input_file = 'data/full_dataset/embeddings/embeddings_merged.pkl'
output_file = 'data/full_dataset/embeddings/embeddings_standardized.pkl'

# Verify we're in the correct directory
print(f"Current working directory: {os.getcwd()}")
print(f"Project root contents: {os.listdir('.')}")

# Verify input file exists
if os.path.exists(input_file):
    file_size = os.path.getsize(input_file) / (1024**3)  # Size in GB
    print(f"✓ Input file found: {input_file}")
    print(f"✓ File size: {file_size:.2f} GB")
else:
    print(f"✗ Error: Input file not found at {input_file}")
    print(f"Available files in data/full_dataset/embeddings/:")
    if os.path.exists('data/full_dataset/embeddings/'):
        print(os.listdir('data/full_dataset/embeddings/'))
    
print(f"Output will be saved to: {output_file}")

Current working directory: /home/richard/projects/CS182-Final-Project
Project root contents: ['results', 'experiments', '.gitignore', 'README.md', 'data', 'requirements.txt', '.git', 'docs', 'src']
✓ Input file found: data/full_dataset/embeddings/embeddings_merged.pkl
✓ File size: 11.23 GB
Output will be saved to: data/full_dataset/embeddings/embeddings_standardized.pkl


In [3]:
# Load and examine the embeddings data
print("Loading embeddings data...")
with open(input_file, 'rb') as f:
    embeddings_data = pickle.load(f)

print(f"Data type: {type(embeddings_data)}")
print(f"Number of proteins: {len(embeddings_data)}")

# Examine a sample embedding
sample_key = list(embeddings_data.keys())[0]
sample_embedding = embeddings_data[sample_key]
print(f"\nSample protein ID: {sample_key}")
print(f"Sample embedding type: {type(sample_embedding)}")
print(f"Sample embedding shape: {sample_embedding.shape}")
print(f"Sample embedding dtype: {sample_embedding.dtype}")
print(f"Sample values (first 5): {sample_embedding[0, :5]}")

Loading embeddings data...
Data type: <class 'dict'>
Number of proteins: 12026

Sample protein ID: O14796
Sample embedding type: <class 'torch.Tensor'>
Sample embedding shape: torch.Size([134, 960])
Sample embedding dtype: torch.float16
Sample values (first 5): tensor([ 0.0030, -0.0006,  0.0012, -0.0003,  0.0016], dtype=torch.float16)


In [4]:
# Analyze embedding dimensions and compute statistics
print("Analyzing embedding dimensions...")

# Collect all embeddings to compute global statistics
all_embeddings = []
sequence_lengths = []
embedding_dim = None

print("Processing embeddings to collect statistics...")
for protein_id, embedding in tqdm(embeddings_data.items(), desc="Processing proteins"):
    # Convert to numpy if it's a torch tensor
    if isinstance(embedding, torch.Tensor):
        embedding_np = embedding.cpu().numpy().astype(np.float32)
    else:
        embedding_np = embedding.astype(np.float32)
    
    sequence_lengths.append(embedding_np.shape[0])
    if embedding_dim is None:
        embedding_dim = embedding_np.shape[1]
    
    # Flatten the embedding to add to the collection
    all_embeddings.append(embedding_np.flatten())

print(f"\nEmbedding dimension: {embedding_dim}")
print(f"Sequence length range: {min(sequence_lengths)} - {max(sequence_lengths)}")
print(f"Average sequence length: {np.mean(sequence_lengths):.2f}")
print(f"Total embedding values: {sum(len(emb) for emb in all_embeddings):,}")

Analyzing embedding dimensions...
Processing embeddings to collect statistics...


Processing proteins: 100%|██████████| 12026/12026 [00:29<00:00, 404.34it/s]



Embedding dimension: 960
Sequence length range: 26 - 1502
Average sequence length: 522.11
Total embedding values: 6,027,745,920


In [5]:
# Compute global statistics for standardization (memory-efficient)
print("Computing global statistics for standardization...")

# Initialize variables for incremental computation
total_sum = 0.0
total_sum_sq = 0.0
total_count = 0
min_val = float('inf')
max_val = float('-inf')

print("Computing statistics incrementally...")
for i, embeddings in enumerate(all_embeddings):
    # Convert to numpy array if needed
    if not isinstance(embeddings, np.ndarray):
        embeddings = np.array(embeddings)
    
    # Update statistics
    total_sum += np.sum(embeddings)
    total_sum_sq += np.sum(embeddings ** 2)
    total_count += embeddings.size
    min_val = min(min_val, np.min(embeddings))
    max_val = max(max_val, np.max(embeddings))
    
    if (i + 1) % 100 == 0:  # Progress indicator
        print(f"Processed {i + 1} embedding arrays...")

# Compute final statistics
global_mean = total_sum / total_count
global_std = np.sqrt((total_sum_sq / total_count) - (global_mean ** 2))

print(f"\nGlobal statistics:")
print(f"Mean: {global_mean:.6f}")
print(f"Standard deviation: {global_std:.6f}")
print(f"Min value: {min_val:.6f}")
print(f"Max value: {max_val:.6f}")
print(f"Total values: {total_count:,}")

print("\nStatistics computed without memory issues.")

Computing global statistics for standardization...
Computing statistics incrementally...
Processed 100 embedding arrays...
Processed 200 embedding arrays...
Processed 300 embedding arrays...
Processed 400 embedding arrays...
Processed 500 embedding arrays...
Processed 600 embedding arrays...
Processed 700 embedding arrays...
Processed 800 embedding arrays...
Processed 900 embedding arrays...
Processed 1000 embedding arrays...
Processed 1100 embedding arrays...
Processed 1200 embedding arrays...
Processed 1300 embedding arrays...
Processed 1400 embedding arrays...
Processed 1500 embedding arrays...
Processed 1600 embedding arrays...
Processed 1700 embedding arrays...
Processed 1800 embedding arrays...
Processed 1900 embedding arrays...
Processed 2000 embedding arrays...
Processed 2100 embedding arrays...
Processed 2200 embedding arrays...
Processed 2300 embedding arrays...
Processed 2400 embedding arrays...
Processed 2500 embedding arrays...
Processed 2600 embedding arrays...
Processed 

In [6]:
# Apply standardization to all embeddings (minimal memory)
print("Applying standardization to all embeddings...")

os.makedirs("standardized_embeddings", exist_ok=True)

processed_count = 0
sample_key = None
sample_standardized = None

for protein_id, embedding in tqdm(embeddings_data.items(), desc="Standardizing embeddings"):
    # Convert to numpy if it's a torch tensor
    if isinstance(embedding, torch.Tensor):
        embedding_np = embedding.cpu().numpy().astype(np.float32)
    else:
        embedding_np = embedding.astype(np.float32)
    
    # Apply z-score standardization: (x - mean) / std
    standardized_embedding = (embedding_np - global_mean) / global_std
    
    # Save immediately and clear from memory
    np.save(f"standardized_embeddings/{protein_id}.npy", standardized_embedding)
    
    # Keep first one for verification (copy to avoid reference)
    if sample_key is None:
        sample_key = protein_id
        sample_standardized = standardized_embedding.copy()
    
    processed_count += 1
    
    # Clear variables
    del embedding_np, standardized_embedding
    
    # Periodic garbage collection
    if processed_count % 50 == 0:
        gc.collect()

print(f"\nStandardization complete!")
print(f"Number of standardized embeddings: {processed_count}")

print(f"\nSample standardized embedding:")
print(f"Shape: {sample_standardized.shape}")
print(f"Mean: {np.mean(sample_standardized):.6f}")
print(f"Std: {np.std(sample_standardized):.6f}")
print(f"Min: {np.min(sample_standardized):.6f}")
print(f"Max: {np.max(sample_standardized):.6f}")

Applying standardization to all embeddings...


Standardizing embeddings: 100%|██████████| 12026/12026 [02:16<00:00, 88.42it/s] 



Standardization complete!
Number of standardized embeddings: 12026

Sample standardized embedding:
Shape: (134, 960)
Mean: -0.000438
Std: 0.985250
Min: -14.494860
Max: 18.379944


In [7]:
# Memory cleanup before verification
print("Performing comprehensive memory cleanup...")

# List of variables that might be consuming memory
variables_to_clean = [
    'embeddings_data', 'all_embeddings', 'standardized_embeddings', 
    'standardized_embedding', 'embedding_np', 'standardized_embedding_var',
    'sample_standardized', 'current_batch', 'batch_embeddings',
    'all_embeddings', 'final_dict', 'embedding', 'flat_embedding',
    'chunk_values', 'sample_values', 'sample_concat'
]

# Clean up variables if they exist
cleaned_vars = []
for var_name in variables_to_clean:
    if var_name in globals():
        del globals()[var_name]
        cleaned_vars.append(var_name)
    if var_name in locals():
        del locals()[var_name]
        cleaned_vars.append(var_name)

if cleaned_vars:
    print(f"Cleaned up variables: {cleaned_vars}")

# Force garbage collection multiple times
import gc
for i in range(3):
    collected = gc.collect()
    if collected > 0:
        print(f"Garbage collection round {i+1}: freed {collected} objects")

# Check memory usage if psutil is available
try:
    import psutil
    import os
    process = psutil.Process(os.getpid())
    memory_mb = process.memory_info().rss / 1024 / 1024
    print(f"Current memory usage: {memory_mb:.1f} MB")
except ImportError:
    print("psutil not available for memory monitoring")

# Clear any matplotlib figures if they exist
try:
    import matplotlib.pyplot as plt
    plt.close('all')
    print("Closed all matplotlib figures")
except ImportError:
    pass

# Clear CUDA cache if using PyTorch
try:
    import torch
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("Cleared CUDA cache")
except ImportError:
    pass

print("Memory cleanup completed.\n")

Performing comprehensive memory cleanup...
Cleaned up variables: ['embeddings_data', 'all_embeddings', 'sample_standardized', 'embedding']
Current memory usage: 26562.4 MB
Closed all matplotlib figures
Memory cleanup completed.



In [8]:
# Verify global standardization properties (ultra memory-efficient)
print("Verifying global standardization properties...")

# Get list of saved embedding files
import glob
embedding_files = glob.glob("standardized_embeddings/*.npy")
sample_size = min(1000, len(embedding_files))  # Sample up to 1000 proteins

# Use online/streaming statistics computation
import random

# Randomly sample protein files for verification
sampled_files = random.sample(embedding_files, sample_size) if len(embedding_files) > sample_size else embedding_files

total_sum = 0.0
total_sum_sq = 0.0
total_count = 0
min_val = float('inf')
max_val = float('-inf')

print(f"Computing verification statistics on sample of {len(sampled_files)} proteins...")

for file_path in sampled_files:
    # Load embedding from file
    embedding = np.load(file_path)
    
    # Process in small chunks to avoid memory spikes
    flat_embedding = embedding.flatten()
    
    total_sum += np.sum(flat_embedding)
    total_sum_sq += np.sum(flat_embedding ** 2)
    total_count += flat_embedding.size
    min_val = min(min_val, np.min(flat_embedding))
    max_val = max(max_val, np.max(flat_embedding))
    
    # Clean up
    del embedding, flat_embedding

# Compute final verification statistics
sample_mean = total_sum / total_count
sample_std = np.sqrt((total_sum_sq / total_count) - (sample_mean ** 2))

print(f"\nVerification on sample of {len(sampled_files)} proteins:")
print(f"Sample mean: {sample_mean:.6f} (should be close to 0)")
print(f"Sample std: {sample_std:.6f} (should be close to 1)")
print(f"Sample min: {min_val:.6f}")
print(f"Sample max: {max_val:.6f}")

# Clean up
del sampled_files
gc.collect()

Verifying global standardization properties...
Computing verification statistics on sample of 1000 proteins...

Verification on sample of 1000 proteins:
Sample mean: 0.000037 (should be close to 0)
Sample std: 1.001788 (should be close to 1)
Sample min: -28.569386
Sample max: 29.463774


229

In [9]:
# Save standardized embeddings (ultra memory-efficient streaming)
print(f"Saving standardized embeddings to {output_file} using streaming...")

# Create output directory if it doesn't exist
os.makedirs(os.path.dirname(output_file), exist_ok=True)

# Get all embedding files
embedding_files = glob.glob("standardized_embeddings/*.npy")
total_files = len(embedding_files)
print(f"Found {total_files} standardized embedding files...")

# Use a temporary file approach to build the pickle incrementally
import tempfile
temp_file = output_file + ".tmp"

try:
    # Create dictionary in chunks and save incrementally
    final_dict = {}
    chunk_size = 20  # Very small chunks
    
    print("Processing embeddings in small chunks...")
    for i in tqdm(range(0, total_files, chunk_size), desc="Processing chunks"):
        chunk_files = embedding_files[i:i+chunk_size]
        
        # Process tiny chunk
        for file_path in chunk_files:
            protein_id = os.path.basename(file_path).replace('.npy', '')
            embedding = np.load(file_path)
            final_dict[protein_id] = embedding
            
            # Immediate cleanup
            del embedding
        
        # Force garbage collection every chunk
        gc.collect()
    
    # Save final result
    print("Saving final pickle file...")
    with open(temp_file, 'wb') as f:
        pickle.dump(final_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
    
    # Move temp file to final location
    import shutil
    shutil.move(temp_file, output_file)
    
    # Verify the saved file
    if os.path.exists(output_file):
        output_size = os.path.getsize(output_file) / (1024**3)  # Size in GB
        print(f"\nStandardized embeddings saved successfully!")
        print(f"Output file: {output_file}")
        print(f"Output file size: {output_size:.2f} GB")
        print(f"Total embeddings saved: {len(final_dict)}")
        
        # Clean up individual files after successful save
        print("\nCleaning up individual embedding files...")
        shutil.rmtree("standardized_embeddings")
        print("✓ Individual embedding files cleaned up successfully.")
        
    else:
        print(f"\nError: Failed to save file at {output_file}")
        print("Individual files preserved for safety.")

except Exception as e:
    print(f"\nError during processing: {e}")
    print("Individual files preserved for safety.")
    # Clean up temp file if it exists
    if os.path.exists(temp_file):
        os.remove(temp_file)
    
finally:
    # Clean up memory
    if 'final_dict' in locals():
        del final_dict
    gc.collect()

Saving standardized embeddings to data/full_dataset/embeddings/embeddings_standardized.pkl using streaming...
Found 12026 standardized embedding files...
Processing embeddings in small chunks...


Processing chunks: 100%|██████████| 602/602 [02:01<00:00,  4.94it/s]


Saving final pickle file...

Standardized embeddings saved successfully!
Output file: data/full_dataset/embeddings/embeddings_standardized.pkl
Output file size: 22.46 GB
Total embeddings saved: 12026

Cleaning up individual embedding files...
✓ Individual embedding files cleaned up successfully.


In [14]:
# Simple verification of standardized embeddings
import pickle
import numpy as np

# Load the standardized embeddings
print("Loading standardized embeddings...")
with open('data/full_dataset/embeddings/embeddings_standardized.pkl', 'rb') as f:
    standardized_data = pickle.load(f)

print(f"Number of proteins: {len(standardized_data)}")

# Sample a subset to check statistics (to avoid memory issues)
sample_size = min(100, len(standardized_data))
protein_ids = list(standardized_data.keys())[:sample_size]

all_values = []
for protein_id in protein_ids:
    embedding = standardized_data[protein_id]
    all_values.extend(embedding.flatten())

all_values = np.array(all_values)

print(f"\nStandardized embeddings verification (sample of {sample_size} proteins):")
print(f"Mean: {np.mean(all_values):.6f} (should be close to 0)")
print(f"Std: {np.std(all_values):.6f} (should be close to 1)")
print(f"Min: {np.min(all_values):.6f}")
print(f"Max: {np.max(all_values):.6f}")

# Clean up
del standardized_data, all_values

Loading standardized embeddings...
Number of proteins: 12026

Standardized embeddings verification (sample of 100 proteins):
Mean: -0.000497 (should be close to 0)
Std: 1.003686 (should be close to 1)
Min: -28.067131
Max: 25.947996
