In [None]:
import numpy as np
import torch as pt
from glob import glob
from tqdm import tqdm
import random
import h5py
import os

from utils.feature_extraction import extract_dynamic_features, encode_sequence, mean_coordinates, extract_topology
from utils.PDB_processing import split_nmr_pdb, make_pdb, read_pdb, get_sasa_unbound, fill_nan_with_neighbors, std_elements

from utils.data_handler import collate_batch_features
from utils.model import Model
from utils.configs import config_model, config_data, config_runtime
from utils.for_visualization import p_to_bfactor
from utils.scoring import bc_scoring

In [None]:
device = pt.device("cuda" if pt.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
pdb_files = glob('input_structures/*.pdb')
model_path = 'model_307.pt'

In [None]:
#ids = np.genfromtxt('/home/omokhtar/Desktop/revision/IDRBind_ambiguous_labels/ids_comp.txt', dtype=np.dtype('U'))

In [None]:
results = {}

for each_pdb in tqdm(pdb_files):
    chain_key = each_pdb.split(r"/")[-1].split('.')[0]
    #if chain_key not in ids: continue
    #print (chain_key)
    pdb_chains = split_nmr_pdb(each_pdb)
    models = list(pdb_chains.values())[0] # Single chain PDB structures
    pdb_file = make_pdb(models)
    aa_map, seq, atom_type, atoms_xyz = read_pdb(pdb_file)
    seq=''.join(seq)

    # Get features
    mean_xyz = mean_coordinates(atoms_xyz)
    R, D= extract_topology(mean_xyz)
    # Indices of nearest neighbors
    knn = min(64, D.shape[0])
    D_nn, nn_topk = pt.topk(pt.tensor(D), knn, dim=1, largest=False)
    R_nn = pt.gather(pt.tensor(R), 1, nn_topk.unsqueeze(2).repeat(1, 1, R.shape[2])).to(pt.float32)
    motion_v_nn, motion_s_nn, rmsf, de, CP_nn = extract_dynamic_features(atoms_xyz, nn_topk.numpy())
    sasa_dic_unbound, labeled_seqs2 = get_sasa_unbound(each_pdb)
    assert list(labeled_seqs2.values())[0]==seq
    rsa = np.array(list(sasa_dic_unbound.values())[0])
    rsa = fill_nan_with_neighbors(np.array(rsa))
    rsa = np.array(rsa)[np.array(aa_map) - 1]
    onehot_seq = pt.tensor(encode_sequence(atom_type, std_elements))

    motion_v_nn_noisy = motion_v_nn + 0.0 * np.random.rand(*motion_v_nn.shape)
    rmsf_noisy = rmsf + 0.0 * np.random.rand(*rmsf.shape)
    CP_nn_noisy = CP_nn + 0.0 * np.random.rand(*CP_nn.shape)
    motion_v_nn = motion_v_nn_noisy.astype(motion_v_nn.dtype)
    rmsf = rmsf_noisy.astype(rmsf.dtype)
    CP_nn_no = CP_nn_noisy.astype(CP_nn.dtype)

    
    rmsf, de, rsa = pt.tensor(rmsf).unsqueeze(1), pt.tensor(de).unsqueeze(1), pt.tensor(rsa).to(pt.float64).unsqueeze(1)
    D_nn, nn_topk, R_nn, motion_v_nn, motion_s_nn, CP_nn = D_nn.to(pt.float32).unsqueeze(2), nn_topk.to(pt.int64), R_nn.to(pt.float32), pt.tensor(motion_v_nn).to(pt.float32), pt.tensor(motion_s_nn).to(pt.float32).unsqueeze(2), pt.tensor(CP_nn).to(pt.float32).unsqueeze(2)
    
    features = collate_batch_features([[onehot_seq, rmsf, de, rsa, nn_topk, D_nn, R_nn, motion_v_nn, motion_s_nn, CP_nn, pt.tensor(aa_map)]])
    onehot_seq, rmsf, de, rsa, nn_topk, D_nn, R_nn, motion_v_nn, motion_s_nn, CP_nn, aa_map = features
    
    # Load and Apply
    model = Model(config_model)
    model.load_state_dict(pt.load(model_path, map_location=device, weights_only=True))
    model = model.eval().to(device)
    with pt.no_grad():
        z,_,_ = model(onehot_seq.to(device), rmsf.to(device), de.to(device), rsa.to(device), nn_topk.to(device), D_nn.to(device), R_nn.to(device), motion_v_nn.to(device), motion_s_nn.to(device), CP_nn.to(device), aa_map.to(device))
        results[each_pdb] = [pt.sigmoid(z).detach(), seq]

In [5]:
## To visualize
#p_to_bfactor(results, device)

In [6]:
## To evaluate
#dataset = h5py.File('db_aflow_all_with_dists.h5', 'r')
#for i in results:
#    pred = results[i][0]
#    i = i.split('/')[-1][:-4]
#    print (i)
#    y = pt.tensor(dataset['data']['labels'][i]['label'])
#    print (bc_scoring(y.to(device), pred)[6])

In [7]:
## DBSCAN clustering
from sklearn.cluster import DBSCAN
from mpl_toolkits.mplot3d import Axes3D
import py3Dmol
from matplotlib.colors import to_hex

In [None]:
def create_enhanced_protein_visualization(results):
    for each in results:
        aa_map = np.array(results[each][2])
        coords = np.array(results[each][1])
        seq = np.array([i for i in results[each][3]])
        
        # Calculate average coordinates per residue
        x_avg = np.bincount(aa_map-1, weights=coords[:, 0]) / np.bincount(aa_map-1)
        y_avg = np.bincount(aa_map-1, weights=coords[:, 1]) / np.bincount(aa_map-1)
        z_avg = np.bincount(aa_map-1, weights=coords[:, 2]) / np.bincount(aa_map-1)
        xyz = np.column_stack([x_avg, y_avg, z_avg])
        prob = np.array([i[0].cpu() for i in results[each][0]])
        
        # Separate high and low confidence residues
        high_conf = prob > 0.5
        low_conf = ~high_conf
        
        coords_high = xyz[high_conf]
        coords_low = xyz[low_conf]
        seq_high = seq[high_conf]
        seq_low = seq[low_conf]
        
        print(f"Processing {each}: {len(coords_high)} high-confidence, {len(coords_low)} low-confidence points")
        
        # Create viewer
        view = py3Dmol.view(width=800, height=600)
        
        # Add backbone trace for high-confidence residues
        if len(coords_high) > 1:
            # Sort by residue index for proper connectivity
            high_indices = np.where(high_conf)[0]
            sorted_indices = np.argsort(high_indices)
            sorted_coords = coords_high[sorted_indices]
            
            # Add backbone as a smooth line
            backbone_points = []
            for i, (x, y, z) in enumerate(sorted_coords):
                backbone_points.append({'x': float(x), 'y': float(y), 'z': float(z)})
            
            if len(backbone_points) > 1:
                view.addLine({
                    'start': backbone_points[0],
                    'end': backbone_points[-1],
                    'color': 'lightblue',
                    'radius': 0.2
                })
                
                # Connect consecutive residues
                for i in range(len(backbone_points) - 1):
                    view.addCylinder({
                        'start': backbone_points[i],
                        'end': backbone_points[i + 1],
                        'color': 'lightblue',
                        'radius': 0.1,
                        'opacity': 0.6
                    })
        
        # Add low-confidence residues as smaller, transparent spheres
        for i, (x, y, z) in enumerate(coords_low):
            residue_name = seq_low[i] if i < len(seq_low) else 'X'
            view.addSphere({
                'center': {'x': float(x), 'y': float(y), 'z': float(z)},
                'radius': 0.8,
                'color': 'lightgray',
                'opacity': 0.4
            })
            
        
        # Process high-confidence residues
        if len(coords_high) >= 3:
            # Cluster high-confidence residues
            clustering = DBSCAN(eps=10.0, min_samples=3).fit(coords_high)
            labels = clustering.labels_

            
            # Generate colors
            unique_labels = set(labels)
            n_clusters = len([l for l in unique_labels if l != -1])
            
            if n_clusters > 0:
                colors = plt.cm.Set1(np.linspace(0, 1, max(n_clusters, 8)))
                color_hex = [to_hex(c) for c in colors]
            else:
                color_hex = ['#ff0000']
            
            color_map = {}
            cluster_labels = [l for l in unique_labels if l != -1]
            # Compute average predicted probability per cluster
            cluster_probs = {}
            for cluster_label in cluster_labels:
                cluster_indices = np.where(labels == cluster_label)[0]
                avg_prob = np.mean(prob[high_conf][cluster_indices])
                cluster_probs[cluster_label] = avg_prob
            for label, avg in cluster_probs.items():
                print(f"  Cluster {label}: average predicted probability = {avg:.2f}")
            
            for i, label in enumerate(cluster_labels):
                color_map[label] = color_hex[i % len(color_hex)]
            color_map[-1] = '#8B0000'  # Dark red for noise
            
            # Add high-confidence residues
            for i, (x, y, z) in enumerate(coords_high):
                label = labels[i]
                color = color_map.get(label, '#ff0000')
                residue_name = seq_high[i] if i < len(seq_high) else 'X'
                
                view.addSphere({
                    'center': {'x': float(x), 'y': float(y), 'z': float(z)},
                    'radius': 1.2,
                    'color': color,
                    'opacity': 0.8
                })
                
        
        elif len(coords_high) > 0:
            # Handle case with few high-confidence points
            for i, (x, y, z) in enumerate(coords_high):
                residue_name = seq_high[i] if i < len(seq_high) else 'X'
                view.addSphere({
                    'center': {'x': float(x), 'y': float(y), 'z': float(z)},
                    'radius': 1.2,
                    'color': 'red',
                    'opacity': 0.8
                })
                
        
        view.setBackgroundColor('white')
        view.setStyle({'cartoon': {'color': 'spectrum'}})  # If you have actual structure

        view.zoomTo()
        #view.spin(True)  # Add rotation animation
        view.show()
        
        # Print summary
        print(f"Visualization for {each}:")
        print(f"  Total residues: {len(xyz)}")
        print(f"  High confidence: {len(coords_high)}")
        print(f"  Low confidence: {len(coords_low)}")
        if len(coords_high) >= 3:
            n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
            print(f"  Clusters found: {n_clusters}")
        print("---")


create_enhanced_protein_visualization(results)