In [1]:
import pandas as pd
import numpy as np 
import os 
import pathlib
import json

# Load Data

In [4]:
annot_df = pd.read_parquet("../data/annotation_table_with_uniprot.parquet")
annot_df.shape

In [24]:
annot_df['uniprot_ids_str'].nunique()

9430

In [25]:
annot_df['pdb_chain_key']

0         8grd_B
1         2grj_A
2         2grj_B
3         2grj_C
4         2grj_D
           ...  
114532    3lpp_B
114533    3lpp_D
114534    5lps_A
114535    5lps_A
114536    5lpv_A
Name: pdb_chain_key, Length: 114537, dtype: object

In [16]:
pdb_uniprot_mapping = pd.read_csv("../data/pdb_chain_uniprot.tsv", sep="\t", comment='#')
pdb_uniprot_mapping.head()

  pdb_uniprot_mapping = pd.read_csv("../data/pdb_chain_uniprot.tsv", sep="\t", comment='#')


Unnamed: 0,PDB,CHAIN,SP_PRIMARY,RES_BEG,RES_END,PDB_BEG,PDB_END,SP_BEG,SP_END
0,101m,A,P02185,1,154,0,153.0,1,154
1,102l,A,P00720,1,40,1,40.0,1,40
2,102l,A,P00720,42,165,41,,41,164
3,102m,A,P02185,1,154,0,153.0,1,154
4,103l,A,P00720,1,40,1,,1,40


# Build Bipartite Graph (UniProt - PDB mapping)
Using NetworkKit to create a bipartite graph where:
- One set of nodes: UniProt IDs (from annot_df)
- Other set of nodes: PDB IDs (with chain info)
- Edges: Exist between UniProt ID and PDB_chain if they map to each other

In [26]:
import networkit as nk

# Get unique UniProt-PDB_chain edges from annotation dataframe
edges_df = annot_df[['uniprot_ids_str', 'pdb_chain_key']].dropna().drop_duplicates()
edges_df.columns = ['uniprot', 'pdb_chain']

print(f"Number of unique edges: {len(edges_df)}")
print(f"Unique UniProt IDs: {edges_df['uniprot'].nunique()}")
print(f"Unique PDB chains: {edges_df['pdb_chain'].nunique()}")

Number of unique edges: 67487
Unique UniProt IDs: 9430
Unique PDB chains: 67487


In [27]:
# Create node ID mappings
# NetworkKit requires integer node IDs, so we create mappings
unique_uniprots = sorted(edges_df['uniprot'].unique())
unique_pdb_chains = sorted(edges_df['pdb_chain'].unique())

# UniProt IDs get IDs from 0 to n_uniprots-1
# PDB chains get IDs from n_uniprots to n_uniprots + n_pdb_chains - 1
uniprot_to_id = {uniprot: i for i, uniprot in enumerate(unique_uniprots)}
pdb_chain_to_id = {pdb_chain: i + len(unique_uniprots) for i, pdb_chain in enumerate(unique_pdb_chains)}

# Reverse mappings for lookup
id_to_uniprot = {v: k for k, v in uniprot_to_id.items()}
id_to_pdb_chain = {v: k for k, v in pdb_chain_to_id.items()}

# Combined mapping
id_to_node = {**id_to_uniprot, **id_to_pdb_chain}
node_to_id = {**uniprot_to_id, **pdb_chain_to_id}

print(f"UniProt node IDs: 0 to {len(unique_uniprots)-1}")
print(f"PDB chain node IDs: {len(unique_uniprots)} to {len(unique_uniprots) + len(unique_pdb_chains) - 1}")
print(f"Total nodes: {len(unique_uniprots) + len(unique_pdb_chains)}")

UniProt node IDs: 0 to 9429
PDB chain node IDs: 9430 to 76916
Total nodes: 76917


In [28]:
# Build the bipartite graph using NetworkKit
n_nodes = len(unique_uniprots) + len(unique_pdb_chains)
G = nk.Graph(n_nodes, weighted=False, directed=False)

# Add edges - vectorized approach for speed
for _, row in edges_df.iterrows():
    uniprot_id = uniprot_to_id[row['uniprot']]
    pdb_chain_id = pdb_chain_to_id[row['pdb_chain']]
    G.addEdge(uniprot_id, pdb_chain_id)

print(f"Graph created successfully!")
print(f"Number of nodes: {G.numberOfNodes()}")
print(f"Number of edges: {G.numberOfEdges()}")

Graph created successfully!
Number of nodes: 76917
Number of edges: 67487


In [29]:
# Store node type information for bipartite verification
node_types = {}
for node_id in range(len(unique_uniprots)):
    node_types[node_id] = 'uniprot'
for node_id in range(len(unique_uniprots), n_nodes):
    node_types[node_id] = 'pdb_chain'

# Verify bipartite structure (edges should only connect different node types)
is_bipartite = True
for u, v in G.iterEdges():
    if node_types[u] == node_types[v]:
        is_bipartite = False
        break

print(f"Graph is bipartite: {is_bipartite}")

# Basic graph statistics
print(f"\n--- Graph Statistics ---")
print(f"Number of UniProt nodes: {len(unique_uniprots)}")
print(f"Number of PDB chain nodes: {len(unique_pdb_chains)}")
print(f"Total edges: {G.numberOfEdges()}")
print(f"Average degree: {2 * G.numberOfEdges() / G.numberOfNodes():.2f}")

Graph is bipartite: True

--- Graph Statistics ---
Number of UniProt nodes: 9430
Number of PDB chain nodes: 67487
Total edges: 67487
Average degree: 1.75


In [30]:
# Helper functions to work with the graph
def get_pdb_chains_for_uniprot(uniprot_id):
    """Get all PDB chains mapped to a UniProt ID"""
    if uniprot_id not in uniprot_to_id:
        return []
    node_id = uniprot_to_id[uniprot_id]
    neighbors = list(G.iterNeighbors(node_id))
    return [id_to_pdb_chain[n] for n in neighbors]

def get_uniprots_for_pdb_chain(pdb_chain):
    """Get all UniProt IDs mapped to a PDB chain"""
    if pdb_chain not in pdb_chain_to_id:
        return []
    node_id = pdb_chain_to_id[pdb_chain]
    neighbors = list(G.iterNeighbors(node_id))
    return [id_to_uniprot[n] for n in neighbors]

# Example usage
example_uniprot = list(unique_uniprots)[0]
print(f"PDB chains mapped to {example_uniprot}: {get_pdb_chains_for_uniprot(example_uniprot)[:5]}...")

example_pdb = list(unique_pdb_chains)[0]
print(f"UniProt IDs mapped to {example_pdb}: {get_uniprots_for_pdb_chain(example_pdb)}")

PDB chains mapped to : ['6grg_B', '6grg_E', '5gr8_C', '5gr8_B', '3cjq_F']...
UniProt IDs mapped to 10gs_B: ['P09211']


In [32]:
# Validate mapping consistency with official PDB-UniProt TSV file
# Create comparable format: pdb_chain from TSV uses lowercase PDB + chain
pdb_uniprot_mapping['pdb_chain'] = pdb_uniprot_mapping['PDB'].str.lower() + '_' + pdb_uniprot_mapping['CHAIN'].astype(str)

# Get unique mappings from official TSV
official_mapping = pdb_uniprot_mapping[['pdb_chain', 'SP_PRIMARY']].drop_duplicates()
official_mapping.columns = ['pdb_chain', 'uniprot']

# Convert to sets for comparison
our_edges = set(zip(edges_df['pdb_chain'], edges_df['uniprot']))
official_edges = set(zip(official_mapping['pdb_chain'], official_mapping['uniprot']))

# Check consistency
consistent_edges = our_edges & official_edges
our_only = our_edges - official_edges
official_only_for_our_pdbs = {(p, u) for p, u in official_edges if p in edges_df['pdb_chain'].values} - our_edges

print(f"=== Mapping Consistency Check ===")
print(f"Total edges in our mapping (annot_df): {len(our_edges)}")
print(f"Total edges in official TSV: {len(official_edges)}")
print(f"\nConsistent edges (in both): {len(consistent_edges)}")
print(f"Edges only in our mapping: {len(our_only)}")
print(f"Official edges for our PDB chains not in our mapping: {len(official_only_for_our_pdbs)}")
print(f"\nConsistency rate: {len(consistent_edges) / len(our_edges) * 100:.2f}%")

=== Mapping Consistency Check ===
Total edges in our mapping (annot_df): 67487
Total edges in official TSV: 909287

Consistent edges (in both): 61154
Edges only in our mapping: 6333
Official edges for our PDB chains not in our mapping: 1506

Consistency rate: 90.62%


In [33]:
# Investigate the inconsistent edges
if len(our_only) > 0:
    print("Sample edges in our mapping but NOT in official TSV:")
    sample_our_only = list(our_only)[:5]
    for pdb_chain, uniprot in sample_our_only:
        # Check what UniProt IDs the official TSV has for this PDB chain
        official_uniprots = official_mapping[official_mapping['pdb_chain'] == pdb_chain]['uniprot'].tolist()
        print(f"  {pdb_chain} -> {uniprot} (Official has: {official_uniprots if official_uniprots else 'NOT FOUND'})")

print("\n" + "="*50)

if len(official_only_for_our_pdbs) > 0:
    print("\nSample edges in official TSV but NOT in our mapping (for PDBs we have):")
    sample_official_only = list(official_only_for_our_pdbs)[:5]
    for pdb_chain, uniprot in sample_official_only:
        # Check what UniProt IDs our mapping has for this PDB chain
        our_uniprots = edges_df[edges_df['pdb_chain'] == pdb_chain]['uniprot'].tolist()
        print(f"  {pdb_chain} -> {uniprot} (Our mapping has: {our_uniprots if our_uniprots else 'NOT FOUND'})")

Sample edges in our mapping but NOT in official TSV:
  6s0q_A -> P0ABE7, P29274 (Official has: ['P29274', 'P0ABE7'])
  3b9f_B ->  (Official has: NOT FOUND)
  6agg_A ->  (Official has: NOT FOUND)
  1wz1_B ->  (Official has: NOT FOUND)
  6ye3_B ->  (Official has: NOT FOUND)


Sample edges in official TSV but NOT in our mapping (for PDBs we have):
  2wkp_A -> O49003 (Our mapping has: ['O49003, P63000'])
  7sus_A -> P00268 (Our mapping has: ['P35414, P00268'])
  5osc_B -> P26048 (Our mapping has: ['P62812, P26048, Q7NDN8'])
  7opr_A -> P51159 (Our mapping has: ['Q9HCH5, P51159'])
  7zi0_B -> Q99835 (Our mapping has: ['Q99835, P0ABE7'])


# Correct UniProt Mapping Based on Official PDB-UniProt TSV
Update the `uniprot_ids_str` column in `annot_df` using the official mapping from the PDB library.

In [None]:
# Create a lookup dictionary from official mapping: pdb_chain -> list of UniProt IDs
official_pdb_to_uniprots = pdb_uniprot_mapping.groupby('pdb_chain')['SP_PRIMARY'].apply(lambda x: ', '.join(sorted(set(x)))).to_dict()

print(f"Official mapping covers {len(official_pdb_to_uniprots)} PDB chains")
print(f"Sample mappings:")
for i, (pdb, uniprots) in enumerate(list(official_pdb_to_uniprots.items())[:3]):
    print(f"  {pdb} -> {uniprots}")

In [None]:
# Create a copy of annot_df and update UniProt IDs based on official mapping
annot_df_corrected = annot_df.copy()

# Update uniprot_ids_str based on pdb_chain_key using official mapping
annot_df_corrected['uniprot_ids_str_original'] = annot_df_corrected['uniprot_ids_str']
annot_df_corrected['uniprot_ids_str'] = annot_df_corrected['pdb_chain_key'].map(official_pdb_to_uniprots)

# Check how many were updated
n_total = len(annot_df_corrected)
n_mapped = annot_df_corrected['uniprot_ids_str'].notna().sum()
n_changed = (annot_df_corrected['uniprot_ids_str'] != annot_df_corrected['uniprot_ids_str_original']).sum()

print(f"Total rows: {n_total}")
print(f"Rows with official UniProt mapping: {n_mapped} ({n_mapped/n_total*100:.2f}%)")
print(f"Rows where UniProt ID changed: {n_changed} ({n_changed/n_total*100:.2f}%)")
print(f"Rows with no official mapping (NaN): {n_total - n_mapped}")

In [None]:
# Show some examples of changes
changed_mask = (annot_df_corrected['uniprot_ids_str'] != annot_df_corrected['uniprot_ids_str_original']) & \
               annot_df_corrected['uniprot_ids_str'].notna() & \
               annot_df_corrected['uniprot_ids_str_original'].notna()

print("Sample of changed UniProt mappings:")
sample_changes = annot_df_corrected[changed_mask][['pdb_chain_key', 'uniprot_ids_str_original', 'uniprot_ids_str']].head(10)
for _, row in sample_changes.iterrows():
    print(f"  {row['pdb_chain_key']}: '{row['uniprot_ids_str_original']}' -> '{row['uniprot_ids_str']}'")

In [None]:
# Drop the original column and save the corrected dataframe
annot_df_corrected = annot_df_corrected.drop(columns=['uniprot_ids_str_original'])

# Save to a new parquet file
output_path = "../data/annotation_table_with_uniprot_corrected.parquet"
annot_df_corrected.to_parquet(output_path)

print(f"Saved corrected annotation table to: {output_path}")
print(f"Shape: {annot_df_corrected.shape}")
print(f"Unique UniProt IDs (corrected): {annot_df_corrected['uniprot_ids_str'].nunique()}")