In [57]:
import oddt
import numpy as np
from oddt.toolkits import ob
from oddt.fingerprints import SPLIF
from typing import Tuple, List, Dict
import os
import random
from tqdm import tqdm

def splif_to_bitvec_and_weights(splif_result: np.ndarray, size: int = 4096) -> Tuple[np.ndarray, np.ndarray]:
    bitvec = np.zeros(size, dtype=bool)
    weights = np.zeros(size, dtype=int)
    
    unique_hashes, counts = np.unique(splif_result['hash'], return_counts=True)
    bitvec[unique_hashes] = True
    weights[unique_hashes] = counts
    
    return bitvec, weights

def weighted_tanimoto_similarity(splif1: np.ndarray, splif2: np.ndarray, size: int = 4096) -> float:
    bitvec1, weights1 = splif_to_bitvec_and_weights(splif1, size)
    bitvec2, weights2 = splif_to_bitvec_and_weights(splif2, size)
    
    common_bits = bitvec1 & bitvec2
    weighted_intersection = np.sum(np.minimum(weights1[common_bits], weights2[common_bits]))
    weighted_union = np.sum(np.maximum(weights1, weights2))
    
    return weighted_intersection / weighted_union if weighted_union > 0 else 0.0

def load_molecule(file_path: str, file_type: str) -> oddt.toolkit.Molecule:
    molecule = next(ob.readfile(file_type, file_path))
    if file_type == 'pdb':
        molecule.protein = True
    return molecule

def calculate_splif(ligand: oddt.toolkit.Molecule, protein: oddt.toolkit.Molecule) -> np.ndarray:
    return SPLIF(ligand, protein)

def process_complex(pdb_path: str, sdf_path: str) -> np.ndarray:
    protein = load_molecule(pdb_path, 'pdb')
    ligand = load_molecule(sdf_path, 'sdf')
    return calculate_splif(ligand, protein)

def get_pdb_ids(base_path: str, n: int = None) -> List[str]:
    all_pdb_ids = [d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))]
    if n is None or n >= len(all_pdb_ids):
        return all_pdb_ids
    return random.sample(all_pdb_ids, n)

def calculate_similarity_matrix(base_path: str, pdb_ids: List[str]) -> np.ndarray:
    n = len(pdb_ids)
    similarity_matrix = np.zeros((n, n))
    splif_fingerprints = {}

    for i, pdb_id in enumerate(tqdm(pdb_ids, desc="Processing complexes")):
        pdb_path = os.path.join(base_path, pdb_id, f"{pdb_id}_protein.pdb")
        sdf_path = os.path.join(base_path, pdb_id, f"{pdb_id}_ligand.sdf")
        splif_fingerprints[pdb_id] = process_complex(pdb_path, sdf_path)

    for i in tqdm(range(n), desc="Calculating similarities"):
        for j in range(i, n):
            similarity = weighted_tanimoto_similarity(
                splif_fingerprints[pdb_ids[i]], 
                splif_fingerprints[pdb_ids[j]]
            )
            similarity_matrix[i, j] = similarity_matrix[j, i] = similarity

    return similarity_matrix

def main():
    base_path = '/mnt/data/pdbbind2020-PL'
    n_complexes = 50  
    
    pdb_ids = get_pdb_ids(base_path, n_complexes)
    
    print(f"Number of complexes selected: {len(pdb_ids)}")
    
    similarity_matrix = calculate_similarity_matrix(base_path, pdb_ids)
    
    print("Similarity matrix shape:", similarity_matrix.shape)
    print("Sample of similarity matrix:")
    print(similarity_matrix[:5, :5])

    # results
    np.save(f"splif_similarity_matrix_{len(pdb_ids)}.npy", similarity_matrix)
    with open(f"pdb_ids_{len(pdb_ids)}.txt", "w") as f:
        for pdb_id in pdb_ids:
            f.write(f"{pdb_id}\n")

    print(f"Results saved to 'splif_similarity_matrix_{len(pdb_ids)}.npy' and 'pdb_ids_{len(pdb_ids)}.txt'")

if __name__ == "__main__":
    main()

Number of complexes selected: 50


Processing complexes: 100%|██████████| 50/50 [00:37<00:00,  1.33it/s]
Calculating similarities: 100%|██████████| 50/50 [00:00<00:00, 1561.69it/s]

Similarity matrix shape: (50, 50)
Sample of similarity matrix:
[[1.         0.01486989 0.00342466 0.02312139 0.01038961]
 [0.01486989 1.         0.0130719  0.03055556 0.01754386]
 [0.00342466 0.0130719  1.         0.0801105  0.04668305]
 [0.02312139 0.03055556 0.0801105  1.         0.0993228 ]
 [0.01038961 0.01754386 0.04668305 0.0993228  1.        ]]
Results saved to 'splif_similarity_matrix_50.npy' and 'pdb_ids_50.txt'



