In [None]:
import pickle
import numpy as np
from tqdm.auto import tqdm
from rdkit import Chem
from rdkit.Chem import Descriptors
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import pandas as pd
from rdkit.Chem.rdmolops import RemoveHs
from rdkit.Chem import rdMolAlign as MA
import py3Dmol
from copy import deepcopy

# ===================================================================
# 1. Core computation and visualization helper functions
# ===================================================================

def set_rdmol_positions(rdkit_mol, pos):
    mol = Chem.Mol(rdkit_mol); conformer = mol.GetConformer(0)
    for i in range(pos.shape[0]): conformer.SetAtomPosition(i, pos[i].tolist())
    return mol

def get_best_rmsd(probe, ref):
    probe = RemoveHs(probe); ref = RemoveHs(ref)
    return MA.GetBestRMS(probe, ref)

def calculate_molecule_metrics(ref_pos, gen_pos, rdmol, delta=1.25):
    num_nodes, num_refs, num_samples = rdmol.GetNumAtoms(), ref_pos.shape[0], gen_pos.shape[0]
    if num_refs == 0 or num_samples == 0: return 0.0, float('inf'), None
    rmsd_matrix = np.zeros([num_refs, num_samples])
    for i in range(num_refs):
        ref_mol = set_rdmol_positions(rdmol, ref_pos[i])
        for j in range(num_samples):
            gen_mol = set_rdmol_positions(rdmol, gen_pos[j])
            rmsd_matrix[i, j] = get_best_rmsd(gen_mol, ref_mol)
    min_rmsd_for_each_ref = rmsd_matrix.min(axis=1)
    
    mat_r = min_rmsd_for_each_ref.mean()
    cov_r = np.sum(min_rmsd_for_each_ref <= delta) / num_refs
    
    # Find the overall best match (reference conformation index, generated conformation index)
    best_ref_idx, best_gen_idx = np.unravel_index(rmsd_matrix.argmin(), rmsd_matrix.shape)
    
    return cov_r, mat_r, (best_ref_idx, best_gen_idx)

def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
    def sigmoid(x): return 1 / (np.exp(-x) + 1)
    if beta_schedule == "sigmoid":
        betas = np.linspace(-6, 6, num_diffusion_timesteps); return sigmoid(betas) * (beta_end - beta_start) + beta_start
    return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)

def visualize_mol(mol, size=(400, 400)):
    mblock = Chem.MolToMolBlock(mol)
    viewer = py3Dmol.view(width=size[0], height=size[1])
    viewer.addModel(mblock, 'mol')
    viewer.setStyle({'stick':{}, 'sphere':{'radius':0.35}})
    viewer.zoomTo()
    return viewer

# ===================================================================
# 2. Data loading and automatic case study selection
# ===================================================================
print("--- 1. Loading data and performing precomputations ---")
your_results_path = 'checkpoints/qm9_condition/samples/sample_all.pkl'
baseline_results_path = 'checkpoints/subgdiff_baseline/samples_all.pkl'
test_set_path = 'data/GEOM/QM9/test_data_1k.pkl'

try:
    with open(your_results_path, 'rb') as f: your_results = {d.smiles: d for d in pickle.load(f)}
    with open(baseline_results_path, 'rb') as f: baseline_results = {d.smiles: d for d in pickle.load(f)}
    with open(test_set_path, 'rb') as f:
        raw_test_list = pickle.load(f)
        test_set_by_smiles = {}
        for data in raw_test_list:
            if data.smiles not in test_set_by_smiles:
                test_set_by_smiles[data.smiles] = {'pos_ref': [], 'rdmol': data.rdmol}
            test_set_by_smiles[data.smiles]['pos_ref'].append(data.pos)
        for smiles in test_set_by_smiles:
            test_set_by_smiles[smiles]['pos_ref'] = torch.stack(test_set_by_smiles[smiles]['pos_ref'])
    print("Data loaded successfully!")

    betas = get_beta_schedule("sigmoid", beta_start=1.e-7, beta_end=2.e-3, num_diffusion_timesteps=5000)
    alphas = torch.from_numpy((1. - betas).cumprod(axis=0)).float()
    alpha_t = alphas[-2]
    
    analysis_data = []
    comparable_smiles = set(your_results.keys()).intersection(set(baseline_results.keys()))

    for smiles in tqdm(comparable_smiles, desc="Analyzing molecules"):
        ref_data = test_set_by_smiles[smiles]
        pos_0_ensemble = ref_data['pos_ref']
        max_displacements = [torch.norm((p + torch.randn_like(p) * (1.0 - alpha_t).sqrt() / alpha_t.sqrt()) - p, dim=1).max().item() for p in pos_0_ensemble]
        gnd = np.mean(max_displacements)
        
        your_cov_r, _, _ = calculate_molecule_metrics(pos_0_ensemble.numpy(), your_results[smiles].pos_gen.reshape(-1, ref_data['rdmol'].GetNumAtoms(), 3).numpy(), ref_data['rdmol'])
        baseline_cov_r, _, _ = calculate_molecule_metrics(pos_0_ensemble.numpy(), baseline_results[smiles].pos_gen.reshape(-1, ref_data['rdmol'].GetNumAtoms(), 3).numpy(), ref_data['rdmol'])
        
        analysis_data.append({"smiles": smiles, "gnd": gnd, "improvement": your_cov_r - baseline_cov_r})
    
    # 1. Based on previous findings, set the crossover threshold
    gnd_crossover_threshold = 36.54
    
    # 2. Select all hard cases with GND above this threshold
    df = pd.DataFrame(analysis_data)
    hard_cases = df[df['gnd'] >= gnd_crossover_threshold]
    
    if hard_cases.empty:
        print(f"\nWarning: No cases found with GND > {gnd_crossover_threshold}Å in the dataset. Consider lowering the threshold.")
        best_case_smiles = None
    else:
        # 3. Among these hard cases, find the molecule with the largest performance improvement
        best_case_smiles = hard_cases.loc[hard_cases['improvement'].idxmax()]['smiles']
    
        print(f"\nAmong all hard cases with GND > {gnd_crossover_threshold}Å, found the molecule with the most significant performance improvement.")
        print(f"Best case study SMILES: {best_case_smiles}\n")

except (FileNotFoundError, ValueError) as e:
    print(f"Error: could not find data or data is empty! {e}"); best_case_smiles = None

# ===================================================================
# 3. Case study visualization
# ===================================================================
if best_case_smiles:
    print(f"--- 3. Visualizing case study molecule {best_case_smiles} ---")

    # Extract data
    ref_data = test_set_by_smiles[best_case_smiles]
    rdmol = ref_data['rdmol']
    num_nodes = rdmol.GetNumAtoms()
    ref_pos_ensemble = ref_data['pos_ref'].numpy()
    your_gen_pos = your_results[best_case_smiles].pos_gen.reshape(-1, num_nodes, 3).numpy()
    baseline_gen_pos = baseline_results[best_case_smiles].pos_gen.reshape(-1, num_nodes, 3).numpy()
    
    # Compute metrics and find best-matching conformer indices
    your_cov_r, your_mat_r, your_best_indices = calculate_molecule_metrics(ref_pos_ensemble, your_gen_pos, rdmol)
    baseline_cov_r, baseline_mat_r, baseline_best_indices = calculate_molecule_metrics(ref_pos_ensemble, baseline_gen_pos, rdmol)
    
    # Extract coordinates for visualization
    # Ground truth: select the reference conformer best matched by your model
    ground_truth_pos = ref_pos_ensemble[your_best_indices[0]] 
    # Your model: select the best-matched generated conformer
    your_best_gen_pos = your_gen_pos[your_best_indices[1]]
    # Baseline: select its own best-matched generated conformer
    baseline_best_gen_pos = baseline_gen_pos[baseline_best_indices[1]]
    
    # Create RDKit molecule objects
    mol_truth = set_rdmol_positions(rdmol, ground_truth_pos)
    mol_yours = set_rdmol_positions(rdmol, your_best_gen_pos)
    mol_baseline = set_rdmol_positions(rdmol, baseline_best_gen_pos)
    
    # Print performance report
    print("\nPerformance comparison report:")
    print("-" * 58)
    print(f"Metric\t\t | Your Model\t | Baseline")
    print("-" * 58)
    print(f"COV-R (higher is better)\t | {your_cov_r:.4f}\t | {baseline_cov_r:.4f}")
    print(f"MAT-R (lower is better)\t | {your_mat_r:.4f}\t | {baseline_mat_r:.4f}")
    print("-" * 58)
    
    # Visualization
    print("\nGround Truth Conformation:")
    display(visualize_mol(mol_truth))
    
    print("Best Conformation from SSD:")
    display(visualize_mol(mol_yours))
    
    print("Best Conformation from Baseline:")
    display(visualize_mol(mol_baseline))
