In [1]:
import pandas as pd
import numpy as np
import os
from sklearn.model_selection import train_test_split
import torch
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 1) Load the four “full” splits
train_full  = pd.read_csv('../data/full_dataset/train_data.csv')
val_full    = pd.read_csv('../data/full_dataset/validation_data.csv')
test1_full  = pd.read_csv('../data/full_dataset/test1_data.csv')
test2_full  = pd.read_csv('../data/full_dataset/test2_data.csv')

# 2) Build a 10 000-example train set, 50/50
n_train = 10000
n_pos_train = n_neg_train = n_train // 2
pos_train = train_full[train_full['isInteraction'] == 1]\
              .sample(n=n_pos_train, random_state=42)
neg_train = train_full[train_full['isInteraction'] == 0]\
              .sample(n=n_neg_train, random_state=42)
train_data = pd.concat([pos_train, neg_train])\
               .sample(frac=1, random_state=43)\
               .reset_index(drop=True)

# 3) Build a “reasonable” validation set: 2 000 examples, 50/50
n_val = 2000
n_pos_val = n_neg_val = n_val // 2
pos_val = val_full[val_full['isInteraction'] == 1]\
            .sample(n=n_pos_val, random_state=44)
neg_val = val_full[val_full['isInteraction'] == 0]\
            .sample(n=n_neg_val, random_state=44)
cv_data = pd.concat([pos_val, neg_val])\
            .sample(frac=1, random_state=45)\
            .reset_index(drop=True)
cv_data['trainTest'] = 'validation'

# 4) Build test1: 2 000 examples, 50/50
n_test1 = 2000
n_pos_test1 = n_neg_test1 = n_test1 // 2
pos_test1 = test1_full[test1_full['isInteraction'] == 1]\
              .sample(n=n_pos_test1, random_state=46)
neg_test1 = test1_full[test1_full['isInteraction'] == 0]\
              .sample(n=n_neg_test1, random_state=46)
test1_data = pd.concat([pos_test1, neg_test1])\
               .sample(frac=1, random_state=47)\
               .reset_index(drop=True)

# 5) Build test2: 10 000 examples, ~9% pos, ~91% neg
n_test2 = 10000
n_pos_test2 = int(n_test2 * 0.09)
n_neg_test2 = n_test2 - n_pos_test2
pos_test2 = test2_full[test2_full['isInteraction'] == 1]\
              .sample(n=n_pos_test2, random_state=48)
neg_test2 = test2_full[test2_full['isInteraction'] == 0]\
              .sample(n=n_neg_test2, random_state=48)
test2_data = pd.concat([pos_test2, neg_test2])\
               .sample(frac=1, random_state=49)\
               .reset_index(drop=True)

# 6) Print class‐balance stats
print(f"Train set:      {train_data.shape}, pos-ratio={train_data['isInteraction'].mean():.3f}")
print(f"Validation set: {cv_data.shape},   pos-ratio={cv_data['isInteraction'].mean():.3f}")
print(f"Test1 set:      {test1_data.shape}, pos-ratio={test1_data['isInteraction'].mean():.3f}")
print(f"Test2 set:      {test2_data.shape}, pos-ratio={test2_data['isInteraction'].mean():.3f}")

# 7) Save to disk
out_dir = '../data/medium_set'
os.makedirs(out_dir, exist_ok=True)

for df, name in [
    (train_data, 'train_data'),
    (cv_data,    'validation_data'),
    (test1_data, 'test1_data'),
    (test2_data, 'test2_data'),
]:
    df.to_pickle(f'{out_dir}/{name}.pkl')
    df.to_csv(f'{out_dir}/{name}.csv', index=False)

print("\nAll datasets successfully saved to 'data/medium_set'")

Train set:      (10000, 6), pos-ratio=0.500
Validation set: (2000, 6),   pos-ratio=0.500
Test1 set:      (2000, 6), pos-ratio=0.500
Test2 set:      (10000, 6), pos-ratio=0.090

All datasets successfully saved to 'data/medium_set'


In [3]:
display(train_data)

Unnamed: 0,uniprotID_A,uniprotID_B,isInteraction,trainTest,sequence_A,sequence_B
0,Q92529,Q9H6L4,0,train,MLPRTKYNRFRNDSVTSVDDLLHSLSVSGGGGKVSAARATPAAAPY...,MAQKPKVDPHVGRLGYLQALVTEFQETQSQDAKEQVLANLANFAYD...
1,P09326,Q02446,0,train,MCSRGWDSCLALELLLLPLSLLVTSIQGHLVHMTVVSGSNVTLNIS...,MSDQKKEEEEEAAAAAAMATEGGKTSEPENNNKKPKTSGSQDSQPS...
2,Q4V328,Q9H2H9,0,train,MAQALSEEEFQRMQAQLLELRTNNYQLSDELRKNGVELTSLRQKVA...,MMHFKSGLELTELQNMTVPEDDNISNDSNDFTEVENGQINSKFISD...
3,O95835,Q9NTJ5,0,train,MKRSEKPEGYRQMRPKTFPASNYTVSSRQMLQEIRESLRNLSKPSD...,MATAAYEQLKLHITPEKFYVEACDDGADDVLTIDRVSTEVTLAVKK...
4,O95125,Q96JC9,1,train,MATAVEPEDQDLWEEEGILMVKLEDDFTCRPESVLQRDDPVLETSH...,MNGTANPLLDREEHCLRLGESFEKRPRASFHTIRYDFKPASIDTSC...
...,...,...,...,...,...,...
9995,P78310,Q13094,0,train,MALLLCFVLLCGVVDFARSLSITTPEEMIEKAKGETAYLPCKFTLS...,MALRNVPFRSEVLGWDPDSLADYFKKLNYKDCEKAVKKYHIDGARF...
9996,Q5T7W7,Q9Y2D8,1,train,MPSSTSPDQGDDLENCILRFSDLDLKDMSLINPSSSLKAELDGSTK...,MGDWMTVTDPGLSSESKTISQYTSETKMSPSSLYSQQVLCSSIPLS...
9997,Q3KNW5,Q8WXI8,0,train,MRANCSSSSACPANSSEEELPVGLEVHGNLELVFTVVSTVMMGLLM...,MGLEKPQSKLEGGMHPQLIPSVIAVVFILLLSVCFIASCLVTHHNF...
9998,Q9UIG4,Q9Y224,1,train,MILNWKLLGILVLCLHTRGISGSEGHPSHPPAEDREEAGSPTLPQG...,MFRRKLTALDYHNPAGFNCKDETEFRNFIVWLEDQKIRHYKIEDRG...


In [4]:
import pandas as pd
from collections import Counter

# This code assumes 'train_data' DataFrame is already defined and populated.
# If 'train_data' is not yet defined in the session, 
# you might need to re-run the cell that creates it first.
# For example, if you loaded it from a file:
# train_data = pd.read_csv('../data/medium_set/train_data.csv') 
# Or if it's from the notebook cell: make sure that cell has been executed.

print("--- Checking Individual Protein Frequencies in train_data ---")

# Ensure the DataFrame and relevant columns exist
if 'train_data' not in locals() or not isinstance(train_data, pd.DataFrame):
    print("Error: 'train_data' DataFrame not found. Please ensure it is loaded.")
elif 'uniprotID_A' not in train_data.columns or 'uniprotID_B' not in train_data.columns:
    print("Error: 'train_data' must contain 'uniprotID_A' and 'uniprotID_B' columns.")
else:
    # Extract all protein IDs from both columns into a single list
    all_proteins_in_pairs = list(train_data['uniprotID_A']) + list(train_data['uniprotID_B'])
    
    # Count the occurrences of each protein ID
    protein_counts = Counter(all_proteins_in_pairs)
    
    # Filter for proteins that appear in more than one pair context
    # (i.e., their count in the combined list is > 1)
    proteins_in_multiple_pairs = {protein: count for protein, count in protein_counts.items() if count > 1}
    
    num_unique_proteins_total = len(protein_counts)
    num_proteins_in_multiple = len(proteins_in_multiple_pairs)

    print(f"Total unique proteins involved in any pair: {num_unique_proteins_total}")
    
    if num_proteins_in_multiple > 0:
        print(f"Found {num_proteins_in_multiple} unique proteins that appear in more than one pair context.")
        print("Displaying these proteins and their frequencies (how many pairs they are part of):")
        
        # Sort for consistent display (e.g., by count descending, then by protein ID)
        sorted_multiple_proteins = sorted(proteins_in_multiple_pairs.items(), key=lambda item: (-item[1], item[0]))
        
        # Displaying top N or all, depending on how many there are
        display_limit = 20 
        for i, (protein, count) in enumerate(sorted_multiple_proteins):
            if i < display_limit:
                print(f"  Protein ID: {protein}, Appears in: {count} pairs")
            elif i == display_limit:
                print(f"  ... and {num_proteins_in_multiple - display_limit} more proteins.")
                break
        if num_proteins_in_multiple <= display_limit:
             print(f"  (Displayed all {num_proteins_in_multiple} proteins occurring in multiple pairs)")

    else:
        print("No protein appears in more than one pair context (each protein is unique to a single slot in a single pair, or appears only once across all pairs).")

print("\n--- End of Individual Protein Frequency Check ---")

--- Checking Individual Protein Frequencies in train_data ---
Total unique proteins involved in any pair: 5589
Found 3199 unique proteins that appear in more than one pair context.
Displaying these proteins and their frequencies (how many pairs they are part of):
  Protein ID: O76024, Appears in: 70 pairs
  Protein ID: P05067, Appears in: 61 pairs
  Protein ID: Q08379, Appears in: 59 pairs
  Protein ID: A8MQ03, Appears in: 58 pairs
  Protein ID: Q96HA8, Appears in: 52 pairs
  Protein ID: Q7Z699, Appears in: 51 pairs
  Protein ID: P60410, Appears in: 50 pairs
  Protein ID: Q9NRD5, Appears in: 50 pairs
  Protein ID: Q15323, Appears in: 48 pairs
  Protein ID: O60333, Appears in: 44 pairs
  Protein ID: Q0VD86, Appears in: 43 pairs
  Protein ID: Q8TBB1, Appears in: 43 pairs
  Protein ID: P60409, Appears in: 42 pairs
  Protein ID: P61981, Appears in: 41 pairs
  Protein ID: Q6A162, Appears in: 41 pairs
  Protein ID: Q04864, Appears in: 40 pairs
  Protein ID: P13473, Appears in: 39 pairs
  Pro

## Encode proteins using ESM C model

In [4]:
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig

device = "cuda" if torch.cuda.is_available() else "cpu"

if torch.cuda.is_available():
    target_gpu_index_for_fraction = 0

    print(f"Attempting to set memory fraction to 0.7 for GPU {target_gpu_index_for_fraction}")
    torch.cuda.set_per_process_memory_fraction(0.7, device=target_gpu_index_for_fraction)
    print(f"Call to set_per_process_memory_fraction for GPU {target_gpu_index_for_fraction} completed.")

model = ESMC.from_pretrained("esmc_300m").to(device)
print(f"[ESM-C] Loaded locally on {device}")
model_type = "local"

Attempting to set memory fraction to 0.7 for GPU 0
Call to set_per_process_memory_fraction for GPU 0 completed.


Fetching 4 files: 100%|██████████| 4/4 [00:00<00:00, 97541.95it/s]


[ESM-C] Loaded locally on cuda


In [5]:
def get_protein_embedding(sequence):
    """
    Get protein embedding for a given sequence using the loaded ESM model.
    Optimized with torch.no_grad() for inference.
    """
    with torch.no_grad():  # Ensures no gradients are computed for model operations
        protein = ESMProtein(sequence=sequence)
        protein_tensor = model.encode(protein)  # Model inference step
        logits_output = model.logits(           # Model inference step
            protein_tensor,
            LogitsConfig(sequence=False, return_embeddings=True)
        )
        # Get the per-protein representation by mean-pooling across sequence length
        embedding = logits_output.embeddings
    return embedding.cpu().numpy() # Move to CPU and convert to NumPy array


In [6]:
from tqdm.auto import tqdm # Use tqdm.auto for better notebook compatibility

def process_dataset(df, sample_size=None):
    """Process dataset to add embeddings for both protein sequences.
    This function calls the optimized get_protein_embedding.

    Args:
        df: DataFrame with protein sequences in 'sequence_A' and 'sequence_B' columns.
        sample_size: Optional, number of examples to process (for testing).
    """
    # Create a copy of the dataframe to avoid modifying the original
    if sample_size is not None and 0 < sample_size < len(df):
        # Ensure sample_size is valid if provided
        result_df = df.sample(n=sample_size, random_state=42).copy()
        print(f"Processing a sample of {len(result_df)} examples.")
    elif sample_size is not None and sample_size >= len(df):
        print(f"sample_size ({sample_size}) is >= DataFrame length ({len(df)}). Processing all examples.")
        result_df = df.copy()
    else: # sample_size is None or 0 or invalid
        result_df = df.copy()
        print(f"Processing all {len(result_df)} examples.")

    embeddings_A = []
    embeddings_B = []

    # Iterate through the selected DataFrame rows
    # Ensure 'sequence_A' and 'sequence_B' columns exist
    if 'sequence_A' not in result_df.columns or 'sequence_B' not in result_df.columns:
        raise ValueError("DataFrame must contain 'sequence_A' and 'sequence_B' columns.")

    for i, row in tqdm(result_df.iterrows(), total=len(result_df), desc="Encoding proteins"):
        # Get embeddings for protein A and B using the optimized function
        embedding_A = get_protein_embedding(row['sequence_A'])
        embedding_B = get_protein_embedding(row['sequence_B'])

        embeddings_A.append(embedding_A)
        embeddings_B.append(embedding_B)

    # Store the embeddings as new columns in the DataFrame
    result_df['embedding_A'] = embeddings_A
    result_df['embedding_B'] = embeddings_B

    return result_df

In [None]:
print("\n--- Encoding train split ---")



df = pd.read_pickle("../data/medium_set/train_data.pkl")
train_data_with_embeddings = process_dataset(df)
train_data_with_embeddings.to_pickle('../data/medium_set/embeddings/train_data_with_embeddings.pkl')



--- Encoding train split ---
Processing all 10000 examples.


Encoding proteins: 100%|██████████| 10000/10000 [10:05<00:00, 16.51it/s]


In [None]:
# Cell 2: Validation split
if 'cuda' in str(device):
    torch.cuda.empty_cache()
print("\n--- Encoding validation split ---")
df = pd.read_pickle("../data/medium_set/validation_data.pkl")
val_data_with_embeddings = process_dataset(df)
val_data_with_embeddings.to_pickle('../data/medium_set/embeddings/validation_data_with_embeddings.pkl')

In [None]:
# Cell 3: Test1 split
if 'cuda' in str(device):
    torch.cuda.empty_cache()
print("\n--- Encoding test1 split ---")
df = pd.read_pickle("../data/medium_set/test1_data.pkl", index_col=0)
test1_data_with_embeddings = process_dataset(df)
test1_data_with_embeddings.to_pickle('../data/medium_set/embeddings/test1_data_with_embeddings.pkl')

In [9]:
def process_dataset(df, sample_size=None, device_to_check=None): # Added device_to_check
    """Process dataset to add embeddings for both protein sequences,
    with GPU cache clearing to help manage memory.

    Args:
        df: DataFrame with protein sequences in 'sequence_A' and 'sequence_B' columns.
        sample_size: Optional, number of examples to process (for testing).
        device_to_check: The torch.device being used (e.g., global 'device').
                         This is needed to conditionally empty CUDA cache.
    """
    if device_to_check is None:
        # Fallback to a globally defined 'device' if not passed, or raise an error
        # For this example, let's try to use a global 'device' if not provided,
        # but passing it explicitly is better practice.
        try:
            # This line assumes 'device' is a global variable.
            # If not, this will cause a NameError.
            # It's better to require 'device_to_check' to be passed.
            # For robustness, we should handle if 'device' isn't globally defined.
            # Consider raising an error if device_to_check is None and no global 'device' is found.
            # For now, let's assume 'device' is available if device_to_check is None
            global device # This declares intent to use a global variable named 'device'
            current_device_str = str(device)
            if 'cuda' not in current_device_str:
                 print("Warning: device_to_check not provided, and global 'device' is not CUDA. CUDA cache will not be cleared.")
        except NameError:
            print("Warning: 'device_to_check' not provided and global 'device' not found. GPU cache clearing will be skipped.")
            current_device_str = "cpu" # Assume CPU if device is unknown
    else:
        current_device_str = str(device_to_check)


    # --- DataFrame preparation (sampling) ---
    if sample_size is not None and 0 < sample_size < len(df):
        result_df = df.sample(n=sample_size, random_state=42).copy()
        print(f"Processing a sample of {len(result_df)} examples.")
    elif sample_size is not None and sample_size >= len(df):
        print(f"sample_size ({sample_size}) is >= DataFrame length ({len(df)}). Processing all examples.")
        result_df = df.copy()
    else: # sample_size is None or 0 or invalid
        result_df = df.copy()
        print(f"Processing all {len(result_df)} examples.")

    embeddings_A = []
    embeddings_B = []

    if 'sequence_A' not in result_df.columns or 'sequence_B' not in result_df.columns:
        raise ValueError("DataFrame must contain 'sequence_A' and 'sequence_B' columns.")

    # --- Main processing loop ---
    for i, row in tqdm(result_df.iterrows(), total=len(result_df), desc="Encoding proteins"):
        embedding_A_val = None
        embedding_B_val = None
        try:
            # Get embeddings for protein A and B.
            # Crucially, get_protein_embedding should use `torch.no_grad()` internally
            # and return embeddings on the CPU (e.g., as NumPy arrays).
            embedding_A_val = get_protein_embedding(row['sequence_A'])
            embedding_B_val = get_protein_embedding(row['sequence_B'])

        except Exception as e:
            row_identifier = row.name if hasattr(row, 'name') and row.name is not None else f"at numerical index {i}"
            print(f"Error getting embedding for row {row_identifier}: {e}")
            # Embeddings will remain None and be appended as such.

        finally:
            embeddings_A.append(embedding_A_val)
            embeddings_B.append(embedding_B_val)

            # --- GPU Memory Optimization ---
            # If running on GPU, clear the CUDA cache after processing each pair.
            # This helps release memory that PyTorch's allocator might be holding onto
            # but isn't actively used by tensors still in scope.
            # This does NOT free memory of tensors that are still referenced.
            # Calling this can have a small performance overhead.
            if 'cuda' in current_device_str:
                torch.cuda.empty_cache()

    # --- Store embeddings in DataFrame ---
    result_df['embedding_A'] = embeddings_A
    result_df['embedding_B'] = embeddings_B

    return result_df


In [None]:
print(f"Current device from previous cells: {device if 'device' in locals() else 'device not defined yet'}")

if 'cuda' in str(device if 'device' in locals() else ""): # Check if the globally defined 'device' is CUDA
    print(f"Operating on CUDA device: {device}. Clearing cache before processing test2 split.")
    # This call to empty_cache() is a CUDA operation.
    torch.cuda.empty_cache()                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            

print("\n--- Encoding test2 split ---")

df_path = "../data/medium_set/test2_data.pkl"
df = pd.read_pickle(df_path, index_col=0)
print(f"Loaded '{df_path}' with {len(df)} rows.")

if torch.cuda.is_available():
    target_gpu_index_for_fraction = 0

    print(f"Attempting to set memory fraction to 0.7 for GPU {target_gpu_index_for_fraction} (Timing is critical and likely too late here).")
    torch.cuda.set_per_process_memory_fraction(0.7, device=target_gpu_index_for_fraction)
    print(f"Call to set_per_process_memory_fraction for GPU {target_gpu_index_for_fraction} completed.")

test2_data_with_embeddings = process_dataset(df)

output_dir = '../data/medium_set/embeddings/'
os.makedirs(output_dir, exist_ok=True) # Create the directory if it doesn't exist

output_file_path = os.path.join(output_dir, 'test2_data_with_embeddings.pkl')
test2_data_with_embeddings.to_pickle(output_file_path)
print(f"Processed data and saved embeddings to: {output_file_path}")

In [9]:
import pickle

df = pd.read_csv('../data/medium_set/embeddings/test2_data_with_embeddings.csv')

with open('../data/medium_set/embeddings/test2_data_with_embeddings.pkl', 'wb') as f:
    pickle.dump(df, f)

In [None]:
df = pd.read_pickle('../data/medium_set/embeddings/test2_data_with_embeddings.pkl')
print(df.head())
print(df.dtypes)
print(df.info())
print(df.describe())
print(type(df['embedding_A'][0]))
print(type(df['embedding_B'][0]))

In [None]:
import ast
import numpy as np

df['embedding_A'] = df['embedding_A'].apply(lambda x: x.tolist())
df.to_csv('fixed_embeddings.csv', index=False)
df['embedding_B'] = df['embedding_B'].apply(lambda x: x.tolist())
df.to_csv('fixed_embeddings.csv', index=False)
# df['embedding_A'] = df['embedding_A'].apply(ast.literal_eval)
# df['embedding_B'] = df['embedding_B'].apply(ast.literal_eval)

embedding_A = np.vstack([np.array(x).squeeze() for x in df['embedding_A']])
embedding_B = np.vstack([np.array(x).squeeze() for x in df['embedding_B']])

print(embedding_A.shape)
print(embedding_B.shape)


