In [26]:
from pathlib import Path
import torch
import itertools

root = Path(r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\HEAL_PDB")
split_files = {
    "train": root / "train_pdbch.pt",
    "val"  : root / "val_pdbch.pt",
    "test" : root / "test_pdbch.pt",
}

# 1) Load and unwrap
splits = {}
for name, path in split_files.items():
    obj = torch.load(path, map_location="cpu")
    # if it’s a 1-key dict, pull out the sole value
    if isinstance(obj, dict) and len(obj)==1:
        key = next(iter(obj))
        ids = obj[key]
    else:
        ids = obj
    # ensure it’s a flat list of strings
    assert isinstance(ids, (list, tuple))
    splits[name] = list(ids)
    print(f"{name:5s} → {len(ids):5d} IDs  (sample keys: {list(ids)[:3]})")

# 2) Concatenate and basic stats
all_ids = list(itertools.chain.from_iterable(splits.values()))
print(f"\nTotal chains across splits: {len(all_ids):,}")
print("  (should be train + val + test)")
# This duplicate check is case-sensitive
dupes = len(all_ids) - len(set(all_ids))
# 3) Check for duplicates
# Case-insensitive duplicate check
all_ids_lower = [id.lower() for id in all_ids]
case_insensitive_dupes = len(all_ids_lower) - len(set(all_ids_lower))
print(f"Case-insensitive duplicate IDs: {case_insensitive_dupes}")

# Also check for case variations of the same ID
case_variations = len(set(all_ids)) - len(set(all_ids_lower))
print(f"IDs that differ only by case: {case_variations}")
print(f"Duplicate IDs across splits? {dupes}")

# 4) Peek at first 10 test IDs
print("\nFirst 10 test-chain IDs:")
for cid in splits["test"][:10]:
    print(" ", cid)


train → 29893 IDs  (sample keys: ['154L-A', '155C-A', '16PK-A'])
val   →  3322 IDs  (sample keys: ['192L-A', '1A0A-A', '1A21-A'])
test  →  3414 IDs  (sample keys: ['11AS-A', '18GS-A', '1A0P-A'])

Total chains across splits: 36,629
  (should be train + val + test)
Case-insensitive duplicate IDs: 233
IDs that differ only by case: 233
Duplicate IDs across splits? 0

First 10 test-chain IDs:
  11AS-A
  18GS-A
  1A0P-A
  1A22-A
  1A4E-A
  1A6F-A
  1A6J-A
  1A8Y-A
  1A9C-A
  1A9W-E


In [27]:
from pathlib import Path
import torch

def encode_protein_id(protein_id):
    """Convert to Windows-safe: uppercase preserved, lowercase gets tilde after"""
    return ''.join(c + '~' if c.islower() else c for c in protein_id)

def decode_protein_id(encoded_name):
    """Restore original by removing tildes"""
    return encoded_name.replace('~', '')

# ── CONFIGURATION ──────────────────────────────────────────────────────────────
# Source: HEAL_PDB split files
source_root = Path(r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\HEAL_PDB")
split_files = {
    "train": source_root / "train_pdbch.pt",
    "val": source_root / "val_pdbch.pt", 
    "test": source_root / "test_pdbch.pt",
}

# Destination: New PDBCH folder structure
dest_root = Path(r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\PDBCH")
folder_names = {
    "train": "train_pdbch",
    "val": "val_pdbch",
    "test": "test_pdbch"
}

print("🧬 Creating PDBCH Folder Structure with Tilde Encoding")
print("=" * 60)

# ── 1. CREATE ROOT DIRECTORY ──────────────────────────────────────────────────
dest_root.mkdir(parents=True, exist_ok=True)
print(f"📁 Created root directory: {dest_root}")

# ── 2. LOAD SPLITS & CREATE FOLDERS ───────────────────────────────────────────
split_stats = {}

for split_name, source_file in split_files.items():
    print(f"\n🔄 Processing {split_name} split...")
    
    # Load protein IDs from .pt file
    obj = torch.load(source_file, map_location="cpu")
    if isinstance(obj, dict) and len(obj) == 1:
        # Unwrap single-key dict
        protein_ids = next(iter(obj.values()))
    else:
        protein_ids = obj
    
    protein_ids = list(protein_ids)
    print(f"   📋 Loaded {len(protein_ids):,} protein IDs")
    
    # Create split folder
    split_folder = dest_root / folder_names[split_name]
    split_folder.mkdir(exist_ok=True)
    print(f"   📁 Created split folder: {split_folder.name}")
    
    # Create individual protein folders with tilde encoding
    created_count = 0
    case_conflicts = 0
    
    for protein_id in protein_ids:
        # Apply tilde encoding for Windows case-safety
        safe_name = encode_protein_id(protein_id)
        protein_folder = split_folder / safe_name
        
        # Track if this is a case-conflict resolution
        if safe_name != protein_id:
            case_conflicts += 1
        
        # Create empty folder
        if not protein_folder.exists():
            protein_folder.mkdir(exist_ok=True)
            created_count += 1
        
        # Save original ID for reference
        original_id_file = protein_folder / "original_id.txt"
        with open(original_id_file, 'w', encoding='utf-8') as f:
            f.write(protein_id)
    
    print(f"   ✅ Created {created_count:,} unique folders")
    if case_conflicts > 0:
        print(f"   🔀 Applied tilde encoding to {case_conflicts:,} lowercase IDs")
    
    split_stats[split_name] = {
        'protein_ids': len(protein_ids),
        'folders_created': created_count,
        'case_conflicts': case_conflicts
    }

# ── 3. VERIFICATION ────────────────────────────────────────────────────────────
print(f"\n🔍 VERIFICATION RESULTS")
print("=" * 60)

for split_name in split_files.keys():
    split_folder = dest_root / folder_names[split_name]
    
    # Count actual subfolders
    actual_folders = len([d for d in split_folder.iterdir() if d.is_dir()])
    expected_folders = split_stats[split_name]['protein_ids']
    case_conflicts = split_stats[split_name]['case_conflicts']
    
    status = "✅ SUCCESS" if actual_folders == expected_folders else "❌ MISMATCH"
    
    print(f"{split_name.upper():5s}: {actual_folders:,} folders | {expected_folders:,} proteins | {status}")
    if case_conflicts > 0:
        print(f"      (Including {case_conflicts:,} tilde-encoded for case conflicts)")

# ── 4. SAMPLE FOLDER NAMES ────────────────────────────────────────────────────
print(f"\n📋 SAMPLE ENCODED FOLDER NAMES")
print("=" * 60)

for split_name in ["train", "val", "test"]:
    split_folder = dest_root / folder_names[split_name]
    sample_folders = sorted([d.name for d in split_folder.iterdir() if d.is_dir()])[:5]
    
    print(f"{split_name.upper()}:")
    for folder_name in sample_folders:
        original_id = decode_protein_id(folder_name)
        if folder_name != original_id:
            print(f"  {original_id:12s} → {folder_name:15s} (tilde encoded)")
        else:
            print(f"  {original_id:12s} → {folder_name:15s} (unchanged)")

print(f"\n🎉 PDBCH folder structure created successfully!")
print(f"📍 Location: {dest_root}")
print(f"📁 Structure:")
print(f"   {dest_root.name}/")
for folder_name in folder_names.values():
    folder_path = dest_root / folder_name
    folder_count = len([d for d in folder_path.iterdir() if d.is_dir()])
    print(f"   ├── {folder_name}/ ({folder_count:,} protein folders)")

🧬 Creating PDBCH Folder Structure with Tilde Encoding
📁 Created root directory: C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\PDBCH

🔄 Processing train split...
   📋 Loaded 29,893 protein IDs
   📁 Created split folder: train_pdbch
   ✅ Created 29,893 unique folders
   🔀 Applied tilde encoding to 675 lowercase IDs

🔄 Processing val split...
   📋 Loaded 3,322 protein IDs
   📁 Created split folder: val_pdbch
   ✅ Created 3,322 unique folders
   🔀 Applied tilde encoding to 80 lowercase IDs

🔄 Processing test split...
   📋 Loaded 3,414 protein IDs
   📁 Created split folder: test_pdbch
   ✅ Created 3,414 unique folders
   🔀 Applied tilde encoding to 137 lowercase IDs

🔍 VERIFICATION RESULTS
TRAIN: 29,893 folders | 29,893 proteins | ✅ SUCCESS
      (Including 675 tilde-encoded for case conflicts)
VAL  : 3,322 folders | 3,322 proteins | ✅ SUCCESS
      (Including 80 tilde-encoded for case conflicts)
TEST : 3,414 folders | 3,414 proteins | ✅ SUCCESS
      (Including 137 tilde-encoded for 

In [24]:
from pathlib import Path
import torch
from collections import Counter
import re

# Load your splits
root = Path(r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\HEAL_PDB")
split_files = {
    "train": root / "train_pdbch.pt",
    "val": root / "val_pdbch.pt", 
    "test": root / "test_pdbch.pt",
}

splits = {}
for name, path in split_files.items():
    obj = torch.load(path, map_location="cpu")
    if isinstance(obj, dict) and len(obj) == 1:
        obj = next(iter(obj.values()))
    splits[name] = list(obj)

print("=== DUPLICATE ANALYSIS ===")
for split_name, ids in splits.items():
    print(f"\n{split_name.upper()}:")
    print(f"  Total IDs: {len(ids)}")
    print(f"  Unique IDs: {len(set(ids))}")
    print(f"  Duplicates: {len(ids) - len(set(ids))}")
    
    # Find the actual duplicates
    counts = Counter(ids)
    duplicates = {k: v for k, v in counts.items() if v > 1}
    if duplicates:
        print(f"  Duplicate IDs found: {len(duplicates)}")
        for dup_id, count in list(duplicates.items())[:5]:  # Show first 5
            print(f"    '{dup_id}': appears {count} times")
        if len(duplicates) > 5:
            print(f"    ... and {len(duplicates) - 5} more")

print("\n=== CASE SENSITIVITY ANALYSIS ===")
for split_name, ids in splits.items():
    print(f"\n{split_name.upper()}:")
    unique_original = len(set(ids))
    unique_lowercase = len(set(id.lower() for id in ids))
    case_variations = unique_original - unique_lowercase
    print(f"  Case variations: {case_variations}")
    
    if case_variations > 0:
        # Find examples of case variations
        by_lower = {}
        for id in set(ids):
            key = id.lower()
            if key not in by_lower:
                by_lower[key] = []
            by_lower[key].append(id)
        
        variations = {k: v for k, v in by_lower.items() if len(v) > 1}
        print(f"  IDs with case variations: {len(variations)}")
        for lower_key, variants in list(variations.items())[:3]:  # Show first 3
            print(f"    {variants}")

print("\n=== INVALID FILENAME CHARACTER ANALYSIS ===")
# Windows invalid characters: < > : " | ? * and control chars (0-31)
invalid_chars = set('<>:"|?*') | set(chr(i) for i in range(32))

for split_name, ids in splits.items():
    print(f"\n{split_name.upper()}:")
    invalid_ids = []
    for id in set(ids):
        if any(char in invalid_chars for char in id):
            invalid_ids.append(id)
        # Check for trailing dots or spaces (also invalid on Windows)
        if id.endswith('.') or id.endswith(' '):
            invalid_ids.append(id)
    
    print(f"  IDs with invalid filename chars: {len(invalid_ids)}")
    if invalid_ids:
        for invalid_id in invalid_ids[:5]:  # Show first 5
            invalid_chars_found = [c for c in invalid_id if c in invalid_chars]
            print(f"    '{invalid_id}' (chars: {invalid_chars_found})")

print("\n=== PREDICTED FOLDER COUNTS ===")
for split_name, ids in splits.items():
    # Simulate Windows filesystem behavior
    unique_folders = set()
    for id in ids:
        # Convert to lowercase (Windows case-insensitive)
        folder_name = id.lower()
        # Remove invalid characters (simplified - just replace with underscore)
        folder_name = re.sub(r'[<>:"|?*\x00-\x1f]', '_', folder_name)
        # Handle trailing dots/spaces
        folder_name = folder_name.rstrip('. ')
        unique_folders.add(folder_name)
    
    print(f"{split_name}: {len(ids)} IDs → ~{len(unique_folders)} predicted folders")

=== DUPLICATE ANALYSIS ===

TRAIN:
  Total IDs: 29893
  Unique IDs: 29893
  Duplicates: 0

VAL:
  Total IDs: 3322
  Unique IDs: 3322
  Duplicates: 0

TEST:
  Total IDs: 3414
  Unique IDs: 3414
  Duplicates: 0

=== CASE SENSITIVITY ANALYSIS ===

TRAIN:
  Case variations: 155
  IDs with case variations: 155
    ['5IT7-pp', '5IT7-PP']
    ['6GCS-j', '6GCS-J']
    ['5AJ4-Ap', '5AJ4-AP']

VAL:
  Case variations: 4
  IDs with case variations: 4
    ['5IT7-jj', '5IT7-JJ']
    ['4V8M-Bm', '4V8M-BM']
    ['3J9M-l', '3J9M-L']

TEST:
  Case variations: 16
  IDs with case variations: 16
    ['6GIQ-D', '6GIQ-d']
    ['4V6W-CO', '4V6W-Co']
    ['4V6W-CB', '4V6W-Cb']

=== INVALID FILENAME CHARACTER ANALYSIS ===

TRAIN:
  IDs with invalid filename chars: 0

VAL:
  IDs with invalid filename chars: 0

TEST:
  IDs with invalid filename chars: 0

=== PREDICTED FOLDER COUNTS ===
train: 29893 IDs → ~29738 predicted folders
val: 3322 IDs → ~3318 predicted folders
test: 3414 IDs → ~3398 predicted folders


In [20]:
from pathlib import Path

# Base path
base_path = Path(r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\HEAL_PDB")

# Check each split
splits = ["train", "test", "val"]

total_subfolders = 0

for split in splits:
    split_folder = base_path / f"protein_{split}_pdb"
    
    if split_folder.exists():
        # Count subfolders only
        subfolders = [p for p in split_folder.iterdir() if p.is_dir()]
        print(f"{split:5s}: {len(subfolders):5d} subfolders")
        total_subfolders += len(subfolders)
    else:
        print(f"{split:5s}: Folder not found")

print(f"\nTotal: {total_subfolders:5d} subfolders")

train: 29738 subfolders
test :  3398 subfolders
val  :  3318 subfolders

Total: 36454 subfolders


In [11]:
"""
Build per-protein GO-term folders for the PDBch split
----------------------------------------------------

Root/
├── train_pdbch.pt          ← HEAL ID list (dict or list)
├── val_pdbch.pt
├── test_pdbch.pt
├── pdbch_go.tsv            ← your TSV with three ontology columns
└── (generated)
    ├── protein_train_pdb/
    │   └── 154L-A/
    │       ├── mf_go.txt
    │       ├── bp_go.txt
    │       └── cc_go.txt
    ├── protein_val_pdb/
    └── protein_test_pdb/
"""

from pathlib import Path
import torch, csv, os

# ── CONFIG ───────────────────────────────────────────────────────────────────
ROOT = Path(r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\HEAL_PDB")  # adjust
GO_TSV = ROOT / "nrPDB-GO_2019.06.18_annot.tsv"                                       # adjust
SPLIT_PTS = {
    "train": ROOT / "train_pdbch.pt", 
    "val"  : ROOT / "val_pdbch.pt",
    "test" : ROOT / "test_pdbch.pt",
}
OUT_TPL = ROOT / "protein_{split}_pdb"                               # parent dirs
TXT_NAMES = {"mf": "mf_go.txt", "bp": "bp_go.txt", "cc": "cc_go.txt"}

# ── 1. LOAD SPLIT ID LISTS ───────────────────────────────────────────────────
split_ids = {}
for split, fpath in SPLIT_PTS.items():
    obj = torch.load(fpath, map_location="cpu")
    if isinstance(obj, dict) and len(obj) == 1:      # unwrap 1-key dict
        obj = next(iter(obj.values()))
    split_ids[split] = list(obj)
    print(f"{split:5s} : {len(split_ids[split]):6d} IDs")

# ── 2. BUILD LOOK-UP FROM CHAIN → {mf:[], bp:[], cc:[]} ──────────────────────
lookup = {cid: {"mf": [], "bp": [], "cc": []} for cid in
          split_ids["train"] + split_ids["val"] + split_ids["test"]}

with GO_TSV.open(newline='', encoding="utf-8") as fh:
    tsv = csv.reader(fh, delimiter='\t')
    for row in tsv:
        if not row or row[0].startswith('#'):
            continue                         # skip comments / headers
        chain = row[0].strip()
        if chain not in lookup:
            continue                         # GO line not relevant to PDBch
        mf = [g.strip() for g in row[1].split(',') if g.strip()]
        bp = [g.strip() for g in row[2].split(',') if g.strip()] if len(row) > 2 else []
        cc = [g.strip() for g in row[3].split(',') if g.strip()] if len(row) > 3 else []
        lookup[chain] = {"mf": mf, "bp": bp, "cc": cc}

print(f"Loaded GO terms for {sum(bool(v['mf'] or v['bp'] or v['cc'])
                                for v in lookup.values()):,} / {len(lookup):,} chains")

# ── 3. CREATE FOLDERS & WRITE FILES ──────────────────────────────────────────
for split, cids in split_ids.items():
    split_dir = OUT_TPL.with_name(OUT_TPL.name.format(split=split))
    split_dir.mkdir(parents=True, exist_ok=True)
    for cid in cids:
        cid_dir = split_dir / cid
        cid_dir.mkdir(exist_ok=True)
        terms = lookup[cid]
        for ont, fname in TXT_NAMES.items():
            with open(cid_dir / fname, 'w', encoding='utf-8') as fh:
                fh.write('\n'.join(terms[ont]))
print("✓ Folder hierarchy populated")


train :  29893 IDs
val   :   3322 IDs
test  :   3414 IDs
Loaded GO terms for 36,629 / 36,629 chains
✓ Folder hierarchy populated


In [28]:
from pathlib import Path
import csv

def decode_protein_id(encoded_name):
    """Restore original protein ID by removing tildes"""
    return encoded_name.replace('~', '')

def encode_protein_id(protein_id):
    """Convert to Windows-safe: uppercase preserved, lowercase gets tilde after"""
    return ''.join(c + '~' if c.islower() else c for c in protein_id)

# ── CONFIGURATION ──────────────────────────────────────────────────────────────
# Source: TSV file with GO annotations
GO_TSV = Path(r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\HEAL_PDB\nrPDB-GO_2019.06.18_annot.tsv")

# Destination: New PDBCH folder structure
PDBCH_ROOT = Path(r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\PDBCH")
SPLIT_FOLDERS = {
    "train": PDBCH_ROOT / "train_pdbch",
    "val": PDBCH_ROOT / "val_pdbch", 
    "test": PDBCH_ROOT / "test_pdbch"
}

# GO term file names
GO_FILES = {"mf": "mf_go.txt", "bp": "bp_go.txt", "cc": "cc_go.txt"}

print("🧬 Populating PDBCH Folders with GO Terms (Case-Sensitive)")
print("=" * 70)

# ── 1. BUILD CASE-SENSITIVE LOOKUP FROM TSV ───────────────────────────────────
print("📖 Loading GO annotations from TSV...")

# Initialize lookup dictionary for all proteins we'll encounter
go_lookup = {}

with GO_TSV.open(newline='', encoding="utf-8") as fh:
    tsv_reader = csv.reader(fh, delimiter='\t')
    
    for row_num, row in enumerate(tsv_reader, 1):
        if not row or row[0].startswith('#'):
            continue  # Skip empty rows and comments
        
        # Extract protein ID (case-sensitive)
        protein_id = row[0].strip()
        
        # Extract GO terms from columns
        mf_terms = [term.strip() for term in row[1].split(',') if term.strip()] if len(row) > 1 else []
        bp_terms = [term.strip() for term in row[2].split(',') if term.strip()] if len(row) > 2 else []
        cc_terms = [term.strip() for term in row[3].split(',') if term.strip()] if len(row) > 3 else []
        
        # Store in case-sensitive lookup
        go_lookup[protein_id] = {
            "mf": mf_terms,
            "bp": bp_terms, 
            "cc": cc_terms
        }

print(f"✅ Loaded GO annotations for {len(go_lookup):,} proteins from TSV")

# ── 2. PROCESS EACH SPLIT FOLDER ──────────────────────────────────────────────
total_stats = {"processed": 0, "found_go": 0, "missing_go": 0, "files_written": 0}

for split_name, split_folder in SPLIT_FOLDERS.items():
    print(f"\n🔄 Processing {split_name} split...")
    
    if not split_folder.exists():
        print(f"   ❌ Split folder not found: {split_folder}")
        continue
    
    # Get all protein subfolders (tilde-encoded names)
    protein_folders = [d for d in split_folder.iterdir() if d.is_dir()]
    print(f"   📁 Found {len(protein_folders):,} protein folders")
    
    split_stats = {"processed": 0, "found_go": 0, "missing_go": 0, "files_written": 0}
    
    for protein_folder in protein_folders:
        # STEP 1: Decode tilde-encoded folder name to get original protein ID
        encoded_folder_name = protein_folder.name
        original_protein_id = decode_protein_id(encoded_folder_name)
        
        split_stats["processed"] += 1
        
        # STEP 2: Case-sensitive lookup in TSV data
        if original_protein_id in go_lookup:
            go_terms = go_lookup[original_protein_id]
            split_stats["found_go"] += 1
            
            # STEP 3: Write GO term files back to tilde-encoded folder
            for ontology, filename in GO_FILES.items():
                file_path = protein_folder / filename
                terms = go_terms[ontology]
                
                with open(file_path, 'w', encoding='utf-8') as f:
                    f.write('\n'.join(terms))
                
                split_stats["files_written"] += 1
        else:
            # No GO terms found for this protein
            split_stats["missing_go"] += 1
            
            # Create empty GO files
            for ontology, filename in GO_FILES.items():
                file_path = protein_folder / filename
                with open(file_path, 'w', encoding='utf-8') as f:
                    f.write('')  # Empty file
                
                split_stats["files_written"] += 1
    
    # Update total stats
    for key in total_stats:
        total_stats[key] += split_stats[key]
    
    print(f"   ✅ Processed: {split_stats['processed']:,} proteins")
    print(f"   📊 Found GO terms: {split_stats['found_go']:,}")
    print(f"   ❓ Missing GO terms: {split_stats['missing_go']:,}")
    print(f"   📝 Files written: {split_stats['files_written']:,}")

# ── 3. FINAL VERIFICATION ─────────────────────────────────────────────────────
print(f"\n🎉 FINAL RESULTS")
print("=" * 70)
print(f"📊 Total proteins processed: {total_stats['processed']:,}")
print(f"✅ Proteins with GO terms: {total_stats['found_go']:,}")
print(f"❓ Proteins without GO terms: {total_stats['missing_go']:,}")
print(f"📝 Total GO files written: {total_stats['files_written']:,}")

coverage_pct = (total_stats['found_go'] / total_stats['processed'] * 100) if total_stats['processed'] > 0 else 0
print(f"📈 GO term coverage: {coverage_pct:.1f}%")

# ── 4. SAMPLE VERIFICATION ────────────────────────────────────────────────────
print(f"\n🔍 SAMPLE VERIFICATION")
print("=" * 70)

# Check a few folders to verify the structure
for split_name, split_folder in list(SPLIT_FOLDERS.items())[:1]:  # Just check train
    sample_folders = sorted([d for d in split_folder.iterdir() if d.is_dir()])[:3]
    
    print(f"Sample from {split_name}:")
    for folder in sample_folders:
        encoded_name = folder.name
        original_id = decode_protein_id(encoded_name)
        
        # Check if GO files exist
        go_files_exist = all((folder / filename).exists() for filename in GO_FILES.values())
        
        # Check if any GO files have content
        has_content = False
        for filename in GO_FILES.values():
            file_path = folder / filename
            if file_path.exists() and file_path.stat().st_size > 0:
                has_content = True
                break
        
        status = "✅ WITH GO" if has_content else "⭕ EMPTY GO"
        print(f"  {original_id:12s} → {encoded_name:15s} | {status}")

print(f"\n🎯 CRITICAL CHANGES MADE:")
print("1. ✅ Decoded tilde-encoded folder names to get original protein IDs")
print("2. ✅ Used case-sensitive lookup in TSV file (no case conversion)")
print("3. ✅ Wrote GO term files back to tilde-encoded folders")
print("4. ✅ Created empty files for proteins without GO annotations")
print("5. ✅ Preserved exact case sensitivity throughout the process")

🧬 Populating PDBCH Folders with GO Terms (Case-Sensitive)
📖 Loading GO annotations from TSV...
✅ Loaded GO annotations for 36,647 proteins from TSV

🔄 Processing train split...
   📁 Found 29,893 protein folders
   ✅ Processed: 29,893 proteins
   📊 Found GO terms: 29,893
   ❓ Missing GO terms: 0
   📝 Files written: 89,679

🔄 Processing val split...
   📁 Found 3,322 protein folders
   ✅ Processed: 3,322 proteins
   📊 Found GO terms: 3,322
   ❓ Missing GO terms: 0
   📝 Files written: 9,966

🔄 Processing test split...
   📁 Found 3,414 protein folders
   ✅ Processed: 3,414 proteins
   📊 Found GO terms: 3,414
   ❓ Missing GO terms: 0
   📝 Files written: 10,242

🎉 FINAL RESULTS
📊 Total proteins processed: 36,629
✅ Proteins with GO terms: 36,629
❓ Proteins without GO terms: 0
📝 Total GO files written: 109,887
📈 GO term coverage: 100.0%

🔍 SAMPLE VERIFICATION
Sample from train:
  154L-A       → 154L-A          | ✅ WITH GO
  155C-A       → 155C-A          | ✅ WITH GO
  16PK-A       → 16PK-A     

In [29]:
"""
Populate sequence.txt and L.csv for every PDBCH protein folder.
Handles tilde encoding/decoding and case-sensitive FASTA lookup.

Directory layout produced:
PDBCH/
    train_pdbch/<ENCODED_CID>/{sequence.txt, L.csv, *_go.txt}
    val_pdbch/<ENCODED_CID>/...
    test_pdbch/<ENCODED_CID>/...
"""

from pathlib import Path
import csv, os, concurrent.futures as cf
from tqdm.auto import tqdm

def decode_protein_id(encoded_name):
    """Restore original protein ID by removing tildes"""
    return encoded_name.replace('~', '')

def encode_protein_id(protein_id):
    """Convert to Windows-safe: uppercase preserved, lowercase gets tilde after"""
    return ''.join(c + '~' if c.islower() else c for c in protein_id)

# ── CONFIGURATION ──────────────────────────────────────────────────────────────
ROOT = Path(r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data")
FASTA_PATH = ROOT / "nrPDB-GO_2019.06.18_sequences.fasta"

# PDBCH folder structure (tilde-encoded folders)
PDBCH_ROOT = ROOT / "PDBCH"
SPLIT_FOLDERS = {
    "train": PDBCH_ROOT / "train_pdbch",
    "val": PDBCH_ROOT / "val_pdbch", 
    "test": PDBCH_ROOT / "test_pdbch"
}

MAX_WORKERS = os.cpu_count() or 8

print("🧬 Populating PDBCH Sequences with Case-Sensitive FASTA Lookup")
print("=" * 80)

# ── 1. COLLECT ALL TILDE-ENCODED FOLDERS & DECODE TO GET TARGET IDs ───────────
print("📁 Scanning existing PDBCH folders...")

tasks = []  # Will store (encoded_folder_path, original_protein_id, split_name)
all_target_ids = set()  # Original protein IDs we need sequences for

for split_name, split_folder in SPLIT_FOLDERS.items():
    if not split_folder.exists():
        print(f"   ❌ Split folder not found: {split_folder}")
        continue
    
    # Get all tilde-encoded protein folders
    protein_folders = [d for d in split_folder.iterdir() if d.is_dir()]
    print(f"   📂 {split_name}: {len(protein_folders):,} folders found")
    
    for protein_folder in protein_folders:
        # CRITICAL: Decode tilde-encoded folder name to get original protein ID
        encoded_folder_name = protein_folder.name
        original_protein_id = decode_protein_id(encoded_folder_name)
        
        # Store task info
        tasks.append((protein_folder, original_protein_id, split_name))
        all_target_ids.add(original_protein_id)

print(f"🎯 Total folders to process: {len(tasks):,}")
print(f"🔍 Unique protein IDs needed: {len(all_target_ids):,}")

# ── 2. PARSE FASTA FILE (CASE-SENSITIVE LOOKUP) ───────────────────────────────
print(f"\n📖 Parsing FASTA file: {FASTA_PATH.name}")

sequences = {}  # Case-sensitive lookup: original_protein_id -> sequence

with FASTA_PATH.open() as fh:
    current_id, sequence_lines = None, []
    
    for line in tqdm(fh, desc="Reading FASTA", unit="lines"):
        if line.startswith('>'):
            # Save previous sequence if it's one we need
            if current_id and current_id in all_target_ids:
                sequences[current_id] = ''.join(sequence_lines).upper()
            
            # Start new sequence
            current_id = line[1:].split()[0].strip()  # Extract ID (case-sensitive)
            sequence_lines = []
        else:
            sequence_lines.append(line.strip())
    
    # Don't forget the last sequence
    if current_id and current_id in all_target_ids:
        sequences[current_id] = ''.join(sequence_lines).upper()

print(f"✅ Sequences found: {len(sequences):,} / {len(all_target_ids):,}")

# Identify missing sequences
missing_sequences = all_target_ids - sequences.keys()
if missing_sequences:
    print(f"❌ Missing sequences: {len(missing_sequences):,}")
    # Show a few examples
    for missing_id in sorted(missing_sequences)[:5]:
        print(f"   • {missing_id}")
    if len(missing_sequences) > 5:
        print(f"   • ... and {len(missing_sequences) - 5} more")

# ── 3. PARALLEL WRITER FUNCTION ───────────────────────────────────────────────
def write_sequence_files(task):
    """Write sequence.txt and L.csv to tilde-encoded folder"""
    protein_folder, original_protein_id, split_name = task
    
    try:
        # Check if we have the sequence for this protein
        if original_protein_id not in sequences:
            return f"MISSING: {original_protein_id} (folder: {protein_folder.name})"
        
        sequence = sequences[original_protein_id]
        
        # Write sequence.txt to the TILDE-ENCODED folder
        sequence_file = protein_folder / "sequence.txt"
        sequence_file.write_text(sequence + "\n", encoding="utf-8")
        
        # Write L.csv (sequence length) to the TILDE-ENCODED folder
        length_file = protein_folder / "L.csv"
        with length_file.open('w', newline='', encoding="utf-8") as f:
            csv.writer(f).writerow([len(sequence)])
        
        return None  # Success
        
    except Exception as e:
        return f"ERROR: {original_protein_id} -> {str(e)}"

# ── 4. FILTER TASKS (ONLY PROCESS THOSE WITH AVAILABLE SEQUENCES) ─────────────
valid_tasks = []
skipped_tasks = []

for task in tasks:
    protein_folder, original_protein_id, split_name = task
    if original_protein_id in sequences:
        valid_tasks.append(task)
    else:
        skipped_tasks.append(task)

print(f"\n🚀 Processing {len(valid_tasks):,} proteins with sequences...")
if skipped_tasks:
    print(f"⏭️  Skipping {len(skipped_tasks):,} proteins without sequences")

# ── 5. PARALLEL EXECUTION WITH PROGRESS BAR ───────────────────────────────────
errors = []

with cf.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    # Use tqdm to track progress
    results = list(tqdm(
        executor.map(write_sequence_files, valid_tasks, chunksize=256),
        total=len(valid_tasks),
        desc="Writing sequence files",
        unit="protein"
    ))
    
    # Collect errors
    errors = [result for result in results if result is not None]

# ── 6. SUMMARY STATISTICS ─────────────────────────────────────────────────────
print(f"\n🎉 FINAL RESULTS")
print("=" * 80)

successful_writes = len(valid_tasks) - len(errors)
print(f"📊 Total proteins targeted: {len(tasks):,}")
print(f"✅ Successful writes: {successful_writes:,}")
print(f"❌ Failed writes: {len(errors):,}")
print(f"⏭️  Skipped (no sequence): {len(skipped_tasks):,}")

# Break down by split
split_stats = {"train": 0, "val": 0, "test": 0}
for task in valid_tasks:
    split_name = task[2]
    split_stats[split_name] += 1

print(f"\n📁 Files written per split:")
for split_name, count in split_stats.items():
    print(f"   {split_name}: {count:,} proteins")

# Show errors if any
if errors:
    print(f"\n❌ Errors encountered:")
    for error_msg in errors[:10]:  # Show first 10 errors
        print(f"   • {error_msg}")
    if len(errors) > 10:
        print(f"   • ... and {len(errors) - 10} more errors")

# ── 7. VERIFICATION SAMPLE ────────────────────────────────────────────────────
print(f"\n🔍 VERIFICATION SAMPLE")
print("=" * 80)

# Check a few folders to verify files were created correctly
for split_name, split_folder in list(SPLIT_FOLDERS.items())[:1]:  # Just check train
    sample_folders = sorted([d for d in split_folder.iterdir() if d.is_dir()])[:3]
    
    print(f"Sample from {split_name}:")
    for folder in sample_folders:
        encoded_name = folder.name
        original_id = decode_protein_id(encoded_name)
        
        # Check if files exist
        sequence_file = folder / "sequence.txt"
        length_file = folder / "L.csv"
        
        seq_exists = sequence_file.exists()
        len_exists = length_file.exists()
        
        if seq_exists and len_exists:
            # Read sequence length for verification
            try:
                with length_file.open('r') as f:
                    length = int(next(csv.reader(f))[0])
                status = f"✅ SEQ_LEN={length}"
            except:
                status = "⚠️  FILES_EXIST_BUT_UNREADABLE"
        else:
            status = "❌ MISSING_FILES"
        
        print(f"  {original_id:12s} → {encoded_name:15s} | {status}")

print(f"\n🎯 CRITICAL WORKFLOW EXECUTED:")
print("1. ✅ Scanned tilde-encoded folders in PDBCH structure")
print("2. ✅ Decoded folder names to get original protein IDs")
print("3. ✅ Performed case-sensitive FASTA sequence lookup")
print("4. ✅ Wrote sequence files back to tilde-encoded folders")
print("5. ✅ Preserved exact case sensitivity throughout process")

🧬 Populating PDBCH Sequences with Case-Sensitive FASTA Lookup
📁 Scanning existing PDBCH folders...
   📂 train: 29,893 folders found
   📂 val: 3,322 folders found
   📂 test: 3,414 folders found
🎯 Total folders to process: 36,629
🔍 Unique protein IDs needed: 36,629

📖 Parsing FASTA file: nrPDB-GO_2019.06.18_sequences.fasta


Reading FASTA: 0lines [00:00, ?lines/s]

✅ Sequences found: 36,629 / 36,629

🚀 Processing 36,629 proteins with sequences...


Writing sequence files:   0%|          | 0/36629 [00:00<?, ?protein/s]


🎉 FINAL RESULTS
📊 Total proteins targeted: 36,629
✅ Successful writes: 36,629
❌ Failed writes: 0
⏭️  Skipped (no sequence): 0

📁 Files written per split:
   train: 29,893 proteins
   val: 3,322 proteins
   test: 3,414 proteins

🔍 VERIFICATION SAMPLE
Sample from train:
  154L-A       → 154L-A          | ✅ SEQ_LEN=185
  155C-A       → 155C-A          | ✅ SEQ_LEN=135
  16PK-A       → 16PK-A          | ✅ SEQ_LEN=415

🎯 CRITICAL WORKFLOW EXECUTED:
1. ✅ Scanned tilde-encoded folders in PDBCH structure
2. ✅ Decoded folder names to get original protein IDs
3. ✅ Performed case-sensitive FASTA sequence lookup
4. ✅ Wrote sequence files back to tilde-encoded folders
5. ✅ Preserved exact case sensitivity throughout process


In [30]:
"""
Transfer final_filtered_256_stripped.a3m files from old HEAL_PDB structure 
to new PDBCH structure with proper tilde encoding and case-sensitive matching.

Critical workflow:
1. Scan tilde-encoded PDBCH folders
2. Decode folder names to get original protein IDs  
3. Find EXACT case-sensitive matches in old HEAL_PDB structure
4. Copy .a3m files to tilde-encoded destination folders
"""

from pathlib import Path
import shutil
from tqdm.auto import tqdm

def decode_protein_id(encoded_name):
    """Restore original protein ID by removing tildes"""
    return encoded_name.replace('~', '')

def encode_protein_id(protein_id):
    """Convert to Windows-safe: uppercase preserved, lowercase gets tilde after"""
    return ''.join(c + '~' if c.islower() else c for c in protein_id)

# ── CONFIGURATION ──────────────────────────────────────────────────────────────
HEAL_PDB_ROOT = Path(r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\HEAL_PDB")
PDBCH_ROOT = Path(r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\PDBCH")

# Source: Old HEAL_PDB structure (non-tilde encoded)
OLD_FOLDERS = {
    "train": HEAL_PDB_ROOT / "protein_train_pdb",
    "val": HEAL_PDB_ROOT / "protein_val_pdb", 
    "test": HEAL_PDB_ROOT / "protein_test_pdb"
}

# Destination: New PDBCH structure (tilde-encoded)
NEW_FOLDERS = {
    "train": PDBCH_ROOT / "train_pdbch",
    "val": PDBCH_ROOT / "val_pdbch",
    "test": PDBCH_ROOT / "test_pdbch"
}

TARGET_FILE = "final_filtered_256_stripped.a3m"

print("🔄 Transferring A3M Files: HEAL_PDB → PDBCH (Case-Sensitive)")
print("=" * 80)

# ── 1. BUILD CASE-SENSITIVE LOOKUP OF AVAILABLE A3M FILES ─────────────────────
print("🔍 Scanning old HEAL_PDB structure for available A3M files...")

old_files_lookup = {}  # Structure: {split: {original_protein_id: file_path}}

for split_name, old_folder in OLD_FOLDERS.items():
    if not old_folder.exists():
        print(f"   ❌ Old folder not found: {old_folder}")
        continue
    
    old_files_lookup[split_name] = {}
    
    # Scan all protein folders in this split
    protein_folders = [d for d in old_folder.iterdir() if d.is_dir()]
    found_files = 0
    
    for protein_folder in protein_folders:
        original_protein_id = protein_folder.name  # Case-sensitive protein ID
        target_file_path = protein_folder / TARGET_FILE
        
        if target_file_path.exists():
            old_files_lookup[split_name][original_protein_id] = target_file_path
            found_files += 1
    
    print(f"   📂 {split_name}: {found_files:,} proteins with {TARGET_FILE}")

total_available = sum(len(split_files) for split_files in old_files_lookup.values())
print(f"✅ Total A3M files available: {total_available:,}")

# ── 2. PROCESS EACH PDBCH SPLIT ───────────────────────────────────────────────
overall_stats = {"copied": 0, "missing": 0, "errors": 0}

for split_name, new_folder in NEW_FOLDERS.items():
    if not new_folder.exists():
        print(f"\n❌ New folder not found: {new_folder}")
        continue
    
    print(f"\n🔄 Processing {split_name} split...")
    
    # Get corresponding old folder lookup
    if split_name not in old_files_lookup:
        print(f"   ❌ No old files found for {split_name} split")
        continue
    
    old_split_files = old_files_lookup[split_name]
    
    # Get all tilde-encoded protein folders in new structure
    protein_folders = [d for d in new_folder.iterdir() if d.is_dir()]
    print(f"   📁 Found {len(protein_folders):,} tilde-encoded folders")
    
    split_stats = {"copied": 0, "missing": 0, "errors": 0}
    examples = {"copied": [], "missing": []}
    
    # Process each tilde-encoded folder
    for protein_folder in tqdm(protein_folders, desc=f"Transferring {split_name}", unit="protein"):
        # STEP 1: Decode tilde-encoded folder name to get original protein ID
        encoded_folder_name = protein_folder.name
        original_protein_id = decode_protein_id(encoded_folder_name)
        
        # STEP 2: Look for EXACT case-sensitive match in old structure
        if original_protein_id in old_split_files:
            # Found exact match!
            source_file_path = old_split_files[original_protein_id]
            destination_file_path = protein_folder / TARGET_FILE
            
            try:
                # STEP 3: Copy file to tilde-encoded destination folder
                shutil.copy2(source_file_path, destination_file_path)
                split_stats["copied"] += 1
                
                # Store example for verification
                if len(examples["copied"]) < 3:
                    examples["copied"].append({
                        "original_id": original_protein_id,
                        "encoded_folder": encoded_folder_name,
                        "source": source_file_path,
                        "dest": destination_file_path
                    })
                
            except Exception as e:
                split_stats["errors"] += 1
                print(f"   ❌ Error copying {original_protein_id}: {e}")
        else:
            # No exact match found
            split_stats["missing"] += 1
            
            # Store example for debugging
            if len(examples["missing"]) < 3:
                examples["missing"].append({
                    "original_id": original_protein_id,
                    "encoded_folder": encoded_folder_name
                })
    
    # Update overall stats
    for key in overall_stats:
        overall_stats[key] += split_stats[key]
    
    # Report split results
    print(f"   ✅ Copied: {split_stats['copied']:,}")
    print(f"   ❓ Missing: {split_stats['missing']:,}")
    print(f"   ❌ Errors: {split_stats['errors']:,}")
    
    # Show examples
    if examples["copied"]:
        print(f"   📋 Copy examples:")
        for ex in examples["copied"]:
            print(f"      {ex['original_id']:12s} → {ex['encoded_folder']:15s} ✅")
    
    if examples["missing"]:
        print(f"   📋 Missing examples:")
        for ex in examples["missing"]:
            print(f"      {ex['original_id']:12s} → {ex['encoded_folder']:15s} ❓")

# ── 3. FINAL SUMMARY ───────────────────────────────────────────────────────────
print(f"\n🎉 TRANSFER COMPLETE")
print("=" * 80)
print(f"✅ Files successfully copied: {overall_stats['copied']:,}")
print(f"❓ Files not found in source: {overall_stats['missing']:,}")
print(f"❌ Copy errors: {overall_stats['errors']:,}")

success_rate = (overall_stats['copied'] / (overall_stats['copied'] + overall_stats['missing']) * 100) if (overall_stats['copied'] + overall_stats['missing']) > 0 else 0
print(f"📈 Success rate: {success_rate:.1f}%")

# ── 4. VERIFICATION SAMPLE ────────────────────────────────────────────────────
print(f"\n🔍 VERIFICATION SAMPLE")
print("=" * 80)

# Check a few transferred files
for split_name, new_folder in list(NEW_FOLDERS.items())[:1]:  # Just check train
    if not new_folder.exists():
        continue
    
    sample_folders = sorted([d for d in new_folder.iterdir() if d.is_dir()])[:3]
    
    print(f"Sample verification from {split_name}:")
    for folder in sample_folders:
        encoded_name = folder.name
        original_id = decode_protein_id(encoded_name)
        a3m_file = folder / TARGET_FILE
        
        if a3m_file.exists():
            file_size = a3m_file.stat().st_size
            status = f"✅ A3M_SIZE={file_size:,}B"
        else:
            status = "❌ A3M_MISSING"
        
        print(f"  {original_id:12s} → {encoded_name:15s} | {status}")

print(f"\n🎯 CRITICAL WORKFLOW EXECUTED:")
print("1. ✅ Scanned tilde-encoded PDBCH folders")
print("2. ✅ Decoded folder names to get original protein IDs")
print("3. ✅ Performed case-sensitive exact matching with old structure")
print("4. ✅ Copied A3M files to tilde-encoded destination folders")
print("5. ✅ Preserved split correspondence (train→train, val→val, test→test)")
print("6. ✅ Maintained exact case sensitivity throughout process")

🔄 Transferring A3M Files: HEAL_PDB → PDBCH (Case-Sensitive)
🔍 Scanning old HEAL_PDB structure for available A3M files...
   📂 train: 29,733 proteins with final_filtered_256_stripped.a3m
   📂 val: 3,316 proteins with final_filtered_256_stripped.a3m
   📂 test: 3,398 proteins with final_filtered_256_stripped.a3m
✅ Total A3M files available: 36,447

🔄 Processing train split...
   📁 Found 29,893 tilde-encoded folders


Transferring train:   0%|          | 0/29893 [00:00<?, ?protein/s]

   ✅ Copied: 29,733
   ❓ Missing: 160
   ❌ Errors: 0
   📋 Copy examples:
      154L-A       → 154L-A          ✅
      155C-A       → 155C-A          ✅
      16PK-A       → 16PK-A          ✅
   📋 Missing examples:
      1W8X-M       → 1W8X-M          ❓
      2WWX-B       → 2WWX-B          ❓
      3J79-f       → 3J79-f~         ❓

🔄 Processing val split...
   📁 Found 3,322 tilde-encoded folders


Transferring val:   0%|          | 0/3322 [00:00<?, ?protein/s]

   ✅ Copied: 3,316
   ❓ Missing: 6
   ❌ Errors: 0
   📋 Copy examples:
      192L-A       → 192L-A          ✅
      1A0A-A       → 1A0A-A          ✅
      1A21-A       → 1A21-A          ✅
   📋 Missing examples:
      1KVE-A       → 1KVE-A          ❓
      3J9M-l       → 3J9M-l~         ❓
      4BTP-A       → 4BTP-A          ❓

🔄 Processing test split...
   📁 Found 3,414 tilde-encoded folders


Transferring test:   0%|          | 0/3414 [00:00<?, ?protein/s]

   ✅ Copied: 3,398
   ❓ Missing: 16
   ❌ Errors: 0
   📋 Copy examples:
      11AS-A       → 11AS-A          ✅
      18GS-A       → 18GS-A          ✅
      1A0P-A       → 1A0P-A          ✅
   📋 Missing examples:
      3H4P-a       → 3H4P-a~         ❓
      4V6W-Ab      → 4V6W-Ab~        ❓
      4V6W-Ca      → 4V6W-Ca~        ❓

🎉 TRANSFER COMPLETE
✅ Files successfully copied: 36,447
❓ Files not found in source: 182
❌ Copy errors: 0
📈 Success rate: 99.5%

🔍 VERIFICATION SAMPLE
Sample verification from train:
  154L-A       → 154L-A          | ✅ A3M_SIZE=61,594B
  155C-A       → 155C-A          | ✅ A3M_SIZE=45,033B
  16PK-A       → 16PK-A          | ✅ A3M_SIZE=121,666B

🎯 CRITICAL WORKFLOW EXECUTED:
1. ✅ Scanned tilde-encoded PDBCH folders
2. ✅ Decoded folder names to get original protein IDs
3. ✅ Performed case-sensitive exact matching with old structure
4. ✅ Copied A3M files to tilde-encoded destination folders
5. ✅ Preserved split correspondence (train→train, val→val, test→test)
6. ✅ M

In [39]:
"""
Hierarchical mapping of problematic PDBCH proteins to OpenProtein Set sequences.

Maps proteins using:
1. Stage 1: Exact ID match (format conversion)
2. Stage 2: Exact sequence match  
3. Stage 3: Similarity match (±15 AA tolerance)

Critical: Preserves exact case sensitivity and tilde encoding throughout.
"""

from pathlib import Path
import csv, time
from collections import defaultdict, Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm.auto import tqdm

def decode_protein_id(encoded_name):
    """Restore original protein ID by removing tildes"""
    return encoded_name.replace('~', '')

def format_protein_id_for_openfold(original_id):
    """
    Format protein ID for OpenProtein Set lookup:
    - Make first 4 characters lowercase
    - Convert - to _
    - Leave chain part exactly as-is (no case change)
    
    Examples:
        3H4P-a  → 3h4p_a
        5IT7-aa → 5it7_aa
        4V6W-Co → 4v6w_Co
    """
    if '-' in original_id:
        pdb_part, chain_part = original_id.split('-', 1)
        return f"{pdb_part[:4].lower()}_{chain_part}"
    elif '_' in original_id:
        pdb_part, chain_part = original_id.split('_', 1)
        return f"{pdb_part[:4].lower()}_{chain_part}"
    else:
        return original_id[:4].lower()

# ── CONFIGURATION ──────────────────────────────────────────────────────────────
ROOT = Path(r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing")
FASTA_FILE = ROOT / "data" / "openfold_pdb_query_sequences.fasta"
PDBCH_ROOT = ROOT / "data" / "PDBCH"

SPLIT_FOLDERS = {
    "train": PDBCH_ROOT / "train_pdbch",
    "val": PDBCH_ROOT / "val_pdbch",
    "test": PDBCH_ROOT / "test_pdbch"
}

OUTPUT_TSV = ROOT / "problematic_proteins_openfold_mapping.tsv"

# Specific problematic proteins (tilde-encoded folder names + additional from percentage list)
PROBLEMATIC_PROTEINS = {
    # Original tilde-encoded list
    "3H4P-a~", "3J79-f~", "3J79-i~", "3J7Y-f~", "3J92-f~", "3J9M-l~", "3JB9-g~",
    "3JB9-i~", "3JB9-r~", "3JCS-d~", "3JCS-e~", "3JD5-f~", "3JD5-j~", "3JD5-n~",
    "4CE4-u~", "4V3P-La~", "4V3P-Lb~", "4V3P-Le~", "4V3P-Lj~", "4V4N-Ad~", "4V4N-Aj~",
    "4V6U-Be~", "4V6U-Bl~", "4V6W-Ab~", "4V6W-Af~", "4V6W-Ca~", "4V6W-Cb~", "4V6W-Cc~",
    "4V6W-Cg~", "4V6W-Cj~", "4V6W-Ck~", "4V6W-Co~", "4V6W-Cr~", "4V7E-Ba~", "4V7E-Bb~",
    "4V7E-Bf~", "4V7E-Bg~", "4V7E-Ce~", "4V7E-Ci~", "4V7E-Co~", "4V7E-Cq~", "4V7E-Cu~",
    "4V8M-Bi~", "4V8M-Bj~", "4V8M-Bk~", "4V8M-Bl~", "4V8M-Bm~", "4V8M-Bp~", "4V8M-Br~",
    "4V8M-Bs~", "4V8M-Bt~", "4V8M-Bu~", "4V8M-Bv~", "4V8M-By~", "5AJ4-Ak~", "5AJ4-Ap~",
    "5GUP-l~", "5GUP-m~", "5GUP-v~", "5IT7-a~a~", "5IT7-c~c~", "5IT7-e~e~", "5IT7-f~f~",
    "5IT7-h~h~", "5IT7-i~i~", "5IT7-j~j~", "5IT7-o~o~", "5IT7-p~p~", "5IT7-r~r~", "5KZ5-a~",
    "5L9W-b~", "5LJ5-s~", "5LNK-h~", "5LNK-l~", "5LNK-m~", "5LNK-n~", "5MMM-i~", "5NGM-Ao~",
    "5O31-m~", "5OOL-w~", "5OPT-l~", "5OPT-p~", "5OPT-r~", "5OPT-t~", "5OQL-e~", "5OQL-h~",
    "5OQL-l~", "5OQL-p~", "5OQL-u~", "5OQL-v~", "5OQL-x~", "5OQL-y~", "5T2A-n~", "5T2A-v~",
    "5T5H-i~", "5T5H-l~", "5T5H-n~", "5T5H-t~", "5T5H-v~", "5T5H-w~", "5V93-e~", "5V93-f~",
    "5V93-n~", "5V93-p~", "5V93-q~", "5VK2-a~", "5XXB-b~", "5XXB-c~", "5XXB-e~", "5XXB-f~",
    "5XXB-g~", "5XXB-h~", "5XXB-i~", "5XXB-n~", "5XXB-o~", "5XXB-p~", "5XXU-a~", "5XXU-b~",
    "5XXU-c~", "5XXU-f~", "5XY3-c~", "5XY3-d~", "5XY3-e~", "5XY3-f~", "5XY3-g~", "5XY3-h~",
    "5XY3-i~", "5XY3-j~", "5XY3-m~", "5XY3-o~", "5XY3-p~", "5XYI-a~", "5XYI-b~", "5XYI-c~",
    "5XYI-e~", "5YZG-w~", "5ZWN-y~", "6AZ1-a~", "6AZ1-b~", "6AZ1-c~", "6AZ1-e~", "6AZ3-d~",
    "6AZ3-h~", "6DZI-t~", "6DZI-z~", "6ERI-Az~", "6G2J-e~", "6G72-b~", "6G72-g~", "6G72-h~",
    "6G72-m~", "6GAZ-An~", "6GAZ-Ao~", "6GAZ-Ap~", "6GB2-Be~", "6GB2-Bw~", "6GCS-c~", "6GCS-d~",
    "6GCS-f~", "6GCS-h~", "6GCS-j~", "6GIQ-d~", "6GIQ-e~", "6GIQ-h~", "6HA1-p~", "6HA8-d~",
    "6HHQ-p~", "6HIV-Bb~", "6HIV-Bc~", "6HIX-Av~", "6HIX-Bg~", "6HIZ-Ci~", "6OKK-b~", "6OKK-c~",
    "6QDV-d~",
    
    # Additional proteins from percentage list (uppercase, no tilde encoding needed)
    "3H4P-A", "3J79-I", "3JB9-I", "3JB9-R", "3JD5-F", "3JD5-J", "4UDF-1B", "4V6U-BL", 
    "4V7E-CE", "4V8T-O", "5IT7-CC", "5IT7-EE", "5IT7-HH", "5IT7-II", "5IT7-RR", "5LNK-N", 
    "5NGM-AO", "5OOL-W", "5OPT-P", "5OPT-R", "5OQL-U", "5OQL-X", "5T5H-I", "5T5H-L", 
    "5T5H-V", "5T5H-W", "5V93-F", "5XXB-E", "5XXB-F", "5XXB-G", "5XXB-O", "5XXU-A", 
    "5XXU-B", "5XXU-C", "5XXU-F", "5XY3-E", "5XY3-F", "5XY3-G", "5XY3-I", "5XY3-J", 
    "5XY3-O", "5XYI-B", "5YZG-W", "5ZWN-Y", "6AZ1-A", "6AZ1-B", "6AZ1-C", "6DZI-T", 
    "6GAZ-AN", "6GAZ-AP", "6GB2-BE", "6GCS-D", "6GCS-F", "6GCS-H", "6GCS-J", "6GIQ-D", 
    "6HIX-AV", "6HIX-BG"
}

LENGTH_TOLERANCE = 15
SIMILARITY_THRESHOLD = 0.0
MAX_THREADS = 8

# ── SIMILARITY FUNCTION ────────────────────────────────────────────────────────
from Bio import pairwise2
def similarity(seq1, seq2):
    """Calculate sequence similarity using Biopython"""
    if not seq1 or not seq2:
        return 0.0
    score = pairwise2.align.globalxx(seq1, seq2, score_only=True)
    return score / max(len(seq1), len(seq2))

def similarity_worker(args):
    """Worker function for parallel similarity calculation"""
    p_seq, f_id, f_seq = args
    try:
        return f_id, similarity(p_seq, f_seq)
    except Exception as e:
        return f_id, 0.0

print("🧬 Mapping Problematic Proteins to OpenProtein Set")
print("=" * 70)

# ── 1. LOAD OPENFOLD FASTA SEQUENCES ───────────────────────────────────────────
print("📖 Loading OpenProtein Set FASTA sequences...")

fasta_seqs = {}
fasta_len = {}

with FASTA_FILE.open() as fh:
    current_id, sequence_lines = None, []
    
    for line in tqdm(fh, desc="Reading FASTA", unit="lines"):
        line = line.strip()
        if line.startswith('>'):
            # Save previous sequence
            if current_id and sequence_lines:
                seq = ''.join(sequence_lines)
                fasta_seqs[current_id] = seq
                fasta_len[current_id] = len(seq)
            
            # Start new sequence
            current_id = line[1:]  # Remove '>' and keep exact ID
            sequence_lines = []
        else:
            sequence_lines.append(line)
    
    # Save last sequence
    if current_id and sequence_lines:
        seq = ''.join(sequence_lines)
        fasta_seqs[current_id] = seq
        fasta_len[current_id] = len(seq)

print(f"✅ Loaded {len(fasta_seqs):,} OpenProtein Set sequences")

# Build quick lookup indices
seq_to_ids = defaultdict(list)
len_index = defaultdict(list)

for fasta_id, seq in fasta_seqs.items():
    seq_to_ids[seq].append(fasta_id)
    len_index[len(seq)].append(fasta_id)

# ── 2. COLLECT TARGET PROTEINS FROM PDBCH FOLDERS ─────────────────────────────
print(f"\n🔍 Scanning PDBCH folders for {len(PROBLEMATIC_PROTEINS):,} problematic proteins...")

target_proteins = []  # List of (tilde_encoded_id, protein_folder_path, original_id, sequence)

for split_name, split_folder in SPLIT_FOLDERS.items():
    if not split_folder.exists():
        continue
    
    split_found = 0
    for tilde_encoded_id in PROBLEMATIC_PROTEINS:
        protein_folder = split_folder / tilde_encoded_id
        
        if protein_folder.exists() and protein_folder.is_dir():
            # Decode tilde to get original protein ID
            original_id = decode_protein_id(tilde_encoded_id)
            
            # Read sequence from sequence.txt
            sequence_file = protein_folder / "sequence.txt"
            if sequence_file.exists():
                try:
                    sequence = ''.join(c for c in sequence_file.read_text().strip() if c.isalpha())
                    target_proteins.append((tilde_encoded_id, protein_folder, original_id, sequence))
                    split_found += 1
                except Exception as e:
                    print(f"   ❌ Error reading sequence for {tilde_encoded_id}: {e}")
    
    print(f"   📂 {split_name}: Found {split_found} proteins")

print(f"✅ Total proteins to map: {len(target_proteins):,}")

# ── 3. HIERARCHICAL MAPPING ────────────────────────────────────────────────────
print(f"\n🔄 Performing hierarchical mapping...")

results = []
stats = Counter(stage1=0, stage2=0, stage3=0, nomatch=0, no_sequence=0)
start_time = time.time()

for tilde_encoded_id, protein_folder, original_id, sequence in tqdm(target_proteins, desc="Mapping proteins", unit="protein"):
    
    if not sequence:
        results.append([tilde_encoded_id, "NO_SEQUENCE", "NA", "No Sequence Available"])
        stats["no_sequence"] += 1
        continue
    
    # STAGE 1: Exact ID match after format conversion
    formatted_id = format_protein_id_for_openfold(original_id)
    
    if formatted_id in fasta_seqs:
        results.append([tilde_encoded_id, formatted_id, "NA", "Stage 1 - Exact ID"])
        stats["stage1"] += 1
        continue
    
    # STAGE 2: Exact sequence match
    matching_ids = seq_to_ids.get(sequence, [])
    if matching_ids:
        results.append([tilde_encoded_id, matching_ids[0], "NA", "Stage 2 - Exact Sequence"])
        stats["stage2"] += 1
        continue
    
    # STAGE 3: Similarity search (±15 AA tolerance)
    protein_length = len(sequence)
    candidates = []
    
    # Collect candidates within length tolerance
    for length in range(max(1, protein_length - LENGTH_TOLERANCE), 
                       protein_length + LENGTH_TOLERANCE + 1):
        candidates.extend(len_index[length])
    
    best_id, best_similarity = None, SIMILARITY_THRESHOLD
    
    if candidates:
        if MAX_THREADS > 1 and len(candidates) > 10:
            # Parallel similarity calculation
            args = [(sequence, fasta_id, fasta_seqs[fasta_id]) for fasta_id in candidates]
            with ThreadPoolExecutor(max_workers=MAX_THREADS) as executor:
                for fasta_id, sim_score in executor.map(similarity_worker, args):
                    if sim_score > best_similarity:
                        best_similarity, best_id = sim_score, fasta_id
        else:
            # Sequential similarity calculation
            for fasta_id in candidates:
                sim_score = similarity(sequence, fasta_seqs[fasta_id])
                if sim_score > best_similarity:
                    best_similarity, best_id = sim_score, fasta_id
    
    if best_id:
        results.append([tilde_encoded_id, best_id, f"{best_similarity:.4f}", "Stage 3 - Similarity"])
        stats["stage3"] += 1
    else:
        results.append([tilde_encoded_id, "NO_MATCH", "NA", "No Match Found"])
        stats["nomatch"] += 1

elapsed_time = time.time() - start_time

# ── 4. WRITE RESULTS TO TSV ────────────────────────────────────────────────────
print(f"\n📝 Writing results to TSV...")

with OUTPUT_TSV.open("w", newline="", encoding="utf-8") as tsv_file:
    writer = csv.writer(tsv_file, delimiter="\t")
    writer.writerow(["original_id_tilde", "matched_openfold_id", "similarity_score", "mapping_stage"])
    writer.writerows(results)

# ── 5. SUMMARY REPORT ──────────────────────────────────────────────────────────
total_processed = sum(stats.values())

print(f"\n🎉 MAPPING COMPLETE")
print("=" * 70)
print(f"📊 Mapping Statistics:")
print(f"   Stage 1 (Exact ID)      : {stats['stage1']:3d} ({stats['stage1']/total_processed*100:5.1f}%)")
print(f"   Stage 2 (Exact Sequence): {stats['stage2']:3d} ({stats['stage2']/total_processed*100:5.1f}%)")
print(f"   Stage 3 (Similarity)    : {stats['stage3']:3d} ({stats['stage3']/total_processed*100:5.1f}%)")
print(f"   No Sequence Available   : {stats['no_sequence']:3d} ({stats['no_sequence']/total_processed*100:5.1f}%)")
print(f"   No Match Found          : {stats['nomatch']:3d} ({stats['nomatch']/total_processed*100:5.1f}%)")

successful_matches = stats['stage1'] + stats['stage2'] + stats['stage3']
print(f"\n✅ Total successful matches: {successful_matches:,} / {total_processed:,} ({successful_matches/total_processed*100:.1f}%)")
print(f"⏱️  Processing time: {elapsed_time:.1f} seconds")
print(f"📁 Results saved to: {OUTPUT_TSV}")

# ── 6. SAMPLE RESULTS ──────────────────────────────────────────────────────────
print(f"\n📋 SAMPLE MAPPING RESULTS:")
print("-" * 70)
for i, (tilde_id, matched_id, score, stage) in enumerate(results[:10]):
    original_id = decode_protein_id(tilde_id)
    formatted_id = format_protein_id_for_openfold(original_id) 
    print(f"{original_id:8s} → {tilde_id:10s} → {formatted_id:8s} → {matched_id:15s} | {stage}")

if len(results) > 10:
    print(f"... and {len(results) - 10} more results in TSV file")

🧬 Mapping Problematic Proteins to OpenProtein Set
📖 Loading OpenProtein Set FASTA sequences...


Reading FASTA: 0lines [00:00, ?lines/s]

✅ Loaded 131,487 OpenProtein Set sequences

🔍 Scanning PDBCH folders for 233 problematic proteins...
   📂 train: Found 208 proteins
   📂 val: Found 5 proteins
   📂 test: Found 20 proteins
✅ Total proteins to map: 233

🔄 Performing hierarchical mapping...


Mapping proteins:   0%|          | 0/233 [00:00<?, ?protein/s]


📝 Writing results to TSV...

🎉 MAPPING COMPLETE
📊 Mapping Statistics:
   Stage 1 (Exact ID)      : 189 ( 81.1%)
   Stage 2 (Exact Sequence):  43 ( 18.5%)
   Stage 3 (Similarity)    :   1 (  0.4%)
   No Sequence Available   :   0 (  0.0%)
   No Match Found          :   0 (  0.0%)

✅ Total successful matches: 233 / 233 (100.0%)
⏱️  Processing time: 7.4 seconds
📁 Results saved to: C:\Users\rfrjo\Documents\Codebases\PFP_Testing\problematic_proteins_openfold_mapping.tsv

📋 SAMPLE MAPPING RESULTS:
----------------------------------------------------------------------
5IT7-EE  → 5IT7-EE    → 5it7_EE  → 5it7_EE         | Stage 1 - Exact ID
5XY3-m   → 5XY3-m~    → 5xy3_m   → 5xy3_m          | Stage 1 - Exact ID
6GAZ-Ap  → 6GAZ-Ap~   → 6gaz_Ap  → 5aj3_p          | Stage 2 - Exact Sequence
4V8M-Bp  → 4V8M-Bp~   → 4v8m_Bp  → 4v8m_Bp         | Stage 1 - Exact ID
5XYI-b   → 5XYI-b~    → 5xyi_b   → 5xyi_b          | Stage 1 - Exact ID
3JB9-i   → 3JB9-i~    → 3jb9_i   → 3jb9_i          | Stage 1 - Ex

In [38]:
"""
Delete existing final_filtered_256_stripped.a3m files for specific problematic proteins
in the PDBCH folder structure.
"""

from pathlib import Path
from tqdm.auto import tqdm

# ── CONFIGURATION ──────────────────────────────────────────────────────────────
PDBCH_ROOT = Path(r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\PDBCH")

SPLIT_FOLDERS = {
    "train": PDBCH_ROOT / "train_pdbch",
    "val": PDBCH_ROOT / "val_pdbch",
    "test": PDBCH_ROOT / "test_pdbch"
}

TARGET_FILE = "final_filtered_256_stripped.a3m"

# List of problematic proteins (tilde-encoded folder names + additional from percentage list)
PROBLEMATIC_PROTEINS = {
    # Original tilde-encoded list
    "3H4P-a~", "3J79-f~", "3J79-i~", "3J7Y-f~", "3J92-f~", "3J9M-l~", "3JB9-g~",
    "3JB9-i~", "3JB9-r~", "3JCS-d~", "3JCS-e~", "3JD5-f~", "3JD5-j~", "3JD5-n~",
    "4CE4-u~", "4V3P-La~", "4V3P-Lb~", "4V3P-Le~", "4V3P-Lj~", "4V4N-Ad~", "4V4N-Aj~",
    "4V6U-Be~", "4V6U-Bl~", "4V6W-Ab~", "4V6W-Af~", "4V6W-Ca~", "4V6W-Cb~", "4V6W-Cc~",
    "4V6W-Cg~", "4V6W-Cj~", "4V6W-Ck~", "4V6W-Co~", "4V6W-Cr~", "4V7E-Ba~", "4V7E-Bb~",
    "4V7E-Bf~", "4V7E-Bg~", "4V7E-Ce~", "4V7E-Ci~", "4V7E-Co~", "4V7E-Cq~", "4V7E-Cu~",
    "4V8M-Bi~", "4V8M-Bj~", "4V8M-Bk~", "4V8M-Bl~", "4V8M-Bm~", "4V8M-Bp~", "4V8M-Br~",
    "4V8M-Bs~", "4V8M-Bt~", "4V8M-Bu~", "4V8M-Bv~", "4V8M-By~", "5AJ4-Ak~", "5AJ4-Ap~",
    "5GUP-l~", "5GUP-m~", "5GUP-v~", "5IT7-a~a~", "5IT7-c~c~", "5IT7-e~e~", "5IT7-f~f~",
    "5IT7-h~h~", "5IT7-i~i~", "5IT7-j~j~", "5IT7-o~o~", "5IT7-p~p~", "5IT7-r~r~", "5KZ5-a~",
    "5L9W-b~", "5LJ5-s~", "5LNK-h~", "5LNK-l~", "5LNK-m~", "5LNK-n~", "5MMM-i~", "5NGM-Ao~",
    "5O31-m~", "5OOL-w~", "5OPT-l~", "5OPT-p~", "5OPT-r~", "5OPT-t~", "5OQL-e~", "5OQL-h~",
    "5OQL-l~", "5OQL-p~", "5OQL-u~", "5OQL-v~", "5OQL-x~", "5OQL-y~", "5T2A-n~", "5T2A-v~",
    "5T5H-i~", "5T5H-l~", "5T5H-n~", "5T5H-t~", "5T5H-v~", "5T5H-w~", "5V93-e~", "5V93-f~",
    "5V93-n~", "5V93-p~", "5V93-q~", "5VK2-a~", "5XXB-b~", "5XXB-c~", "5XXB-e~", "5XXB-f~",
    "5XXB-g~", "5XXB-h~", "5XXB-i~", "5XXB-n~", "5XXB-o~", "5XXB-p~", "5XXU-a~", "5XXU-b~",
    "5XXU-c~", "5XXU-f~", "5XY3-c~", "5XY3-d~", "5XY3-e~", "5XY3-f~", "5XY3-g~", "5XY3-h~",
    "5XY3-i~", "5XY3-j~", "5XY3-m~", "5XY3-o~", "5XY3-p~", "5XYI-a~", "5XYI-b~", "5XYI-c~",
    "5XYI-e~", "5YZG-w~", "5ZWN-y~", "6AZ1-a~", "6AZ1-b~", "6AZ1-c~", "6AZ1-e~", "6AZ3-d~",
    "6AZ3-h~", "6DZI-t~", "6DZI-z~", "6ERI-Az~", "6G2J-e~", "6G72-b~", "6G72-g~", "6G72-h~",
    "6G72-m~", "6GAZ-An~", "6GAZ-Ao~", "6GAZ-Ap~", "6GB2-Be~", "6GB2-Bw~", "6GCS-c~", "6GCS-d~",
    "6GCS-f~", "6GCS-h~", "6GCS-j~", "6GIQ-d~", "6GIQ-e~", "6GIQ-h~", "6HA1-p~", "6HA8-d~",
    "6HHQ-p~", "6HIV-Bb~", "6HIV-Bc~", "6HIX-Av~", "6HIX-Bg~", "6HIZ-Ci~", "6OKK-b~", "6OKK-c~",
    "6QDV-d~",
    
    # Additional proteins from percentage list (uppercase, no tilde encoding needed)
    "3H4P-A", "3J79-I", "3JB9-I", "3JB9-R", "3JD5-F", "3JD5-J", "4UDF-1B", "4V6U-BL", 
    "4V7E-CE", "4V8T-O", "5IT7-CC", "5IT7-EE", "5IT7-HH", "5IT7-II", "5IT7-RR", "5LNK-N", 
    "5NGM-AO", "5OOL-W", "5OPT-P", "5OPT-R", "5OQL-U", "5OQL-X", "5T5H-I", "5T5H-L", 
    "5T5H-V", "5T5H-W", "5V93-F", "5XXB-E", "5XXB-F", "5XXB-G", "5XXB-O", "5XXU-A", 
    "5XXU-B", "5XXU-C", "5XXU-F", "5XY3-E", "5XY3-F", "5XY3-G", "5XY3-I", "5XY3-J", 
    "5XY3-O", "5XYI-B", "5YZG-W", "5ZWN-Y", "6AZ1-A", "6AZ1-B", "6AZ1-C", "6DZI-T", 
    "6GAZ-AN", "6GAZ-AP", "6GB2-BE", "6GCS-D", "6GCS-F", "6GCS-H", "6GCS-J", "6GIQ-D", 
    "6HIX-AV", "6HIX-BG"
}

print("🗑️  Deleting A3M Files for Problematic Proteins")
print("=" * 60)

# ── SCAN AND DELETE ────────────────────────────────────────────────────────────
total_deleted = 0
total_checked = 0

for split_name, split_folder in SPLIT_FOLDERS.items():
    if not split_folder.exists():
        print(f"❌ Split folder not found: {split_folder}")
        continue
    
    print(f"\n🔍 Processing {split_name} split...")
    
    split_deleted = 0
    split_found = 0
    
    # Check each problematic protein in this split
    for protein_id in tqdm(PROBLEMATIC_PROTEINS, desc=f"Checking {split_name}", unit="protein"):
        total_checked += 1
        
        # Look for tilde-encoded folder
        protein_folder = split_folder / protein_id
        
        if protein_folder.exists() and protein_folder.is_dir():
            split_found += 1
            a3m_file = protein_folder / TARGET_FILE
            
            if a3m_file.exists():
                try:
                    # Delete the A3M file
                    a3m_file.unlink()
                    split_deleted += 1
                    total_deleted += 1
                    print(f"   ✅ Deleted: {protein_id}/{TARGET_FILE}")
                except Exception as e:
                    print(f"   ❌ Error deleting {protein_id}/{TARGET_FILE}: {e}")
            # If A3M doesn't exist, that's fine - we wanted to delete it anyway
    
    print(f"   📊 {split_name}: Found {split_found} proteins, deleted {split_deleted} A3M files")

# ── SUMMARY ────────────────────────────────────────────────────────────────────
print(f"\n🎉 DELETION SUMMARY")
print("=" * 60)
print(f"🔍 Total proteins checked: {len(PROBLEMATIC_PROTEINS):,}")
print(f"🗑️  Total A3M files deleted: {total_deleted:,}")

if total_deleted > 0:
    print(f"\n✅ Successfully cleared A3M files for {total_deleted} problematic proteins")
    print("   These proteins are now ready for fresh A3M generation")
else:
    print(f"\n📝 No A3M files found to delete - proteins may already be cleared")

print("\nNext step: Run the mapping script to identify OpenProtein Set matches")

🗑️  Deleting A3M Files for Problematic Proteins

🔍 Processing train split...


Checking train:   0%|          | 0/233 [00:00<?, ?protein/s]

   ✅ Deleted: 5IT7-EE/final_filtered_256_stripped.a3m
   ✅ Deleted: 3J79-I/final_filtered_256_stripped.a3m
   ✅ Deleted: 5IT7-HH/final_filtered_256_stripped.a3m
   ✅ Deleted: 5OPT-P/final_filtered_256_stripped.a3m
   ✅ Deleted: 5XY3-I/final_filtered_256_stripped.a3m
   ✅ Deleted: 5OQL-U/final_filtered_256_stripped.a3m
   ✅ Deleted: 5XXU-A/final_filtered_256_stripped.a3m
   ✅ Deleted: 5XY3-E/final_filtered_256_stripped.a3m
   ✅ Deleted: 6AZ1-C/final_filtered_256_stripped.a3m
   ✅ Deleted: 6AZ1-B/final_filtered_256_stripped.a3m
   ✅ Deleted: 5T5H-V/final_filtered_256_stripped.a3m
   ✅ Deleted: 6GAZ-AP/final_filtered_256_stripped.a3m
   ✅ Deleted: 5XXB-G/final_filtered_256_stripped.a3m
   ✅ Deleted: 3JD5-J/final_filtered_256_stripped.a3m
   ✅ Deleted: 6HIX-BG/final_filtered_256_stripped.a3m
   ✅ Deleted: 5OPT-R/final_filtered_256_stripped.a3m
   ✅ Deleted: 5XXB-E/final_filtered_256_stripped.a3m
   ✅ Deleted: 5XXU-B/final_filtered_256_stripped.a3m
   ✅ Deleted: 5XXB-F/final_filtered_256_st

Checking val:   0%|          | 0/233 [00:00<?, ?protein/s]

   ✅ Deleted: 4UDF-1B/final_filtered_256_stripped.a3m
   📊 val: Found 5 proteins, deleted 1 A3M files

🔍 Processing test split...


Checking test:   0%|          | 0/233 [00:00<?, ?protein/s]

   ✅ Deleted: 3H4P-A/final_filtered_256_stripped.a3m
   ✅ Deleted: 4V8T-O/final_filtered_256_stripped.a3m
   ✅ Deleted: 5ZWN-Y/final_filtered_256_stripped.a3m
   ✅ Deleted: 6GIQ-D/final_filtered_256_stripped.a3m
   📊 test: Found 20 proteins, deleted 4 A3M files

🎉 DELETION SUMMARY
🔍 Total proteins checked: 233
🗑️  Total A3M files deleted: 58

✅ Successfully cleared A3M files for 58 problematic proteins
   These proteins are now ready for fresh A3M generation

Next step: Run the mapping script to identify OpenProtein Set matches


In [41]:
"""
Resumable MSA-fetch + hhfilter pipeline for PDBCH problematic proteins
====================================================================

Input TSV  : problematic_proteins_openfold_mapping.tsv   (original_id_tilde, matched_openfold_id, ...)
Output TSV : problematic_proteins_openfold_mapping_updated.tsv  (+ status columns)
A3M output : <PDBCH>/<split>/<TILDE_ENCODED_ID>/final_filtered_256_stripped.a3m

Concurrency
-----------
• BATCH_SIZE      – chains processed per outer loop
• DL_CONCURRENCY  – parallel S3 downloads (IO-bound)
• HH_PARALLEL     – parallel hhfilter+diversity jobs (CPU-bound)

The script is *idempotent*; rerun to resume unfinished work.
"""

from __future__ import annotations
import csv, os, re, shutil, subprocess, tempfile, itertools, math, time
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Tuple, Dict, Set
from collections import defaultdict

import boto3
from boto3.s3.transfer import TransferConfig
from botocore.config import Config
from botocore import UNSIGNED

import numpy as np
from scipy.spatial.distance import pdist, squareform
from tqdm.auto import tqdm

# ─── PATHS & CONSTANTS ──────────────────────────────────────────────────
ROOT        = Path(r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing")
PDBCH_ROOT  = ROOT / "data" / "PDBCH"

# PDBCH folder structure (tilde-encoded folders)
SPLITS = {
    "train": PDBCH_ROOT / "train_pdbch",
    "val": PDBCH_ROOT / "val_pdbch",
    "test": PDBCH_ROOT / "test_pdbch"
}

MATCH_FILE   = ROOT / "problematic_proteins_openfold_mapping.tsv"
UPDATED_OUT  = ROOT / "problematic_proteins_openfold_mapping_updated.tsv"

MAX_ROWS       = 256          # target rows post-hhfilter
BATCH_SIZE     = 50           # chains per outer loop batch
DL_CONCURRENCY = 14           # parallel S3 downloads
HH_PARALLEL    = 10           # parallel hhfilter jobs

BUCKET      = "openfold"
A3M_FILES   = ("bfd_uniclust_hits.a3m", "mgnify_hits.a3m", "uniref90_hits.a3m")

# ─── UTILS ──────────────────────────────────────────────────────────────
LOWER = ''.join(chr(c) for c in range(97, 123))

def parse_chain_id(s: str) -> Tuple[str, str]:
    """Parse OpenFold chain ID to get PDB and chain parts"""
    s = s.strip().replace('_', '-')
    if '-' in s:
        pdb, chain = s.split('-', 1)
    elif re.fullmatch(r"[0-9][A-Za-z0-9]{3}[A-Za-z0-9]{1,3}", s):
        pdb, chain = s[:4], s[4:]
    else:
        raise ValueError(f"Cannot parse chain ID from '{s}'")
    return pdb.upper(), chain     # pdb upper-case, chain left *as-is*

def to_wsl(path: Path) -> str:
    """Convert Windows path to WSL path"""
    posix = path.resolve().as_posix()
    return posix if posix.startswith("/mnt/") else f"/mnt/{path.drive[0].lower()}{posix[2:]}"

def strip_insertions_a3m(seq: str) -> str:
    """Remove lowercase insertions and dots from A3M sequence"""
    return seq.translate({ord(c): None for c in LOWER + ".*"})

def load_msa_from_a3m(p: Path):
    """Load MSA from A3M file"""
    msa, hdr, buf = [], None, []
    with p.open() as fh:
        for ln in fh:
            ln = ln.rstrip()
            if ln.startswith('>'):
                if hdr is not None:
                    msa.append((hdr, ''.join(buf)))
                hdr, buf = ln[1:], []
            else:
                buf.append(ln)
        if hdr is not None:
            msa.append((hdr, ''.join(buf)))
    return msa

def diversity_max_subsample(msa, k):
    """Subsample MSA to k sequences using maximum diversity"""
    if len(msa) <= k:
        return msa
    seqs = np.array([list(s) for _, s in msa], dtype='U1')
    uniq = {aa: i for i, aa in enumerate(sorted({c for row in seqs for c in row}))}
    arr  = np.vectorize(uniq.get)(seqs)
    dist = squareform(pdist(arr, metric='hamming'))
    keep = [0]
    sel  = np.zeros(len(msa), bool); sel[0] = True
    while sel.sum() < k:
        mean = dist[:, sel].mean(1); mean[sel] = -1
        idx  = int(mean.argmax())
        if mean[idx] <= 0:
            break
        sel[idx] = True; keep.append(idx)
    return [msa[i] for i in keep]

# ─── RESUMABILITY ───────────────────────────────────────────────────────
def load_existing() -> Set[str]:
    """Load already processed proteins from output TSV"""
    if not UPDATED_OUT.exists():
        return set()
    done = set()
    with UPDATED_OUT.open(newline='') as fh:
        for row in csv.DictReader(fh, delimiter='\t'):
            if row.get('status', '') == 'Success':
                done.add(row['original_id_tilde'])
    return done

def append_results(rows: List[Dict[str,str]], first: bool):
    """Append results to output TSV"""
    mode = 'w' if first else 'a'
    with UPDATED_OUT.open(mode, newline='') as fh:
        wr = csv.DictWriter(fh, delimiter='\t', fieldnames=rows[0].keys())
        if first:
            wr.writeheader()
        wr.writerows(rows)

# ─── AWS S3 ─────────────────────────────────────────────────────────────
transfer_cfg = TransferConfig(max_concurrency=DL_CONCURRENCY)
s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))

def download_one(key: str, dest: Path) -> bool:
    """Download single file from S3"""
    try:
        dest.parent.mkdir(parents=True, exist_ok=True)
        s3.download_file(BUCKET, key, str(dest), Config=transfer_cfg)
        return True
    except Exception:
        return False

# ─── MAIN PER-CHAIN PROCESSOR ───────────────────────────────────────────
def process_chain(chain_dir: Path, row: Dict[str,str]) -> Dict[str,str]:
    """Process single protein chain: download MSAs, run hhfilter, save result"""
    try:
        matched_id = row['matched_openfold_id']
        
        # Skip if no match found
        if matched_id in ('NO_MATCH', 'NO_SEQUENCE'):
            return {**row, **{
                'n_rows_dropped': '', 'hhfilter_rows': '',
                'diversity_filtered_rows': '', 'status': f'Skipped: {matched_id}'
            }}
        
        pdb, chain = parse_chain_id(matched_id)
        tmp_dir = chain_dir / "_tmp_a3m"
        tmp_dir.mkdir(exist_ok=True)

        # Check if we downloaded A3M files
        downloaded = [tmp_dir / f for f in A3M_FILES if (tmp_dir / f).exists()]
        if not downloaded:
            return {**row, **{
                'n_rows_dropped': '', 'hhfilter_rows': '',
                'diversity_filtered_rows': '', 'status': 'Failed: download'
            }}

        # Concatenate raw A3Ms
        raw = tmp_dir / "concat_raw.a3m"
        with raw.open('w') as out:
            for f in downloaded:
                out.write(f.read_text())

        # Run hhfilter
        hh_out = tmp_dir / f"hhfiltered_{MAX_ROWS}.a3m"
        subprocess.run([
            "wsl", "hhfilter", "-i", to_wsl(raw), "-o", to_wsl(hh_out),
            "-diff", str(MAX_ROWS)
        ], check=True, capture_output=True)

        # Load MSA and check if empty
        msa = load_msa_from_a3m(hh_out)
        hh_rows = len(msa)
        if not msa:
            return {**row, **{
                'n_rows_dropped': '', 'hhfilter_rows': hh_rows,
                'diversity_filtered_rows': '', 'status': 'Failed: hhfilter-empty'
            }}

        # Strip insertions and filter by length
        tgt_len = len(strip_insertions_a3m(msa[0][1]))
        kept, dropped = [], 0
        for hdr, seq in msa:
            clean = strip_insertions_a3m(seq)
            if len(clean) == tgt_len:
                kept.append((hdr, clean))
            else:
                dropped += 1
                
        if len(kept) < 2:
            return {**row, **{
                'n_rows_dropped': dropped, 'hhfilter_rows': hh_rows,
                'diversity_filtered_rows': '',
                'status': f'Failed: {len(kept)} rows post-strip'
            }}

        # Apply diversity filtering if needed
        div_filtered = 0
        if len(kept) > MAX_ROWS:
            before = len(kept)
            kept = diversity_max_subsample(kept, MAX_ROWS)
            div_filtered = before - len(kept)

        # Save final A3M file to tilde-encoded folder
        final_path = chain_dir / "final_filtered_256_stripped.a3m"
        with final_path.open('w') as fh:
            for h, s in kept:
                fh.write(f">{h}\n{s}\n")

        return {**row, **{
            'n_rows_dropped': dropped,
            'hhfilter_rows': hh_rows,
            'diversity_filtered_rows': div_filtered,
            'status': 'Success'
        }}

    except Exception as e:
        return {**row, **{
            'n_rows_dropped': '', 'hhfilter_rows': '',
            'diversity_filtered_rows': '',
            'status': f'Failed: {str(e)[:80]}'
        }}
    finally:
        shutil.rmtree(chain_dir / "_tmp_a3m", ignore_errors=True)

# ─── DRIVER ─────────────────────────────────────────────────────────────
def find_protein_folder(tilde_encoded_id: str) -> Path | None:
    """Find protein folder in PDBCH structure using tilde-encoded ID"""
    for split_folder in SPLITS.values():
        protein_folder = split_folder / tilde_encoded_id
        if protein_folder.exists() and protein_folder.is_dir():
            return protein_folder
    return None

def grouper(n: int, it):
    """Group iterator into batches of size n"""
    it = iter(it)
    while True:
        chunk = list(itertools.islice(it, n))
        if not chunk: 
            return
        yield chunk

def main():
    print("🧬 PDBCH MSA Fetch and HHFilter Pipeline")
    print("=" * 60)
    
    print("🔄 Loading mapping TSV …")
    if not MATCH_FILE.exists():
        raise FileNotFoundError(f"Mapping file not found: {MATCH_FILE}")
    
    mapping = {}
    with MATCH_FILE.open(newline='') as fh:
        reader = csv.DictReader(fh, delimiter='\t')
        for row in reader:
            mapping[row['original_id_tilde']] = row

    print(f"📋 Loaded {len(mapping):,} protein mappings")

    # Load already processed proteins
    finished = load_existing()
    todo = [pid for pid in mapping if pid not in finished]

    # Verify hhfilter is available
    if subprocess.run(["wsl", "which", "hhfilter"], capture_output=True).returncode:
        raise RuntimeError("❌ hhfilter not found inside WSL. Please install HH-suite in WSL.")

    print(f"✅ Total proteins to process: {len(todo):,}")
    print(f"⏭️  Already completed: {len(finished):,}")
    
    if not todo:
        print("🎉 Nothing to do – all done!")
        return

    is_first_batch = (not UPDATED_OUT.exists())
    outer = tqdm(total=len(todo), desc="Processing proteins", unit="prot")

    for batch_num, batch_pids in enumerate(grouper(BATCH_SIZE, todo), 1):
        # Stage 1: Prepare downloads and validate folders
        dl_jobs, valid_dirs, batch_results = [], [], []

        for tilde_encoded_id in batch_pids:
            row = mapping[tilde_encoded_id]
            
            # Find protein folder in PDBCH structure
            protein_folder = find_protein_folder(tilde_encoded_id)
            if protein_folder is None:
                batch_results.append({**row, **{
                    'n_rows_dropped': '', 'hhfilter_rows': '',
                    'diversity_filtered_rows': '',
                    'status': 'Failed: folder-not-found'
                }})
                outer.update(1)
                continue

            # Skip if no valid match
            matched_id = row['matched_openfold_id']
            if matched_id in ('NO_MATCH', 'NO_SEQUENCE'):
                batch_results.append({**row, **{
                    'n_rows_dropped': '', 'hhfilter_rows': '',
                    'diversity_filtered_rows': '',
                    'status': f'Skipped: {matched_id}'
                }})
                outer.update(1)
                continue

            # Prepare download jobs
            try:
                pdb, chain = parse_chain_id(matched_id)
                tmp_dir = protein_folder / "_tmp_a3m"
                for fname in A3M_FILES:
                    key = f"pdb/{pdb.lower()}_{chain}/a3m/{fname}"
                    dl_jobs.append((key, tmp_dir / fname))
                valid_dirs.append(protein_folder)
            except Exception as e:
                batch_results.append({**row, **{
                    'n_rows_dropped': '', 'hhfilter_rows': '',
                    'diversity_filtered_rows': '',
                    'status': f'Failed: parse-id ({str(e)[:50]})'
                }})
                outer.update(1)
                continue

        # Stage 2: Download A3M files from S3
        if dl_jobs:
            with ThreadPoolExecutor(max_workers=DL_CONCURRENCY) as pool:
                futs = [pool.submit(download_one, k, d) for k, d in dl_jobs]
                successful_downloads = 0
                for fut in tqdm(as_completed(futs), total=len(futs),
                              desc=f"Batch {batch_num} downloads", leave=False, unit="file"):
                    if fut.result():
                        successful_downloads += 1
                
                print(f"   📥 Downloaded {successful_downloads}/{len(dl_jobs)} A3M files")

        # Stage 3: Process with hhfilter and diversity filtering
        if valid_dirs:
            with ThreadPoolExecutor(max_workers=HH_PARALLEL) as pool:
                futs2 = {pool.submit(process_chain, d, mapping[d.name]): d for d in valid_dirs}
                for fut in tqdm(as_completed(futs2), total=len(futs2),
                                desc=f"Batch {batch_num} hhfilter", leave=False, unit="prot"):
                    res = fut.result()
                    batch_results.append(res)
                    outer.update(1)
                    
                    # Show current status
                    status_short = res['status'].split(':')[0] if ':' in res['status'] else res['status']
                    outer.set_postfix_str(f"{res['original_id_tilde'][:12]} ({status_short})")

        # Stage 4: Save batch results
        if batch_results:
            append_results(batch_results, is_first_batch)
            is_first_batch = False
            
            # Count successes
            succ = sum(r['status'] == "Success" for r in batch_results)
            skipped = sum(r['status'].startswith("Skipped") for r in batch_results)
            failed = len(batch_results) - succ - skipped
            
            print(f"📝 Batch {batch_num}: ✅ {succ} Success | ⏭️ {skipped} Skipped | ❌ {failed} Failed")

    outer.close()
    
    # Final summary
    print(f"\n🎉 Pipeline completed!")
    print(f"📁 Results saved to: {UPDATED_OUT}")
    
    # Show final statistics
    if UPDATED_OUT.exists():
        final_stats = {'Success': 0, 'Skipped': 0, 'Failed': 0}
        with UPDATED_OUT.open(newline='') as fh:
            for row in csv.DictReader(fh, delimiter='\t'):
                status = row.get('status', 'Unknown')
                if status == 'Success':
                    final_stats['Success'] += 1
                elif status.startswith('Skipped'):
                    final_stats['Skipped'] += 1
                else:
                    final_stats['Failed'] += 1
        
        total = sum(final_stats.values())
        print(f"\n📊 Final Statistics:")
        print(f"   ✅ Success: {final_stats['Success']:,} ({final_stats['Success']/total*100:.1f}%)")
        print(f"   ⏭️  Skipped: {final_stats['Skipped']:,} ({final_stats['Skipped']/total*100:.1f}%)")
        print(f"   ❌ Failed:  {final_stats['Failed']:,} ({final_stats['Failed']/total*100:.1f}%)")

if __name__ == "__main__":
    main()

🧬 PDBCH MSA Fetch and HHFilter Pipeline
🔄 Loading mapping TSV …
📋 Loaded 233 protein mappings
✅ Total proteins to process: 233
⏭️  Already completed: 0


Processing proteins:   0%|          | 0/233 [00:00<?, ?prot/s]

Batch 1 downloads:   0%|          | 0/150 [00:00<?, ?file/s]

   📥 Downloaded 150/150 A3M files


Batch 1 hhfilter:   0%|          | 0/50 [00:00<?, ?prot/s]

📝 Batch 1: ✅ 50 Success | ⏭️ 0 Skipped | ❌ 0 Failed


Batch 2 downloads:   0%|          | 0/150 [00:00<?, ?file/s]

   📥 Downloaded 150/150 A3M files


Batch 2 hhfilter:   0%|          | 0/50 [00:00<?, ?prot/s]

📝 Batch 2: ✅ 50 Success | ⏭️ 0 Skipped | ❌ 0 Failed


Batch 3 downloads:   0%|          | 0/150 [00:00<?, ?file/s]

   📥 Downloaded 148/150 A3M files


Batch 3 hhfilter:   0%|          | 0/50 [00:00<?, ?prot/s]

📝 Batch 3: ✅ 50 Success | ⏭️ 0 Skipped | ❌ 0 Failed


Batch 4 downloads:   0%|          | 0/150 [00:00<?, ?file/s]

   📥 Downloaded 150/150 A3M files


Batch 4 hhfilter:   0%|          | 0/50 [00:00<?, ?prot/s]

📝 Batch 4: ✅ 50 Success | ⏭️ 0 Skipped | ❌ 0 Failed


Batch 5 downloads:   0%|          | 0/99 [00:00<?, ?file/s]

   📥 Downloaded 99/99 A3M files


Batch 5 hhfilter:   0%|          | 0/33 [00:00<?, ?prot/s]

📝 Batch 5: ✅ 33 Success | ⏭️ 0 Skipped | ❌ 0 Failed

🎉 Pipeline completed!
📁 Results saved to: C:\Users\rfrjo\Documents\Codebases\PFP_Testing\problematic_proteins_openfold_mapping_updated.tsv

📊 Final Statistics:
   ✅ Success: 233 (100.0%)
   ⏭️  Skipped: 0 (0.0%)
   ❌ Failed:  0 (0.0%)


In [3]:
import os
import glob

# Base path to your protein folders
base_path = r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\HEAL_PDB"

# Protein folders to search
folders = ["protein_test_pdb", "protein_train_pdb", "protein_val_pdb"]

# Protein ID to find
target_protein = "3J92-F"

print(f"Searching for protein: {target_protein}")
print("-" * 50)

found = False
for folder in folders:
    folder_path = os.path.join(base_path, folder)
    if os.path.exists(folder_path):
        print(f"Searching in {folder}...")
        
        # Search for the protein ID in folder names/files
        search_pattern = os.path.join(folder_path, f"*{target_protein}*")
        matches = glob.glob(search_pattern)
        
        if matches:
            print(f"  ✓ FOUND in {folder}:")
            for match in matches:
                print(f"    {match}")
            found = True
        else:
            print(f"  ✗ Not found in {folder}")
    else:
        print(f"  ! Folder {folder} doesn't exist")

if not found:
    print(f"\nProtein {target_protein} was not found in any folder.")
else:
    print(f"\nSearch complete!")

Searching for protein: 3J92-F
--------------------------------------------------
Searching in protein_test_pdb...
  ✗ Not found in protein_test_pdb
Searching in protein_train_pdb...
  ✓ FOUND in protein_train_pdb:
    C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\HEAL_PDB\protein_train_pdb\3J92-F
Searching in protein_val_pdb...
  ✗ Not found in protein_val_pdb

Search complete!


In [None]:
#!/usr/bin/env python3
"""
Re-run MSA pipeline for the last 28 troublesome chains.

• Removes any stale final_filtered_256_stripped.a3m
• Downloads raw A3Ms from the OpenFold bucket
• hhfilter → length strip → diversity (≤256)
• Over-writes the row *in whichever TSV the ID lives*.
"""

from __future__ import annotations
import csv, re, shutil, subprocess, itertools
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, Tuple, Optional
from datetime import datetime

import boto3
from boto3.s3.transfer import TransferConfig
from botocore.config import Config
from botocore import UNSIGNED
import numpy as np
from scipy.spatial.distance import pdist, squareform
from tqdm.auto import tqdm

# ─── File system layout ────────────────────────────────────────────────
ROOT = Path(r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing")
HEAL_ROOT = ROOT / "data" / "HEAL_PDB"
SPLITS = {
    "train": HEAL_ROOT / "protein_train_pdb",
    "val"  : HEAL_ROOT / "protein_val_pdb",
    "test" : HEAL_ROOT / "protein_test_pdb",
}

TSV_REMAINING = ROOT / "heal_remaining_protein_matches_updated.tsv"
TSV_FAILED    = ROOT / "heal_failed_protein_matches_processed.tsv"

# ─── Chains to fix ─────────────────────────────────────────────────────
TODO = {
    "3H4P-A", "3J79-I", "3JB9-I", "3JB9-R", "3JD5-F", "3JD5-J", "4UDF-1B",
    "4V6U-BL", "4V7E-CE", "4V8T-O", "5IT7-CC", "5IT7-EE", "5IT7-HH", "5IT7-II",
    "5IT7-RR", "5LNK-N", "5NGM-AO", "5OOL-W", "5OPT-P", "5OPT-R", "5OQL-U",
    "5OQL-X", "5T5H-I", "5T5H-L", "5T5H-V", "5T5H-W", "5V93-F", "5XXB-E",
    "5XXB-F", "5XXB-G", "5XXB-O", "5XXU-A", "5XXU-B", "5XXU-C", "5XXU-F",
    "5XY3-E", "5XY3-F", "5XY3-G", "5XY3-I", "5XY3-J", "5XY3-O", "5XYI-B",
    "5YZG-W", "5ZWN-Y", "6AZ1-A", "6AZ1-B", "6AZ1-C", "6DZI-T", "6GAZ-AN",
    "6GAZ-AP", "6GB2-BE", "6GCS-D", "6GCS-F", "6GCS-H", "6GCS-J", "6GIQ-D",
    "6HIX-AV", "6HIX-BG"
}

# ─── OpenFold bucket info ──────────────────────────────────────────────
BUCKET       = "openfold"
A3M_FILES    = ("bfd_uniclust_hits.a3m","mgnify_hits.a3m","uniref90_hits.a3m")
DL_THREADS   = 14
PROC_THREADS = 8
MAX_ROWS     = 256
TRANSFER_CFG = TransferConfig(max_concurrency=DL_THREADS)
S3           = boto3.client('s3', config=Config(signature_version=UNSIGNED))

LOWER = ''.join(map(chr, range(97,123)))

# ─── Helper functions ──────────────────────────────────────────────────
def to_wsl(path: Path) -> str:
    """Convert Windows path to WSL path."""
    posix = path.resolve().as_posix()
    return posix if posix.startswith("/mnt/") else f"/mnt/{path.drive[0].lower()}{posix[2:]}"

def parse_chain_id(pid:str)->Tuple[str,str]:
    pid = pid.replace('_','-')
    if '-' in pid:
        pdb, chain = pid.split('-',1)
    elif re.fullmatch(r"[0-9][A-Za-z0-9]{3}[A-Za-z0-9]{1,3}", pid):
        pdb, chain = pid[:4], pid[4:]
    else:
        raise ValueError(f"Bad chain id: {pid}")
    return pdb.lower(), chain            # ← chain case kept!

def strip_insertions(seq:str)->str:
    return seq.translate({ord(c):None for c in LOWER+".*"})

def load_a3m(p:Path):
    out, hdr, buf = [], None, []
    for ln in p.read_text().splitlines():
        if ln.startswith('>'):
            if hdr: out.append((hdr,''.join(buf)))
            hdr, buf = ln[1:], []
        else:
            buf.append(ln)
    if hdr: out.append((hdr,''.join(buf)))
    return out

def diversity_max(msa:list[Tuple[str,str]], k:int):
    if len(msa)<=k: return msa
    arr = np.array([list(s) for _,s in msa],dtype='U1')
    alpha = sorted({c for row in arr for c in row})
    to_i  = {a:i for i,a in enumerate(alpha)}
    intarr= np.vectorize(to_i.get)(arr)
    dist  = squareform(pdist(intarr,metric='hamming'))
    keep  = [0]; sel = np.zeros(len(msa),bool); sel[0]=True
    while sel.sum()<k:
        mean = dist[:,sel].mean(1); mean[sel]=-1
        nxt=int(mean.argmax())
        if mean[nxt]<=0: break
        sel[nxt]=True; keep.append(nxt)
    return [msa[i] for i in keep]

def bucket_download(key:str, dest:Path)->bool:
    try:
        if dest.exists(): return True
        dest.parent.mkdir(parents=True,exist_ok=True)
        S3.download_file(BUCKET, key, str(dest), Config=TRANSFER_CFG)
        return True
    except Exception as e:
        print(f"  ⚠️  Download failed for {key}: {e}")
        return False

def find_chain_dir(original_id:str)->Optional[Path]:
    """Find directory with exact case-sensitive matching."""
    print(f"  🔍 Looking for: {original_id}")
    
    for split_name, split in SPLITS.items():
        if not split.exists(): 
            continue
            
        # Direct check with exact case
        direct_path = split / original_id
        if direct_path.exists() and direct_path.is_dir():
            print(f"    ✅ Found in {split_name}: {direct_path}")
            return direct_path
            
        # Scan all directories for exact match
        for d in split.iterdir():
            if d.is_dir() and d.name == original_id:
                print(f"    ✅ Found in {split_name}: {d}")
                return d
                
        # Log case-insensitive matches for debugging
        case_insensitive = [d for d in split.iterdir() 
                           if d.is_dir() and d.name.lower() == original_id.lower()]
        if case_insensitive:
            print(f"    ⚠️  Case-insensitive matches in {split_name}: {[d.name for d in case_insensitive]}")
    
    print(f"    ❌ Not found!")
    return None

def delete_stale(chain_dir:Path):
    (chain_dir/"final_filtered_256_stripped.a3m").unlink(missing_ok=True)
    shutil.rmtree(chain_dir/"_tmp_a3m", ignore_errors=True)

def process(chain_dir:Path, matched_id:str):
    print(f"\n📋 Processing {chain_dir.name} → {matched_id}")
    pdb, chain = parse_chain_id(matched_id)
    tmp = chain_dir/"_tmp_a3m"; tmp.mkdir(exist_ok=True)

    # Download all A3M files
    downloaded = []
    for f in A3M_FILES:
        key  = f"pdb/{pdb}_{chain}/a3m/{f}"  # pdb lower, chain as-is
        dest = tmp/f
        if bucket_download(key, dest):
            downloaded.append(f)
            print(f"    ✓ Downloaded {f}")
        else:
            print(f"    ✗ Failed to download {f}")
    
    if not downloaded:
        shutil.rmtree(tmp,ignore_errors=True)
        return 0,0,0,"Failed: no downloads"

    # Concatenate available files
    raw = tmp/"all_raw.a3m"
    with raw.open('w') as out:
        for f in downloaded:
            out.write((tmp/f).read_text())

    # Run hhfilter with WSL path conversion
    hh_out = tmp/f"all_hhfiltered_{MAX_ROWS}.a3m"
    res = subprocess.run([
        "wsl","hhfilter",
        "-i",to_wsl(raw),
        "-o",to_wsl(hh_out),
        "-diff",str(MAX_ROWS)
    ], capture_output=True, text=True)
    
    if res.returncode:
        print(f"    ❌ hhfilter failed: {res.stderr}")
        shutil.rmtree(tmp,ignore_errors=True)
        return 0,0,0,f"Failed: hhfilter ({res.returncode})"

    msa = load_a3m(hh_out)
    hh_rows = len(msa)
    print(f"    📊 hhfilter output: {hh_rows} rows")
    
    if not msa:
        shutil.rmtree(tmp,ignore_errors=True)
        return 0,hh_rows,0,"Failed: hh-empty"

    # Strip insertions
    tgt_len = len(strip_insertions(msa[0][1]))
    kept, dropped = [], 0
    for h,s in msa:
        clean = strip_insertions(s)
        if len(clean) == tgt_len: 
            kept.append((h,clean))
        else: 
            dropped += 1
    
    print(f"    🧹 Stripped: {len(kept)} kept, {dropped} dropped")
    
    if len(kept) < 2:
        shutil.rmtree(tmp,ignore_errors=True)
        return dropped,hh_rows,0,f"Failed: only {len(kept)} rows"

    # Diversity filtering
    div = 0
    if len(kept) > MAX_ROWS:
        bef = len(kept)
        kept = diversity_max(kept, MAX_ROWS)
        div = bef - len(kept)
        print(f"    🎯 Diversity filtered: {bef} → {len(kept)}")

    # Write final A3M
    final_path = chain_dir/"final_filtered_256_stripped.a3m"
    with final_path.open('w') as fh:
        for h,s in kept: 
            fh.write(f">{h}\n{s}\n")
    
    print(f"    ✅ Success! Final MSA: {len(kept)} rows")
    shutil.rmtree(tmp,ignore_errors=True)
    return dropped,hh_rows,div,"Success"

# ─── TSV helpers ────────────────────────────────────────────────────────
def load_tsv(p:Path): 
    return list(csv.DictReader(p.open(),delimiter='\t')) if p.exists() else []

def save_tsv(p:Path, rows):
    if not rows: return
    # Create backup
    if p.exists():
        backup = p.with_suffix(f'.{datetime.now().strftime("%Y%m%d_%H%M%S")}.bak')
        shutil.copy2(p, backup)
        print(f"  💾 Created backup: {backup.name}")
    
    with p.open('w',newline='') as fh:
        wr = csv.DictWriter(fh,delimiter='\t',fieldnames=rows[0].keys())
        wr.writeheader()
        wr.writerows(rows)

# ─── Main ───────────────────────────────────────────────────────────────
def main():
    print("🔧 Re-running MSA pipeline for 28 proteins")
    print("=" * 60)
    
    # Verify hhfilter
    if subprocess.run(["wsl","which","hhfilter"],capture_output=True).returncode:
        raise RuntimeError("hhfilter not found in WSL")

    # Load TSVs
    rows_rem = load_tsv(TSV_REMAINING)
    rows_fail = load_tsv(TSV_FAILED)
    
    # Build index
    index: Dict[str,Tuple[str,int]] = {}
    for i,r in enumerate(rows_rem): 
        index[r['original_id']] = ('rem',i)
    for i,r in enumerate(rows_fail): 
        index[r['original_id']] = ('fail',i)

    # Find proteins in TSVs
    todo = [pid for pid in TODO if pid in index]
    missing = TODO - set(todo)
    
    print(f"\n📊 Status:")
    print(f"  - To process: {len(todo)}/{len(TODO)}")
    if missing:
        print(f"  - Not found in TSVs: {missing}")

    # Clean up old files
    print(f"\n🧹 Cleaning old A3M files...")
    for pid in todo:
        d = find_chain_dir(pid)
        if d: 
            delete_stale(d)

    # Process proteins
    print(f"\n🚀 Processing {len(todo)} proteins...")
    with ThreadPoolExecutor(max_workers=PROC_THREADS) as pool:
        fut2pid = {}
        
        for pid in todo:
            which, idx = index[pid]
            row = rows_rem[idx] if which=='rem' else rows_fail[idx]
            d = find_chain_dir(pid)
            
            if d is None:
                print(f"\n❌ Directory not found for {pid}")
                # Update status in the appropriate list
                if which == 'rem':
                    rows_rem[idx]['status'] = "Failed: dir-missing"
                else:
                    rows_fail[idx]['status'] = "Failed: dir-missing"
                continue
                
            fut2pid[pool.submit(process, d, row['matched_id'])] = pid

        # Process results
        for fut in tqdm(as_completed(fut2pid), total=len(fut2pid),
                        desc="Processing", unit="prot"):
            pid = fut2pid[fut]
            which, idx = index[pid]
            
            try:
                dropped, hh, div, status = fut.result()
                row = rows_rem[idx] if which=='rem' else rows_fail[idx]
                row['n_rows_dropped'] = str(dropped)
                row['hhfilter_rows'] = str(hh)
                row['diversity_filtered_rows'] = str(div)
                row['status'] = status
            except Exception as e:
                print(f"\n❌ Error processing {pid}: {e}")
                row = rows_rem[idx] if which=='rem' else rows_fail[idx]
                row['status'] = f"Failed: {str(e)[:50]}"

    # Save updated TSVs
    print(f"\n💾 Saving updated TSVs...")
    save_tsv(TSV_REMAINING, rows_rem)
    save_tsv(TSV_FAILED, rows_fail)

    # Summary
    success_count = 0
    for pid in todo:
        which, idx = index[pid]
        row = rows_rem[idx] if which=='rem' else rows_fail[idx]
        if row['status'] == "Success":
            success_count += 1
    
    print(f"\n✅ Complete!")
    print(f"📊 Results: {success_count}/{len(todo)} succeeded")
    
    # Show failures
    failures = []
    for pid in todo:
        which, idx = index[pid]
        row = rows_rem[idx] if which=='rem' else rows_fail[idx]
        if row['status'] != "Success":
            failures.append(f"{pid}: {row['status']}")
    
    if failures:
        print(f"\n❌ Failed proteins:")
        for f in failures[:10]:
            print(f"  - {f}")
        if len(failures) > 10:
            print(f"  ... and {len(failures)-10} more")

if __name__=="__main__":
    main()

🔧 Re-running MSA pipeline for 28 proteins

📊 Status:
  - To process: 28/28

🧹 Cleaning old A3M files...
  🔍 Looking for: 4V3P-LB
    ✅ Found in train: C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\HEAL_PDB\protein_train_pdb\4V3P-LB
  🔍 Looking for: 4V6W-CR
    ✅ Found in test: C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\HEAL_PDB\protein_test_pdb\4V6W-CR
  🔍 Looking for: 4UDF-1B
    ✅ Found in val: C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\HEAL_PDB\protein_val_pdb\4UDF-1B
  🔍 Looking for: 5LJ5-S
    ✅ Found in test: C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\HEAL_PDB\protein_test_pdb\5LJ5-S
  🔍 Looking for: 5AJ4-AK
    ✅ Found in train: C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\HEAL_PDB\protein_train_pdb\5AJ4-AK
  🔍 Looking for: 4V6W-CC
    ✅ Found in test: C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\HEAL_PDB\protein_test_pdb\4V6W-CC
  🔍 Looking for: 4V8T-O
    ✅ Found in test: C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\HEAL_PD

Processing:   0%|          | 0/28 [00:00<?, ?prot/s]

    ✓ Downloaded uniref90_hits.a3m
    📊 hhfilter output: 1559 rows
    🧹 Stripped: 1559 kept, 0 dropped
    ✓ Downloaded bfd_uniclust_hits.a3m
    ✓ Downloaded bfd_uniclust_hits.a3m
    ✓ Downloaded bfd_uniclust_hits.a3m
    ✓ Downloaded mgnify_hits.a3m
    ✓ Downloaded mgnify_hits.a3m
    ✓ Downloaded mgnify_hits.a3m
    🎯 Diversity filtered: 1559 → 256
    ✅ Success! Final MSA: 256 rows

📋 Processing 4V8M-BU → 4v8m_BU
    ✓ Downloaded uniref90_hits.a3m
    ✓ Downloaded bfd_uniclust_hits.a3m
    ✓ Downloaded uniref90_hits.a3m    ✓ Downloaded uniref90_hits.a3m

    ✓ Downloaded mgnify_hits.a3m
    📊 hhfilter output: 315 rows
    🧹 Stripped: 315 kept, 0 dropped
    🎯 Diversity filtered: 315 → 256
    ✅ Success! Final MSA: 256 rows

📋 Processing 5T2A-V → 5t2a_V
    📊 hhfilter output: 780 rows
    🧹 Stripped: 780 kept, 0 dropped
    ✓ Downloaded uniref90_hits.a3m
    ✓ Downloaded bfd_uniclust_hits.a3m
    📊 hhfilter output: 321 rows
    🧹 Stripped: 321 kept, 0 dropped
    🎯 Diversity fil

In [42]:
# %%time
"""
QA check for HEAL_PDB A3Ms
--------------------------
1. Verify every protein (except the exclusions) has final_filtered_256_stripped.a3m
2. For each protein, check Biopython pairwise identity between sequence.txt and
   the stripped query sequence from that A3M.
   • Pass threshold: 95 %
If any A3M is missing → raises RuntimeError after printing offenders.
"""

from pathlib import Path
from tqdm.auto import tqdm
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
import re, csv, textwrap

# ─── paths ──────────────────────────────────────────────────────────────
ROOT      = Path(r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing")
HEAL_ROOT = ROOT / "data" / "PDBCH"
SPLITS = {
    "train": HEAL_ROOT / "train_pdbch",
    "val"  : HEAL_ROOT / "val_pdbch",
    "test" : HEAL_ROOT / "test_pdbch",
}

A3M_NAME   = "final_filtered_256_stripped.a3m"
EXCLUDE = {
    "1W8X-M", "2WWX-B", "3M7G-A", "5GAO-E", "5OXE-A", "1KVE-A", "4BTP-A",
}
THRESH     = 0.95

# ─── helper: strip insertions from A3M sequence ─────────────────────────
LOWER = ''.join(chr(c) for c in range(97,123))  # a-z
def strip_a3m(seq: str) -> str:
    return seq.translate({ord(c):None for c in LOWER + ".*"})

def first_a3m_sequence(p: Path) -> str:
    hdr_seen = False
    with p.open() as fh:
        for ln in fh:
            if ln.startswith('>'):
                hdr_seen = True
            elif hdr_seen:
                return strip_a3m(ln.rstrip())
    return ""

# ─── Biopython similarity (globalxx) ────────────────────────────────────
from Bio import pairwise2
def similarity(seq1:str, seq2:str)->float:
    score = pairwise2.align.globalxx(seq1, seq2, score_only=True)
    return score / max(len(seq1), len(seq2)) if max(len(seq1),len(seq2)) else 0.0

# ─── collect all protein dirs ───────────────────────────────────────────
dirs = []
for folder in SPLITS.values():
    dirs.extend([d for d in folder.iterdir() if d.is_dir() and d.name not in EXCLUDE])

print(f"Total proteins to inspect (after exclusions): {len(dirs):,}")

# ─── pass 1: ensure A3M exists ──────────────────────────────────────────
missing = []
for d in tqdm(dirs, desc="Checking A3M presence", unit="prot"):
    if not (d / A3M_NAME).exists():
        missing.append(d.name)

if missing:
    print("\n❌ The following proteins are missing the hhfilter A3M:")
    print(textwrap.fill(' '.join(sorted(missing)), width=100))
    raise RuntimeError("A3M file(s) missing – fix before continuing!")

print("✓ Every protein has the hhfilter A3M\n")

# ─── pass 2: pairwise identity check (multithreaded) ────────────────────
def check_identity_worker(d: Path) -> tuple[str, float] | None:
    """Worker function for identity checking"""
    seq_txt = d / "sequence.txt"
    a3m     = d / A3M_NAME
    try:
        seq_query = first_a3m_sequence(a3m)
        seq_local = ''.join(c for c in seq_txt.read_text().strip() if c.isalpha())
        sim = similarity(seq_query, seq_local)
        if sim < THRESH:
            return (d.name, sim)
        return None
    except Exception as e:
        tqdm.write(f"[ERR] {d.name}: {e}")
        return (d.name, 0.0)

fails = []
with ThreadPoolExecutor(max_workers=8) as executor:
    future_to_dir = {executor.submit(check_identity_worker, d): d for d in dirs}
    
    for future in tqdm(as_completed(future_to_dir), total=len(dirs), 
                       desc="Pairwise identity check", unit="prot"):
        result = future.result()
        if result is not None:
            fails.append(result)

# ─── summary ────────────────────────────────────────────────────────────
print("\n" + "="*60)
print("QA SUMMARY")
print("="*60)
print(f"Proteins checked : {len(dirs):,}")
print(f"Failed identity  : {len(fails):,} (< {THRESH:.2f})")

if fails:
    print("\n⟹ Proteins failing identity threshold:")
    for pid, s in sorted(fails, key=lambda x: x[0]):
        print(f"  {pid:10s}  {s*100:5.1f} %")

else:
    print("\n🎉 All proteins ≥ 95 % identity – good to go!")

Total proteins to inspect (after exclusions): 36,622


Checking A3M presence:   0%|          | 0/36622 [00:00<?, ?prot/s]

✓ Every protein has the hhfilter A3M



Pairwise identity check:   0%|          | 0/36622 [00:00<?, ?prot/s]


QA SUMMARY
Proteins checked : 36,622
Failed identity  : 2 (< 0.95)

⟹ Proteins failing identity threshold:
  4UDF-1B      79.2 %
  4V8T-O       90.9 %


SystemExit: ❌ Root directory not found: C:\Users\rfrjo\Documents\Codebases\PFP_Testing\notebooks\--f=c:\Users\rfrjo\AppData\Roaming\jupyter\runtime\kernel-v323960d60bfa4c55e8e4d0be3c857aed9127406ae.json

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [11]:
#!/usr/bin/env python3
"""
Fix corrupted sequence.txt files by matching with exact case-sensitive IDs
==========================================================================

This script:
1. Reads all protein folders from train/val/test splits
2. Loads sequences from nrPDB-GO_2019.06.18_sequences.fasta
3. Compares sequence.txt with the correct sequence (case-sensitive match)
4. Fixes mismatches by replacing sequence.txt with correct content
5. Reports all fixes made
"""

from pathlib import Path
from tqdm.auto import tqdm
from collections import defaultdict
import textwrap

# ─── PATHS ──────────────────────────────────────────────────────────────
ROOT = Path(r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing")
HEAL_ROOT = ROOT / "data" / "HEAL_PDB"
SPLITS = {
    "train": HEAL_ROOT / "protein_train_pdb",
    "val": HEAL_ROOT / "protein_val_pdb",
    "test": HEAL_ROOT / "protein_test_pdb",
}

FASTA_FILE = ROOT / "data" / "nrPDB-GO_2019.06.18_sequences.fasta"

# ─── LOAD FASTA WITH EXACT CASE ─────────────────────────────────────────
def load_fasta_exact_case():
    """Load FASTA file preserving exact case of IDs."""
    print("📄 Loading FASTA file with case-sensitive IDs...")
    sequences = {}
    current_id = None
    current_seq = []
    
    with FASTA_FILE.open() as fh:
        for line in fh:
            line = line.strip()
            if line.startswith('>'):
                # Save previous sequence
                if current_id:
                    sequences[current_id] = ''.join(current_seq)
                
                # Parse new ID - take everything after '>' up to first space
                # This preserves exact case
                full_header = line[1:]
                current_id = full_header.split()[0] if ' ' in full_header else full_header
                current_seq = []
            else:
                current_seq.append(line)
        
        # Don't forget the last sequence
        if current_id:
            sequences[current_id] = ''.join(current_seq)
    
    print(f"  ✓ Loaded {len(sequences):,} sequences")
    
    # Show some examples to verify case preservation
    examples = list(sequences.keys())[:5]
    print("\n  Sample IDs (showing exact case):")
    for ex in examples:
        print(f"    {ex}")
    
    return sequences

# ─── CHECK AND FIX SEQUENCES ────────────────────────────────────────────
def check_and_fix_sequences(fasta_sequences):
    """Check all protein folders and fix sequence mismatches."""
    print("\n🔍 Checking all protein folders...")
    
    # Collect all protein directories
    all_dirs = []
    for split_name, split_path in SPLITS.items():
        if not split_path.exists():
            print(f"  ⚠️  Split path doesn't exist: {split_path}")
            continue
        
        dirs = [d for d in split_path.iterdir() if d.is_dir()]
        print(f"  Found {len(dirs)} proteins in {split_name}")
        all_dirs.extend([(split_name, d) for d in dirs])
    
    print(f"\n📊 Total protein folders to check: {len(all_dirs)}")
    
    # Check each directory
    mismatches = []
    not_found = []
    fixed = []
    errors = []
    
    for split_name, protein_dir in tqdm(all_dirs, desc="Checking sequences"):
        protein_id = protein_dir.name  # This is the exact case-sensitive folder name
        seq_file = protein_dir / "sequence.txt"
        
        # Check if sequence.txt exists
        if not seq_file.exists():
            errors.append((protein_id, split_name, "No sequence.txt"))
            continue
        
        # Read current sequence
        try:
            current_seq = ''.join(c for c in seq_file.read_text().strip() if c.isalpha())
        except Exception as e:
            errors.append((protein_id, split_name, f"Read error: {e}"))
            continue
        
        # Look up correct sequence (exact case match)
        if protein_id not in fasta_sequences:
            not_found.append((protein_id, split_name))
            # Try case-insensitive search to help debug
            case_insensitive_matches = [k for k in fasta_sequences.keys() 
                                       if k.lower() == protein_id.lower()]
            if case_insensitive_matches:
                tqdm.write(f"  ⚠️  {protein_id} not found, but found case variants: {case_insensitive_matches}")
            continue
        
        correct_seq = fasta_sequences[protein_id]
        
        # Compare sequences
        if current_seq != correct_seq:
            mismatches.append((protein_id, split_name))
            
            # Fix the sequence
            try:
                seq_file.write_text(correct_seq + '\n')
                fixed.append((protein_id, split_name))
                
                # Show details for first few fixes
                if len(fixed) <= 3:
                    tqdm.write(f"\n  🔧 Fixed {protein_id}:")
                    tqdm.write(f"     Old length: {len(current_seq)}")
                    tqdm.write(f"     New length: {len(correct_seq)}")
                    if len(current_seq) > 0 and len(correct_seq) > 0:
                        # Check if it's completely different or just truncated
                        if current_seq[:10] != correct_seq[:10]:
                            tqdm.write(f"     Old start: {current_seq[:30]}...")
                            tqdm.write(f"     New start: {correct_seq[:30]}...")
                        else:
                            tqdm.write(f"     Sequences had same start but different lengths")
            except Exception as e:
                errors.append((protein_id, split_name, f"Write error: {e}"))
    
    # Print summary
    print("\n" + "="*60)
    print("SEQUENCE FIX SUMMARY")
    print("="*60)
    print(f"Total checked      : {len(all_dirs):,}")
    print(f"Sequences correct  : {len(all_dirs) - len(mismatches) - len(not_found) - len(errors):,}")
    print(f"Mismatches found   : {len(mismatches):,}")
    print(f"Successfully fixed : {len(fixed):,}")
    print(f"Not found in FASTA : {len(not_found):,}")
    print(f"Errors             : {len(errors):,}")
    
    # Show all mismatches
    if mismatches:
        print(f"\n❌ All {len(mismatches)} proteins with sequence mismatches:")
        by_split = defaultdict(list)
        for pid, split in mismatches:
            by_split[split].append(pid)
        
        for split in ['train', 'val', 'test']:
            if split in by_split:
                print(f"\n  {split} ({len(by_split[split])} proteins):")
                # Group into lines of ~10 proteins each
                proteins = sorted(by_split[split])
                for i in range(0, len(proteins), 10):
                    batch = proteins[i:i+10]
                    print(f"    {' '.join(batch)}")
    
    # Show not found
    if not_found:
        print(f"\n⚠️  {len(not_found)} proteins not found in FASTA (need investigation):")
        for pid, split in not_found[:20]:  # Show first 20
            print(f"    {pid:15s} ({split})")
        if len(not_found) > 20:
            print(f"    ... and {len(not_found)-20} more")
    
    # Show errors
    if errors:
        print(f"\n❌ {len(errors)} errors encountered:")
        for pid, split, err in errors[:10]:
            print(f"    {pid:15s} ({split}): {err}")
        if len(errors) > 10:
            print(f"    ... and {len(errors)-10} more")
    
    # Final status
    if fixed:
        print(f"\n✅ Successfully fixed {len(fixed)} sequence files!")
        if len(fixed) != len(mismatches):
            print(f"⚠️  Warning: Found {len(mismatches)} mismatches but only fixed {len(fixed)}")
    else:
        print("\n✅ No sequence fixes needed - all sequences match!")
    
    return mismatches, fixed, not_found, errors

# ─── MAIN ───────────────────────────────────────────────────────────────
def main():
    print("🔧 Case-Sensitive Sequence Fix Tool")
    print("=" * 60)
    
    # Load FASTA with exact case preservation
    fasta_sequences = load_fasta_exact_case()
    
    # Check and fix sequences
    mismatches, fixed, not_found, errors = check_and_fix_sequences(fasta_sequences)
    
    # If there were fixes, suggest re-running the QA check
    if fixed:
        print("\n💡 Recommendation: Re-run your QA check script to verify all sequences now pass!")
    
    print("\n✨ Done!")

if __name__ == "__main__":
    main()

🔧 Case-Sensitive Sequence Fix Tool
📄 Loading FASTA file with case-sensitive IDs...
  ✓ Loaded 36,641 sequences

  Sample IDs (showing exact case):
    11AS-A
    154L-A
    155C-A
    16PK-A
    16VP-A

🔍 Checking all protein folders...
  Found 29738 proteins in train
  Found 3318 proteins in val
  Found 3398 proteins in test

📊 Total protein folders to check: 36454


Checking sequences:   0%|          | 0/36454 [00:00<?, ?it/s]


  🔧 Fixed 3J79-I:
     Old length: 104
     New length: 221
     Old start: MVNVPKTRKTYCSNKCKKHTMHKVSQYKKG...
     New start: MTNTSNELKHYNVKGKKKVLVPVNAKKTIN...

  🔧 Fixed 3J92-F:
     Old length: 110
     New length: 250
     Old start: MSGRLWSKAIFAGYKRGLRNQREHTALLKI...
     New start: MEGADVKEKKKKVPAVPETLKKKRKNFAEL...

  🔧 Fixed 3JB9-G:
     Old length: 558
     New length: 115
     Old start: MLVANYSSDSEEQENSQSPNIQPLLHTENL...
     New start: MADLVDKPRSELSEIELARLEEYEFSAGPL...

SEQUENCE FIX SUMMARY
Total checked      : 36,454
Sequences correct  : 36,372
Mismatches found   : 82
Successfully fixed : 82
Not found in FASTA : 0
Errors             : 0

❌ All 82 proteins with sequence mismatches:

  train (74 proteins):
    3J79-I 3J92-F 3JB9-G 3JB9-I 3JB9-R 3JD5-F 3JD5-J 4V3P-LA 4V3P-LB 4V3P-LE
    4V3P-LJ 4V6U-BL 4V6W-AF 4V7E-BG 4V7E-CE 4V7E-CU 4V8M-BK 4V8M-BR 4V8M-BS 4V8M-BU
    4V8M-BV 5AJ4-AK 5AJ4-AP 5IT7-CC 5IT7-EE 5IT7-HH 5IT7-II 5IT7-RR 5LNK-N 5NGM-AO
    5OOL-W 5OPT-L 5OPT-P 5OPT-R 