In [2]:
"""
Generate ESM-C (600 M) sequence embeddings for every protein chain that
already has an *_msa_emb.pt file under data/protein_data_pdb/.

Run from the project root or adjust ROOT accordingly.

"""

import os
from pathlib import Path
from typing import List

import torch
from tqdm.auto import tqdm

# ──────────────────────────────── config ──────────────────────────────── #
ROOT   = Path(r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\protein_data_pdb")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "esmc_600m"        # Fair-ESM model registry tag
SAVE_TAG   = "esmc_emb"         # <id>_<SAVE_TAG>.pt
# ──────────────────────────────────────────────────────────────────────── #

def collect_target_dirs(root: Path) -> List[Path]:
    """Return directories that already contain *_msa_emb.pt."""
    targets = []
    for d in root.iterdir():
        if not d.is_dir():          # skip stray files
            continue
        msa_files = list(d.glob("*_msa_emb.pt"))
        if msa_files:               # chain has an msa embedding → include
            targets.append(d)
    return sorted(targets)

def load_sequence(chain_dir: Path) -> str:
    seq_file = chain_dir / "sequence.txt"
    if not seq_file.exists():
        raise FileNotFoundError(f"{seq_file} is missing")
    # one-liner: strip newlines/spaces & ignore FASTA headers (if they ever sneak in)
    with seq_file.open() as fh:
        seq = "".join(line.strip() for line in fh if not line.startswith(">"))
    return seq

@torch.inference_mode()
def embed_sequence(model, seq: str) -> torch.Tensor:
    """Return [L, D] tensor on *CPU* to keep GPU memory free between calls."""
    from esm.sdk.api import ESMProtein, LogitsConfig  # local import keeps startup snappy
    prot   = ESMProtein(sequence=seq)
    hidden = model.encode(prot)                       # [L, D] on DEVICE
    out    = model.logits(hidden, LogitsConfig(sequence=True, return_embeddings=True))
    emb    = out.embeddings.cpu()                    # move to CPU immediately
    return emb

def main():
    print(f"⇢ Scanning {ROOT} …")
    chain_dirs = collect_target_dirs(ROOT)
    if not chain_dirs:
        print("No *_msa_emb.pt files found – nothing to do.")
        return
    print(f"Found {len(chain_dirs)} chains with MSA embeddings\n")

    print(f"⇢ Loading {MODEL_NAME} on {DEVICE} …")
    from esm.models.esmc import ESMC
    model = ESMC.from_pretrained(MODEL_NAME).to(DEVICE).eval()

    for chain_dir in tqdm(chain_dirs, desc="Embedding chains"):
        chain_id = chain_dir.name                      # e.g. "1FIO-A"
        out_file = chain_dir / f"{chain_id}_{SAVE_TAG}.pt"
        if out_file.exists():
            tqdm.write(f"• {chain_id}: {out_file.name} already exists – skipping")
            continue

        # 1) load sequence
        seq = load_sequence(chain_dir)

        # 2) embed
        emb = embed_sequence(model, seq)               # [L, D]

        # 3) save
        torch.save(emb, out_file)

        # 4) feedback
        tqdm.write(f"✓ {chain_id}: len={len(seq):4d}, shape={tuple(emb.shape)}  →  {out_file.name}")

    print("\nAll done!")

if __name__ == "__main__":
    main()


⇢ Scanning C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\protein_data_pdb …
Found 12 chains with MSA embeddings

⇢ Loading esmc_600m on cpu …


Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Embedding chains:   0%|          | 0/12 [00:00<?, ?it/s]

✓ 1FIO-A: len= 196, shape=(1, 198, 1152)  →  1FIO-A_esmc_emb.pt
✓ 1V29-A: len= 220, shape=(1, 222, 1152)  →  1V29-A_esmc_emb.pt
✓ 1WFX-A: len= 186, shape=(1, 188, 1152)  →  1WFX-A_esmc_emb.pt
✓ 2B0T-A: len= 738, shape=(1, 740, 1152)  →  2B0T-A_esmc_emb.pt
✓ 2FVA-A: len=  82, shape=(1, 84, 1152)  →  2FVA-A_esmc_emb.pt
✓ 2JJ9-A: len= 788, shape=(1, 790, 1152)  →  2JJ9-A_esmc_emb.pt
✓ 2L0T-B: len= 163, shape=(1, 165, 1152)  →  2L0T-B_esmc_emb.pt
✓ 2PBZ-A: len= 320, shape=(1, 322, 1152)  →  2PBZ-A_esmc_emb.pt
✓ 2VZ6-A: len= 313, shape=(1, 315, 1152)  →  2VZ6-A_esmc_emb.pt
✓ 3BJV-A: len= 161, shape=(1, 163, 1152)  →  3BJV-A_esmc_emb.pt
✓ 5Y88-T: len= 283, shape=(1, 285, 1152)  →  5Y88-T_esmc_emb.pt
✓ 6FOK-A: len= 723, shape=(1, 725, 1152)  →  6FOK-A_esmc_emb.pt

All done!


In [4]:
import pandas as pd
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set the path to your protein data directory
data_dir = Path("C:/Users/rfrjo/Documents/Codebases/PFP_Testing/data/protein_data_pdb")

def analyze_protein_lengths():
    """Analyze L.csv files across all proteins in the dataset."""
    
    # Lists to store data
    protein_ids = []
    lengths = []
    missing_files = []
    
    print("Scanning protein directories for L.csv files...")
    print("=" * 60)
    
    # Iterate through all protein directories
    for protein_dir in data_dir.iterdir():
        if not protein_dir.is_dir():
            continue
            
        protein_id = protein_dir.name
        l_csv_path = protein_dir / "L.csv"
        
        if l_csv_path.exists():
            try:
                # Read the L.csv file
                length_df = pd.read_csv(l_csv_path)
                protein_length = int(length_df["L"].iloc[0])
                
                protein_ids.append(protein_id)
                lengths.append(protein_length)
                
                # print(f"✓ {protein_id}: {protein_length} residues")
                
            except Exception as e:
                print(f"✗ Error reading {protein_id}/L.csv: {e}")
                missing_files.append(protein_id)
        else:
            print(f"✗ Missing L.csv: {protein_id}")
            missing_files.append(protein_id)
    
    print("=" * 60)
    print(f"Successfully processed: {len(lengths)} proteins")
    print(f"Missing/Error files: {len(missing_files)}")
    
    if missing_files:
        print(f"Problematic proteins: {missing_files}")
    
    return protein_ids, lengths, missing_files

# Run the analysis
protein_ids, lengths, missing_files = analyze_protein_lengths()

# Create DataFrame for easier analysis
if lengths:
    df = pd.DataFrame({
        'protein_id': protein_ids,
        'length': lengths
    })
    
    print("\n" + "=" * 60)
    print("SUMMARY STATISTICS")
    print("=" * 60)
    
    # Basic statistics
    print(f"Total proteins analyzed: {len(lengths)}")
    print(f"Mean length: {np.mean(lengths):.2f} residues")
    print(f"Median length: {np.median(lengths):.2f} residues")
    print(f"Standard deviation: {np.std(lengths):.2f} residues")
    print(f"Minimum length: {np.min(lengths)} residues")
    print(f"Maximum length: {np.max(lengths)} residues")
    
    # Percentiles
    print(f"\nPercentiles:")
    for p in [25, 50, 75, 90, 95, 99]:
        print(f"  {p}th percentile: {np.percentile(lengths, p):.0f} residues")
    
    # Length distribution
    print(f"\nLength Distribution:")
    bins = [0, 100, 200, 300, 400, 500, 750, 1000, float('inf')]
    labels = ['<100', '100-199', '200-299', '300-399', '400-499', '500-749', '750-999', '1000+']
    
    df['length_bin'] = pd.cut(df['length'], bins=bins, labels=labels, right=False)
    length_counts = df['length_bin'].value_counts().sort_index()
    
    for bin_label, count in length_counts.items():
        percentage = (count / len(df)) * 100
        print(f"  {bin_label} residues: {count} proteins ({percentage:.1f}%)")
    
    # Display detailed DataFrame
    print(f"\nDetailed Data (first 10 proteins):")
    print(df.head(10).to_string(index=False))
    
    print(f"\nLongest proteins:")
    print(df.nlargest(5, 'length').to_string(index=False))
    
    print(f"\nShortest proteins:")
    print(df.nsmallest(5, 'length').to_string(index=False))
    
else:
    print("No valid L.csv files found!")

Scanning protein directories for L.csv files...
Successfully processed: 29740 proteins
Missing/Error files: 0

SUMMARY STATISTICS
Total proteins analyzed: 29740
Mean length: 276.22 residues
Median length: 243.00 residues
Standard deviation: 165.85 residues
Minimum length: 60 residues
Maximum length: 1000 residues

Percentiles:
  25th percentile: 146 residues
  50th percentile: 243 residues
  75th percentile: 365 residues
  90th percentile: 495 residues
  95th percentile: 592 residues
  99th percentile: 821 residues

Length Distribution:
  <100 residues: 3125 proteins (10.5%)
  100-199 residues: 8629 proteins (29.0%)
  200-299 residues: 6746 proteins (22.7%)
  300-399 residues: 5330 proteins (17.9%)
  400-499 residues: 3034 proteins (10.2%)
  500-749 residues: 2337 proteins (7.9%)
  750-999 residues: 538 proteins (1.8%)
  1000+ residues: 1 proteins (0.0%)

Detailed Data (first 10 proteins):
protein_id  length length_bin
    154L-A     185    100-199
    155C-A     135    100-199
    16P

In [1]:
"""
Comprehensive Data Inspection Script
Run this in your Jupyter notebook to understand your protein data format.
"""

import torch
import pandas as pd
from pathlib import Path
from typing import Any, Dict, Union
import pprint

# Configuration
PROTEIN_ID = "1FIO-A"
DATA_DIR = Path(r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\protein_data_pdb")
PROTEIN_DIR = DATA_DIR / PROTEIN_ID

def inspect_file(file_path: Path, description: str):
    """Inspect a single .pt file and print detailed information."""
    print(f"\n{'='*60}")
    print(f"INSPECTING: {description}")
    print(f"File: {file_path}")
    print(f"{'='*60}")
    
    if not file_path.exists():
        print(f"❌ File does not exist: {file_path}")
        return None
    
    try:
        # Load the data
        data = torch.load(file_path, map_location='cpu')
        
        # Check the type
        print(f"📁 Data type: {type(data)}")
        
        if isinstance(data, dict):
            print(f"🔑 Dictionary with {len(data)} keys:")
            for i, (key, value) in enumerate(data.items()):
                print(f"  [{i+1}] Key: '{key}'")
                print(f"      Type: {type(value)}")
                if torch.is_tensor(value):
                    print(f"      Shape: {tuple(value.shape)}")
                    print(f"      Dtype: {value.dtype}")
                    print(f"      Device: {value.device}")
                    print(f"      Min/Max: {value.min().item():.4f} / {value.max().item():.4f}")
                else:
                    print(f"      Value: {value}")
                print()
        
        elif torch.is_tensor(data):
            print(f"📊 Direct tensor:")
            print(f"   Shape: {tuple(data.shape)}")
            print(f"   Dtype: {data.dtype}")
            print(f"   Device: {data.device}")
            print(f"   Min/Max: {data.min().item():.4f} / {data.max().item():.4f}")
            
        else:
            print(f"📦 Other data type: {type(data)}")
            print(f"   Content: {data}")
            
        return data
        
    except Exception as e:
        print(f"❌ Error loading {file_path}: {e}")
        return None

def inspect_protein_data(protein_id: str):
    """Inspect all relevant files for a protein."""
    protein_dir = DATA_DIR / protein_id
    
    print(f"🧬 PROTEIN DATA INSPECTION")
    print(f"Protein ID: {protein_id}")
    print(f"Directory: {protein_dir}")
    
    if not protein_dir.exists():
        print(f"❌ Protein directory does not exist: {protein_dir}")
        return
    
    # List all files in the protein directory
    print(f"\n📂 Files in protein directory:")
    for file_path in sorted(protein_dir.iterdir()):
        print(f"   {file_path.name}")
    
    # 1. Inspect ESM-C embeddings
    esmc_file = protein_dir / f"{protein_id}_esmc_emb.pt"
    esmc_data = inspect_file(esmc_file, "ESM-C Embeddings")
    
    # 2. Inspect MSA embeddings  
    msa_file = protein_dir / f"{protein_id}_msa_emb.pt"
    msa_data = inspect_file(msa_file, "MSA Embeddings")
    
    # 3. Inspect labels
    for task in ["mf", "bp", "cc"]:
        label_file = protein_dir / f"{task}_labels.pt"
        if label_file.exists():
            label_data = inspect_file(label_file, f"{task.upper()} Labels")
        else:
            print(f"\n⚠️  {task.upper()} labels file not found: {label_file}")
    
    # 4. Inspect sequence file
    seq_file = protein_dir / "sequence.txt"
    if seq_file.exists():
        print(f"\n{'='*60}")
        print(f"INSPECTING: Sequence File")
        print(f"{'='*60}")
        with open(seq_file, 'r') as f:
            sequence = f.read().strip()
        print(f"📝 Sequence length: {len(sequence)}")
        print(f"📝 First 50 chars: {sequence[:50]}...")
        print(f"📝 Last 50 chars: ...{sequence[-50:]}")
    
    # 5. Inspect length file
    length_file = protein_dir / "L.csv"
    if length_file.exists():
        print(f"\n{'='*60}")
        print(f"INSPECTING: Length File")
        print(f"{'='*60}")
        length_df = pd.read_csv(length_file)
        print(f"📊 L.csv content:")
        print(length_df)
        protein_length = int(length_df.iloc[0, 0])
        print(f"📏 Reported length: {protein_length}")
    
    # 6. Cross-validation checks
    print(f"\n{'='*60}")
    print(f"CROSS-VALIDATION CHECKS")
    print(f"{'='*60}")
    
    # Check if sequence length matches L.csv
    if 'sequence' in locals() and 'protein_length' in locals():
        seq_len = len(sequence)
        print(f"🔍 Sequence file length: {seq_len}")
        print(f"🔍 L.csv length: {protein_length}")
        print(f"🔍 Match: {'✅' if seq_len == protein_length else '❌'}")
    
    # Check embedding dimensions vs sequence length
    if esmc_data is not None and 'sequence' in locals():
        if isinstance(esmc_data, dict):
            # Try to find the main tensor
            main_tensor = None
            for key, value in esmc_data.items():
                if torch.is_tensor(value) and len(value.shape) >= 2:
                    main_tensor = value
                    break
        else:
            main_tensor = esmc_data
            
        if main_tensor is not None:
            print(f"🔍 ESM-C tensor shape: {tuple(main_tensor.shape)}")
            print(f"🔍 Expected: sequence length + special tokens")
            
    if msa_data is not None:
        if isinstance(msa_data, dict):
            # Try to find the main tensor
            main_tensor = None
            for key, value in msa_data.items():
                if torch.is_tensor(value) and len(value.shape) >= 3:
                    main_tensor = value
                    break
        else:
            main_tensor = msa_data
            
        if main_tensor is not None:
            print(f"🔍 MSA tensor shape: {tuple(main_tensor.shape)}")
            print(f"🔍 Expected: [N_sequences, sequence_length + special_tokens, embedding_dim]")
    
    return esmc_data, msa_data

# Run the inspection
print("🚀 Starting comprehensive protein data inspection...")
esmc_data, msa_data = inspect_protein_data(PROTEIN_ID)

print(f"\n{'='*60}")
print(f"SUMMARY RECOMMENDATIONS")
print(f"{'='*60}")
print("Based on the inspection above, here's what we found:")
print("(Copy this output and share it for proper data loading implementation)")

🚀 Starting comprehensive protein data inspection...
🧬 PROTEIN DATA INSPECTION
Protein ID: 1FIO-A
Directory: C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\protein_data_pdb\1FIO-A

📂 Files in protein directory:
   1FIO-A_esmc_emb.pt
   1FIO-A_msa_emb.pt
   bp_labels.pt
   ca_dist_matrix.pt
   cb_dist_matrix.pt
   cc_labels.pt
   L.csv
   mf_labels.pt
   msa_raw
   sequence.txt

INSPECTING: ESM-C Embeddings
File: C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\protein_data_pdb\1FIO-A\1FIO-A_esmc_emb.pt
📁 Data type: <class 'torch.Tensor'>
📊 Direct tensor:
   Shape: (1, 198, 1152)
   Dtype: torch.float32
   Device: cpu
   Min/Max: -1.1053 / 1.2845

INSPECTING: MSA Embeddings
File: C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\protein_data_pdb\1FIO-A\1FIO-A_msa_emb.pt
📁 Data type: <class 'dict'>
🔑 Dictionary with 2 keys:
  [1] Key: 'embeddings'
      Type: <class 'torch.Tensor'>
      Shape: (256, 197, 768)
      Dtype: torch.float32
      Device: cpu
      Min/Max: -13.3794 