In [39]:
# ResX: Fast PDB-UniProt Residue Mapping
# Initial prototype and testing

import sqlite3
import polars as pl
from typing import Optional
import time

# Configuration
DATABASE_PATH = "C:/Users/Miche/Desktop/TopUniPDBMapper/topunipdbmapper.db"

def get_db():
    """Create database connection"""
    return sqlite3.connect(DATABASE_PATH)

# Core mapping functions
def get_mapping_window(
    pdb_id: str, 
    chain_id: str, 
    residue_number: int, 
    window: int = 0,
    insertion_code: Optional[str] = None
) -> pl.DataFrame:
    """Get residue mapping with optional window of surrounding residues."""
    with get_db() as conn:
        query = """
        SELECT 
            pdb_residue_number,
            pdb_residue_insertion_code,
            pdb_residue_name,
            uniprot_accession_id,
            uniprot_residue_number,
            uniprot_residue_name
        FROM residues 
        WHERE pdb_accession_id = ? 
        AND pdb_chain_id = ? 
        """
        params = [pdb_id.lower(), chain_id]

        if window > 0:
            query += "AND pdb_residue_number BETWEEN ? AND ? "
            params.extend([residue_number - window, residue_number + window])
        else:
            query += "AND pdb_residue_number = ? "
            params.append(residue_number)

        if insertion_code:
            query += "AND pdb_residue_insertion_code = ? "
            params.append(insertion_code)

        query += "ORDER BY pdb_residue_number, pdb_residue_insertion_code"
        
        return pl.read_database(query, conn, execute_options={"parameters": params})

def get_mapping_range(
    pdb_id: str, 
    chain_id: str, 
    start_res: int, 
    end_res: int,
    insertion_code_start: Optional[str] = None,
    insertion_code_end: Optional[str] = None
) -> pl.DataFrame:
    """Get residue mappings for a range of residues."""
    with get_db() as conn:
        query = """
        SELECT 
            pdb_residue_number,
            pdb_residue_insertion_code,
            pdb_residue_name,
            uniprot_accession_id,
            uniprot_residue_number,
            uniprot_residue_name
        FROM residues 
        WHERE pdb_accession_id = ? 
        AND pdb_chain_id = ? 
        AND pdb_residue_number BETWEEN ? AND ?
        """
        params = [pdb_id.lower(), chain_id, start_res, end_res]

        if insertion_code_start and insertion_code_end:
            query += """
            AND ((pdb_residue_number = ? AND pdb_residue_insertion_code >= ?) 
            OR (pdb_residue_number = ? AND pdb_residue_insertion_code <= ?)
            OR (pdb_residue_number > ? AND pdb_residue_number < ?))
            """
            params.extend([
                start_res, insertion_code_start, 
                end_res, insertion_code_end,
                start_res, end_res
            ])

        query += "ORDER BY pdb_residue_number, pdb_residue_insertion_code"
        
        return pl.read_database(query, conn, execute_options={"parameters": params})

def get_mapping_from_uniprot(
    uniprot_id: str,
    residue_number: int,
    window: int = 0
) -> pl.DataFrame:
    """Get all PDB mappings for a UniProt residue."""
    with get_db() as conn:
        query = """
        SELECT 
            pdb_accession_id,
            pdb_chain_id,
            pdb_residue_number,
            pdb_residue_insertion_code,
            pdb_residue_name,
            uniprot_residue_name
        FROM residues 
        WHERE uniprot_accession_id = ? 
        """
        params = [uniprot_id]

        if window > 0:
            query += "AND uniprot_residue_number BETWEEN ? AND ? "
            params.extend([residue_number - window, residue_number + window])
        else:
            query += "AND uniprot_residue_number = ? "
            params.append(residue_number)

        query += "ORDER BY pdb_accession_id, pdb_chain_id, pdb_residue_number"
        
        return pl.read_database(query, conn, execute_options={"parameters": params})

# Database inspection
print("Database tables:")
with get_db() as conn:
    tables = pl.read_database("SELECT name FROM sqlite_master WHERE type='table';", conn)
    print(tables)
    
    print("\nSample data (first 5 rows):")
    sample = pl.read_database("SELECT * FROM residues LIMIT 5;", conn)
    print(sample)

# Test cases
def run_tests():
    """Run various test cases for the mapping functions"""
    
    print("\n1. Testing PDB→UniProt window query (residue 5 with window=2):")
    result = get_mapping_window("101m", "A", 5, window=2)
    print(result)

    print("\n2. Testing PDB→UniProt range query (residues 1-5):")
    result = get_mapping_range("101m", "A", 1, 5)
    print(result)
    
    print("\n3. Testing UniProt→PDB mapping (P02185, residue 2):")
    result = get_mapping_from_uniprot("P02185", 2)
    print(result)

    print("\n4. Testing UniProt→PDB mapping with window (P02185, residue 2, window=1):")
    result = get_mapping_from_uniprot("P02185", 2, window=1)
    print(result)

# Performance testing
def test_performance(n_iterations: int = 100):
    """Run performance tests on different query types"""
    
    print("\nPerformance testing:")
    
    # Single residue query
    start = time.time()
    for _ in range(n_iterations):
        get_mapping_window("101m", "A", 5)
    single_time = (time.time() - start) * 1000 / n_iterations
    print(f"Single residue query: {single_time:.2f}ms")
    
    # Window query
    start = time.time()
    for _ in range(n_iterations):
        get_mapping_window("101m", "A", 5, window=2)
    window_time = (time.time() - start) * 1000 / n_iterations
    print(f"Window query (±2): {window_time:.2f}ms")
    
    # Range query
    start = time.time()
    for _ in range(n_iterations):
        get_mapping_range("101m", "A", 1, 5)
    range_time = (time.time() - start) * 1000 / n_iterations
    print(f"Range query (5 residues): {range_time:.2f}ms")
    
    # UniProt query
    start = time.time()
    for _ in range(n_iterations):
        get_mapping_from_uniprot("P02185", 2)
    uniprot_time = (time.time() - start) * 1000 / n_iterations
    print(f"UniProt query: {uniprot_time:.2f}ms")

if __name__ == "__main__":
    run_tests()
    test_performance()

Database tables:
shape: (2, 1)
┌───────────────┐
│ name          │
│ ---           │
│ str           │
╞═══════════════╡
│ residues      │
│ processed_ids │
└───────────────┘

Sample data (first 5 rows):
shape: (5, 17)
┌───────────┬───────────┬──────────┬───────────┬───┬───────────┬───────────┬───────────┬───────────┐
│ entity_nu ┆ entity_ty ┆ entityId ┆ segId     ┆ … ┆ pdb_resid ┆ uniprot_a ┆ uniprot_r ┆ uniprot_r │
│ m         ┆ pe        ┆ ---      ┆ ---       ┆   ┆ ue_name   ┆ ccession_ ┆ esidue_nu ┆ esidue_na │
│ ---       ┆ ---       ┆ str      ┆ str       ┆   ┆ ---       ┆ id        ┆ mber      ┆ me        │
│ i64       ┆ str       ┆          ┆           ┆   ┆ str       ┆ ---       ┆ ---       ┆ ---       │
│           ┆           ┆          ┆           ┆   ┆           ┆ str       ┆ i64       ┆ str       │
╞═══════════╪═══════════╪══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪═══════════╡
│ 1         ┆ protein   ┆ A        ┆ 101m_A_1_ ┆ … ┆ MET       ┆ P02185   