In [None]:
import pickle
import numpy as np
from tqdm.auto import tqdm
from rdkit import Chem
from rdkit.Chem import Descriptors
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import pandas as pd
from rdkit.Chem.rdmolops import RemoveHs
from rdkit.Chem import rdMolAlign as MA

# ===================================================================
# 1. Core computation and helper functions
# ===================================================================

def set_rdmol_positions(rdkit_mol, pos):
    mol = Chem.Mol(rdkit_mol); conformer = mol.GetConformer(0)
    for i in range(pos.shape[0]): conformer.SetAtomPosition(i, pos[i].tolist())
    return mol

def get_best_rmsd(probe, ref):
    probe = RemoveHs(probe); ref = RemoveHs(ref)
    return MA.GetBestRMS(probe, ref)

def calculate_molecule_metrics(ref_pos, gen_pos, rdmol, delta=0.5):
    num_nodes, num_refs, num_samples = rdmol.GetNumAtoms(), ref_pos.shape[0], gen_pos.shape[0]
    if num_refs == 0 or num_samples == 0: return 0.0, float('inf')
    rmsd_matrix = np.zeros([num_refs, num_samples])
    for i in range(num_refs):
        ref_mol = set_rdmol_positions(rdmol, ref_pos[i])
        for j in range(num_samples):
            gen_mol = set_rdmol_positions(rdmol, gen_pos[j])
            rmsd_matrix[i, j] = get_best_rmsd(gen_mol, ref_mol)
    min_rmsd_for_each_ref = rmsd_matrix.min(axis=1)
    mat_r = min_rmsd_for_each_ref.mean()
    cov_r = np.sum(min_rmsd_for_each_ref <= delta) / num_refs
    return cov_r, mat_r

def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
    def sigmoid(x): return 1 / (np.exp(-x) + 1)
    if beta_schedule == "sigmoid":
        betas = np.linspace(-6, 6, num_diffusion_timesteps)
        return sigmoid(betas) * (beta_end - beta_start) + beta_start
    return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)

# ===================================================================
# 2. Data loading and precomputation
# ===================================================================
print("--- 1. Loading data and performing precomputations ---")
your_results_path = 'checkpoints/qm9_condition/samples/sample_all.pkl'
baseline_results_path = 'checkpoints/subgdiff_baseline/samples_all.pkl'
test_set_path = 'data/GEOM/QM9/test_data_1k.pkl'

try:
    with open(your_results_path, 'rb') as f: your_results = {d.smiles: d for d in pickle.load(f)}
    with open(baseline_results_path, 'rb') as f: baseline_results = {d.smiles: d for d in pickle.load(f)}
    with open(test_set_path, 'rb') as f:
        raw_test_list = pickle.load(f)
    
    # Group test_list by SMILES
    test_set_by_smiles = {}
    for data in raw_test_list:
        if data.smiles not in test_set_by_smiles:
            test_set_by_smiles[data.smiles] = {'pos_ref': [], 'rdmol': data.rdmol}
        test_set_by_smiles[data.smiles]['pos_ref'].append(data.pos)
    for smiles in test_set_by_smiles:
        test_set_by_smiles[smiles]['pos_ref'] = torch.stack(test_set_by_smiles[smiles]['pos_ref'])
    print("Data loaded successfully!")

    # Precompute metrics for all comparable molecules
    betas = get_beta_schedule("sigmoid", beta_start=1.e-7, beta_end=2.e-3, num_diffusion_timesteps=5000)
    alphas = torch.from_numpy((1. - betas).cumprod(axis=0)).float()
    alpha_t = alphas[-2]
    
    analysis_data = []
    comparable_smiles = set(your_results.keys()).intersection(set(baseline_results.keys()))

    for smiles in tqdm(comparable_smiles, desc="Pre-calculating all metrics"):
        ref_data = test_set_by_smiles[smiles]
        rdmol = ref_data['rdmol']
        pos_0_ensemble = ref_data['pos_ref']
        
        # Compute intrinsic properties
        num_atoms = rdmol.GetNumAtoms()
        num_rot_bonds = Descriptors.NumRotatableBonds(rdmol)
        
        # Compute GND
        max_displacements = []
        for pos_0 in pos_0_ensemble:
            pos_noise = torch.randn_like(pos_0)
            pos_perturbed = pos_0 + pos_noise * (1.0 - alpha_t).sqrt() / alpha_t.sqrt()
            delta_pos = pos_perturbed - pos_0
            max_displacements.append(torch.norm(delta_pos, dim=1).max().item())
        gnd = np.mean(max_displacements)
        
        # Compute performance
        pos_0_ensemble_np = pos_0_ensemble.numpy()
        your_gen_pos = your_results[smiles].pos_gen.reshape(-1, num_atoms, 3).numpy()
        baseline_gen_pos = baseline_results[smiles].pos_gen.reshape(-1, num_atoms, 3).numpy()
        your_cov_r, your_mat_r = calculate_molecule_metrics(pos_0_ensemble_np, your_gen_pos, rdmol)
        baseline_cov_r, baseline_mat_r = calculate_molecule_metrics(pos_0_ensemble_np, baseline_gen_pos, rdmol)

        analysis_data.append({
            "smiles": smiles, "gnd": gnd, "num_atoms": num_atoms, "num_rot_bonds": num_rot_bonds,
            "your_cov_r": your_cov_r, "baseline_cov_r": baseline_cov_r,
            "your_mat_r": your_mat_r, "baseline_mat_r": baseline_mat_r
        })
    
    df = pd.DataFrame(analysis_data)
    print("Finished precomputing metrics for all molecules!\n")

except FileNotFoundError as e:
    print(f"Error: could not find data file! {e}"); df = None

# ===================================================================
# 3. Threshold scanning and detecting performance reversal points
# ===================================================================
if df is not None and not df.empty:
    print("--- 3. Performing threshold scan and plotting absolute performance comparison ---")
    
    # We use GND and Num Atoms as complexity indicators
    metrics_to_analyze = ['gnd', 'num_atoms']
    
    for metric in metrics_to_analyze:
        print(f"\n--- Analyzing metric: {metric} ---")
        
        # Determine the threshold range
        min_thresh = df[metric].quantile(0.50)
        max_thresh = df[metric].quantile(0.95)
        thresholds = np.linspace(min_thresh, max_thresh, 20)
        
        # Store average performance under each threshold
        your_avg_cov_list = []
        baseline_avg_cov_list = []
        your_avg_mat_list = []
        baseline_avg_mat_list = []
        
        for thresh in thresholds:
            # Select all molecules with complexity > current threshold
            subset = df[df[metric] > thresh]
            if len(subset) < 5: # Stop if the subset is too small to avoid statistical bias
                break
            
            # Key modification: compute average performance for each model separately, not just differences
            your_avg_cov_list.append(subset['your_cov_r'].mean())
            baseline_avg_cov_list.append(subset['baseline_cov_r'].mean())
            
            your_avg_mat_list.append(subset['your_mat_r'].mean())
            baseline_avg_mat_list.append(subset['baseline_mat_r'].mean())

        valid_thresholds = thresholds[:len(your_avg_cov_list)]

        # Plotting
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 7))
        
        # --- Plot 1: COV-R ---
        ax1.plot(valid_thresholds, your_avg_cov_list, marker='o', label='SSD-SubGDiff', linewidth=2)
        ax1.plot(valid_thresholds, baseline_avg_cov_list, marker='o', label='SubGDiff', linewidth=2)
        ax1.set_xlabel('Drift', fontsize=24, fontweight="bold")
        ax1.set_ylabel('COV-R (%)', fontsize=24, fontweight="bold")
        ax1.legend(fontsize=18)

        for label in ax1.get_xticklabels():
            label.set_fontsize(16)
            label.set_fontweight("bold")

        for label in ax1.get_yticklabels():
            label.set_fontsize(20)
            label.set_fontweight("bold")

        # --- Plot 2: MAT-R ---
        ax2.plot(valid_thresholds, baseline_avg_mat_list, marker='o', label='SSD-SubGDiff', linewidth=2)
        ax2.plot(valid_thresholds, your_avg_mat_list, marker='o', label='SubGDiff', linewidth=2)
        ax2.set_xlabel('Num of Atoms', fontsize=24, fontweight="bold")
        ax2.set_ylabel('MAT-R (Ã…)', fontsize=24, fontweight="bold")
        ax2.legend(fontsize=18)

        for label in ax1.get_xticklabels():
            label.set_fontsize(16)
            label.set_fontweight("bold")

        for label in ax1.get_yticklabels():
            label.set_fontsize(20)
            label.set_fontweight("bold")
        
        plt.tight_layout(rect=[0, 0, 1, 0.96])
        plt.show()
