In [None]:
import numpy as np
import networkx as nx
from Bio import PDB
import os
import statistics


def calculate_distance(atom1, atom2):
    return np.linalg.norm(atom1.get_coord() - atom2.get_coord())


def is_sidechain_atom(atom):

    backbone_atoms = ['N', 'CA', 'C', 'O']
    return atom.get_name() not in backbone_atoms


def build_residue_interaction_network(pdb_file, cutoff=5):
    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure('protein', pdb_file)
    

    G = nx.Graph()
    
    
    for model in structure:
        for chain in model:
            for residue in chain:
                if PDB.is_aa(residue) and residue.get_id()[0] == " ":
                    residue_id = f"{residue.get_parent().get_id()}_{residue.get_id()[1]}_{residue.resname}"
                    if residue_id not in G:
                        G.add_node(residue_id)
    
  
    contacts = {}
    
    
    for model in structure:
        for chain in model:
            residues = [residue for residue in chain if PDB.is_aa(residue) and residue.get_id()[0] == " "]
            for i in range(len(residues)):
                residue1 = residues[i]
                residue1_id = f"{residue1.get_parent().get_id()}_{residue1.get_id()[1]}_{residue1.resname}"
                
                
                sidechain_atoms1 = [atom for atom in residue1 if is_sidechain_atom(atom)]
                if not sidechain_atoms1:  
                    continue
                    
                for j in range(i+1, len(residues)):
                    residue2 = residues[j]
                    residue2_id = f"{residue2.get_parent().get_id()}_{residue2.get_id()[1]}_{residue2.resname}"
                    
                    
                    sidechain_atoms2 = [atom for atom in residue2 if is_sidechain_atom(atom)]
                    if not sidechain_atoms2:  
                        continue
                    
                   
                    contact_count = 0
                    for atom1 in sidechain_atoms1:
                        for atom2 in sidechain_atoms2:
                            distance = calculate_distance(atom1, atom2)
                            if distance < cutoff:  
                                contact_count += 1
                                        
                   
                    if contact_count > 0:
                        sorted_edge = tuple(sorted([residue1_id, residue2_id]))
                        if sorted_edge not in contacts:
                            contacts[sorted_edge] = contact_count
                        else:
                            contacts[sorted_edge] += contact_count
                                                        
                       
                        if not G.has_edge(residue1_id, residue2_id):
                            G.add_edge(residue1_id, residue2_id, weight=0)
                        G[residue1_id][residue2_id]['weight'] += contact_count
    
   
    max_contact = max(contacts.values()) if contacts else 1

   
    for edge, contact_count in contacts.items():
        residue1_id, residue2_id = edge
        weight = G[residue1_id][residue2_id]['weight']
        normalized_weight = weight / max_contact
        G[residue1_id][residue2_id]['weight'] = normalized_weight

    return G


def compute_edge_variances(psns):
    all_edges = set()
    for psn in psns:
        all_edges.update(psn.edges())
    
  
    edge_weights = {edge: [] for edge in all_edges}
    
    for edge in all_edges:
        for psn in psns:
            if psn.has_edge(*edge):
                edge_weights[edge].append(psn.edges[edge]['weight'])
            else:
                edge_weights[edge].append(0) 
    
  
    edge_variances = {}
    for edge in all_edges:
        weights = edge_weights[edge]
        var = statistics.pvariance(weights)
        edge_variances[edge] = var
    
    return edge_variances


pdb_dir = "frames_wt1"

pdb_files = [os.path.join(pdb_dir, f) for f in os.listdir(pdb_dir) if f.endswith('.pdb')]


group_psns = [build_residue_interaction_network(pdb_file) for pdb_file in pdb_files]


group_edge_variances = compute_edge_variances(group_psns)
print("edge_variances:", group_edge_variances)  
print("edge_num:", len(group_edge_variances))


for i, psn in enumerate(group_psns):
    print(f"network {i+1}: nodes={psn.number_of_nodes()}, edges={psn.number_of_edges()}")