<h1>Initial Prot: GO, Graph</h1>

In [3]:
import tensorflow as tf
import torch
import numpy as np
import os
import csv
from tqdm import tqdm
import math
import warnings
from pathlib import Path

def process_tfrecords_to_pytorch_format():
    """
    Processes all PDB-GO TFRecord files and organizes data into PyTorch-friendly format.
    
    Directory structure created:
    protein_data_pdb/
    ├── {protein_id}/
    │   ├── mf_labels.pt (Molecular Function labels)
    │   ├── cc_labels.pt (Cellular Component labels) 
    │   ├── bp_labels.pt (Biological Process labels)
    │   ├── ca_dist_matrix.pt (C-alpha distance matrix)
    │   ├── cb_dist_matrix.pt (C-beta distance matrix)
    │   └── L.csv (sequence length)
    """
    
    # Configuration
    base_path = r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\PDB-GO"
    output_dir = "protein_data_pdb"
    num_files = 30
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    print(f"Created output directory: {output_dir}")
    
    # Statistics tracking
    total_proteins_processed = 0
    failed_proteins = 0
    files_processed = 0
    
    # Process all 30 TFRecord files
    print(f"Processing {num_files} TFRecord files...")
    
    for file_idx in tqdm(range(num_files), desc="Processing TFRecord files", position=0):
        file_path = os.path.join(base_path, f"PDB_GO_train_{file_idx:02d}-of-30.tfrecords")
        
        # Check if file exists
        if not os.path.exists(file_path):
            print(f"Warning: File {file_path} not found, skipping...")
            continue
        
        try:
            # Count records first for progress bar
            print(f"\nCounting records in file {file_idx:02d}...")
            dataset = tf.data.TFRecordDataset(file_path)
            total_records = sum(1 for _ in dataset)
            print(f"Found {total_records} records in file {file_idx:02d}")
            
            # Process records
            dataset = tf.data.TFRecordDataset(file_path)
            file_proteins_processed = 0
            
            for raw_record in tqdm(dataset, desc=f"File {file_idx:02d}", total=total_records, position=1, leave=False):
                try:
                    # Parse the TFRecord
                    example = tf.train.Example()
                    example.ParseFromString(raw_record.numpy())
                    
                    # Extract protein ID
                    prot_id = example.features.feature['prot_id'].bytes_list.value[0].decode('utf-8')
                    
                    # Create protein directory
                    protein_dir = os.path.join(output_dir, prot_id)
                    os.makedirs(protein_dir, exist_ok=True)
                    
                    # Extract L value first (sequence length)
                    L_value = None
                    if 'L' in example.features.feature:
                        L_value = example.features.feature['L'].int64_list.value[0]
                        
                        # Save L value to CSV
                        with open(os.path.join(protein_dir, 'L.csv'), 'w', newline='') as csvfile:
                            writer = csv.writer(csvfile)
                            writer.writerow(['L'])
                            writer.writerow([L_value])
                    
                    # Extract and save label features
                    label_features = {
                        'mf_labels': 'Molecular Function',
                        'cc_labels': 'Cellular Component', 
                        'bp_labels': 'Biological Process'
                    }
                    
                    for label_type, label_name in label_features.items():
                        if label_type in example.features.feature:
                            labels = list(example.features.feature[label_type].int64_list.value)
                            labels_tensor = torch.tensor(labels, dtype=torch.long)
                            torch.save(labels_tensor, os.path.join(protein_dir, f"{label_type}.pt"))
                    
                    # Extract and save distance matrices
                    matrix_features = ['ca_dist_matrix', 'cb_dist_matrix']
                    
                    for matrix_type in matrix_features:
                        if matrix_type in example.features.feature:
                            matrix_data = list(example.features.feature[matrix_type].float_list.value)
                            matrix_tensor = torch.tensor(matrix_data, dtype=torch.float32)
                            
                            # Try to reshape to square matrix if L is available
                            if L_value is not None:
                                expected_size = L_value * L_value
                                if len(matrix_data) == expected_size:
                                    matrix_tensor = matrix_tensor.reshape(L_value, L_value)
                                    # print(f"  Reshaped {matrix_type} to ({L_value}, {L_value}) for {prot_id}")
                                else:
                                    print(f"  Warning: {matrix_type} size {len(matrix_data)} doesn't match L²={expected_size} for {prot_id}")
                            
                            torch.save(matrix_tensor, os.path.join(protein_dir, f"{matrix_type}.pt"))
                    
                    file_proteins_processed += 1
                    total_proteins_processed += 1
                    
                except Exception as e:
                    print(f"\nError processing protein record in file {file_idx:02d}: {e}")
                    failed_proteins += 1
                    continue
            
            files_processed += 1
            print(f"Completed file {file_idx:02d}: {file_proteins_processed} proteins processed")
            
        except Exception as e:
            print(f"Error processing file {file_idx:02d}: {e}")
            continue
    
    # Print final statistics
    print("\n" + "="*60)
    print("PROCESSING COMPLETE!")
    print("="*60)
    print(f"Files processed: {files_processed}/{num_files}")
    print(f"Total proteins processed: {total_proteins_processed}")
    print(f"Failed proteins: {failed_proteins}")
    print(f"Success rate: {(total_proteins_processed/(total_proteins_processed + failed_proteins)*100):.2f}%")
    print(f"Output directory: {os.path.abspath(output_dir)}")
    
    # Verify a few sample outputs
    print("\nVerifying sample outputs...")
    sample_dirs = []
    for item in os.listdir(output_dir):
        item_path = os.path.join(output_dir, item)
        if os.path.isdir(item_path):
            sample_dirs.append(item)
            if len(sample_dirs) >= 3:  # Check first 3 found
                break
    
    for sample_dir in sample_dirs:
        sample_path = os.path.join(output_dir, sample_dir)
        files = os.listdir(sample_path)
        print(f"  {sample_dir}: {len(files)} files - {', '.join(files)}")

def verify_pytorch_format(protein_id=None, output_dir="protein_data_pdb"):
    """
    Utility function to verify the PyTorch format for a specific protein.
    """
    if protein_id is None:
        # Find the first available protein
        for item in os.listdir(output_dir):
            if os.path.isdir(os.path.join(output_dir, item)):
                protein_id = item
                break
    
    if protein_id is None:
        print("No protein directories found!")
        return
    
    protein_dir = os.path.join(output_dir, protein_id)
    print(f"\nVerifying PyTorch format for protein: {protein_id}")
    print("-" * 50)
    
    # Check each file type
    file_types = {
        'L.csv': 'Sequence length',
        'mf_labels.pt': 'Molecular Function labels',
        'cc_labels.pt': 'Cellular Component labels',
        'bp_labels.pt': 'Biological Process labels',
        'ca_dist_matrix.pt': 'C-alpha distance matrix',
        'cb_dist_matrix.pt': 'C-beta distance matrix'
    }
    
    for filename, description in file_types.items():
        filepath = os.path.join(protein_dir, filename)
        if os.path.exists(filepath):
            if filename.endswith('.pt'):
                tensor = torch.load(filepath)
                print(f"✓ {description}: shape={tensor.shape}, dtype={tensor.dtype}")
            elif filename.endswith('.csv'):
                with open(filepath, 'r') as f:
                    content = f.read().strip().split('\n')
                    print(f"✓ {description}: {content[1] if len(content) > 1 else 'N/A'}")
        else:
            print(f"✗ {description}: File not found")

if __name__ == "__main__":
    # Run the main processing
    process_tfrecords_to_pytorch_format()
    
    # Verify the format for one protein
    print("\n" + "="*60)
    print("VERIFICATION")
    print("="*60)
    verify_pytorch_format()
    
    print("\nScript completed successfully!")
    print("You can now use the organized data for PyTorch model inference.")

Created output directory: protein_data_pdb
Processing 30 TFRecord files...


Processing TFRecord files:   0%|          | 0/30 [00:00<?, ?it/s]


Counting records in file 00...
Found 996 records in file 00


Processing TFRecord files:   3%|▎         | 1/30 [01:07<32:48, 67.87s/it]

Completed file 00: 996 proteins processed

Counting records in file 01...
Found 996 records in file 01


Processing TFRecord files:   7%|▋         | 2/30 [01:49<24:30, 52.53s/it]

Completed file 01: 996 proteins processed

Counting records in file 02...
Found 996 records in file 02


Processing TFRecord files:  10%|█         | 3/30 [02:38<22:58, 51.04s/it]

Completed file 02: 996 proteins processed

Counting records in file 03...
Found 996 records in file 03


Processing TFRecord files:  13%|█▎        | 4/30 [03:26<21:26, 49.49s/it]

Completed file 03: 996 proteins processed

Counting records in file 04...
Found 994 records in file 04


Processing TFRecord files:  17%|█▋        | 5/30 [04:11<19:59, 48.00s/it]

Completed file 04: 994 proteins processed

Counting records in file 05...
Found 995 records in file 05


Processing TFRecord files:  20%|██        | 6/30 [04:56<18:51, 47.15s/it]

Completed file 05: 995 proteins processed

Counting records in file 06...
Found 996 records in file 06


Processing TFRecord files:  23%|██▎       | 7/30 [05:37<17:15, 45.01s/it]

Completed file 06: 996 proteins processed

Counting records in file 07...
Found 996 records in file 07


Processing TFRecord files:  27%|██▋       | 8/30 [06:20<16:19, 44.52s/it]

Completed file 07: 996 proteins processed

Counting records in file 08...
Found 996 records in file 08


Processing TFRecord files:  30%|███       | 9/30 [07:07<15:47, 45.13s/it]

Completed file 08: 996 proteins processed

Counting records in file 09...
Found 996 records in file 09


Processing TFRecord files:  33%|███▎      | 10/30 [07:52<15:01, 45.09s/it]

Completed file 09: 996 proteins processed

Counting records in file 10...
Found 996 records in file 10


Processing TFRecord files:  37%|███▋      | 11/30 [08:39<14:28, 45.72s/it]

Completed file 10: 996 proteins processed

Counting records in file 11...
Found 996 records in file 11


Processing TFRecord files:  40%|████      | 12/30 [09:24<13:39, 45.54s/it]

Completed file 11: 996 proteins processed

Counting records in file 12...
Found 996 records in file 12


Processing TFRecord files:  43%|████▎     | 13/30 [10:09<12:50, 45.35s/it]

Completed file 12: 996 proteins processed

Counting records in file 13...
Found 995 records in file 13


Processing TFRecord files:  47%|████▋     | 14/30 [10:50<11:42, 43.88s/it]

Completed file 13: 995 proteins processed

Counting records in file 14...
Found 996 records in file 14


Processing TFRecord files:  50%|█████     | 15/30 [11:31<10:46, 43.07s/it]

Completed file 14: 996 proteins processed

Counting records in file 15...
Found 995 records in file 15


Processing TFRecord files:  53%|█████▎    | 16/30 [12:17<10:15, 44.00s/it]

Completed file 15: 995 proteins processed

Counting records in file 16...
Found 996 records in file 16


Processing TFRecord files:  57%|█████▋    | 17/30 [13:01<09:33, 44.13s/it]

Completed file 16: 996 proteins processed

Counting records in file 17...
Found 996 records in file 17


Processing TFRecord files:  60%|██████    | 18/30 [13:51<09:10, 45.91s/it]

Completed file 17: 996 proteins processed

Counting records in file 18...
Found 996 records in file 18


Processing TFRecord files:  63%|██████▎   | 19/30 [14:36<08:21, 45.59s/it]

Completed file 18: 996 proteins processed

Counting records in file 19...
Found 996 records in file 19


Processing TFRecord files:  67%|██████▋   | 20/30 [15:24<07:41, 46.18s/it]

Completed file 19: 996 proteins processed

Counting records in file 20...
Found 995 records in file 20


Processing TFRecord files:  70%|███████   | 21/30 [16:09<06:52, 45.80s/it]

Completed file 20: 995 proteins processed

Counting records in file 21...
Found 996 records in file 21


Processing TFRecord files:  73%|███████▎  | 22/30 [16:54<06:05, 45.64s/it]

Completed file 21: 996 proteins processed

Counting records in file 22...
Found 996 records in file 22


Processing TFRecord files:  77%|███████▋  | 23/30 [18:19<06:41, 57.34s/it]

Completed file 22: 996 proteins processed

Counting records in file 23...
Found 995 records in file 23


Processing TFRecord files:  80%|████████  | 24/30 [19:01<05:17, 52.84s/it]

Completed file 23: 995 proteins processed

Counting records in file 24...
Found 996 records in file 24


Processing TFRecord files:  83%|████████▎ | 25/30 [19:40<04:03, 48.77s/it]

Completed file 24: 996 proteins processed

Counting records in file 25...
Found 996 records in file 25


Processing TFRecord files:  87%|████████▋ | 26/30 [20:19<03:03, 45.85s/it]

Completed file 25: 996 proteins processed

Counting records in file 26...
Found 996 records in file 26


Processing TFRecord files:  90%|█████████ | 27/30 [20:58<02:10, 43.64s/it]

Completed file 26: 996 proteins processed

Counting records in file 27...
Found 996 records in file 27


Processing TFRecord files:  93%|█████████▎| 28/30 [21:46<01:29, 44.93s/it]

Completed file 27: 996 proteins processed

Counting records in file 28...
Found 996 records in file 28


Processing TFRecord files:  97%|█████████▋| 29/30 [22:31<00:45, 45.05s/it]

Completed file 28: 996 proteins processed

Counting records in file 29...
Found 1018 records in file 29


Processing TFRecord files: 100%|██████████| 30/30 [23:19<00:00, 46.65s/it]

Completed file 29: 1018 proteins processed

PROCESSING COMPLETE!
Files processed: 30/30
Total proteins processed: 29895
Failed proteins: 0
Success rate: 100.00%
Output directory: c:\Users\rfrjo\Documents\Codebases\PFP_Testing\protein_data_pdb

Verifying sample outputs...
  154L-A: 6 files - bp_labels.pt, ca_dist_matrix.pt, cb_dist_matrix.pt, cc_labels.pt, L.csv, mf_labels.pt
  155C-A: 6 files - bp_labels.pt, ca_dist_matrix.pt, cb_dist_matrix.pt, cc_labels.pt, L.csv, mf_labels.pt
  16PK-A: 6 files - bp_labels.pt, ca_dist_matrix.pt, cb_dist_matrix.pt, cc_labels.pt, L.csv, mf_labels.pt

VERIFICATION

Verifying PyTorch format for protein: 154L-A
--------------------------------------------------
✓ Sequence length: 185
✓ Molecular Function labels: shape=torch.Size([489]), dtype=torch.int64
✓ Cellular Component labels: shape=torch.Size([320]), dtype=torch.int64
✓ Biological Process labels: shape=torch.Size([1943]), dtype=torch.int64
✓ C-alpha distance matrix: shape=torch.Size([185, 185]), dt




In [4]:
import tensorflow as tf
import torch
import numpy as np
import os
import csv
from tqdm import tqdm
import math
import warnings
from pathlib import Path

def process_tfrecords_to_pytorch_format():
    """
    Processes all PDB-GO validation TFRecord files and organizes data into PyTorch-friendly format.
    
    Directory structure created:
    protein_data_pdb_val/
    ├── {protein_id}/
    │   ├── mf_labels.pt (Molecular Function labels)
    │   ├── cc_labels.pt (Cellular Component labels) 
    │   ├── bp_labels.pt (Biological Process labels)
    │   ├── ca_dist_matrix.pt (C-alpha distance matrix)
    │   ├── cb_dist_matrix.pt (C-beta distance matrix)
    │   └── L.csv (sequence length)
    """
    
    # Configuration
    base_path = r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\PDB-GO"
    output_dir = "protein_data_pdb_val"
    num_files = 3
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    print(f"Created output directory: {output_dir}")
    
    # Statistics tracking
    total_proteins_processed = 0
    failed_proteins = 0
    files_processed = 0
    
    # Process all 3 validation TFRecord files
    print(f"Processing {num_files} validation TFRecord files...")
    
    for file_idx in tqdm(range(num_files), desc="Processing TFRecord files", position=0):
        file_path = os.path.join(base_path, f"PDB_GO_valid_{file_idx:02d}-of-03.tfrecords")
        
        # Check if file exists
        if not os.path.exists(file_path):
            print(f"Warning: File {file_path} not found, skipping...")
            continue
        
        try:
            # Count records first for progress bar
            print(f"\nCounting records in file {file_idx:02d}...")
            dataset = tf.data.TFRecordDataset(file_path)
            total_records = sum(1 for _ in dataset)
            print(f"Found {total_records} records in file {file_idx:02d}")
            
            # Process records
            dataset = tf.data.TFRecordDataset(file_path)
            file_proteins_processed = 0
            
            for raw_record in tqdm(dataset, desc=f"File {file_idx:02d}", total=total_records, position=1, leave=False):
                try:
                    # Parse the TFRecord
                    example = tf.train.Example()
                    example.ParseFromString(raw_record.numpy())
                    
                    # Extract protein ID
                    prot_id = example.features.feature['prot_id'].bytes_list.value[0].decode('utf-8')
                    
                    # Create protein directory
                    protein_dir = os.path.join(output_dir, prot_id)
                    os.makedirs(protein_dir, exist_ok=True)
                    
                    # Extract L value first (sequence length)
                    L_value = None
                    if 'L' in example.features.feature:
                        L_value = example.features.feature['L'].int64_list.value[0]
                        
                        # Save L value to CSV
                        with open(os.path.join(protein_dir, 'L.csv'), 'w', newline='') as csvfile:
                            writer = csv.writer(csvfile)
                            writer.writerow(['L'])
                            writer.writerow([L_value])
                    
                    # Extract and save label features
                    label_features = {
                        'mf_labels': 'Molecular Function',
                        'cc_labels': 'Cellular Component', 
                        'bp_labels': 'Biological Process'
                    }
                    
                    for label_type, label_name in label_features.items():
                        if label_type in example.features.feature:
                            labels = list(example.features.feature[label_type].int64_list.value)
                            labels_tensor = torch.tensor(labels, dtype=torch.long)
                            torch.save(labels_tensor, os.path.join(protein_dir, f"{label_type}.pt"))
                    
                    # Extract and save distance matrices
                    matrix_features = ['ca_dist_matrix', 'cb_dist_matrix']
                    
                    for matrix_type in matrix_features:
                        if matrix_type in example.features.feature:
                            matrix_data = list(example.features.feature[matrix_type].float_list.value)
                            matrix_tensor = torch.tensor(matrix_data, dtype=torch.float32)
                            
                            # Try to reshape to square matrix if L is available
                            if L_value is not None:
                                expected_size = L_value * L_value
                                if len(matrix_data) == expected_size:
                                    matrix_tensor = matrix_tensor.reshape(L_value, L_value)
                                    # print(f"  Reshaped {matrix_type} to ({L_value}, {L_value}) for {prot_id}")
                                else:
                                    print(f"  Warning: {matrix_type} size {len(matrix_data)} doesn't match L²={expected_size} for {prot_id}")
                            
                            torch.save(matrix_tensor, os.path.join(protein_dir, f"{matrix_type}.pt"))
                    
                    file_proteins_processed += 1
                    total_proteins_processed += 1
                    
                except Exception as e:
                    print(f"\nError processing protein record in file {file_idx:02d}: {e}")
                    failed_proteins += 1
                    continue
            
            files_processed += 1
            print(f"Completed file {file_idx:02d}: {file_proteins_processed} proteins processed")
            
        except Exception as e:
            print(f"Error processing file {file_idx:02d}: {e}")
            continue
    
    # Print final statistics
    print("\n" + "="*60)
    print("PROCESSING COMPLETE!")
    print("="*60)
    print(f"Files processed: {files_processed}/{num_files}")
    print(f"Total proteins processed: {total_proteins_processed}")
    print(f"Failed proteins: {failed_proteins}")
    print(f"Success rate: {(total_proteins_processed/(total_proteins_processed + failed_proteins)*100):.2f}%")
    print(f"Output directory: {os.path.abspath(output_dir)}")
    
    # Verify a few sample outputs
    print("\nVerifying sample outputs...")
    sample_dirs = []
    for item in os.listdir(output_dir):
        item_path = os.path.join(output_dir, item)
        if os.path.isdir(item_path):
            sample_dirs.append(item)
            if len(sample_dirs) >= 3:  # Check first 3 found
                break
    
    for sample_dir in sample_dirs:
        sample_path = os.path.join(output_dir, sample_dir)
        files = os.listdir(sample_path)
        print(f"  {sample_dir}: {len(files)} files - {', '.join(files)}")

def verify_pytorch_format(protein_id=None, output_dir="protein_data_pdb_val"):
    """
    Utility function to verify the PyTorch format for a specific protein.
    """
    if protein_id is None:
        # Find the first available protein
        for item in os.listdir(output_dir):
            if os.path.isdir(os.path.join(output_dir, item)):
                protein_id = item
                break
    
    if protein_id is None:
        print("No protein directories found!")
        return
    
    protein_dir = os.path.join(output_dir, protein_id)
    print(f"\nVerifying PyTorch format for protein: {protein_id}")
    print("-" * 50)
    
    # Check each file type
    file_types = {
        'L.csv': 'Sequence length',
        'mf_labels.pt': 'Molecular Function labels',
        'cc_labels.pt': 'Cellular Component labels',
        'bp_labels.pt': 'Biological Process labels',
        'ca_dist_matrix.pt': 'C-alpha distance matrix',
        'cb_dist_matrix.pt': 'C-beta distance matrix'
    }
    
    for filename, description in file_types.items():
        filepath = os.path.join(protein_dir, filename)
        if os.path.exists(filepath):
            if filename.endswith('.pt'):
                tensor = torch.load(filepath)
                print(f"✓ {description}: shape={tensor.shape}, dtype={tensor.dtype}")
            elif filename.endswith('.csv'):
                with open(filepath, 'r') as f:
                    content = f.read().strip().split('\n')
                    print(f"✓ {description}: {content[1] if len(content) > 1 else 'N/A'}")
        else:
            print(f"✗ {description}: File not found")

if __name__ == "__main__":
    # Run the main processing
    process_tfrecords_to_pytorch_format()
    
    # Verify the format for one protein
    print("\n" + "="*60)
    print("VERIFICATION")
    print("="*60)
    verify_pytorch_format()
    
    print("\nScript completed successfully!")
    print("You can now use the organized validation data for PyTorch model inference.")

Created output directory: protein_data_pdb_val
Processing 3 validation TFRecord files...


Processing TFRecord files:   0%|          | 0/3 [00:00<?, ?it/s]


Counting records in file 00...
Found 1106 records in file 00


Processing TFRecord files:  33%|███▎      | 1/3 [00:54<01:49, 54.74s/it]

Completed file 00: 1106 proteins processed

Counting records in file 01...
Found 1107 records in file 01


Processing TFRecord files:  67%|██████▋   | 2/3 [01:47<00:53, 53.85s/it]

Completed file 01: 1107 proteins processed

Counting records in file 02...
Found 1108 records in file 02


Processing TFRecord files: 100%|██████████| 3/3 [02:36<00:00, 52.24s/it]


Completed file 02: 1108 proteins processed

PROCESSING COMPLETE!
Files processed: 3/3
Total proteins processed: 3321
Failed proteins: 0
Success rate: 100.00%
Output directory: c:\Users\rfrjo\Documents\Codebases\PFP_Testing\protein_data_pdb_val

Verifying sample outputs...
  192L-A: 6 files - bp_labels.pt, ca_dist_matrix.pt, cb_dist_matrix.pt, cc_labels.pt, L.csv, mf_labels.pt
  1A0A-A: 6 files - bp_labels.pt, ca_dist_matrix.pt, cb_dist_matrix.pt, cc_labels.pt, L.csv, mf_labels.pt
  1A21-A: 6 files - bp_labels.pt, ca_dist_matrix.pt, cb_dist_matrix.pt, cc_labels.pt, L.csv, mf_labels.pt

VERIFICATION

Verifying PyTorch format for protein: 192L-A
--------------------------------------------------
✓ Sequence length: 164
✓ Molecular Function labels: shape=torch.Size([489]), dtype=torch.int64
✓ Cellular Component labels: shape=torch.Size([320]), dtype=torch.int64
✓ Biological Process labels: shape=torch.Size([1943]), dtype=torch.int64
✓ C-alpha distance matrix: shape=torch.Size([164, 164]), d

<h1>Fasta Sequences: ESM-C Embeddings</h1>

In [15]:
import os
from pathlib import Path
import re
from tqdm import tqdm

def parse_fasta_sequences(fasta_path):
    """
    Parse FASTA file and create a mapping of protein_id -> sequence.
    Expected format: >{protein_id} nrPDB
    """
    protein_sequences = {}
    current_protein_id = None
    current_sequence = []
    
    print(f"Parsing FASTA file: {fasta_path}")
    
    try:
        with open(fasta_path, 'r') as f:
            lines = f.readlines()
            
        for line in tqdm(lines, desc="Reading FASTA"):
            line = line.strip()
            
            if line.startswith('>'):
                # Save previous protein sequence if exists
                if current_protein_id is not None:
                    protein_sequences[current_protein_id] = ''.join(current_sequence)
                
                # Start new protein
                parts = line.split()
                if len(parts) >= 1:
                    current_protein_id = parts[0][1:]  # Remove the '>' character
                    current_sequence = []
            
            elif line and current_protein_id is not None:
                # Add sequence line (remove any whitespace)
                current_sequence.append(line.strip())
        
        # Don't forget the last protein
        if current_protein_id is not None:
            protein_sequences[current_protein_id] = ''.join(current_sequence)
    
    except FileNotFoundError:
        print(f"Error: FASTA file not found at {fasta_path}")
        return {}
    except Exception as e:
        print(f"Error reading FASTA file: {e}")
        return {}
    
    print(f"Successfully parsed {len(protein_sequences)} protein sequences")
    return protein_sequences

def append_sequences_to_directory(protein_data_dir, protein_sequences):
    """
    Append sequences to each protein directory.
    """
    if not os.path.exists(protein_data_dir):
        print(f"Error: Protein data directory not found at {protein_data_dir}")
        return 0, 0
    
    # Get all protein directories
    protein_dirs = [d for d in os.listdir(protein_data_dir) 
                   if os.path.isdir(os.path.join(protein_data_dir, d))]
    
    print(f"\nProcessing {len(protein_dirs)} proteins in {protein_data_dir}")
    
    sequences_added = 0
    sequences_missing = 0
    
    for protein_id in tqdm(protein_dirs, desc=f"Adding sequences to {protein_data_dir}"):
        protein_path = os.path.join(protein_data_dir, protein_id)
        
        # Check if sequence exists for this protein
        if protein_id in protein_sequences:
            sequence = protein_sequences[protein_id]
            
            # Save sequence as text file
            sequence_file_path = os.path.join(protein_path, 'sequence.txt')
            try:
                with open(sequence_file_path, 'w') as f:
                    f.write(sequence)
                sequences_added += 1
            except Exception as e:
                print(f"Error writing sequence for {protein_id}: {e}")
                sequences_missing += 1
        else:
            print(f"Warning: No sequence found for protein {protein_id}")
            sequences_missing += 1
    
    return sequences_added, sequences_missing

def verify_sequence_addition(protein_data_dir, sample_size=5):
    """
    Verify that sequences were added correctly by checking a few samples.
    """
    print(f"\nVerifying sequence addition in {protein_data_dir}:")
    print("-" * 60)
    
    if not os.path.exists(protein_data_dir):
        print(f"Directory not found: {protein_data_dir}")
        return
    
    # Get sample proteins
    protein_dirs = [d for d in os.listdir(protein_data_dir) 
                   if os.path.isdir(os.path.join(protein_data_dir, d))]
    
    sample_proteins = protein_dirs[:sample_size]
    
    for protein_id in sample_proteins:
        protein_path = os.path.join(protein_data_dir, protein_id)
        sequence_file = os.path.join(protein_path, 'sequence.txt')
        
        if os.path.exists(sequence_file):
            try:
                with open(sequence_file, 'r') as f:
                    sequence = f.read().strip()
                
                files = os.listdir(protein_path)
                print(f"✓ {protein_id}: {len(files)} files, sequence length: {len(sequence)}")
                print(f"  Files: {', '.join(sorted(files))}")
                if len(sequence) > 50:
                    print(f"  Sequence preview: {sequence[:50]}...")
                else:
                    print(f"  Sequence: {sequence}")
                print()
            except Exception as e:
                print(f"✗ {protein_id}: Error reading sequence - {e}")
        else:
            files = os.listdir(protein_path)
            print(f"✗ {protein_id}: No sequence.txt found ({len(files)} files total)")

def main():
    """
    Main function to append sequences to both protein data directories.
    """
    
    # Configuration
    fasta_path = r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\nrPDB-GO_2019.06.18_sequences.fasta"
    directories = ["protein_data_pdb", "protein_data_pdb_val"]
    
    print("="*70)
    print("PROTEIN SEQUENCE APPENDER")
    print("="*70)
    
    # Parse FASTA file once
    protein_sequences = parse_fasta_sequences(fasta_path)
    
    if len(protein_sequences) == 0:
        print("Error: No sequences found. Exiting.")
        return
    
    # Process each directory
    total_added = 0
    total_missing = 0
    
    for directory in directories:
        print(f"\n{'='*70}")
        print(f"PROCESSING DIRECTORY: {directory}")
        print(f"{'='*70}")
        
        if os.path.exists(directory):
            added, missing = append_sequences_to_directory(directory, protein_sequences)
            total_added += added
            total_missing += missing
            
            print(f"\nResults for {directory}:")
            print(f"  ✓ Sequences added: {added}")
            print(f"  ✗ Sequences missing: {missing}")
            
            # Verify a few samples
            verify_sequence_addition(directory)
        else:
            print(f"Directory {directory} not found, skipping...")
    
    # Final summary
    print("="*70)
    print("SEQUENCE ADDITION COMPLETE!")
    print("="*70)
    print(f"Total sequences added: {total_added}")
    print(f"Total sequences missing: {total_missing}")
    print(f"Success rate: {(total_added/(total_added + total_missing)*100):.2f}%")
    
    print(f"\nSequences saved as 'sequence.txt' in each protein directory.")
    print(f"Ready for ESM embedding preprocessing!")

if __name__ == "__main__":
    main()

PROTEIN SEQUENCE APPENDER
Parsing FASTA file: C:\Users\rfrjo\Documents\Codebases\PFP_Testing\nrPDB-GO_2019.06.18_sequences.fasta


Reading FASTA: 100%|██████████| 227386/227386 [00:00<00:00, 1529900.19it/s]


Successfully parsed 36641 protein sequences

PROCESSING DIRECTORY: protein_data_pdb


KeyboardInterrupt: 

In [None]:
import os
import time
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig

def generate_esmc600m_hf(seqs, device):
    """
    Generate ESM-C 600M embeddings via ESM package.
    Returns a list of [L, D] numpy arrays.
    """
    # Load model
    client = ESMC.from_pretrained("esmc_600m").to(device)
    
    embeddings = []
    for seq in tqdm(seqs, desc="Embedding sequences"):
        # Create protein object
        protein = ESMProtein(sequence=seq)
        
        # Get embeddings
        with torch.no_grad():
            protein_tensor = client.encode(protein)
            logits_output = client.logits(
                protein_tensor, 
                LogitsConfig(sequence=True, return_embeddings=True)
            )
            # Get per-residue embeddings
            emb = logits_output.embeddings.cpu().numpy()
        
        embeddings.append(emb)
    
    return embeddings

def main():
    # --- Configuration ---
    dataset_dir = "protein_data/proteins"
    output_dir = "esmc_embeddings"
    num_proteins = 100

    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Gather first N protein IDs and sequences
    protein_ids = sorted(os.listdir(dataset_dir))[:num_proteins]
    seqs = []
    for pid in protein_ids:
        fasta_path = os.path.join(dataset_dir, pid, "sequence.fasta")
        with open(fasta_path, "r") as f:
            # Skip header lines and join sequence lines
            seq = "".join(line.strip() for line in f if not line.startswith(">"))
            seqs.append(seq)

    # Generate and time embeddings
    start = time.perf_counter()
    embeddings = generate_esmc600m_hf(seqs, device)
    elapsed = time.perf_counter() - start
    print(f"\nGenerated embeddings for {len(seqs)} proteins in {elapsed:.2f} seconds")

    # Save per-protein average embeddings
    for pid, emb in zip(protein_ids, embeddings):
        avg_emb = emb.mean(axis=0)  # [D]
        out_path = os.path.join(output_dir, f"{pid}_avg.npy")
        np.save(out_path, avg_emb)

    # Optionally, save full timing stats to CSV
    df = pd.DataFrame({
        "protein_id": protein_ids,
        "sequence_length": [emb.shape[0] for emb in embeddings],
        "avg_embedding_dim": [emb.shape[1] for emb in embeddings]
    })
    df.to_csv(os.path.join(output_dir, "embedding_stats.csv"), index=False)

if __name__ == "__main__":
    main()

<h1>MSA DATA</h1>

In [13]:
#!/usr/bin/env python3
"""
Check which PDB-chain IDs in a local `protein_data_pdb` directory
have pre-computed MSAs in the OpenProteinSet S3 bucket.

Changes vs. original script
---------------------------
* Preserve the chain identifier’s original case (only the 4-char
  PDB code is lower-cased) → 1A0C-A ➜ 1a0c_A
* Use `KeyCount` instead of `'Contents' in response` for clarity.
* Optional: set `Delimiter="/"` so that a single request is enough
  to know whether *any* object exists under the prefix.
"""

import os
import boto3
from botocore import UNSIGNED
from botocore.config import Config
from botocore.exceptions import ClientError


BUCKET = "openfold"
PREFIX_ROOT = "pdb/"                       # as documented

EXPECTED_MSA_KEYS = [
    "a3m/bfd_uniclust_hits.a3m",
    "a3m/mgnify_hits.a3m",
    "a3m/uniref90_hits.a3m",
    "hhr/pdb70_hits.hhr",
]


def get_local_protein_ids(protein_data_dir="protein_data_pdb", sample_size=1000):
    """Return up to `sample_size` sub-directory names (PDB-chain IDs)."""
    if not os.path.isdir(protein_data_dir):
        raise FileNotFoundError(f"{protein_data_dir} not found")
    dirs = sorted(d for d in os.listdir(protein_data_dir)
                  if os.path.isdir(os.path.join(protein_data_dir, d)))
    return dirs[:sample_size]


def convert_protein_id_format(protein_id: str) -> str:
    """
    Convert '1A0C-A' → '1a0c_A'
      • pdb id (first 4 chars) → lower case
      • chain id (after last '-') → keep original case
    """
    pdb, chain = protein_id.rsplit("-", 1)
    return f"{pdb.lower()}_{chain}"


def prefix_exists(s3, bucket: str, prefix: str) -> bool:
    """Cheap existence check – we just need to know if *something* is there."""
    resp = s3.list_objects_v2(
        Bucket=bucket,
        Prefix=prefix,
        Delimiter="/",   # don’t pull the entire sub-tree
        MaxKeys=1,
    )
    return resp.get("KeyCount", 0) > 0


def check_msa_files(s3, bucket: str, converted_id: str):
    """Return which of the 4 expected MSA files are present."""
    present, missing = [], []
    for relkey in EXPECTED_MSA_KEYS:
        key = f"{PREFIX_ROOT}{converted_id}/{relkey}"
        try:
            s3.head_object(Bucket=bucket, Key=key)
            present.append(relkey)
        except ClientError as e:
            if e.response["Error"]["Code"] != "404":
                raise
            missing.append(relkey)
    return present, missing


def main(sample_size=100):
    s3 = boto3.client("s3", region_name="us-east-1",
                      config=Config(signature_version=UNSIGNED))

    local_ids = get_local_protein_ids(sample_size=sample_size)

    matches, misses = [], []
    for pid in local_ids:
        cid = convert_protein_id_format(pid)
        pref = f"{PREFIX_ROOT}{cid}/"
        if prefix_exists(s3, BUCKET, pref):
            present, missing = check_msa_files(s3, BUCKET, cid)
            matches.append((pid, cid, present, missing))
        else:
            misses.append((pid, cid))

    # --- summary ------------------------------------------------------------
    print(f"\nScanned: {len(local_ids)} chains")
    print(f"Found : {len(matches)}  ({len(matches)/len(local_ids):.1%})")
    if matches:
        print("\n✔ Matches:")
        for pid, cid, pres, miss in matches:
            print(f"  {pid:10} → {cid:10}  {len(pres)}/4 MSA files")
    if misses:
        print("\n✘ Not found:")
        for pid, cid in misses:
            print(f"  {pid:10} → {cid}")


if __name__ == "__main__":
    main(sample_size=1000)   # adjust or remove arg to scan all



Scanned: 1000 chains
Found : 837  (83.7%)

✔ Matches:
  155C-A     → 155c_A      4/4 MSA files
  16VP-A     → 16vp_A      4/4 MSA files
  1914-A     → 1914_A      4/4 MSA files
  19HC-A     → 19hc_A      4/4 MSA files
  1A05-A     → 1a05_A      4/4 MSA files
  1A0C-A     → 1a0c_A      4/4 MSA files
  1A0D-A     → 1a0d_A      4/4 MSA files
  1A0E-A     → 1a0e_A      4/4 MSA files
  1A0H-A     → 1a0h_A      4/4 MSA files
  1A0I-A     → 1a0i_A      4/4 MSA files
  1A0J-A     → 1a0j_A      4/4 MSA files
  1A0Q-L     → 1a0q_L      4/4 MSA files
  1A0R-P     → 1a0r_P      4/4 MSA files
  1A14-H     → 1a14_H      4/4 MSA files
  1A14-L     → 1a14_L      4/4 MSA files
  1A17-A     → 1a17_A      3/4 MSA files
  1A1S-A     → 1a1s_A      4/4 MSA files
  1A1Z-A     → 1a1z_A      4/4 MSA files
  1A25-A     → 1a25_A      4/4 MSA files
  1A2A-A     → 1a2a_A      4/4 MSA files
  1A2O-A     → 1a2o_A      4/4 MSA files
  1A2Z-A     → 1a2z_A      4/4 MSA files
  1A3W-A     → 1a3w_A      4/4 MSA files
  

In [16]:
import boto3
from botocore import UNSIGNED
from botocore.config import Config
from botocore.exceptions import NoCredentialsError, ClientError

def test_aws_s3_access():
    """Test AWS S3 access to OpenFold database."""
    print("Accessing OpenFold database...")
    
    # Use the working configuration
    bucket_config = {'name': 'openfold', 'region': 'us-east-1'}
    
    try:
        print(f"Using bucket: {bucket_config['name']} in {bucket_config['region']}")
        
        # Configure S3 client for unsigned access
        s3_client = boto3.client(
            's3',
            region_name=bucket_config['region'],
            config=Config(signature_version=UNSIGNED)
        )
        
        # Try to list top-level contents
        print("  Listing top-level contents...")
        response = s3_client.list_objects_v2(
            Bucket=bucket_config['name'],
            Delimiter='/',
            MaxKeys=10
        )
        
        if 'CommonPrefixes' in response:
            print("  Top-level directories:")
            for prefix in response['CommonPrefixes']:
                print(f"    - {prefix['Prefix']}")
        
        # Try to list PDB directory contents
        print("  Checking for 'pdb/' directory...")
        print("  Getting first 3000 PDB IDs...")
        
        pdb_ids = []
        continuation_token = None
        
        # Keep fetching until we have 3000 or no more results
        while len(pdb_ids) < 3000:
            if continuation_token:
                pdb_response = s3_client.list_objects_v2(
                    Bucket=bucket_config['name'],
                    Prefix='pdb/',
                    Delimiter='/',
                    MaxKeys=1000,
                    ContinuationToken=continuation_token
                )
            else:
                pdb_response = s3_client.list_objects_v2(
                    Bucket=bucket_config['name'],
                    Prefix='pdb/',
                    Delimiter='/',
                    MaxKeys=1000
                )
            
            if 'CommonPrefixes' in pdb_response:
                for prefix in pdb_response['CommonPrefixes']:
                    if len(pdb_ids) >= 3000:
                        break
                    pdb_id = prefix['Prefix'].replace('pdb/', '').rstrip('/')
                    pdb_ids.append(pdb_id)
                
                print(f"    Retrieved {len(pdb_ids)} PDB IDs so far...")
                
                # Check if there are more results
                if pdb_response.get('IsTruncated', False):
                    continuation_token = pdb_response.get('NextContinuationToken')
                else:
                    break
            else:
                print("  ✗ No PDB directories found")
                break
        
        if pdb_ids:
            print(f"  All {len(pdb_ids)} PDB IDs found:")
            for i, pdb_id in enumerate(pdb_ids):
                print(f"    {i+1}. {pdb_id}")
            
            print(f"  ✓ Successfully found {len(pdb_ids)} PDB entries in {bucket_config['name']}")
            return bucket_config, pdb_ids
        else:
            print("  ✗ No PDB directories found")
            
    except ClientError as e:
        error_code = e.response['Error']['Code']
        if error_code == 'NoSuchBucket':
            print(f"  ✗ Bucket {bucket_config['name']} does not exist")
        elif error_code == 'AccessDenied':
            print(f"  ✗ Access denied to {bucket_config['name']}")
        else:
            print(f"  ✗ Error: {error_code} - {e.response['Error']['Message']}")
    except Exception as e:
        print(f"  ✗ Unexpected error: {str(e)}")
    
    return None, []

def main():
    """Main function to access OpenFold database and get 3000 PDB IDs."""
    print("="*60)
    print("OPENFOLD DATABASE ACCESS - GET 3000 PDB IDs")
    print("="*60)
    
    print("Required packages: boto3, requests")
    print("Install with: pip install boto3 requests")
    
    try:
        # Get PDB IDs from OpenFold database
        bucket_config, pdb_ids = test_aws_s3_access()
        
        if bucket_config and pdb_ids:
            print(f"\n✓ SUCCESS! Retrieved {len(pdb_ids)} PDB IDs from OpenFold database")
            print(f"  Bucket: {bucket_config['name']}")
            print(f"  Region: {bucket_config['region']}")
            return bucket_config, pdb_ids
        else:
            print(f"\n✗ Failed to retrieve PDB IDs")
            return None, []
        
    except ImportError as e:
        print(f"Import error: {e}")
        print("Please install required packages: pip install boto3 requests")
    except Exception as e:
        print(f"Unexpected error: {e}")
        return None, []

if __name__ == "__main__":
    main()

OPENFOLD DATABASE ACCESS - GET 3000 PDB IDs
Required packages: boto3, requests
Install with: pip install boto3 requests
Accessing OpenFold database...
Using bucket: openfold in us-east-1
  Listing top-level contents...
  Top-level directories:
    - alignment_db/
    - data_caches/
    - openfold_params/
    - openfold_soloseq_params/
    - pdb/
    - soloseq_embeddings/
    - uniclust30/
  Checking for 'pdb/' directory...
  Getting first 3000 PDB IDs...
    Retrieved 1000 PDB IDs so far...
    Retrieved 2000 PDB IDs so far...
    Retrieved 3000 PDB IDs so far...
  All 3000 PDB IDs found:
    1. 101m_A
    2. 102l_A
    3. 102m_A
    4. 103l_A
    5. 104l_A
    6. 104m_A
    7. 106m_A
    8. 107l_A
    9. 108l_A
    10. 109l_A
    11. 109m_A
    12. 10gs_A
    13. 10mh_A
    14. 110l_A
    15. 111l_A
    16. 112l_A
    17. 113l_A
    18. 114l_A
    19. 115l_A
    20. 117e_A
    21. 118l_A
    22. 119l_A
    23. 11as_A
    24. 11ba_A
    25. 11gs_A
    26. 120l_A
    27. 121p_A
    28. 

In [2]:
#!/usr/bin/env python3
"""
OpenProteinSet MSA-matcher *with sequence-rescue fallback*
[OPTIMIZED VERSION with parallel S3 lookups]

Step 1 – exact ID match (PARALLELIZED)
---------------------------------------
- convert `1A0C-A` → `1a0c_A`, look for `pdb/1a0c_A/…` in the public
  S3 bucket **openfold** using parallel threads for 5-10x speedup

Step 2 – identical-sequence rescue (PARALLELIZED for S3 checks)
---------------------------------------------------------------
- read the sequence of the missing chain from your
  `nrPDB-GO_2019.06.18_sequences.fasta` file  
- stream **once** through  
  https://www.bx.psu.edu/~thanh/Downloads/Protein-DNA%20Docking/Scripts%20and%20Lists/pdb_seqres.all.txt  
  and collect *all* PDB-chain IDs that have *exactly* the same sequence  
- for every such twin ID: check S3 existence in parallel batches

The script prints  
  – ORIGINAL hits (step 1)  
  – rescued mappings ("use MSA of …")  
  – still-missing chains (need a new MSA)

Adjust `SAMPLE_SIZE` at the top – `None` means "scan everything".
"""

import os, sys, hashlib, requests, boto3, itertools
from collections import defaultdict
from botocore.config import Config
from botocore import UNSIGNED
from botocore.exceptions import ClientError
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

# ——————————————————————————————————————————————————————  CONFIG  ——— #
PROT_DIR     = "protein_data_pdb"            # local hierarchy of chains
LOCAL_FASTA  = "nrPDB-GO_2019.06.18_sequences.fasta"
REMOTE_URL   = ("https://www.bx.psu.edu/~thanh/Downloads/Protein-DNA%20"
                "Docking/Scripts%20and%20Lists/pdb_seqres.all.txt")
BUCKET       = "openfold"
PREFIX_ROOT  = "pdb/"
REGION       = "us-east-1"
SAMPLE_SIZE  = None          # None → full scan; else first N sub-dirs
EXPECTED_MSA = ["a3m/bfd_uniclust_hits.a3m", "a3m/mgnify_hits.a3m",
                "a3m/uniref90_hits.a3m",  "hhr/pdb70_hits.hhr"]
MAX_WORKERS  = 20           # Number of parallel S3 threads
# ——————————————————————————————————————————————————————————————— #

def convert_id(pid_hyphen: str) -> str:
    pdb, chain = pid_hyphen.rsplit("-", 1)
    return f"{pdb.lower()}_{chain}"

def prefix_exists(client, cid: str) -> bool:
    resp = client.list_objects_v2(Bucket=BUCKET, Prefix=f"{PREFIX_ROOT}{cid}/",
                                  Delimiter="/", MaxKeys=1)
    return resp.get("KeyCount", 0) > 0

def check_protein_exists(pid: str, s3_client) -> tuple:
    """Check if a protein exists and return (pid, cid, exists)"""
    cid = convert_id(pid)
    exists = prefix_exists(s3_client, cid)
    return pid, cid, exists

def check_twin_exists(twin_id: str, s3_client) -> tuple:
    """Check if a twin ID exists and return (twin_id, exists)"""
    exists = prefix_exists(s3_client, twin_id)
    return twin_id, exists

# ---------- FASTA helpers -------------------------------------------------- #
def parse_fasta(stream):
    """yield (header, seq) from an iterable of lines"""
    header, seq_lines = None, []
    for line in stream:
        line = line.strip()
        if line.startswith(">"):
            if header:
                yield header, "".join(seq_lines)
            header, seq_lines = line[1:], []
        else:
            seq_lines.append(line)
    if header:
        yield header, "".join(seq_lines)

def load_local_fasta(path=LOCAL_FASTA):
    """id → sequence  (id in hyphen form, e.g. 1A0C-A)"""
    id2seq = {}
    with open(path, "r", encoding="utf8") as fh:
        for hdr, seq in tqdm(parse_fasta(fh), desc="Loading local FASTA"):
            # header is like  >154L-A nrPDB
            pid = hdr.split()[0]           # keep '154L-A'
            id2seq[pid.upper()] = seq.upper()
    return id2seq

def md5(s: str) -> str:
    return hashlib.md5(s.encode()).hexdigest()

# ---------- main ----------------------------------------------------------- #
def main():
    if not os.path.isdir(PROT_DIR):
        sys.exit(f"ERROR: '{PROT_DIR}' not found")

    # 0. collect local IDs -------------------------------------------------- #
    local_ids = sorted(d for d in os.listdir(PROT_DIR)
                       if os.path.isdir(os.path.join(PROT_DIR, d)))
    if SAMPLE_SIZE:
        local_ids = local_ids[:SAMPLE_SIZE]
    print(f"Scanning {len(local_ids)} chains …")

    s3 = boto3.client("s3", region_name=REGION,
                      config=Config(signature_version=UNSIGNED))

    hits, misses = [], []                 # (orig_id_hyphen, cid)

    # 1. exact-ID lookup [PARALLELIZED] ------------------------------------ #
    print("Checking exact ID matches (parallel)...")
    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        # Submit all tasks
        future_to_pid = {executor.submit(check_protein_exists, pid, s3): pid 
                         for pid in local_ids}
        
        # Process results with progress bar
        for future in tqdm(as_completed(future_to_pid), total=len(local_ids), 
                          desc="Checking exact ID matches"):
            pid, cid, exists = future.result()
            if exists:
                hits.append((pid, cid))
            else:
                misses.append((pid, cid))

    print(f"  exact matches : {len(hits)}")
    print(f"  missing       : {len(misses)}")

    if not misses:
        print("All chains found – nothing to rescue.")
        return

    # 2. build seq → representative-cid dict from the hits ----------------- #
    print("\nBuilding hash map for sequences that are already present …")
    local_fasta = load_local_fasta()
    seq_hash_to_rep = {}
    for pid_h, cid in tqdm(hits, desc="Building sequence hash map"):
        seq = local_fasta.get(pid_h.upper())
        if not seq:           # not in the FASTA file (rare)
            continue
        seq_hash_to_rep.setdefault(md5(seq), cid)

    # 3. gather sequences of the missing set ------------------------------- #
    miss_seq_map = {}       # pid_hyphen → sequence (upper‐case)
    seqs_needed  = set()    # hashes we must look for in remote file
    for pid_h, _ in tqdm(misses, desc="Gathering missing sequences"):
        seq = local_fasta.get(pid_h.upper())
        if not seq:
            continue
        miss_seq_map[pid_h] = seq
        seqs_needed.add(md5(seq))

    if not seqs_needed:
        print("No sequences for the missing set found in local FASTA.")
        return

    # 4. scan the remote pdb_seqres file *once* ---------------------------- #
    print(f"Streaming remote sequence catalogue to look for "
          f"{len(seqs_needed)} unique sequences …")
    remote_seq_map = defaultdict(list)   # hash → [remote_id1, …]

    with requests.get(REMOTE_URL, stream=True, timeout=30) as r:
        r.raise_for_status()
        for hdr, seq in tqdm(parse_fasta(r.iter_lines(decode_unicode=True)), 
                             desc="Scanning remote sequences"):
            h = md5(seq.upper())
            if h in seqs_needed:
                rid = hdr.split()[0]      # e.g. '101m_A'
                remote_seq_map[h].append(rid)

    # 5. rescue [PARALLELIZED S3 CHECKS] ----------------------------------- #
    rescued, still = {}, []
    
    # First pass: check sequences already in our hash map
    fast_rescued = []
    needs_s3_check = []
    
    for pid_h, cid in misses:
        seq = miss_seq_map.get(pid_h)
        if not seq:
            still.append(pid_h)
            continue
        h = md5(seq)
        # first try a representative already present
        if h in seq_hash_to_rep:
            rescued[pid_h] = seq_hash_to_rep[h]
            fast_rescued.append(pid_h)
        else:
            needs_s3_check.append((pid_h, cid, h))
    
    # Second pass: parallel S3 checks for remaining sequences
    if needs_s3_check:
        print(f"Checking S3 for {len(needs_s3_check)} potential rescues (parallel)...")
        
        # Collect all unique twin IDs to check
        twins_to_check = {}  # twin_id -> [(pid_h, cid, h), ...]
        for pid_h, cid, h in needs_s3_check:
            for twin_id in remote_seq_map.get(h, []):
                if twin_id.lower() != cid.lower():
                    if twin_id not in twins_to_check:
                        twins_to_check[twin_id] = []
                    twins_to_check[twin_id].append((pid_h, cid, h))
        
        # Check all twins in parallel
        twin_exists = {}
        with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
            future_to_twin = {executor.submit(check_twin_exists, twin_id, s3): twin_id 
                             for twin_id in twins_to_check}
            
            for future in tqdm(as_completed(future_to_twin), total=len(twins_to_check),
                              desc="Checking twin IDs in S3"):
                twin_id, exists = future.result()
                twin_exists[twin_id] = exists
        
        # Process results
        for pid_h, cid, h in needs_s3_check:
            rescued_flag = False
            for twin_id in remote_seq_map.get(h, []):
                if twin_id.lower() != cid.lower() and twin_exists.get(twin_id, False):
                    rescued[pid_h] = twin_id
                    seq_hash_to_rep[h] = twin_id   # cache
                    rescued_flag = True
                    break
            if not rescued_flag:
                still.append(pid_h)

    # 6. report ------------------------------------------------------------ #
    print("\n================  RESULT  =================")
    print(f"hit by ID           : {len(hits)}")
    print(f"rescued by sequence : {len(rescued)}")
    print(f"still missing       : {len(still)}\n")

    if rescued:
        print("Rescued mappings (reuse MSA of):")
        for m, r in sorted(rescued.items()):
            print(f"  {m:10}  →  {r}")
    if still:
        print("\nNo identical MSA available for:")
        for m in still:
            print(f"  {m}")

    # 7. save results to TSV file ----------------------------------------- #
    print("\nSaving results to openfold_results.tsv ...")
    
    with open("openfold_results__sample.tsv", "w") as f:
        # Write header
        f.write("status\toriginal_id\tmatched_id\tconverted_id\tnotes\n")
        
        # Write direct hits
        for pid_h, cid in sorted(hits):
            f.write(f"DIRECT\t{pid_h}\t{pid_h}\t{cid}\tExact ID match\n")
        
        # Write rescued entries
        for pid_h, rescue_cid in sorted(rescued.items()):
            f.write(f"RESCUED\t{pid_h}\t{rescue_cid}\t{rescue_cid}\tUsing sequence-identical protein\n")
        
        # Write still missing entries
        for pid_h in sorted(still):
            cid = convert_id(pid_h)  # Convert for consistency
            f.write(f"FAILED\t{pid_h}\t\t{cid}\tNo sequence match found\n")
    
    print(f"Results saved: {len(hits)} direct + {len(rescued)} rescued + {len(still)} failed")
    print(f"Total MSAs available: {len(hits) + len(rescued)} / {len(local_ids)} ({(len(hits) + len(rescued))/len(local_ids)*100:.1f}%)")

# --------------------------------------------------------------------------- #
if __name__ == "__main__":
    main()

Scanning 1000 chains …
Checking exact ID matches (parallel)...


Checking exact ID matches: 100%|██████████| 1000/1000 [00:09<00:00, 103.11it/s]


  exact matches : 837
  missing       : 163

Building hash map for sequences that are already present …


Loading local FASTA: 36641it [00:00, 200940.47it/s]
Building sequence hash map: 100%|██████████| 837/837 [00:00<00:00, 209215.28it/s]
Gathering missing sequences: 100%|██████████| 163/163 [00:00<00:00, 161167.27it/s]


Streaming remote sequence catalogue to look for 163 unique sequences …


Scanning remote sequences: 277869it [00:07, 36752.61it/s]


Checking S3 for 163 potential rescues (parallel)...


Checking twin IDs in S3: 100%|██████████| 806/806 [00:09<00:00, 81.80it/s] 


hit by ID           : 837
rescued by sequence : 161
still missing       : 2

Rescued mappings (reuse MSA of):
  154L-A      →  153l_A
  16PK-A      →  13pk_A
  1A4B-A      →  1a4a_A
  1A63-A      →  1a62_A
  1A6E-A      →  1a6d_A
  1AD4-A      →  1ad1_A
  1AFJ-A      →  1afi_A
  1AG8-A      →  1a4z_A
  1AK7-A      →  1ak6_A
  1AL8-A      →  1al7_A
  1ANV-A      →  1adu_A
  1AOM-A      →  1aof_A
  1AS8-A      →  1aq8_A
  1ASZ-A      →  1asy_A
  1ATB-A      →  1ata_A
  1AUX-A      →  1auv_A
  1AVA-A      →  1amy_A
  1AW2-A      →  1aw1_A
  1AY8-A      →  1ay4_A
  1AYN-2      →  1aym_2
  1B15-A      →  1a4u_A
  1B4N-A      →  1b25_A
  1B5E-A      →  1b49_A
  1B8H-A      →  1b77_A
  1B8U-A      →  1b8p_A
  1B9T-A      →  1b9s_A
  1BBD-L      →  1a3r_L
  1BC8-C      →  1bc7_C
  1BCO-A      →  1bcm_A
  1BCS-B      →  1bcr_B
  1BD3-D      →  1bd3_A
  1BD6-A      →  1bc6_A
  1BDD-A      →  1bdc_A
  1BFM-A      →  1a7w_A
  1BHB-A      →  1bha_A
  1BJT-A      →  1bgw_A
  1BLD-A      →  1bla_A
 




In [3]:
#!/usr/bin/env python3
"""
OpenProteinSet MSA-matcher  (ID-exact + identical-sequence rescue)

• Step 1: 1A0C-A → 1a0c_A  → look in s3://openfold/pdb/…   (parallel)
• Step 2: identical sequence search in wwPDB `pdb_seqres.txt`
          → for every “twin” chain try s3 in parallel again

Outputs:  openfold_results_updated.tsv   (DIRECT / RESCUED / FAILED)
"""

import os, sys, hashlib, requests, boto3
from collections import defaultdict
from botocore.config import Config
from botocore import UNSIGNED
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

# ─────────────────────────── CONFIG ────────────────────────── #
PROT_DIR     = "protein_data_pdb"
LOCAL_FASTA  = "nrPDB-GO_2019.06.18_sequences.fasta"
REMOTE_URL   = "https://files.wwpdb.org/pub/pdb/derived_data/pdb_seqres.txt"
BUCKET       = "openfold"
PREFIX_ROOT  = "pdb/"
REGION       = "us-east-1"
SAMPLE_SIZE  = None          # None → scan all local chains
MAX_WORKERS  = 20            # parallel S3 threads
# ───────────────────────────────────────────────────────────── #

def convert_id(pid_hyphen: str) -> str:
    """1A0C-A  →  1a0c_A  (what the bucket uses)"""
    pdb, chain = pid_hyphen.rsplit("-", 1)
    return f"{pdb.lower()}_{chain}"

def prefix_exists(client, cid: str) -> bool:
    resp = client.list_objects_v2(Bucket=BUCKET,
                                  Prefix=f"{PREFIX_ROOT}{cid}/",
                                  Delimiter="/",
                                  MaxKeys=1)
    return resp.get("KeyCount", 0) > 0

def check_protein_exists(pid: str, s3) -> tuple[str, str, bool]:
    cid = convert_id(pid)
    return pid, cid, prefix_exists(s3, cid)

def check_twin_exists(twin_cid: str, s3) -> tuple[str, bool]:
    return twin_cid, prefix_exists(s3, twin_cid)

# ───────── FASTA helpers ───────── #
def parse_fasta(lines):
    hdr, seq = None, []
    for ln in lines:
        ln = ln.strip()
        if ln.startswith(">"):
            if hdr:
                yield hdr, "".join(seq)
            hdr, seq = ln[1:], []
        else:
            seq.append(ln)
    if hdr:
        yield hdr, "".join(seq)

def load_local_fasta(path=LOCAL_FASTA):
    id2seq = {}
    with open(path, encoding="utf8") as fh:
        for hdr, seq in tqdm(parse_fasta(fh),
                             desc="Loading local FASTA"):
            pid = hdr.split()[0]      # header like  >154L-A nrPDB
            id2seq[pid.upper()] = seq.upper()
    return id2seq

md5 = lambda s: hashlib.md5(s.encode()).hexdigest()

# ───────── main ───────── #
def main():
    if not os.path.isdir(PROT_DIR):
        sys.exit(f"ERROR: directory '{PROT_DIR}' not found")

    # 0) local chain list
    local_ids = sorted(d for d in os.listdir(PROT_DIR)
                       if os.path.isdir(os.path.join(PROT_DIR, d)))
    if SAMPLE_SIZE:
        local_ids = local_ids[:SAMPLE_SIZE]
    print(f"Scanning {len(local_ids)} local chains\n")

    s3 = boto3.client("s3", region_name=REGION,
                      config=Config(signature_version=UNSIGNED))

    hits, misses = [], []        # (pid-hyphen, cid)

    # 1) exact ID look-up (parallel)
    print("Step 1  ─ exact-ID look-up")
    with ThreadPoolExecutor(MAX_WORKERS) as exe:
        fut = {exe.submit(check_protein_exists, pid, s3): pid
               for pid in local_ids}
        for f in tqdm(as_completed(fut), total=len(fut)):
            pid, cid, ok = f.result()
            (hits if ok else misses).append((pid, cid))
    print(f"  found   : {len(hits)}")
    print(f"  missing : {len(misses)}\n")
    if not misses:
        return

    # 2) sequence hashes of what we already have
    local_fasta = load_local_fasta()
    seqhash2cid = {}
    for pid, cid in hits:
        seq = local_fasta.get(pid.upper())
        if seq:
            seqhash2cid.setdefault(md5(seq), cid)

    # 3) gather sequences of missing set
    miss_seq, needed = {}, set()
    for pid, _ in misses:
        seq = local_fasta.get(pid.upper())
        if seq:
            miss_seq[pid] = seq
            needed.add(md5(seq))
    if not needed:
        print("No sequences for missing set in local FASTA")
        return

    # 4) stream wwPDB catalogue once
    print("Step 2  ─ streaming wwPDB pdb_seqres.txt")
    remote_seq = defaultdict(list)   # hash → [cid, …]

    with requests.get(REMOTE_URL, stream=True, timeout=60) as r:
        r.raise_for_status()
        for hdr, seq in tqdm(parse_fasta(r.iter_lines(decode_unicode=True)),
                             desc="Scanning remote catalogue"):
            h = md5(seq.upper())
            if h in needed:
                rid = hdr.split()[0]             # e.g. 1B3Z_A
                # unify: first 4 chars lower-case to match bucket
                rid = f"{rid[:4].lower()}{rid[4:]}"
                remote_seq[h].append(rid)

    # 5) rescue attempt
    rescued, still = {}, []
    fast, todo = [], []
    for pid, cid in misses:
        seq = miss_seq.get(pid)
        if not seq:
            still.append(pid); continue
        h = md5(seq)
        if h in seqhash2cid:                 # already have same seq
            rescued[pid] = seqhash2cid[h]
            fast.append(pid)
        else:
            todo.append((pid, cid, h))

    if todo:
        # flatten twin list
        twins = {}
        for pid, cid, h in todo:
            for twin in remote_seq.get(h, []):
                if twin.lower() != cid.lower():
                    twins.setdefault(twin, []).append((pid, h))
        print(f"Step 2b ─ checking {len(twins)} twin IDs in S3 (parallel)")
        twin_ok = {}
        with ThreadPoolExecutor(MAX_WORKERS) as exe:
            fut = {exe.submit(check_twin_exists, t, s3): t for t in twins}
            for f in tqdm(as_completed(fut), total=len(fut)):
                tid, ok = f.result()
                twin_ok[tid] = ok
        # assign
        for pid, cid, h in todo:
            for twin in remote_seq.get(h, []):
                if twin_ok.get(twin):
                    rescued[pid] = twin
                    seqhash2cid[h] = twin
                    break
            else:
                still.append(pid)

    # 6) report
    print("\n────────── SUMMARY ──────────")
    print(f"exact ID hits      : {len(hits)}")
    print(f"rescued by sequence: {len(rescued)}")
    print(f"still missing      : {len(still)}")
    print("─────────────────────────────\n")

    # 7) TSV
    out = "openfold_results_updated.tsv"
    with open(out, "w") as fh:
        fh.write("status\toriginal_id\tmatched_id\tconverted_id\tnotes\n")
        for pid, cid in sorted(hits):
            fh.write(f"DIRECT\t{pid}\t{pid}\t{cid}\tExact ID match\n")
        for pid, twin in sorted(rescued.items()):
            fh.write(f"RESCUED\t{pid}\t{twin}\t{twin}\tSequence-identical chain\n")
        for pid in sorted(still):
            fh.write(f"FAILED\t{pid}\t\t{convert_id(pid)}\tNo identical sequence\n")
    print(f"Results written → {out}")

# ───────────────────────────────────────────────────────────── #
if __name__ == "__main__":
    main()


Scanning 29740 local chains

Step 1  ─ exact-ID look-up


100%|██████████| 29740/29740 [03:55<00:00, 126.28it/s]


  found   : 22913
  missing : 6827



Loading local FASTA: 36641it [00:00, 168029.52it/s]


Step 2  ─ streaming wwPDB pdb_seqres.txt


Scanning remote catalogue: 1002995it [00:23, 41987.59it/s]


Step 2b ─ checking 110773 twin IDs in S3 (parallel)


100%|██████████| 110773/110773 [13:41<00:00, 134.83it/s]



────────── SUMMARY ──────────
exact ID hits      : 22913
rescued by sequence: 6802
still missing      : 25
─────────────────────────────

Results written → openfold_results_updated.tsv


In [10]:
#!/usr/bin/env python3
"""
Search for amino acid sequence using RCSB PDB API
"""

import requests
import json

def search_sequence_rcsb(sequence):
    """Search sequence using RCSB PDB sequence search API."""
    
    # Clean sequence
    clean_seq = sequence.replace(" ", "").replace("\n", "").upper()
    
    # RCSB PDB sequence search endpoint
    url = "https://search.rcsb.org/rcsbsearch/v2/query"
    
    # Query for exact sequence match
    query = {
        "query": {
            "type": "terminal",
            "service": "sequence",
            "parameters": {
                "evalue_cutoff": 1e-10,
                "identity_cutoff": 1.0,  # 100% identity
                "sequence_type": "protein",
                "value": clean_seq
            }
        },
        "return_type": "entry"
    }
    
    headers = {'Content-Type': 'application/json'}
    
    try:
        response = requests.post(url, data=json.dumps(query), headers=headers, timeout=30)
        response.raise_for_status()
        
        results = response.json()
        
        if results.get("result_set"):
            print("✓ FOUND matches!")
            for i, result in enumerate(results["result_set"][:5], 1):  # Show first 5
                pdb_id = result["identifier"]
                print(f"{i}. PDB ID: {pdb_id}")
                
                # Get more details
                details_url = f"https://data.rcsb.org/rest/v1/core/entry/{pdb_id}"
                try:
                    details = requests.get(details_url, timeout=10).json()
                    title = details.get("struct", {}).get("title", "No title")
                    print(f"   Title: {title}")
                except:
                    print("   Title: Could not fetch")
                print()
            
            return results["result_set"]
        else:
            print("✗ No exact matches found")
            return []
            
    except Exception as e:
        print(f"Error: {e}")
        return []

# Your sequence
target_sequence = """FTGVQGRVIGYDILRSPEVDKAKPLFTETQWDGSELPIYDAKPLQDALVEYFGTEQDRRH
YPAPGSFIVCANKGVTAERPKNDADMKPGQGYGVWSAIAISFAKDPTKDSSMFVEDAGVW
ETPNEDELLEYLEGRRKAMAKSIAECGQDAHASFESSWIGFAYTMMEPGQIGNAITVAPY
VSLPIDSIPGGSILTPDKDMEIMENLTMPEWLEKMGYKSLSANNALKY"""

# Search
results = search_sequence_rcsb(target_sequence)

✓ FOUND matches!
1. PDB ID: 1HQ6
   Title: STRUCTURE OF PYRUVOYL-DEPENDENT HISTIDINE DECARBOXYLASE AT PH 8

2. PDB ID: 1IBT
   Title: STRUCTURE OF THE D53,54N MUTANT OF HISTIDINE DECARBOXYLASE AT-170 C

3. PDB ID: 1IBU
   Title: STRUCTURE OF THE D53,54N MUTANT OF HISTIDINE DECARBOXYLASE AT 25 C

4. PDB ID: 1IBV
   Title: STRUCTURE OF THE D53,54N MUTANT OF HISTIDINE DECARBOXYLASE BOUND WITH HISTIDINE METHYL ESTER AT-170 C

5. PDB ID: 1IBW
   Title: STRUCTURE OF THE D53,54N MUTANT OF HISTIDINE DECARBOXYLASE BOUND WITH HISTIDINE METHYL ESTER AT 25 C



In [4]:
#!/usr/bin/env python3
"""
Rescue FAILED proteins from openfold_results_updated.tsv by:
1. Getting their amino acid sequences from local FASTA
2. Searching RCSB PDB API for identical sequences
3. Checking if alternative IDs exist in OpenFold S3 bucket
4. Updating the TSV with successful matches
"""

import csv
import requests
import json
import boto3
import time
from botocore import UNSIGNED
from botocore.config import Config
from collections import defaultdict

# File paths
TSV_FILE = r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\openfold_results_updated.tsv"
FASTA_FILE = r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\nrPDB-GO_2019.06.18_sequences.fasta"
OUTPUT_TSV = r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\openfold_results_final.tsv"

# S3 configuration
BUCKET = "openfold"
REGION = "us-east-1"

def load_fasta_sequences(fasta_path):
    """Load all sequences from FASTA file into a dictionary."""
    print("Loading FASTA sequences...")
    sequences = {}
    
    with open(fasta_path, 'r', encoding='utf-8') as f:
        current_id = None
        current_seq = []
        
        for line in f:
            line = line.strip()
            if line.startswith('>'):
                # Save previous sequence
                if current_id:
                    sequences[current_id] = ''.join(current_seq)
                
                # Start new sequence
                current_id = line[1:].split()[0]  # Get ID (first part after >)
                current_seq = []
            else:
                current_seq.append(line)
        
        # Save last sequence
        if current_id:
            sequences[current_id] = ''.join(current_seq)
    
    print(f"Loaded {len(sequences)} sequences from FASTA")
    return sequences

def search_rcsb_by_sequence(sequence, exclude_id=None):
    """Search RCSB PDB for proteins with identical sequence."""
    print(f"  Searching RCSB for sequence (length: {len(sequence)})...")
    
    url = "https://search.rcsb.org/rcsbsearch/v2/query"
    
    query = {
        "query": {
            "type": "terminal",
            "service": "sequence",
            "parameters": {
                "evalue_cutoff": 1e-10,
                "identity_cutoff": 1.0,  # 100% identity for exact matches
                "sequence_type": "protein", 
                "value": sequence
            }
        },
        "return_type": "entry"
    }
    
    headers = {'Content-Type': 'application/json'}
    
    try:
        response = requests.post(url, data=json.dumps(query), headers=headers, timeout=60)
        response.raise_for_status()
        
        results = response.json()
        
        if results.get("result_set"):
            found_ids = []
            for result in results["result_set"]:
                pdb_id = result["identifier"]
                
                # Exclude the original failed ID
                if exclude_id and pdb_id.upper() == exclude_id.upper()[:4]:
                    continue
                    
                found_ids.append(pdb_id)
            
            print(f"    Found {len(found_ids)} alternative PDB entries")
            return found_ids
        else:
            print("    No matches found in RCSB")
            return []
            
    except Exception as e:
        print(f"    Error searching RCSB: {e}")
        return []

def convert_pdb_to_s3_format(pdb_id, chain):
    """Convert PDB ID format to S3 bucket format."""
    # PDB format: 1ABC, chain: A
    # S3 format: 1abc_A
    return f"{pdb_id.lower()}_{chain}"

def check_s3_bucket_exists(s3_client, s3_id):
    """Check if a protein ID exists in the OpenFold S3 bucket."""
    prefix = f"pdb/{s3_id}/"
    
    try:
        response = s3_client.list_objects_v2(
            Bucket=BUCKET,
            Prefix=prefix,
            MaxKeys=1
        )
        
        return response.get("KeyCount", 0) > 0
        
    except Exception as e:
        print(f"    Error checking S3 for {s3_id}: {e}")
        return False

def get_pdb_chains(pdb_id):
    """Get all chains for a PDB ID from RCSB."""
    print(f"    Getting chains for {pdb_id}...")
    
    url = f"https://data.rcsb.org/rest/v1/core/entry/{pdb_id}"
    
    try:
        response = requests.get(url, timeout=30)
        response.raise_for_status()
        
        # For now, try common chains - getting exact chains requires more complex API calls
        common_chains = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
        return common_chains
        
    except Exception as e:
        print(f"    Error getting chains for {pdb_id}: {e}")
        return ['A', 'B']  # Default fallback

def rescue_failed_proteins():
    """Main function to rescue failed proteins."""
    
    # Initialize S3 client
    s3_client = boto3.client("s3", region_name=REGION,
                            config=Config(signature_version=UNSIGNED))
    
    # Load FASTA sequences
    sequences = load_fasta_sequences(FASTA_FILE)
    
    # Read TSV file
    print("\nReading TSV file...")
    rows = []
    failed_proteins = []
    
    with open(TSV_FILE, 'r', newline='', encoding='utf-8') as f:
        reader = csv.DictReader(f, delimiter='\t')
        for row in reader:
            rows.append(row)
            if row['status'] == 'FAILED':
                failed_proteins.append(row)
    
    print(f"Found {len(failed_proteins)} FAILED proteins to rescue")
    
    rescued_count = 0
    
    # Process each failed protein
    for i, failed_row in enumerate(failed_proteins, 1):
        original_id = failed_row['original_id']
        print(f"\n[{i}/{len(failed_proteins)}] Processing: {original_id}")
        
        # Get sequence from FASTA
        sequence = sequences.get(original_id)
        if not sequence:
            print(f"  ❌ Sequence not found in FASTA for {original_id}")
            continue
        
        # Search RCSB for alternative IDs
        alternative_pdbs = search_rcsb_by_sequence(sequence, exclude_id=original_id)
        
        if not alternative_pdbs:
            print(f"  ❌ No alternative PDB IDs found for {original_id}")
            continue
        
        # Check each alternative PDB in S3 bucket
        found_match = False
        for pdb_id in alternative_pdbs[:5]:  # Check first 5 alternatives
            print(f"  Checking PDB: {pdb_id}")
            
            # Get chains for this PDB
            chains = get_pdb_chains(pdb_id)
            
            # Check each chain
            for chain in chains:
                s3_id = convert_pdb_to_s3_format(pdb_id, chain)
                
                if check_s3_bucket_exists(s3_client, s3_id):
                    print(f"  ✅ FOUND MATCH: {s3_id}")
                    
                    # Update the row
                    for row in rows:
                        if row['original_id'] == original_id and row['status'] == 'FAILED':
                            row['status'] = 'RESCUED'
                            row['matched_id'] = s3_id
                            row['converted_id'] = s3_id
                            row['notes'] = f'Found via RCSB PDB sequence search: {pdb_id} chain {chain}'
                            break
                    
                    rescued_count += 1
                    found_match = True
                    break
            
            if found_match:
                break
            
            # Small delay to be nice to APIs
            time.sleep(0.5)
        
        if not found_match:
            print(f"  ❌ No S3 matches found for {original_id}")
    
    # Write updated TSV
    print(f"\nWriting updated results to {OUTPUT_TSV}")
    with open(OUTPUT_TSV, 'w', newline='', encoding='utf-8') as f:
        fieldnames = ['status', 'original_id', 'matched_id', 'converted_id', 'notes']
        writer = csv.DictWriter(f, fieldnames=fieldnames, delimiter='\t')
        writer.writeheader()
        writer.writerows(rows)
    
    print(f"\n🎉 RESCUE COMPLETE!")
    print(f"Rescued: {rescued_count} proteins")
    print(f"Still failed: {len(failed_proteins) - rescued_count} proteins")
    print(f"Updated results saved to: {OUTPUT_TSV}")

def summarize_results():
    """Print a summary of the final results."""
    print(f"\n" + "="*60)
    print("FINAL SUMMARY")
    print("="*60)
    
    status_counts = defaultdict(int)
    
    try:
        with open(OUTPUT_TSV, 'r', newline='', encoding='utf-8') as f:
            reader = csv.DictReader(f, delimiter='\t')
            for row in reader:
                status_counts[row['status']] += 1
        
        total = sum(status_counts.values())
        print(f"Total proteins: {total}")
        for status, count in sorted(status_counts.items()):
            percentage = (count / total) * 100 if total > 0 else 0
            print(f"{status:12}: {count:4d} ({percentage:5.1f}%)")
            
    except FileNotFoundError:
        print("Output file not found. Run the rescue process first.")

if __name__ == "__main__":
    print("="*60)
    print("PROTEIN RESCUE VIA RCSB PDB SEQUENCE SEARCH")
    print("="*60)
    
    try:
        rescue_failed_proteins()
        summarize_results()
        
    except KeyboardInterrupt:
        print("\n\nProcess interrupted by user")
    except Exception as e:
        print(f"\nError: {e}")
        import traceback
        traceback.print_exc()

PROTEIN RESCUE VIA RCSB PDB SEQUENCE SEARCH
Loading FASTA sequences...
Loaded 36641 sequences from FASTA

Reading TSV file...
Found 25 FAILED proteins to rescue

[1/25] Processing: 1BVS-A
  Searching RCSB for sequence (length: 203)...
    Error searching RCSB: Expecting value: line 1 column 1 (char 0)
  ❌ No alternative PDB IDs found for 1BVS-A

[2/25] Processing: 1IBW-B
  Searching RCSB for sequence (length: 228)...
    Found 5 alternative PDB entries
  Checking PDB: 1HQ6
    Getting chains for 1HQ6...
  ✅ FOUND MATCH: 1hq6_A

[3/25] Processing: 1PSP-A
  Searching RCSB for sequence (length: 106)...
    Found 2 alternative PDB entries
  Checking PDB: 1POS
    Getting chains for 1POS...
  Checking PDB: 2PSP
    Getting chains for 2PSP...
  ❌ No S3 matches found for 1PSP-A

[4/25] Processing: 4ROS-A
  Searching RCSB for sequence (length: 320)...
    Found 2 alternative PDB entries
  Checking PDB: 5UJK
    Getting chains for 5UJK...
  Checking PDB: 5ULV
    Getting chains for 5ULV...
  ❌ 

In [6]:
#!/usr/bin/env python3
"""
Print the first 5000 protein IDs from the OpenFold S3 bucket
"""

import boto3
from botocore import UNSIGNED
from botocore.config import Config

def get_first_5k_protein_ids():
    """Get and print the first 5000 protein IDs from OpenFold S3 bucket."""
    
    # Configure S3 client for anonymous access
    s3 = boto3.client("s3",
                      region_name="us-east-1",
                      config=Config(signature_version=UNSIGNED))
    
    bucket = "openfold"
    prefix = "pdb/"
    target_count = 20000
    
    print(f"Fetching first {target_count:,} protein IDs from s3://{bucket}/{prefix}")
    print("="*60)
    
    protein_ids = []
    continuation_token = None
    
    try:
        while len(protein_ids) < target_count:
            # Prepare list_objects_v2 parameters
            params = {
                'Bucket': bucket,
                'Prefix': prefix,
                'Delimiter': '/',
                'MaxKeys': 1000  # Max per request
            }
            
            if continuation_token:
                params['ContinuationToken'] = continuation_token
            
            # Make the request
            response = s3.list_objects_v2(**params)
            
            # Extract protein IDs from CommonPrefixes
            if 'CommonPrefixes' in response:
                for prefix_obj in response['CommonPrefixes']:
                    if len(protein_ids) >= target_count:
                        break
                    
                    # Extract protein ID (remove "pdb/" prefix and trailing "/")
                    protein_id = prefix_obj['Prefix'].replace(prefix, '').rstrip('/')
                    protein_ids.append(protein_id)
                
                # Print progress every 1000 entries
                if len(protein_ids) % 1000 == 0:
                    print(f"Progress: {len(protein_ids):,} protein IDs collected...")
            
            # Check if there are more results
            if response.get('IsTruncated', False) and len(protein_ids) < target_count:
                continuation_token = response.get('NextContinuationToken')
            else:
                break
        
        # Print all collected protein IDs
        print(f"\nFirst {len(protein_ids):,} protein IDs in OpenFold S3 bucket:")
        print("-" * 60)
        
        for i, protein_id in enumerate(protein_ids, 1):
            print(f"{i:5d}. {protein_id}")
        
        print("-" * 60)
        print(f"Total: {len(protein_ids):,} protein IDs")
        
        # Some basic statistics
        print(f"\nBASIC STATISTICS:")
        print(f"Total proteins: {len(protein_ids):,}")
        
        # Count by chain
        chain_counts = {}
        pdb_codes = set()
        
        for pid in protein_ids:
            if '_' in pid:
                pdb_code, chain = pid.rsplit('_', 1)
                pdb_codes.add(pdb_code)
                chain_counts[chain] = chain_counts.get(chain, 0) + 1
        
        print(f"Unique PDB codes: {len(pdb_codes):,}")
        print(f"Most common chains: {sorted(chain_counts.items(), key=lambda x: x[1], reverse=True)[:5]}")
        
        return protein_ids
        
    except Exception as e:
        print(f"Error: {e}")
        return []

def save_to_file(protein_ids, filename="openfold_first_5k_proteins.txt"):
    """Save the protein IDs to a text file."""
    if protein_ids:
        with open(filename, 'w') as f:
            f.write(f"First {len(protein_ids):,} protein IDs from OpenFold S3 bucket\n")
            f.write("="*60 + "\n\n")
            
            for i, protein_id in enumerate(protein_ids, 1):
                f.write(f"{i:5d}. {protein_id}\n")
        
        print(f"\nProtein IDs saved to: {filename}")

if __name__ == "__main__":
    print("OPENFOLD S3 BUCKET - FIRST 5K PROTEIN IDs")
    print("="*60)
    
    # Get the protein IDs
    protein_ids = get_first_5k_protein_ids()
    

OPENFOLD S3 BUCKET - FIRST 5K PROTEIN IDs
Fetching first 20,000 protein IDs from s3://openfold/pdb/
Progress: 1,000 protein IDs collected...
Progress: 2,000 protein IDs collected...
Progress: 3,000 protein IDs collected...
Progress: 4,000 protein IDs collected...
Progress: 5,000 protein IDs collected...
Progress: 6,000 protein IDs collected...
Progress: 7,000 protein IDs collected...
Progress: 8,000 protein IDs collected...
Progress: 9,000 protein IDs collected...
Progress: 10,000 protein IDs collected...
Progress: 11,000 protein IDs collected...
Progress: 12,000 protein IDs collected...
Progress: 13,000 protein IDs collected...
Progress: 14,000 protein IDs collected...
Progress: 15,000 protein IDs collected...
Progress: 16,000 protein IDs collected...
Progress: 17,000 protein IDs collected...
Progress: 18,000 protein IDs collected...
Progress: 19,000 protein IDs collected...
Progress: 20,000 protein IDs collected...

First 20,000 protein IDs in OpenFold S3 bucket:
--------------------

In [10]:
#!/usr/bin/env python3
"""
Quick peek into the public OpenFold bucket:

• print the first 10 UniClust30 cluster IDs
• probe for one specific UniParc record  (UPI00001735AD)
"""

import boto3
from botocore import UNSIGNED
from botocore.config import Config

BUCKET  = "openfold"
REGION  = "us-east-1"
CLUST_PREFIX = "uniclust30/"      # directory that holds the filtered UniRef/UniClust files
UNIPARC_ID   = "A0A009J662"    # the record we want to probe for

# ───────── S3 client (anonymous, read-only) ────────────────────────────────
s3 = boto3.client("s3",
                  region_name=REGION,
                  config=Config(signature_version=UNSIGNED))

# ───────── 1. first 10 cluster dirs ────────────────────────────────────────
print("First 10 cluster-level directories under", CLUST_PREFIX)
paginator = s3.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=BUCKET, Prefix=CLUST_PREFIX, Delimiter="/")

clusters = []
for page in pages:
    for p in page.get("CommonPrefixes", []):
        # "uniclust30/A0A009EX06/"  →  "A0A009EX06"
        clusters.append(p["Prefix"].split("/")[-2])
    if len(clusters) >= 10:
        break

for i, cid in enumerate(clusters[:10], 1):
    print(f"{i:2d}. {cid}")

# ───────── 2. probe for the UniParc record ────────────────────────────────
upi_prefix = f"uniparc/{UNIPARC_ID}/"      # path used in OpenFold
resp = s3.list_objects_v2(Bucket=BUCKET, Prefix=upi_prefix, MaxKeys=1)

if resp.get("KeyCount", 0):
    found_key = resp["Contents"][0]["Key"]
    print(f"\n✅  {UNIPARC_ID} found  (example object: {found_key})")
else:
    print(f"\n✗  {UNIPARC_ID} not present in bucket")


First 10 cluster-level directories under uniclust30/
 1. A0A009EX06
 2. A0A009EY75
 3. A0A009FAV8
 4. A0A009GC30
 5. A0A009GC83
 6. A0A009H2D3
 7. A0A009H847
 8. A0A009HRS4
 9. A0A009IPF9
10. A0A009J662

✗  A0A009J662 not present in bucket


In [13]:
#!/usr/bin/env python3
"""
Download a3m files for UniProt ID A0A009EX06 and extract original amino acid sequence
"""

import boto3
import os
import re
from botocore import UNSIGNED
from botocore.config import Config

def find_uniprot_in_uniclust(uniprot_id="A0A009EX06"):
    """Find UniProt ID variants in the uniclust30 directory."""
    
    s3 = boto3.client("s3",
                      region_name="us-east-1",
                      config=Config(signature_version=UNSIGNED))
    
    bucket = "openfold"
    search_prefix = f"uniclust30/{uniprot_id}"
    
    print(f"Searching for {uniprot_id} variants in s3://{bucket}/uniclust30/")
    
    try:
        # Search for entries starting with the UniProt ID
        response = s3.list_objects_v2(
            Bucket=bucket,
            Prefix=search_prefix,
            Delimiter="/"
        )
        
        variants = []
        if 'CommonPrefixes' in response:
            for prefix_obj in response['CommonPrefixes']:
                variant = prefix_obj['Prefix'].replace("uniclust30/", "").rstrip("/")
                variants.append(variant)
                print(f"  Found variant: {variant}")
        
        if not variants:
            print(f"  ❌ No variants found for {uniprot_id}")
        
        return variants
        
    except Exception as e:
        print(f"Error searching for {uniprot_id}: {e}")
        return []

def download_a3m_files(uniprot_variant, download_dir="a3m_downloads"):
    """Download a3m files for a specific UniProt variant."""
    
    s3 = boto3.client("s3",
                      region_name="us-east-1",
                      config=Config(signature_version=UNSIGNED))
    
    bucket = "openfold"
    base_prefix = f"uniclust30/{uniprot_variant}/"
    
    # Common a3m file paths in OpenFold
    a3m_files = [
        "a3m/uniclust30.a3m",
        "a3m/bfd_uniclust_hits.a3m", 
        "a3m/mgnify_hits.a3m",
        "a3m/uniref90_hits.a3m"
    ]
    
    # Create download directory
    os.makedirs(download_dir, exist_ok=True)
    variant_dir = os.path.join(download_dir, uniprot_variant)
    os.makedirs(variant_dir, exist_ok=True)
    
    downloaded_files = []
    
    print(f"\nDownloading a3m files for {uniprot_variant}...")
    
    for a3m_file in a3m_files:
        s3_key = base_prefix + a3m_file
        local_path = os.path.join(variant_dir, os.path.basename(a3m_file))
        
        try:
            print(f"  Downloading: {a3m_file}")
            s3.download_file(bucket, s3_key, local_path)
            
            # Check if file has content
            if os.path.getsize(local_path) > 0:
                downloaded_files.append(local_path)
                print(f"    ✅ Downloaded: {local_path} ({os.path.getsize(local_path):,} bytes)")
            else:
                print(f"    ❌ Empty file: {local_path}")
                os.remove(local_path)
                
        except Exception as e:
            print(f"    ❌ Failed to download {a3m_file}: {e}")
    
    return downloaded_files

def extract_original_sequence(a3m_file_path):
    """Extract the original amino acid sequence from a3m file (first sequence)."""
    
    print(f"\nExtracting original sequence from: {os.path.basename(a3m_file_path)}")
    
    try:
        with open(a3m_file_path, 'r') as f:
            lines = f.readlines()
        
        # Find the first sequence (should be the query/original sequence)
        original_sequence = ""
        in_first_sequence = False
        sequence_count = 0
        
        for i, line in enumerate(lines):
            line = line.strip()
            
            if line.startswith('>'):
                sequence_count += 1
                if sequence_count == 1:
                    print(f"  First header: {line}")
                    in_first_sequence = True
                else:
                    break  # We've finished the first sequence
            elif in_first_sequence and line:
                # Remove gaps and lowercase letters (insertions)
                clean_line = re.sub(r'[a-z\-\.]', '', line)
                original_sequence += clean_line
        
        if original_sequence:
            print(f"  ✅ Original sequence found (length: {len(original_sequence)})")
            print(f"  First 60 chars: {original_sequence[:60]}...")
            if len(original_sequence) > 60:
                print(f"  Last 60 chars:  ...{original_sequence[-60:]}")
            
            return original_sequence
        else:
            print(f"  ❌ No sequence found in {a3m_file_path}")
            return None
            
    except Exception as e:
        print(f"  ❌ Error reading {a3m_file_path}: {e}")
        return None

def process_uniprot_id(uniprot_id="A0A009EX06"):
    """Complete process: find, download, and extract sequence for UniProt ID."""
    
    print("="*70)
    print(f"PROCESSING UNIPROT ID: {uniprot_id}")
    print("="*70)
    
    # Step 1: Find variants
    variants = find_uniprot_in_uniclust(uniprot_id)
    
    if not variants:
        print(f"❌ No variants found for {uniprot_id}")
        return
    
    # Step 2: Process each variant
    all_sequences = {}
    
    for variant in variants:
        print(f"\n{'-'*50}")
        print(f"Processing variant: {variant}")
        print(f"{'-'*50}")
        
        # Download a3m files
        downloaded_files = download_a3m_files(variant)
        
        if not downloaded_files:
            print(f"❌ No a3m files downloaded for {variant}")
            continue
        
        # Extract sequences from each a3m file
        variant_sequences = {}
        for a3m_file in downloaded_files:
            sequence = extract_original_sequence(a3m_file)
            if sequence:
                file_type = os.path.basename(a3m_file).replace('.a3m', '')
                variant_sequences[file_type] = sequence
        
        all_sequences[variant] = variant_sequences
    
    # Step 3: Summary
    print("\n" + "="*70)
    print("SUMMARY")
    print("="*70)
    
    for variant, sequences in all_sequences.items():
        print(f"\n{variant}:")
        if sequences:
            for file_type, sequence in sequences.items():
                print(f"  {file_type}: {len(sequence)} amino acids")
                
            # Show the sequence from the main uniclust30 file if available
            if 'uniclust30' in sequences:
                print(f"\n  Main sequence ({variant}):")
                print(f"  {sequences['uniclust30']}")
        else:
            print("  No sequences extracted")
    
    return all_sequences

def save_sequences_to_fasta(all_sequences, output_file="A0A009EX06_sequences.fasta"):
    """Save extracted sequences to a FASTA file."""
    
    if not all_sequences:
        print("No sequences to save")
        return
    
    print(f"\nSaving sequences to: {output_file}")
    
    with open(output_file, 'w') as f:
        for variant, sequences in all_sequences.items():
            for file_type, sequence in sequences.items():
                header = f">{variant}_{file_type}"
                f.write(f"{header}\n")
                
                # Write sequence in 80-character lines
                for i in range(0, len(sequence), 80):
                    f.write(f"{sequence[i:i+80]}\n")
                
                f.write("\n")
    
    print(f"✅ Sequences saved to: {output_file}")

if __name__ == "__main__":
    # Process the UniProt ID
    uniprot_id = "A0A009EX06"
    sequences = process_uniprot_id(uniprot_id)
    
    # Save to FASTA file
    if sequences:
        save_sequences_to_fasta(sequences)
    
    print(f"\n🎉 Processing complete for {uniprot_id}!")

PROCESSING UNIPROT ID: A0A009EX06
Searching for A0A009EX06 variants in s3://openfold/uniclust30/
  Found variant: A0A009EX06

--------------------------------------------------
Processing variant: A0A009EX06
--------------------------------------------------

Downloading a3m files for A0A009EX06...
  Downloading: a3m/uniclust30.a3m
    ✅ Downloaded: a3m_downloads\A0A009EX06\uniclust30.a3m (1,001,304 bytes)
  Downloading: a3m/bfd_uniclust_hits.a3m
    ❌ Failed to download a3m/bfd_uniclust_hits.a3m: An error occurred (404) when calling the HeadObject operation: Not Found
  Downloading: a3m/mgnify_hits.a3m
    ❌ Failed to download a3m/mgnify_hits.a3m: An error occurred (404) when calling the HeadObject operation: Not Found
  Downloading: a3m/uniref90_hits.a3m
    ❌ Failed to download a3m/uniref90_hits.a3m: An error occurred (404) when calling the HeadObject operation: Not Found

Extracting original sequence from: uniclust30.a3m
  First header: >tr|A0A009EX06|A0A009EX06_ACIBA Uncharacteriz

In [14]:
#!/usr/bin/env python3
"""
Download a3m files for PDB ID 1eyy_A and extract original amino acid sequence
"""

import boto3
import os
import re
from botocore import UNSIGNED
from botocore.config import Config

def check_pdb_exists(pdb_id="1eyy_A"):
    """Check if PDB ID exists in OpenFold bucket and show available files."""
    
    s3 = boto3.client("s3",
                      region_name="us-east-1",
                      config=Config(signature_version=UNSIGNED))
    
    bucket = "openfold"
    pdb_prefix = f"pdb/{pdb_id}/"
    
    print(f"Checking PDB ID: {pdb_id}")
    print(f"Looking in: s3://{bucket}/{pdb_prefix}")
    
    try:
        # List all contents under this PDB ID
        response = s3.list_objects_v2(
            Bucket=bucket,
            Prefix=pdb_prefix,
            MaxKeys=100
        )
        
        if response.get('Contents'):
            print(f"✅ {pdb_id} exists! Available files:")
            
            # Organize files by type
            files_by_type = {}
            for obj in response['Contents']:
                file_path = obj['Key'].replace(pdb_prefix, '')
                file_size = obj['Size']
                
                if '/' in file_path:
                    file_type = file_path.split('/')[0]
                    files_by_type.setdefault(file_type, []).append((file_path, file_size))
                else:
                    files_by_type.setdefault('root', []).append((file_path, file_size))
            
            # Print organized file listing
            for file_type, files in sorted(files_by_type.items()):
                print(f"\n  📁 {file_type}/")
                for file_path, file_size in files:
                    print(f"    - {file_path} ({file_size:,} bytes)")
            
            return True
        else:
            print(f"❌ {pdb_id} not found in OpenFold bucket")
            return False
            
    except Exception as e:
        print(f"Error checking {pdb_id}: {e}")
        return False

def download_pdb_a3m_files(pdb_id="1eyy_A", download_dir="pdb_a3m_downloads"):
    """Download a3m files for a specific PDB ID."""
    
    s3 = boto3.client("s3",
                      region_name="us-east-1",
                      config=Config(signature_version=UNSIGNED))
    
    bucket = "openfold"
    base_prefix = f"pdb/{pdb_id}/"
    
    # Common a3m file paths in OpenFold PDB structure
    a3m_files = [
        "a3m/bfd_uniclust_hits.a3m",
        "a3m/mgnify_hits.a3m", 
        "a3m/uniref90_hits.a3m"
    ]
    
    # Create download directory
    os.makedirs(download_dir, exist_ok=True)
    pdb_dir = os.path.join(download_dir, pdb_id)
    os.makedirs(pdb_dir, exist_ok=True)
    
    downloaded_files = []
    
    print(f"\nDownloading a3m files for {pdb_id}...")
    
    # First, let's see what a3m files actually exist
    try:
        response = s3.list_objects_v2(
            Bucket=bucket,
            Prefix=base_prefix + "a3m/",
            MaxKeys=50
        )
        
        actual_a3m_files = []
        if response.get('Contents'):
            for obj in response['Contents']:
                if obj['Key'].endswith('.a3m'):
                    a3m_path = obj['Key'].replace(base_prefix, '')
                    actual_a3m_files.append(a3m_path)
                    print(f"  Found: {a3m_path} ({obj['Size']:,} bytes)")
        
        if not actual_a3m_files:
            print(f"  ❌ No a3m files found for {pdb_id}")
            return []
        
        # Download the actual a3m files
        for a3m_file in actual_a3m_files:
            s3_key = base_prefix + a3m_file
            local_filename = a3m_file.replace('/', '_')  # flatten path
            local_path = os.path.join(pdb_dir, local_filename)
            
            try:
                print(f"  Downloading: {a3m_file}")
                s3.download_file(bucket, s3_key, local_path)
                
                # Check if file has content
                if os.path.getsize(local_path) > 0:
                    downloaded_files.append(local_path)
                    print(f"    ✅ Downloaded: {local_path} ({os.path.getsize(local_path):,} bytes)")
                else:
                    print(f"    ❌ Empty file: {local_path}")
                    os.remove(local_path)
                    
            except Exception as e:
                print(f"    ❌ Failed to download {a3m_file}: {e}")
        
    except Exception as e:
        print(f"Error listing a3m files: {e}")
    
    return downloaded_files

def extract_original_sequence_pdb(a3m_file_path):
    """Extract the original amino acid sequence from PDB a3m file (first sequence)."""
    
    print(f"\nExtracting original sequence from: {os.path.basename(a3m_file_path)}")
    
    try:
        with open(a3m_file_path, 'r') as f:
            lines = f.readlines()
        
        # Find the first sequence (should be the query/original sequence)
        original_sequence = ""
        in_first_sequence = False
        sequence_count = 0
        first_header = ""
        
        for i, line in enumerate(lines):
            line = line.strip()
            
            if line.startswith('>'):
                sequence_count += 1
                if sequence_count == 1:
                    first_header = line
                    print(f"  First header: {line}")
                    in_first_sequence = True
                else:
                    break  # We've finished the first sequence
            elif in_first_sequence and line:
                # Remove gaps and lowercase letters (insertions)
                # Keep only uppercase amino acids
                clean_line = re.sub(r'[a-z\-\.]', '', line)
                original_sequence += clean_line
        
        if original_sequence:
            print(f"  ✅ Original sequence found (length: {len(original_sequence)})")
            print(f"  First 60 chars: {original_sequence[:60]}...")
            if len(original_sequence) > 60:
                print(f"  Last 60 chars:  ...{original_sequence[-60:]}")
            
            return original_sequence, first_header
        else:
            print(f"  ❌ No sequence found in {a3m_file_path}")
            return None, None
            
    except Exception as e:
        print(f"  ❌ Error reading {a3m_file_path}: {e}")
        return None, None

def process_pdb_id(pdb_id="1eyy_A"):
    """Complete process: check, download, and extract sequence for PDB ID."""
    
    print("="*70)
    print(f"PROCESSING PDB ID: {pdb_id}")
    print("="*70)
    
    # Step 1: Check if PDB exists
    if not check_pdb_exists(pdb_id):
        return
    
    # Step 2: Download a3m files
    print(f"\n{'-'*50}")
    print(f"Downloading a3m files for: {pdb_id}")
    print(f"{'-'*50}")
    
    downloaded_files = download_pdb_a3m_files(pdb_id)
    
    if not downloaded_files:
        print(f"❌ No a3m files downloaded for {pdb_id}")
        return
    
    # Step 3: Extract sequences from each a3m file
    all_sequences = {}
    
    for a3m_file in downloaded_files:
        sequence, header = extract_original_sequence_pdb(a3m_file)
        if sequence:
            file_type = os.path.basename(a3m_file).replace('.a3m', '')
            all_sequences[file_type] = {
                'sequence': sequence,
                'header': header
            }
    
    # Step 4: Summary
    print("\n" + "="*70)
    print("SUMMARY")
    print("="*70)
    
    print(f"\n{pdb_id}:")
    if all_sequences:
        for file_type, data in all_sequences.items():
            sequence = data['sequence']
            header = data['header']
            print(f"\n  📄 {file_type}:")
            print(f"     Header: {header}")
            print(f"     Length: {len(sequence)} amino acids")
            print(f"     Sequence: {sequence}")
    else:
        print("  No sequences extracted")
    
    return all_sequences

def save_pdb_sequences_to_fasta(pdb_id, all_sequences, output_file=None):
    """Save extracted PDB sequences to a FASTA file."""
    
    if not all_sequences:
        print("No sequences to save")
        return
    
    if output_file is None:
        output_file = f"{pdb_id}_sequences.fasta"
    
    print(f"\nSaving sequences to: {output_file}")
    
    with open(output_file, 'w') as f:
        for file_type, data in all_sequences.items():
            sequence = data['sequence']
            header = data['header']
            
            # Clean header and add file type info
            clean_header = header.replace('>', '').strip()
            fasta_header = f">{pdb_id}_{file_type}|{clean_header}"
            
            f.write(f"{fasta_header}\n")
            
            # Write sequence in 80-character lines
            for i in range(0, len(sequence), 80):
                f.write(f"{sequence[i:i+80]}\n")
            
            f.write("\n")
    
    print(f"✅ Sequences saved to: {output_file}")

def download_other_files(pdb_id="1eyy_A", file_types=["hhr"]):
    """Download other file types like .hhr files."""
    
    s3 = boto3.client("s3",
                      region_name="us-east-1",
                      config=Config(signature_version=UNSIGNED))
    
    bucket = "openfold"
    base_prefix = f"pdb/{pdb_id}/"
    
    print(f"\nChecking for other file types: {file_types}")
    
    for file_type in file_types:
        try:
            response = s3.list_objects_v2(
                Bucket=bucket,
                Prefix=base_prefix + f"{file_type}/",
                MaxKeys=20
            )
            
            if response.get('Contents'):
                print(f"\n  📁 {file_type}/ files:")
                for obj in response['Contents'][:5]:  # Show first 5
                    file_path = obj['Key'].replace(base_prefix, '')
                    print(f"    - {file_path} ({obj['Size']:,} bytes)")
            else:
                print(f"  ❌ No {file_type} files found")
                
        except Exception as e:
            print(f"  Error checking {file_type}: {e}")

if __name__ == "__main__":
    # Process the PDB ID
    pdb_id = "1eyy_A"
    sequences = process_pdb_id(pdb_id)
    
    # Save to FASTA file
    if sequences:
        save_pdb_sequences_to_fasta(pdb_id, sequences)
    
    # Check for other file types
    download_other_files(pdb_id, ["hhr", "sto"])
    
    print(f"\n🎉 Processing complete for {pdb_id}!")

PROCESSING PDB ID: 1eyy_A
Checking PDB ID: 1eyy_A
Looking in: s3://openfold/pdb/1eyy_A/
✅ 1eyy_A exists! Available files:

  📁 a3m/
    - a3m/bfd_uniclust_hits.a3m (1,944,936 bytes)
    - a3m/mgnify_hits.a3m (2,921,551 bytes)
    - a3m/uniref90_hits.a3m (6,612,125 bytes)

  📁 hhr/
    - hhr/pdb70_hits.hhr (1,107,346 bytes)

--------------------------------------------------
Downloading a3m files for: 1eyy_A
--------------------------------------------------

Downloading a3m files for 1eyy_A...
  Found: a3m/bfd_uniclust_hits.a3m (1,944,936 bytes)
  Found: a3m/mgnify_hits.a3m (2,921,551 bytes)
  Found: a3m/uniref90_hits.a3m (6,612,125 bytes)
  Downloading: a3m/bfd_uniclust_hits.a3m
    ✅ Downloaded: pdb_a3m_downloads\1eyy_A\a3m_bfd_uniclust_hits.a3m (1,944,936 bytes)
  Downloading: a3m/mgnify_hits.a3m
    ✅ Downloaded: pdb_a3m_downloads\1eyy_A\a3m_mgnify_hits.a3m (2,921,551 bytes)
  Downloading: a3m/uniref90_hits.a3m
    ✅ Downloaded: pdb_a3m_downloads\1eyy_A\a3m_uniref90_hits.a3m (6,612

In [None]:
#!/usr/bin/env python3
"""
fetch_1dp5_A.py  – print the amino-acid sequence of PDB chain 1dp5_A

Order of battle
---------------
1. Anonymous S3 call to the OpenFold bucket:
   • look under  pdb_mmcif/1dp5_A/
   • grab the first *.a3m  file (there is exactly one per chain)
   • read the first two lines (header + sequence)

2. If the chain isn’t present in OpenFold yet, fall back to RCSB:
   • https://www.rcsb.org/fasta/entry/1dp5?chain=A
"""

import boto3, botocore, requests
from botocore import UNSIGNED
from botocore.config import Config

PDB_CHAIN   = "1dp5_A"
BUCKET      = "openfold"
PREFIX      = f"pdb_mmcif/{PDB_CHAIN}/"
S3          = boto3.client("s3", config=Config(signature_version=UNSIGNED))

def seq_from_openfold():
    """Return (header, seq) if 1dp5_A is present in the OpenFold bucket."""
    resp = S3.list_objects_v2(Bucket=BUCKET, Prefix=PREFIX)
    if resp.get("KeyCount", 0) == 0:
        return None                          # nothing under that prefix

    # pick the first alignment object that ends in .a3m
    a3m_key = next(o["Key"] for o in resp["Contents"] if o["Key"].endswith(".a3m"))
    body    = S3.get_object(Bucket=BUCKET, Key=a3m_key)["Body"].read().decode()
    header, seq = body.splitlines()[:2]      # only need the rep-sequence
    return header, seq

def seq_from_rcsb():
    """Fallback – fetch chain-specific FASTA from RCSB PDB."""
    url   = "https://www.rcsb.org/fasta/entry/1dp5?chain=A"
    text  = requests.get(url, timeout=30).text
    if text.startswith(">"):
        header, seq = text.splitlines()[:2]
        return header, seq
    return None

for grab in (seq_from_openfold, seq_from_rcsb):
    try:
        result = grab()
        if result:
            header, seq = result
            print(header)
            print(seq)
            break
    except (StopIteration, botocore.exceptions.ClientError):
        pass
else:
    raise RuntimeError(f"Could not retrieve the sequence for {PDB_CHAIN}")


>1DP5_1|Chain A|PROTEINASE A|Saccharomyces cerevisiae (4932)
GGHDVPLTNYLNAQYYTDITLGTPPQNFKVILDTGSSNLWVPSNECGSLACFLHSKYDHEASSSYKANGTEFAIQYGTGSLEGYISQDTLSIGDLTIPKQDFAEATSEPGLTFAFGKFDGILGLGYDTISVDKVVPPFYNAIQQDLLDEKRFAFYLGDTSKDTENGGEATFGGIDESKFKGDITWLPVRRKAYWEVKFEGIGLGDEYAELESHGAAIDTGTSLITLPSGLAEMINAEIGAKKGWTGQYTLDCNTRDNLPDLIFNFNGYNFTIGPYDYTLEVSGSCISAITPMDFPEPVGPLAIVGDAFLRKYYSIYDLGNNAVGLAKAI


In [12]:
#!/usr/bin/env python3
"""
Extract amino acid sequence from OpenFold PDB dataset for a specific protein ID
"""
import boto3
from botocore import UNSIGNED
from botocore.config import Config
import io
import gzip

BUCKET = "openfold"
REGION = "us-east-1"

def parse_a3m(a3m_content):
    """
    Parse A3M format file and extract sequences.
    Returns a list of tuples (header, sequence)
    """
    sequences = []
    current_header = None
    current_seq = []
    
    lines = a3m_content.strip().split('\n')
    
    for line in lines:
        line = line.strip()
        if line.startswith('>'):
            # Save previous sequence if exists
            if current_header is not None:
                sequences.append((current_header, ''.join(current_seq)))
            # Start new sequence
            current_header = line[1:]  # Remove '>' character
            current_seq = []
        else:
            # Add to current sequence (remove gaps and lowercase for query)
            current_seq.append(line)
    
    # Don't forget the last sequence
    if current_header is not None:
        sequences.append((current_header, ''.join(current_seq)))
    
    return sequences

def get_query_sequence_from_a3m(a3m_content):
    """
    Extract the query sequence (first sequence) from an A3M file.
    The query sequence is typically uppercase without gaps.
    """
    sequences = parse_a3m(a3m_content)
    if sequences:
        header, seq = sequences[0]
        # Remove gaps and lowercase letters (which represent insertions in MSA)
        query_seq = ''.join([aa for aa in seq if aa.isupper() and aa != '-'])
        return header, query_seq
    return None, None

def list_pdb_contents(pdb_id, s3_client):
    """List contents of a PDB directory to understand structure"""
    prefix = f"pdb/{pdb_id}/"
    print(f"\nListing contents of {prefix}")
    
    try:
        response = s3_client.list_objects_v2(
            Bucket=BUCKET,
            Prefix=prefix,
            MaxKeys=20
        )
        
        if 'Contents' in response:
            for obj in response['Contents']:
                print(f"  {obj['Key']}")
        else:
            print(f"  No objects found with prefix {prefix}")
    except Exception as e:
        print(f"  Error listing: {e}")

def get_pdb_sequence(pdb_id, s3_client):
    """
    Get the amino acid sequence for a PDB entry from OpenFold bucket.
    
    Args:
        pdb_id: PDB ID in format like '1dp5_A'
        s3_client: Boto3 S3 client
    
    Returns:
        tuple: (header, amino_acid_sequence) or (None, None) if not found
    """
    # First, let's check what files exist for this PDB ID
    list_pdb_contents(pdb_id, s3_client)
    
    # Try to get one of the MSA files (let's use bfd_uniclust_hits.a3m)
    a3m_key = f"pdb/{pdb_id}/a3m/bfd_uniclust_hits.a3m"
    
    try:
        print(f"Attempting to fetch: {a3m_key}")
        response = s3_client.get_object(Bucket=BUCKET, Key=a3m_key)
        
        # Read and decompress if needed
        content = response['Body'].read()
        
        # Check if content is gzipped
        if content[:2] == b'\x1f\x8b':  # gzip magic number
            content = gzip.decompress(content)
        
        # Decode to string
        a3m_content = content.decode('utf-8')
        
        # Extract query sequence
        header, query_seq = get_query_sequence_from_a3m(a3m_content)
        
        return header, query_seq
        
    except s3_client.exceptions.NoSuchKey:
        print(f"File not found: {a3m_key}")
        # Try other MSA files
        other_msas = ["mgnify_hits.a3m", "uniref90_hits.a3m"]
        
        for msa_file in other_msas:
            try:
                a3m_key = f"pdb/{pdb_id}/a3m/{msa_file}"
                print(f"Trying alternative: {a3m_key}")
                response = s3_client.get_object(Bucket=BUCKET, Key=a3m_key)
                content = response['Body'].read()
                
                if content[:2] == b'\x1f\x8b':
                    content = gzip.decompress(content)
                
                a3m_content = content.decode('utf-8')
                header, query_seq = get_query_sequence_from_a3m(a3m_content)
                
                if query_seq:
                    return header, query_seq
                    
            except s3_client.exceptions.NoSuchKey:
                continue
    
    except Exception as e:
        print(f"Error accessing S3: {e}")
    
    return None, None

def main():
    # Create S3 client with anonymous access
    s3 = boto3.client(
        "s3",
        region_name=REGION,
        config=Config(signature_version=UNSIGNED)
    )
    
    # Test with PDB ID 1dp5_A
    pdb_id = "1dp5_A"
    print(f"\nFetching amino acid sequence for PDB ID: {pdb_id}")
    print("=" * 60)
    
    header, sequence = get_pdb_sequence(pdb_id, s3)
    
    if sequence:
        print(f"\nHeader: {header}")
        print(f"\nAmino Acid Sequence ({len(sequence)} residues):")
        print(sequence)
        
        # Print in chunks of 60 for readability
        print("\nFormatted sequence:")
        for i in range(0, len(sequence), 60):
            print(f"{i+1:4d} {sequence[i:i+60]}")
    else:
        print(f"\nCould not retrieve sequence for {pdb_id}")
    
    # Example of how to use this for matching against a dataset
    print("\n" + "=" * 60)
    print("To match against your dataset, you can:")
    print("1. Extract query sequences from MSA files")
    print("2. Compare with your target sequences")
    print("3. Use the matching MSA for further analysis")

if __name__ == "__main__":
    main()


Fetching amino acid sequence for PDB ID: 1dp5_A

Listing contents of pdb/1dp5_A/
  pdb/1dp5_A/a3m/bfd_uniclust_hits.a3m
  pdb/1dp5_A/a3m/mgnify_hits.a3m
  pdb/1dp5_A/a3m/uniref90_hits.a3m
  pdb/1dp5_A/hhr/pdb70_hits.hhr
Attempting to fetch: pdb/1dp5_A/a3m/bfd_uniclust_hits.a3m

Header: query

Amino Acid Sequence (329 residues):
GGHDVPLTNYLNAQYYTDITLGTPPQNFKVILDTGSSNLWVPSNECGSLACFLHSKYDHEASSSYKANGTEFAIQYGTGSLEGYISQDTLSIGDLTIPKQDFAEATSEPGLTFAFGKFDGILGLGYDTISVDKVVPPFYNAIQQDLLDEKRFAFYLGDTSKDTENGGEATFGGIDESKFKGDITWLPVRRKAYWEVKFEGIGLGDEYAELESHGAAIDTGTSLITLPSGLAEMINAEIGAKKGWTGQYTLDCNTRDNLPDLIFNFNGYNFTIGPYDYTLEVSGSCISAITPMDFPEPVGPLAIVGDAFLRKYYSIYDLGNNAVGLAKAI

Formatted sequence:
   1 GGHDVPLTNYLNAQYYTDITLGTPPQNFKVILDTGSSNLWVPSNECGSLACFLHSKYDHE
  61 ASSSYKANGTEFAIQYGTGSLEGYISQDTLSIGDLTIPKQDFAEATSEPGLTFAFGKFDG
 121 ILGLGYDTISVDKVVPPFYNAIQQDLLDEKRFAFYLGDTSKDTENGGEATFGGIDESKFK
 181 GDITWLPVRRKAYWEVKFEGIGLGDEYAELESHGAAIDTGTSLITLPSGLAEMINAEIGA
 241 KKGWTGQYTLDCNTRDNLPDLIFNFNGYNFTIGPYDYTLEVSGSCISAI

In [None]:
#!/usr/bin/env python3
"""
Match failed proteins from OpenFold results with PDB sequences in S3 bucket
by comparing amino acid sequences.
"""
import boto3
from botocore import UNSIGNED
from botocore.config import Config
import pandas as pd
import os
import gzip
from concurrent.futures import ThreadPoolExecutor, as_completed
from collections import defaultdict
import time
from tqdm import tqdm

BUCKET = "openfold"
REGION = "us-east-1"

class ProteinMatcher:
    def __init__(self, tsv_path, fasta_path, output_path):
        self.tsv_path = tsv_path
        self.fasta_path = fasta_path
        self.output_path = output_path
        
        # S3 client with anonymous access
        self.s3 = boto3.client(
            "s3",
            region_name=REGION,
            config=Config(signature_version=UNSIGNED)
        )
        
        # Cache for PDB sequences to avoid repeated S3 calls
        self.pdb_sequence_cache = {}
        
        # Dictionary to store original_id -> sequence mapping
        self.target_sequences = {}
        
        # List of available MSA files in order of preference
        self.msa_files = ["bfd_uniclust_hits.a3m", "mgnify_hits.a3m", "uniref90_hits.a3m"]
    
    def parse_fasta(self):
        """Parse the FASTA file and create a mapping of ID to sequence."""
        print("Parsing FASTA file...")
        current_id = None
        current_seq = []
        
        with open(self.fasta_path, 'r') as f:
            for line in f:
                line = line.strip()
                if line.startswith('>'):
                    # Save previous sequence
                    if current_id:
                        self.target_sequences[current_id] = ''.join(current_seq)
                    # Extract ID (first part after >)
                    current_id = line[1:].split()[0]
                    current_seq = []
                else:
                    current_seq.append(line)
            
            # Don't forget the last sequence
            if current_id:
                self.target_sequences[current_id] = ''.join(current_seq)
        
        print(f"Loaded {len(self.target_sequences)} sequences from FASTA file")
    
    def get_query_sequence_from_a3m(self, a3m_content):
        """Extract the query sequence (first sequence) from an A3M file."""
        lines = a3m_content.strip().split('\n')
        
        # Find the first sequence (query)
        in_first_seq = False
        seq_lines = []
        
        for line in lines:
            line = line.strip()
            if line.startswith('>'):
                if in_first_seq:
                    # We've reached the second sequence, stop
                    break
                in_first_seq = True
                continue
            elif in_first_seq and line:
                seq_lines.append(line)
        
        if seq_lines:
            # Join all sequence lines and remove gaps and lowercase
            full_seq = ''.join(seq_lines)
            # Keep only uppercase letters (actual residues in query)
            query_seq = ''.join([aa for aa in full_seq if aa.isupper() and aa != '-'])
            return query_seq
        
        return None
    
    def get_pdb_sequence(self, pdb_id):
        """Get sequence from PDB entry, with caching."""
        # Check cache first
        if pdb_id in self.pdb_sequence_cache:
            return self.pdb_sequence_cache[pdb_id]
        
        # Try each MSA file in order
        for msa_file in self.msa_files:
            a3m_key = f"pdb/{pdb_id}/a3m/{msa_file}"
            
            try:
                response = self.s3.get_object(Bucket=BUCKET, Key=a3m_key)
                content = response['Body'].read()
                
                # Handle gzip if needed
                if content[:2] == b'\x1f\x8b':
                    content = gzip.decompress(content)
                
                a3m_content = content.decode('utf-8')
                query_seq = self.get_query_sequence_from_a3m(a3m_content)
                
                if query_seq:
                    # Cache the result
                    self.pdb_sequence_cache[pdb_id] = query_seq
                    return query_seq
                
            except self.s3.exceptions.NoSuchKey:
                continue
            except Exception as e:
                print(f"Error accessing {a3m_key}: {e}")
                continue
        
        # Cache None result to avoid repeated failed attempts
        self.pdb_sequence_cache[pdb_id] = None
        return None
    
    def get_all_pdb_ids(self):
        """Get list of all PDB IDs in the bucket (more efficient than searching all)."""
        print("Fetching list of PDB IDs from S3...")
        pdb_ids = set()
        
        paginator = self.s3.get_paginator('list_objects_v2')
        pages = paginator.paginate(Bucket=BUCKET, Prefix="pdb/", Delimiter="/")
        
        for page in pages:
            if 'CommonPrefixes' in page:
                for prefix in page['CommonPrefixes']:
                    # Extract PDB ID from path like "pdb/1abc_A/"
                    pdb_id = prefix['Prefix'].split('/')[-2]
                    pdb_ids.add(pdb_id)
        
        print(f"Found {len(pdb_ids)} PDB IDs in bucket")
        return list(pdb_ids)
    
    def search_sequence_in_pdbs(self, target_sequence, pdb_ids, batch_size=100, max_workers=10):
        """Search for target sequence in PDB entries using parallel processing."""
        # Create a progress bar for PDB search
        pbar = tqdm(total=len(pdb_ids), desc="Searching PDBs", leave=False)
        
        # Process in batches to avoid overwhelming the system
        for i in range(0, len(pdb_ids), batch_size):
            batch = pdb_ids[i:i + batch_size]
            
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                # Submit tasks
                future_to_pdb = {
                    executor.submit(self.get_pdb_sequence, pdb_id): pdb_id 
                    for pdb_id in batch
                }
                
                # Process results as they complete
                for future in as_completed(future_to_pdb):
                    pdb_id = future_to_pdb[future]
                    pbar.update(1)
                    
                    try:
                        pdb_sequence = future.result()
                        if pdb_sequence and pdb_sequence == target_sequence:
                            pbar.close()
                            return pdb_id  # Return immediately when found
                    except Exception as e:
                        # Don't print errors during search to keep output clean
                        pass
        
        pbar.close()
        return None
    
    def build_sequence_to_pdb_index(self, pdb_ids, batch_size=1000, max_workers=20):
        """Build an index of sequence -> PDB ID for faster lookups."""
        print("Building sequence index (this may take a few minutes)...")
        sequence_to_pdb = {}
        
        # Process in batches with progress bar
        with tqdm(total=len(pdb_ids), desc="Indexing PDB sequences") as pbar:
            for i in range(0, len(pdb_ids), batch_size):
                batch = pdb_ids[i:i + batch_size]
                
                with ThreadPoolExecutor(max_workers=max_workers) as executor:
                    future_to_pdb = {
                        executor.submit(self.get_pdb_sequence, pdb_id): pdb_id 
                        for pdb_id in batch
                    }
                    
                    for future in as_completed(future_to_pdb):
                        pdb_id = future_to_pdb[future]
                        pbar.update(1)
                        
                        try:
                            sequence = future.result()
                            if sequence:
                                if sequence not in sequence_to_pdb:
                                    sequence_to_pdb[sequence] = []
                                sequence_to_pdb[sequence].append(pdb_id)
                        except Exception:
                            pass
        
        print(f"Index built: {len(sequence_to_pdb)} unique sequences found")
        return sequence_to_pdb
    
    def process_failed_proteins(self, use_index=True):
        """Main processing function."""
        # Load TSV file
        print(f"Loading TSV file: {self.tsv_path}")
        df = pd.read_csv(self.tsv_path, sep='\t')
        
        # Parse FASTA file
        self.parse_fasta()
        
        # Filter for failed entries
        failed_df = df[df['status'] == 'FAILED'].copy()
        print(f"Found {len(failed_df)} failed proteins to process")
        
        # Get all PDB IDs once
        pdb_ids = self.get_all_pdb_ids()
        
        # Build sequence index if requested (much faster for multiple searches)
        sequence_to_pdb = {}
        if use_index:
            sequence_to_pdb = self.build_sequence_to_pdb_index(pdb_ids)
        
        # Process each failed protein
        results = []
        for idx, row in tqdm(failed_df.iterrows(), total=len(failed_df), desc="Processing failed proteins"):
            original_id = row['original_id']
            
            # Get target sequence from FASTA
            if original_id not in self.target_sequences:
                results.append({
                    'original_id': original_id,
                    'status': 'NOT_IN_FASTA',
                    'converted_id': None,
                    'notes': 'Original ID not found in FASTA file'
                })
                continue
            
            target_sequence = self.target_sequences[original_id]
            
            # Search for matching sequence
            start_time = time.time()
            
            if use_index:
                # Use pre-built index (much faster)
                if target_sequence in sequence_to_pdb:
                    matched_pdb = sequence_to_pdb[target_sequence][0]  # Take first match
                    search_time = time.time() - start_time
                    results.append({
                        'original_id': original_id,
                        'status': 'FOUND',
                        'converted_id': matched_pdb,
                        'notes': f'Sequence match found (indexed search: {search_time:.4f}s)'
                    })
                else:
                    search_time = time.time() - start_time
                    results.append({
                        'original_id': original_id,
                        'status': 'NOT_FOUND',
                        'converted_id': None,
                        'notes': f'No sequence match in index ({search_time:.4f}s)'
                    })
            else:
                # Direct search (slower)
                print(f"\nSearching for {original_id} (length: {len(target_sequence)})")
                matched_pdb = self.search_sequence_in_pdbs(target_sequence, pdb_ids)
                search_time = time.time() - start_time
                
                if matched_pdb:
                    results.append({
                        'original_id': original_id,
                        'status': 'FOUND',
                        'converted_id': matched_pdb,
                        'notes': f'Sequence match found in {search_time:.2f}s'
                    })
                else:
                    results.append({
                        'original_id': original_id,
                        'status': 'NOT_FOUND',
                        'converted_id': None,
                        'notes': f'No sequence match in PDB database'
                    })
        
        # Create results dataframe
        results_df = pd.DataFrame(results)
        
        # Save results
        results_df.to_csv(self.output_path, sep='\t', index=False)
        print(f"\nResults saved to: {self.output_path}")
        
        # Print summary
        print("\nSummary:")
        print(f"Total failed proteins: {len(failed_df)}")
        print(f"Found matches: {len(results_df[results_df['status'] == 'FOUND'])}")
        print(f"Not found: {len(results_df[results_df['status'] == 'NOT_FOUND'])}")
        print(f"Not in FASTA: {len(results_df[results_df['status'] == 'NOT_IN_FASTA'])}")
        
        return results_df

def main():
    # File paths
    tsv_path = r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\openfold_results_updated.tsv"
    fasta_path = r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\nrPDB-GO_2019.06.18_sequences.fasta"
    output_path = r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\openfold_failed_proteins_matched.tsv"
    
    # Create matcher and process
    matcher = ProteinMatcher(tsv_path, fasta_path, output_path)
    results = matcher.process_failed_proteins()
    
    print("\nProcessing complete!")

if __name__ == "__main__":
    main()

Loading TSV file: C:\Users\rfrjo\Documents\Codebases\PFP_Testing\openfold_results_updated.tsv
Parsing FASTA file...
Loaded 36641 sequences from FASTA file
Found 25 failed proteins to process
Fetching list of PDB IDs from S3...
Found 131487 PDB IDs in bucket
Building sequence index (this may take a few minutes)...


Indexing PDB sequences:   3%|▎         | 3524/131487 [01:30<1:01:44, 34.54it/s]

In [2]:
#!/usr/bin/env python3
"""
Rescue FAILED chains in openfold_results_updated.tsv by locating a
sequence-identical PDB chain in the OpenFold S3 bucket.

• Downloads only 4 kB from each .a3m (query line lives there)
• Builds / loads a sha1(sequence) → [pdb_id,…] index, cached on disk
• Patches the original TSV in-place (status, converted_id, notes)
"""

import gzip, hashlib, pickle, time
from collections import defaultdict
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor

import boto3
from botocore import UNSIGNED
from botocore.config import Config
import pandas as pd
from tqdm import tqdm

# ───────── constants ──────────────────────────────────────────────────────
BUCKET   = "openfold"
REGION   = "us-east-1"
MSA_FILES = ("bfd_uniclust_hits.a3m", "mgnify_hits.a3m", "uniref90_hits.a3m")
RANGE     = "bytes=0-4095"
PDB_LIST_CACHE  = "pdb_ids.pkl.gz"
SEQ_INDEX_CACHE = "seq2pdb.pkl.gz"          # NEW: cache for the big index

# ───────── tiny helpers ───────────────────────────────────────────────────
norm = lambda s: s.lower().replace("-", "_")
sha  = lambda seq: hashlib.sha1(seq.encode()).hexdigest()

def first_query_seq(buf: bytes) -> str | None:
    if buf.startswith(b"\x1f\x8b"):
        buf = gzip.decompress(buf)
    lines = buf.split(b"\n", 3)
    if len(lines) < 2:
        return None
    raw = lines[1].decode("utf-8", "ignore").strip()
    return "".join(c for c in raw if c.isupper() and c != "-")

# ───────── class ─────────────────────────────────────────────────────────
class ProteinMatcher:
    def __init__(self, tsv, fasta, out_tsv, max_workers=32):
        self.tsv      = Path(tsv)
        self.fasta    = Path(fasta)
        self.out_tsv  = Path(out_tsv)
        self.workers  = max_workers

        self.s3 = boto3.client(
            "s3", region_name=REGION,
            config=Config(signature_version=UNSIGNED))

        self.fasta_map: dict[str, str] = {}
        self.pdb_seq_cache: dict[str, str | None] = {}

    # ─── FASTA loader ────────────────────────────────────────────
    def _load_fasta(self):
        print("Parsing FASTA …")
        cur, buf = None, []
        with self.fasta.open() as fh:
            for ln in fh:
                ln = ln.strip()
                if ln.startswith(">"):
                    if cur:
                        self.fasta_map[norm(cur)] = "".join(buf)
                    cur, buf = ln[1:].split()[0], []
                else:
                    buf.append(ln)
            if cur:
                self.fasta_map[norm(cur)] = "".join(buf)
        print(f"FASTA records: {len(self.fasta_map):,}")

    # ─── S3 listing (cached) ────────────────────────────────────
    def _pdb_ids(self) -> list[str]:
        if Path(PDB_LIST_CACHE).exists():
            with gzip.open(PDB_LIST_CACHE, "rb") as fh:
                ids = pickle.load(fh)
            print(f"PDB list from cache ({len(ids):,})")
            return ids

        print("Listing pdb/* prefixes in S3 …")
        ids = set()
        paginator = self.s3.get_paginator("list_objects_v2")
        for pg in paginator.paginate(Bucket=BUCKET, Prefix="pdb/", Delimiter="/"):
            for pref in pg.get("CommonPrefixes", []):
                ids.add(pref["Prefix"].split("/")[1])   # 1dp5_A
        ids = sorted(ids)
        with gzip.open(PDB_LIST_CACHE, "wb") as fh:
            pickle.dump(ids, fh)
        print(f"PDB list cached ({len(ids):,})")
        return ids

    # ─── download query seq for one chain ───────────────────────
    def _pdb_query_seq(self, pdb_id: str) -> str | None:
        if pdb_id in self.pdb_seq_cache:
            return self.pdb_seq_cache[pdb_id]

        for msa in MSA_FILES:
            key = f"pdb/{pdb_id}/a3m/{msa}"
            try:
                buf = self.s3.get_object(
                    Bucket=BUCKET, Key=key, Range=RANGE)["Body"].read()
                seq = first_query_seq(buf)
                if seq:
                    self.pdb_seq_cache[pdb_id] = seq
                    return seq
            except self.s3.exceptions.NoSuchKey:
                continue
        self.pdb_seq_cache[pdb_id] = None
        return None

    # ─── build / load big index ─────────────────────────────────
    def _seq_index(self, pdb_ids: list[str]) -> dict[str, list[str]]:
        if Path(SEQ_INDEX_CACHE).exists():
            with gzip.open(SEQ_INDEX_CACHE, "rb") as fh:
                idx = pickle.load(fh)
            print(f"Sequence index loaded ({len(idx):,} unique hashes)")
            return idx

        print("Building sequence→PDB index (one-off, threaded) …")
        idx: dict[str, list[str]] = defaultdict(list)
        with ThreadPoolExecutor(self.workers) as ex:
            for pid, seq in tqdm(
                zip(pdb_ids, ex.map(self._pdb_query_seq, pdb_ids)),
                total=len(pdb_ids), desc="Indexing"):
                if seq:
                    idx[sha(seq)].append(pid)

        with gzip.open(SEQ_INDEX_CACHE, "wb") as fh:
            pickle.dump(idx, fh)
        print(f"Index cached ({len(idx):,} unique hashes)")
        return idx

    # ─── main routine ───────────────────────────────────────────
    def run(self):
        self._load_fasta()
        df = pd.read_csv(self.tsv, sep="\t")
        failed = df[df["status"] == "FAILED"]
        print(f"FAILED rows: {len(failed):,}")

        pdb_ids = self._pdb_ids()
        seq2pdb = self._seq_index(pdb_ids)

        # patch dataframe in-place
        for idx in tqdm(failed.index, desc="Rescuing"):
            orig_id = norm(df.at[idx, "original_id"])
            seq = self.fasta_map.get(orig_id)
            if not seq:
                df.at[idx, "notes"]  = "seq missing in FASTA"
                df.at[idx, "status"] = "NOT_IN_FASTA"
                continue

            hit = seq2pdb.get(sha(seq))
            if hit:
                df.at[idx, "status"]       = "RESCUED"
                df.at[idx, "converted_id"] = hit[0]
                df.at[idx, "notes"]        = "sequence-identical chain"
            else:
                df.at[idx, "status"] = "NOT_FOUND"
                df.at[idx, "notes"]  = "no identical seq"

        df.to_csv(self.out_tsv, sep="\t", index=False)
        print(f"\nSaved → {self.out_tsv}")
        print(df["status"].value_counts())

# ───────── CLI ───────────────────────────────────────────────────────────
if __name__ == "__main__":
    BASE = Path(r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing")
    ProteinMatcher(
        BASE / "openfold_results_updated.tsv",
        BASE / "nrPDB-GO_2019.06.18_sequences.fasta",
        BASE / "openfold_results_rescued.tsv",
        max_workers=128,
    ).run()


Parsing FASTA …
FASTA records: 36,408
FAILED rows: 25
PDB list from cache (131,487)
Sequence index loaded (131,482 unique hashes)


Rescuing: 100%|██████████| 25/25 [00:00<00:00, 8324.67it/s]


Saved → C:\Users\rfrjo\Documents\Codebases\PFP_Testing\openfold_results_rescued.tsv
status
DIRECT       22913
RESCUED       6823
NOT_FOUND        4
Name: count, dtype: int64





In [1]:
# Cell 1  ────────────────────────────────────────────────────────────────
from __future__ import annotations
import gzip, pickle, hashlib
from pathlib import Path
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor

import boto3, pandas as pd
from botocore.config import Config
from botocore import UNSIGNED
from tqdm.auto import tqdm   # pretty progress bars in notebooks

# ───── EDIT THESE PATHS IF NEEDED ───────────────────────────────────────
BASE          = Path(r"C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data")
PDB_DIR       = BASE / "protein_data_pdb"                    # folder of IDs
FASTA_FILE    = BASE / "nrPDB-GO_2019.06.18_sequences.fasta"

MAP_TSV       = BASE / "protein_data_pdb_mapping_updated.tsv"        # output

PDB_LIST_CACHE  = BASE / "pdb_ids.pkl.gz"
SEQ_INDEX_CACHE = BASE / "seq2pdb.pkl.gz"

# ───── S3 & bucket constants ────────────────────────────────────────────
BUCKET   = "openfold"; REGION = "us-east-1"
MSA_FILES = ("bfd_uniclust_hits.a3m", "mgnify_hits.a3m", "uniref90_hits.a3m")
RANGE     = "bytes=0-4095"         # first 4 kB holds query sequence
THREADS   = 96                     # parallel S3 workers

# ───── tiny helpers ─────────────────────────────────────────────────────
norm = lambda s: s.lower().replace("-", "_")
sha  = lambda s: hashlib.sha1(s.encode()).hexdigest()

def first_query_seq(buf: bytes) -> str | None:
    if buf[:2] == b"\x1f\x8b":             # gzip magic
        import gzip as _gz; buf = _gz.decompress(buf)
    parts = buf.split(b"\n", 3)
    if len(parts) < 2: return None
    raw = parts[1].decode("utf-8", "ignore").strip()
    return "".join(c for c in raw if c.isupper() and c != "-")

# ───── main callable ----------------------------------------------------
def run_mapping(refresh: bool = False, workers: int = THREADS):
    """
    Map every chain in protein_data_pdb/ to a sequence-identical chain
    present in the OpenFold bucket.

    Parameters
    ----------
    refresh : bool
        If True, ignore on-disk caches and rebuild pdb_ids + seq2pdb index.
    workers : int
        Degree of S3 parallelism when (re)building the big index.
    """
    # 1. gather local chain IDs (folder / file names)
    local_ids = sorted(p.name for p in PDB_DIR.iterdir())
    print(f"Local chains: {len(local_ids):,}")

    # 2. load FASTA into dict
    fasta_map: dict[str, str] = {}
    cur, buf = None, []
    with FASTA_FILE.open() as fh:
        for ln in fh:
            ln = ln.rstrip()
            if ln.startswith(">"):
                if cur: fasta_map[norm(cur)] = "".join(buf)
                cur, buf = ln[1:].split()[0], []
            else:
                buf.append(ln)
        if cur:
            fasta_map[norm(cur)] = "".join(buf)
    print(f"FASTA entries loaded: {len(fasta_map):,}")

    # 3. anonymous S3 client
    s3 = boto3.client("s3",
                      region_name=REGION,
                      config=Config(signature_version=UNSIGNED))

    # 4. all chain IDs present in bucket  (use / rebuild cache)
    if PDB_LIST_CACHE.exists() and not refresh:
        with gzip.open(PDB_LIST_CACHE, "rb") as fh:
            bucket_ids = pickle.load(fh)
        print(f"PDB list from cache ({len(bucket_ids):,})")
    else:
        print("Listing pdb/* prefixes in S3 …")
        bucket_ids = set()
        paginator = s3.get_paginator("list_objects_v2")
        for pg in paginator.paginate(Bucket=BUCKET, Prefix="pdb/", Delimiter="/"):
            for pref in pg.get("CommonPrefixes", []):
                bucket_ids.add(pref["Prefix"].split("/")[1])
        bucket_ids = sorted(bucket_ids)
        with gzip.open(PDB_LIST_CACHE, "wb") as fh:
            pickle.dump(bucket_ids, fh)
        print(f"PDB list cached ({len(bucket_ids):,})")

    # 5. sequence hash → bucket chains index  (use / rebuild cache)
    if SEQ_INDEX_CACHE.exists() and not refresh:
        with gzip.open(SEQ_INDEX_CACHE, "rb") as fh:
            seq2pdb: dict[str, list[str]] = pickle.load(fh)
        print(f"Sequence index from cache ({len(seq2pdb):,} hashes)")
    else:
        print("Building sequence→PDB index (one-off) …")
        seq2pdb = defaultdict(list)
        def query_seq(pid):
            for msa in MSA_FILES:
                key=f"pdb/{pid}/a3m/{msa}"
                try:
                    buf = s3.get_object(Bucket=BUCKET, Key=key,
                                        Range=RANGE)["Body"].read()
                    seq = first_query_seq(buf)
                    if seq: return seq
                except s3.exceptions.NoSuchKey: pass
            return None
        with ThreadPoolExecutor(workers) as ex:
            for pid, seq in tqdm(
                zip(bucket_ids, ex.map(query_seq, bucket_ids)),
                total=len(bucket_ids), desc="Indexing"):
                if seq:
                    seq2pdb[sha(seq)].append(pid)
        with gzip.open(SEQ_INDEX_CACHE, "wb") as fh:
            pickle.dump(seq2pdb, fh)
        print(f"Index cached ({len(seq2pdb):,} hashes)")

    # 6. map every local chain
    rows=[]
    for lid in tqdm(local_ids, desc="Mapping chains"):
        key = norm(lid)
        seq = fasta_map.get(key)
        if not seq:
            rows.append((lid, None, "NOT_IN_FASTA", "no sequence in FASTA"))
            continue
        hits = seq2pdb.get(sha(seq))
        if hits:
            rows.append((lid, hits[0], "MATCH", "sequence-identical chain"))
        else:
            rows.append((lid, None, "NO_MATCH", "not present in bucket"))

    df = pd.DataFrame(rows, columns=["original_id","matched_id",
                                     "status","notes"])
    df.to_csv(MAP_TSV, sep="\t", index=False)
    display(df.head())
    print(f"✔ mapping saved → {MAP_TSV}")
    print(df["status"].value_counts())

# ───── call the function from another cell whenever you want ────────────
run_mapping()



Local chains: 29,740
FASTA entries loaded: 36,408
PDB list from cache (131,487)
Building sequence→PDB index (one-off) …


Indexing:   0%|          | 0/131487 [00:00<?, ?it/s]

Index cached (131,482 hashes)


Mapping chains:   0%|          | 0/29740 [00:00<?, ?it/s]

Unnamed: 0,original_id,matched_id,status,notes
0,154L-A,153l_A,MATCH,sequence-identical chain
1,155C-A,155c_A,MATCH,sequence-identical chain
2,16PK-A,13pk_A,MATCH,sequence-identical chain
3,16VP-A,16vp_A,MATCH,sequence-identical chain
4,1914-A,1914_A,MATCH,sequence-identical chain


✔ mapping saved → C:\Users\rfrjo\Documents\Codebases\PFP_Testing\data\protein_data_pdb_mapping_updated.tsv
status
MATCH       29558
NO_MATCH      182
Name: count, dtype: int64


In [13]:
#!/usr/bin/env python3
"""
Simple script to fetch and display the amino acid sequence for a specific PDB ID
from the OpenFold S3 bucket.
"""

import boto3
import gzip
from botocore import UNSIGNED
from botocore.config import Config

# Constants
BUCKET = "openfold"
REGION = "us-east-1"
PDB_ID = "5y8q_A"

# MSA files to try in order
MSA_FILES = ["bfd_uniclust_hits.a3m", "mgnify_hits.a3m", "uniref90_hits.a3m"]

def get_pdb_sequence(pdb_id):
    """Fetch the amino acid sequence for a PDB ID from OpenFold S3."""
    
    # Create S3 client with anonymous access
    s3 = boto3.client(
        "s3",
        region_name=REGION,
        config=Config(signature_version=UNSIGNED)
    )
    
    # Try each MSA file
    for msa_file in MSA_FILES:
        key = f"pdb/{pdb_id}/a3m/{msa_file}"
        
        try:
            print(f"Trying: {key}")
            
            # Get the full file (not just range)
            response = s3.get_object(Bucket=BUCKET, Key=key)
            content = response['Body'].read()
            
            # Decompress if gzipped
            if content[:2] == b'\x1f\x8b':  # gzip magic number
                content = gzip.decompress(content)
            
            # Parse the A3M file
            lines = content.decode('utf-8').split('\n')
            
            # Find the first sequence (query)
            in_first_seq = False
            sequence_lines = []
            
            for line in lines:
                line = line.strip()
                if line.startswith('>'):
                    if in_first_seq:
                        # We've hit the second sequence, stop
                        break
                    in_first_seq = True
                    print(f"Header: {line}")
                elif in_first_seq and line:
                    sequence_lines.append(line)
            
            if sequence_lines:
                # Join all sequence lines
                full_sequence = ''.join(sequence_lines)
                
                # Extract only uppercase letters (no gaps or lowercase)
                clean_sequence = ''.join(c for c in full_sequence if c.isupper() and c != '-')
                
                print(f"\nFound sequence in: {msa_file}")
                return clean_sequence
                
        except s3.exceptions.NoSuchKey:
            print(f"  File not found: {key}")
            continue
        except Exception as e:
            print(f"  Error: {e}")
            continue
    
    return None

# Main execution
if __name__ == "__main__":
    print(f"Fetching amino acid sequence for PDB ID: {PDB_ID}")
    print("=" * 60)
    
    sequence = get_pdb_sequence(PDB_ID)
    
    if sequence:
        print(f"\nAmino Acid Sequence ({len(sequence)} residues):")
        print(sequence)
        
        # Also print in formatted chunks for readability
        print("\nFormatted (60 residues per line):")
        for i in range(0, len(sequence), 60):
            print(f"{i+1:4d} {sequence[i:i+60]}")
    else:
        print(f"\nCould not retrieve sequence for {PDB_ID}")

Fetching amino acid sequence for PDB ID: 5y8q_A
Trying: pdb/5y8q_A/a3m/bfd_uniclust_hits.a3m
Header: >query

Found sequence in: bfd_uniclust_hits.a3m

Amino Acid Sequence (232 residues):
GSHMAHSKHGLKEEMTMKYHMEGCVNGHKFVITGEGIGYPFKGKQTINLCVIEGGPLPFSEDILSAGFXDRIFTEYPQDIVDYFKNSCPAGYTWGRSFLFEDGAVCICNVDITVSVKENCIYHKSIFNGVNFPADGPVMKKMTTNWEASCEKIMPVPKQGILKGDVSMYLLLKDGGRYRCQFDTVYKAKSVPSKMPEWHFIQHKLLREDRSDAKNQKWQLTEHAIAFPSALA

Formatted (60 residues per line):
   1 GSHMAHSKHGLKEEMTMKYHMEGCVNGHKFVITGEGIGYPFKGKQTINLCVIEGGPLPFS
  61 EDILSAGFXDRIFTEYPQDIVDYFKNSCPAGYTWGRSFLFEDGAVCICNVDITVSVKENC
 121 IYHKSIFNGVNFPADGPVMKKMTTNWEASCEKIMPVPKQGILKGDVSMYLLLKDGGRYRC
 181 QFDTVYKAKSVPSKMPEWHFIQHKLLREDRSDAKNQKWQLTEHAIAFPSALA


In [42]:
#!/usr/bin/env python3
"""
Simple script to fetch and display the amino acid sequence for a specific PDB ID
from all available MSA files in the OpenFold S3 bucket.
"""

import boto3
import gzip
from botocore import UNSIGNED
from botocore.config import Config

# Constants
BUCKET = "openfold"
REGION = "us-east-1"
PDB_ID = "4ror_A"

# MSA files to try
MSA_FILES = ["bfd_uniclust_hits.a3m", "mgnify_hits.a3m", "uniref90_hits.a3m"]

def get_pdb_sequence_from_file(s3, pdb_id, msa_file):
    """Fetch the amino acid sequence for a PDB ID from a specific MSA file."""
    
    key = f"pdb/{pdb_id}/a3m/{msa_file}"
    
    try:
        print(f"Trying: {key}")
        
        # Get the full file (not just range)
        response = s3.get_object(Bucket=BUCKET, Key=key)
        content = response['Body'].read()
        
        # Decompress if gzipped
        if content[:2] == b'\x1f\x8b':  # gzip magic number
            content = gzip.decompress(content)
        
        # Parse the A3M file
        lines = content.decode('utf-8').split('\n')
        
        # Find the first sequence (query)
        in_first_seq = False
        sequence_lines = []
        header = ""
        
        for line in lines:
            line = line.strip()
            if line.startswith('>'):
                if in_first_seq:
                    # We've hit the second sequence, stop
                    break
                in_first_seq = True
                header = line
                print(f"  Header: {line}")
            elif in_first_seq and line:
                sequence_lines.append(line)
        
        if sequence_lines:
            # Join all sequence lines
            full_sequence = ''.join(sequence_lines)
            
            # Extract only uppercase letters (no gaps or lowercase)
            clean_sequence = ''.join(c for c in full_sequence if c.isupper() and c != '-')
            
            print(f"  ✓ Found sequence ({len(clean_sequence)} residues)")
            return clean_sequence, header
            
    except s3.exceptions.NoSuchKey:
        print(f"  ✗ File not found: {key}")
        return None, None
    except Exception as e:
        print(f"  ✗ Error: {e}")
        return None, None
    
    return None, None

def get_all_pdb_sequences(pdb_id):
    """Fetch amino acid sequences from all available MSA files for a PDB ID."""
    
    # Create S3 client with anonymous access
    s3 = boto3.client(
        "s3",
        region_name=REGION,
        config=Config(signature_version=UNSIGNED)
    )
    
    results = {}
    
    # Try each MSA file
    for msa_file in MSA_FILES:
        sequence, header = get_pdb_sequence_from_file(s3, pdb_id, msa_file)
        if sequence:
            results[msa_file] = {
                'sequence': sequence,
                'header': header
            }
        print()  # blank line for readability
    
    return results

# Main execution
if __name__ == "__main__":
    print(f"Fetching amino acid sequences for PDB ID: {PDB_ID}")
    print("=" * 60)
    
    sequences = get_all_pdb_sequences(PDB_ID)
    
    if sequences:
        print("SUMMARY:")
        print("=" * 60)
        
        for msa_file, data in sequences.items():
            sequence = data['sequence']
            header = data['header']
            
            print(f"\n📁 {msa_file}")
            print(f"Header: {header}")
            print(f"Sequence ({len(sequence)} residues):")
            print(sequence)
            
            # Also print in formatted chunks for readability
            print("Formatted (60 residues per line):")
            for i in range(0, len(sequence), 60):
                print(f"{i+1:4d} {sequence[i:i+60]}")
            print("-" * 60)
        
        # Check if all sequences are identical
        all_sequences = [data['sequence'] for data in sequences.values()]
        if len(set(all_sequences)) == 1:
            print(f"\n✓ All {len(sequences)} sequences are identical!")
        else:
            print(f"\n⚠ Found {len(set(all_sequences))} different sequences across {len(sequences)} files")
            
    else:
        print(f"\nCould not retrieve any sequences for {PDB_ID}")

Fetching amino acid sequences for PDB ID: 4ror_A
Trying: pdb/4ror_A/a3m/bfd_uniclust_hits.a3m
  Header: >query
  ✓ Found sequence (320 residues)

Trying: pdb/4ror_A/a3m/mgnify_hits.a3m
  Header: >query
  ✓ Found sequence (320 residues)

Trying: pdb/4ror_A/a3m/uniref90_hits.a3m
  Header: >query
  ✓ Found sequence (320 residues)

SUMMARY:

📁 bfd_uniclust_hits.a3m
Header: >query
Sequence (320 residues):
MARSKIALIGAGQIGGTLAHLAGLKELGDVVLFDIVDGVPQGKALDIAESAPVDGFDAKYSGASDYSAIAGADVVIVTAGVPRKPGMSRDDLIGINLKVMEAVGAGIKEHAPDAFVICITNPLDAMVWALQKFSGLPTNKVVGMAGVLDSARFRHFLAEEFGVSVEDVTAFVLGGHGDDMVPLTRYSTVAGVPLTDLVKLGWTTQEKLDAMVERTRKGGGEIVNLLKTGSAFYAPAASAIAMAESYLRDKKRVLPCAAYLDGQYGIDGLYVGVPVVIGENGVERVLEVTFNDDEKAMFEKSVNSVKGLIEACKSVNDKLA
Formatted (60 residues per line):
   1 MARSKIALIGAGQIGGTLAHLAGLKELGDVVLFDIVDGVPQGKALDIAESAPVDGFDAKY
  61 SGASDYSAIAGADVVIVTAGVPRKPGMSRDDLIGINLKVMEAVGAGIKEHAPDAFVICIT
 121 NPLDAMVWALQKFSGLPTNKVVGMAGVLDSARFRHFLAEEFGVSVEDVTAFVLGGHGDDM
 181 VPLTRYSTVAGVPLTDLVKLGWTTQEKLDAMVERTRKG

In [40]:
# ─── Enhanced OpenFold Canonicaliser ─────────────────────────────────────────
# Make sure you've restarted the kernel after running:
# pip install "openfold @ git+https://github.com/aqlaboratory/openfold"

# 1️⃣ EDIT THIS LINE ONLY -------------------------------------------------
seq = "QTSCDQWATFTGNGYTVSNNLWGASAGSGFGCVTAVSLSGGASWHADWQWSGGQNNVKSYQNSQIAIPQKRTVNSISSMPTTASWSYSGSNIRANVAYDLFTAANPNHVTYSGDYELMIWLGKYGDIGPIGSSQGTVNVGGQSWTLYYGYNGAMQVYSFVAQTNTTNYSGDVKNFFNYLRDNKGYNAAGQYVLSYQFGTECFTGSGTLNVASWTASIN"
# ------------------------------------------------------------------------

from openfold.np import residue_constants as rc
import re

# Get authentic OpenFold mappings
RESTYPE_3TO1 = rc.restype_3to1  # OpenFold's official 3-to-1 mapping
CANON = set("ARNDCQEGHILKMFPSTWYV")  # 20 standard amino acids

# Enhanced mapping for common non-canonical variants
ADDITIONAL_3TO1 = {
    # Common protonation states and variants
    'HID': 'H', 'HIE': 'H', 'HIP': 'H',  # Histidine variants
    'ASH': 'D', 'GLH': 'E',              # Protonated Asp/Glu
    'CYX': 'C', 'CYM': 'C',              # Cysteine variants
    'LYN': 'K',                          # Neutral lysine
    # Selenocysteine and pyrrolysine (21st and 22nd amino acids)
    'SEC': 'U', 'PYL': 'O',
    # Terminal modifications
    'ACE': 'X', 'NH2': 'X',              # Acetyl/amide caps
}

def to_openfold(s: str) -> str:
    """
    Convert 1- or 3-letter sequence to OpenFold MSA-query compatible form.
    
    Handles:
    - Standard 20 amino acids
    - Common non-canonical variants (HID, HIE, etc.)
    - Selenocysteine (U) and Pyrrolysine (O)
    - Terminal modifications
    - Mixed case input
    
    Args:
        s: Input sequence (1-letter or 3-letter format)
        
    Returns:
        Canonicalized 1-letter sequence with unknown residues as 'X'
    """
    s = s.strip().upper()
    
    # Enhanced heuristic for format detection
    looks_three = (
        bool(re.search(r"\s", s)) or  # Contains whitespace
        (len(s) % 3 == 0 and s.isalpha() and len(s) > 3) or  # Multiple of 3, all letters
        bool(re.search(r"[A-Z]{3,}", s))  # Contains 3+ consecutive letters
    )
    
    if looks_three:
        # Enhanced 3-letter parsing
        # Handle various separators and formats
        s_clean = re.sub(r'[^A-Z]', ' ', s)  # Replace non-letters with spaces
        triplets = re.findall(r'[A-Z]{3}', s_clean)
        
        result = []
        for triplet in triplets:
            # Try OpenFold mapping first
            if triplet in RESTYPE_3TO1:
                result.append(RESTYPE_3TO1[triplet])
            # Try additional mappings
            elif triplet in ADDITIONAL_3TO1:
                result.append(ADDITIONAL_3TO1[triplet])
            # Unknown -> X
            else:
                result.append('X')
        
        return "".join(result)
    
    else:
        # 1-letter input processing
        # Keep only letters, convert non-canonical to X
        result = []
        for aa in s:
            if aa.isalpha():
                if aa in CANON:
                    result.append(aa)
                elif aa in 'UO':  # Selenocysteine, Pyrrolysine
                    result.append(aa)
                else:
                    result.append('X')
        
        return "".join(result)

# Test the function
print("Original :", seq)
print("Canonical:", to_openfold(seq))

# Additional test cases for validation
test_cases = [
    "ARNDCQEGHILKMFPSTWYV",  # All 20 standard
    "HIDHIEGLY",  # Mixed case with histidine variants (if 3-letter)``
    "HID HIE GLY",  # 3-letter with spaces
    "METHYLYSARGTRP",  # 3-letter without spaces
    "ACDEFGX",  # 1-letter with unknown
    "MKUO",  # Including selenocysteine and pyrrolysine
]

print("\n--- Validation Tests ---")
for i, test in enumerate(test_cases, 1):
    result = to_openfold(test)
    print(f"Test {i}: '{test}' -> '{result}'")
    
# Check if OpenFold mapping is available
print(f"\n--- Mapping Info ---")
print(f"OpenFold RESTYPE_3TO1 entries: {len(RESTYPE_3TO1) if hasattr(rc, 'restype_3to1') else 'Not available'}")
print(f"Standard amino acids: {sorted(CANON)}")
print(f"Additional mappings: {len(ADDITIONAL_3TO1)}")

ModuleNotFoundError: No module named 'tree'

In [19]:
#!/usr/bin/env python3
"""
OpenFold Database-Driven Sequence Transformer

This script downloads the actual databases that OpenFold uses and attempts to 
reproduce the Q→E transformation by accessing the same data sources.

The key insight: OpenFold likely uses PDB SeqRes sequences (official sequences)
rather than sequences extracted from structure coordinates.
"""

import os
import sys
import gzip
import urllib.request
import subprocess
from pathlib import Path
import re
from typing import Dict, Optional, Tuple
import tempfile
import shutil


class OpenFoldDatabaseManager:
    """Manages downloading and parsing OpenFold databases."""
    
    def __init__(self, data_dir: str = "openfold_data"):
        self.data_dir = Path(data_dir)
        self.data_dir.mkdir(exist_ok=True)
        
        # Database URLs (from OpenFold scripts)
        self.databases = {
            'pdb_seqres': {
                'url': 'ftp://ftp.wwpdb.org/pub/pdb/derived_data/pdb_seqres.txt',
                'local_path': self.data_dir / 'pdb_seqres.txt',
                'description': 'PDB SeqRes - Official sequences from PDB entries'
            },
            'pdb_mmcif_sample': {
                'url': 'https://files.rcsb.org/download/1olq.cif',
                'local_path': self.data_dir / '1olq.cif', 
                'description': 'Sample PDB structure file for 1olq'
            }
        }
    
    def download_database(self, db_name: str, force_redownload: bool = False) -> bool:
        """Download a specific database if not already present."""
        if db_name not in self.databases:
            print(f"Unknown database: {db_name}")
            return False
            
        db_info = self.databases[db_name]
        local_path = db_info['local_path']
        
        if local_path.exists() and not force_redownload:
            print(f"✓ {db_info['description']} already exists at {local_path}")
            return True
            
        print(f"📥 Downloading {db_info['description']}...")
        print(f"   URL: {db_info['url']}")
        print(f"   Destination: {local_path}")
        
        try:
            # Create directory if needed
            local_path.parent.mkdir(parents=True, exist_ok=True)
            
            # Download with progress
            def progress_hook(block_num, block_size, total_size):
                if total_size > 0:
                    percent = min(100, (block_num * block_size * 100) // total_size)
                    print(f"   Progress: {percent}%", end='\r')
            
            urllib.request.urlretrieve(db_info['url'], local_path, progress_hook)
            print(f"\n✅ Successfully downloaded {db_name}")
            return True
            
        except Exception as e:
            print(f"❌ Failed to download {db_name}: {e}")
            return False
    
    def download_all_databases(self, force_redownload: bool = False) -> bool:
        """Download all necessary databases."""
        print("🗄️  OpenFold Database Download Manager")
        print("=" * 50)
        
        success = True
        for db_name in self.databases.keys():
            if not self.download_database(db_name, force_redownload):
                success = False
        
        print("\n" + "=" * 50)
        if success:
            print("✅ All databases downloaded successfully!")
        else:
            print("❌ Some database downloads failed!")
        
        return success


class PDBSeqResParser:
    """Parses PDB SeqRes database to extract official sequences."""
    
    def __init__(self, seqres_file: Path):
        self.seqres_file = seqres_file
        self.sequences = {}
        self._parse_seqres()
    
    def _parse_seqres(self):
        """Parse the PDB SeqRes file format."""
        print(f"📖 Parsing PDB SeqRes database: {self.seqres_file}")
        
        if not self.seqres_file.exists():
            print(f"❌ SeqRes file not found: {self.seqres_file}")
            return
        
        current_entry = None
        current_sequence = ""
        
        try:
            with open(self.seqres_file, 'r') as f:
                for line_num, line in enumerate(f, 1):
                    line = line.strip()
                    
                    if line.startswith('>'):
                        # Save previous entry
                        if current_entry and current_sequence:
                            self.sequences[current_entry] = current_sequence
                        
                        # Parse header: >PDB_ID:CHAIN_ID or >PDB_ID_CHAIN_ID
                        header = line[1:]  # Remove '>'
                        
                        # Handle different formats
                        if ':' in header:
                            current_entry = header.replace(':', '_')
                        elif '_' in header:
                            current_entry = header
                        else:
                            # Some entries might just be PDB_ID
                            current_entry = header
                        
                        current_sequence = ""
                        
                    elif line and not line.startswith('#'):
                        # Sequence line
                        current_sequence += line
                    
                    # Progress indicator for large files
                    if line_num % 50000 == 0:
                        print(f"   Processed {line_num:,} lines...")
                
                # Don't forget the last entry
                if current_entry and current_sequence:
                    self.sequences[current_entry] = current_sequence
        
        except Exception as e:
            print(f"❌ Error parsing SeqRes file: {e}")
            return
        
        print(f"✅ Parsed {len(self.sequences):,} sequences from PDB SeqRes")
        
        # Show some examples
        if self.sequences:
            print("\n📋 Sample entries found:")
            for i, (entry_id, seq) in enumerate(list(self.sequences.items())[:5]):
                print(f"   {entry_id}: {seq[:50]}{'...' if len(seq) > 50 else ''}")
    
    def get_sequence(self, pdb_id: str, chain_id: str) -> Optional[str]:
        """Get sequence for a specific PDB ID and chain."""
        # Try different formats
        possible_keys = [
            f"{pdb_id}_{chain_id}",
            f"{pdb_id.upper()}_{chain_id.upper()}",
            f"{pdb_id.lower()}_{chain_id.lower()}",
            f"{pdb_id}:{chain_id}",
            f"{pdb_id.upper()}:{chain_id.upper()}",
            f"{pdb_id.lower()}:{chain_id.lower()}",
        ]
        
        for key in possible_keys:
            if key in self.sequences:
                return self.sequences[key]
        
        return None
    
    def search_similar_entries(self, query: str) -> Dict[str, str]:
        """Search for entries that match the query pattern."""
        query_lower = query.lower()
        matches = {}
        
        for entry_id, sequence in self.sequences.items():
            if query_lower in entry_id.lower():
                matches[entry_id] = sequence
        
        return matches


class OpenFoldSequenceTransformer:
    """Main class that attempts to reproduce OpenFold transformations."""
    
    def __init__(self, data_dir: str = "openfold_data"):
        self.db_manager = OpenFoldDatabaseManager(data_dir)
        self.seqres_parser = None
    
    def setup_databases(self, force_redownload: bool = False) -> bool:
        """Download and set up all necessary databases."""
        print("🚀 Setting up OpenFold databases...")
        
        # Download databases
        if not self.db_manager.download_all_databases(force_redownload):
            return False
        
        # Initialize parsers
        seqres_file = self.db_manager.databases['pdb_seqres']['local_path']
        self.seqres_parser = PDBSeqResParser(seqres_file)
        
        return True
    
    def analyze_1olq_transformation(self) -> Dict:
        """Analyze the specific 1olq_A transformation."""
        print("\n🔬 Analyzing 1olq_A Transformation")
        print("=" * 50)
        
        # Input data
        structure_sequence = "QTSCDQWATFTGNGYTVSNNLWGASAGSGFGCVTAVSLSGGASWHADWQWSGGQNNVKSYQNSQIAIPQKRTVNSISSMPTTASWSYSGSNIRANVAYDLFTAANPNHVTYSGDYELMIWLGKYGDIGPIGSSQGTVNVGGQSWTLYYGYNGAMQVYSFVAQTNTTNYSGDVKNFFNYLRDNKGYNAAGQYVLSYQFGTECFTGSGTLNVASWTASIN"
        expected_msa_sequence = "ETSCDQWATFTGNGYTVSNNLWGASAGSGFGCVTAVSLSGGASWHADWQWSGGQNNVKSYQNSQIAIPQKRTVNSISSMPTTASWSYSGSNIRANVAYDLFTAANPNHVTYSGDYELMIWLGKYGDIGPIGSSQGTVNVGGQSWTLYYGYNGAMQVYSFVAQTNTTNYSGDVKNFFNYLRDNKGYNAAGQYVLSYQFGTECFTGSGTLNVASWTASIN---"
        
        result = {
            'structure_sequence': structure_sequence,
            'expected_msa_sequence': expected_msa_sequence,
            'seqres_sequence': None,
            'transformation_found': False,
            'source_of_transformation': 'unknown',
            'differences': []
        }
        
        # Try to find the SeqRes sequence
        if not self.seqres_parser:
            print("❌ SeqRes parser not initialized")
            return result
        
        print("🔍 Searching for 1olq_A in PDB SeqRes database...")
        
        # Direct lookup
        seqres_sequence = self.seqres_parser.get_sequence("1olq", "A")
        
        if seqres_sequence:
            print(f"✅ Found 1olq_A in SeqRes database!")
            print(f"   Length: {len(seqres_sequence)}")
            print(f"   Sequence: {seqres_sequence[:60]}{'...' if len(seqres_sequence) > 60 else ''}")
            
            result['seqres_sequence'] = seqres_sequence
            
            # Compare sequences
            self._compare_sequences(structure_sequence, seqres_sequence, "Structure vs SeqRes", result)
            
            # Check if SeqRes matches the expected MSA sequence (minus trailing ---)
            expected_clean = expected_msa_sequence.rstrip('-')
            if seqres_sequence == expected_clean:
                result['transformation_found'] = True
                result['source_of_transformation'] = 'PDB SeqRes database'
                print("🎉 TRANSFORMATION EXPLAINED: MSA sequence comes from PDB SeqRes!")
            
        else:
            print("❌ 1olq_A not found in direct lookup")
            
            # Search for similar entries
            print("🔍 Searching for similar entries...")
            similar = self.seqres_parser.search_similar_entries("1olq")
            
            if similar:
                print(f"📋 Found {len(similar)} similar entries:")
                for entry_id, seq in list(similar.items())[:10]:  # Show first 10
                    print(f"   {entry_id}: {seq[:50]}{'...' if len(seq) > 50 else ''}")
                
                # Try the first similar entry
                first_entry = list(similar.items())[0]
                result['seqres_sequence'] = first_entry[1]
                self._compare_sequences(structure_sequence, first_entry[1], f"Structure vs {first_entry[0]}", result)
            else:
                print("❌ No similar entries found")
        
        return result
    
    def _compare_sequences(self, seq1: str, seq2: str, comparison_name: str, result: Dict):
        """Compare two sequences and record differences."""
        print(f"\n📊 {comparison_name}:")
        print(f"   Seq1 length: {len(seq1)}")
        print(f"   Seq2 length: {len(seq2)}")
        
        differences = []
        min_len = min(len(seq1), len(seq2))
        
        for i in range(min_len):
            if seq1[i] != seq2[i]:
                differences.append({
                    'position': i + 1,
                    'seq1_aa': seq1[i],
                    'seq2_aa': seq2[i]
                })
        
        if differences:
            print(f"   Differences found: {len(differences)}")
            for diff in differences[:10]:  # Show first 10
                print(f"     Position {diff['position']}: {diff['seq1_aa']} → {diff['seq2_aa']}")
            if len(differences) > 10:
                print(f"     ... and {len(differences) - 10} more")
        else:
            print("   ✅ Sequences are identical!")
        
        result['differences'].extend(differences)
    
    def transform_sequence(self, pdb_id: str, chain_id: str, input_sequence: str) -> Tuple[str, str]:
        """
        Transform a sequence using the database-driven approach.
        
        Returns:
            (transformed_sequence, source_explanation)
        """
        if not self.seqres_parser:
            return input_sequence + "---", "No databases loaded"
        
        # Try to get the SeqRes sequence
        seqres_sequence = self.seqres_parser.get_sequence(pdb_id, chain_id)
        
        if seqres_sequence:
            # Use SeqRes sequence + add trailing gaps
            return seqres_sequence + "---", f"PDB SeqRes database for {pdb_id}_{chain_id}"
        else:
            # Fall back to input sequence + add trailing gaps
            return input_sequence + "---", "Fallback: input sequence (SeqRes not found)"


def main():
    """Main function to test the transformation."""
    print("🧬 OpenFold Database-Driven Sequence Transformer")
    print("=" * 60)
    
    # Initialize transformer
    transformer = OpenFoldSequenceTransformer()
    
    # Setup databases
    print("📥 Setting up databases...")
    if not transformer.setup_databases():
        print("❌ Failed to setup databases. Exiting.")
        return 1
    
    # Analyze the 1olq transformation
    result = transformer.analyze_1olq_transformation()
    
    # Print final results
    print("\n" + "=" * 60)
    print("🎯 FINAL ANALYSIS RESULTS")
    print("=" * 60)
    
    print(f"Structure sequence: {result['structure_sequence'][:50]}...")
    print(f"Expected MSA seq:   {result['expected_msa_sequence'][:50]}...")
    
    if result['seqres_sequence']:
        print(f"SeqRes sequence:    {result['seqres_sequence'][:50]}...")
    
    print(f"\nTransformation found: {result['transformation_found']}")
    print(f"Source: {result['source_of_transformation']}")
    
    if result['differences']:
        print(f"Key differences: {len(result['differences'])}")
        for diff in result['differences'][:5]:
            print(f"  Position {diff['position']}: {diff['seq1_aa']} → {diff['seq2_aa']}")
    
    # Test the transformer
    print("\n🧪 Testing transformer on 1olq_A:")
    transformed, source = transformer.transform_sequence(
        "1olq", "A", result['structure_sequence']
    )
    
    print(f"Input:       {result['structure_sequence'][:50]}...")
    print(f"Transformed: {transformed[:50]}...")
    print(f"Expected:    {result['expected_msa_sequence'][:50]}...")
    print(f"Source:      {source}")
    
    match = transformed == result['expected_msa_sequence']
    print(f"Perfect match: {'✅ YES' if match else '❌ NO'}")
    
    if match:
        print("\n🎉 SUCCESS! The transformation is explained by PDB SeqRes database!")
    else:
        print("\n🤔 The transformation is more complex than just PDB SeqRes lookup.")
        print("   Additional factors may be involved (e.g., other databases, preprocessing steps)")
    
    return 0


if __name__ == "__main__":
    exit_code = main()
    sys.exit(exit_code)

🧬 OpenFold Database-Driven Sequence Transformer
📥 Setting up databases...
🚀 Setting up OpenFold databases...
🗄️  OpenFold Database Download Manager
📥 Downloading PDB SeqRes - Official sequences from PDB entries...
   URL: ftp://ftp.wwpdb.org/pub/pdb/derived_data/pdb_seqres.txt
   Destination: openfold_data\pdb_seqres.txt
❌ Failed to download pdb_seqres: <urlopen error [Errno 11001] getaddrinfo failed>
📥 Downloading Sample PDB structure file for 1olq...
   URL: https://files.rcsb.org/download/1olq.cif
   Destination: openfold_data\1olq.cif

✅ Successfully downloaded pdb_mmcif_sample

❌ Some database downloads failed!
❌ Failed to setup databases. Exiting.


SystemExit: 1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
