# PLINDER Sample Visualization

Interactive notebook to visualize generated samples from the TaPR-Diff pipeline

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
from pathlib import Path
import json

# RDKit imports
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors, Draw
import matplotlib.pyplot as plt
from IPython.display import display, HTML

print("Libraries loaded successfully")

## 1. Load and Inspect Sample Data

In [None]:
# Load the samples file
samples_file = Path("../runs/plinder_debug/samples.pt")
samples = torch.load(samples_file)

print(f"Loaded samples from: {samples_file}")
print(f"\nAvailable keys: {samples.keys()}")

# Extract tensors
X = samples["X"]  # Coordinates [B, Nl, 3]
A = samples["A"]  # Atom types [B, Nl]
B = samples["B"]  # Bond types [B, E]
system_ids = samples.get("system_id", [])

print(f"\nShape information:")
print(f"  X (coordinates): {X.shape} - dtype: {X.dtype}")
print(f"  A (atom types):  {A.shape} - dtype: {A.dtype}")
print(f"  B (bonds):       {B.shape} - dtype: {B.dtype}")
print(f"  System IDs:      {len(system_ids)} samples")

## 2. Examine Data Structures and Types

In [None]:
# Inspect sample 0 in detail
sample_idx = 0
coords = X[sample_idx].numpy()  # Convert to numpy [Nl, 3]
atoms = A[sample_idx].numpy()    # Convert to numpy [Nl]

print(f"Sample {sample_idx} analysis:")
print(f"  Coordinates numpy array shape: {coords.shape}, dtype: {coords.dtype}")
print(f"  Atom types numpy array shape: {atoms.shape}, dtype: {atoms.dtype}")

# Check for NaN or invalid values
print(f"\n  Coordinates stats:")
print(f"    Min: {coords.min():.3f}, Max: {coords.max():.3f}")
print(f"    NaN count: {np.isnan(coords).sum()}")
print(f"    Inf count: {np.isinf(coords).sum()}")

print(f"\n  Atom types stats:")
print(f"    Min: {atoms.min()}, Max: {atoms.max()}")
print(f"    Non-zero atoms: {(atoms > 0).sum()} / {len(atoms)}")
print(f"    Unique types: {np.unique(atoms)}")

# Check a single coordinate
print(f"\n  First coordinate: {coords[0]}")
print(f"  Type: {type(coords[0])}, dtype: {coords[0].dtype}")
print(f"  As tuple: {tuple(coords[0].astype(float))}")

## 3. Atom Symbol Mapping

In [None]:
# Atom vocabulary (must match training config)
ATOM_SYMBOLS = [
    'C', 'N', 'O', 'S', 'P', 'F', 'Cl', 'Br', 'I',  # 0-8
    'H', 'C', 'N', 'O', 'S', 'P',                     # 9-14 (heavy atoms, some duplicates)
    'X', 'X', 'X', 'X', 'X',                           # 15-19 (padding)
] + ['X'] * 44  # Pad to 64 vocab size

print(f"Atom vocabulary size: {len(ATOM_SYMBOLS)}")
print(f"First 20 symbols: {ATOM_SYMBOLS[:20]}")

# Check what atoms are in sample 0
unique_atoms = np.unique(atoms[atoms > 0])  # non-zero atoms
print(f"\nAtoms in sample {sample_idx}: {unique_atoms}")
print("Symbols:")
for atom_idx in unique_atoms:
    if atom_idx < len(ATOM_SYMBOLS):
        print(f"  {int(atom_idx)}: {ATOM_SYMBOLS[int(atom_idx)]}")

## 4. Reconstruct Molecules from Tensors

In [None]:
def build_molecule(atom_indices, coords):
    """Build RDKit molecule from atom indices and coordinates.
    
    Args:
        atom_indices: [Nl] numpy array of atom type indices
        coords: [Nl, 3] numpy array of 3D coordinates
    
    Returns:
        RDKit Mol object or None if invalid
    """
    # Filter valid (non-padding) atoms
    valid = atom_indices > 0
    if not valid.any():
        print("No valid atoms!")
        return None
    
    valid_idx = np.where(valid)[0]
    print(f"Valid atoms: {len(valid_idx)}/{len(atom_indices)}")
    
    # Create editable molecule
    mol = Chem.RWMol()
    
    # Add atoms
    for idx in valid_idx:
        atom_code = int(atom_indices[idx])
        if atom_code >= len(ATOM_SYMBOLS):
            print(f"Warning: atom code {atom_code} exceeds vocab size")
            atom_code = min(atom_code, len(ATOM_SYMBOLS) - 1)
        
        symbol = ATOM_SYMBOLS[atom_code]
        if symbol == 'X':
            print(f"Skipping padding atom")
            continue
        
        try:
            atom = Chem.Atom(symbol)
            mol.AddAtom(atom)
        except Exception as e:
            print(f"Failed to add atom {symbol}: {e}")
            continue
    
    if mol.GetNumAtoms() < 2:
        print(f"Too few atoms: {mol.GetNumAtoms()}")
        return None
    
    mol = mol.GetMol()
    
    # Add simple distance-based connectivity
    atoms_added = sum(1 for idx in valid_idx if ATOM_SYMBOLS[int(atom_indices[idx])] != 'X')
    if atoms_added < 1:
        return None
    
    # Map original indices to new mol indices
    valid_atoms = []
    mol_idx = 0
    for idx in valid_idx:
        atom_code = int(atom_indices[idx])
        symbol = ATOM_SYMBOLS[min(atom_code, len(ATOM_SYMBOLS)-1)]
        if symbol != 'X':
            valid_atoms.append((idx, mol_idx))
            mol_idx += 1
    
    # Add bonds based on distance
    for ii, (orig1, mol_i) in enumerate(valid_atoms):
        for jj, (orig2, mol_j) in enumerate(valid_atoms):
            if ii < jj:
                dist = np.linalg.norm(coords[orig1] - coords[orig2])
                if dist < 1.8:  # Rough bond distance
                    mol.AddBond(mol_i, mol_j, Chem.BondType.SINGLE)
    
    # Set 3D coordinates
    conf = Chem.Conformer(mol.GetNumAtoms())
    for orig, mol_i in valid_atoms:
        pos = coords[orig].astype(float)
        conf.SetAtomPosition(mol_i, (float(pos[0]), float(pos[1]), float(pos[2])))
    mol.AddConformer(conf, assignId=True)
    
    # Sanitize
    try:
        Chem.SanitizeMol(mol)
    except:
        pass  # Often fails with generated molecules, but we can still use them
    
    return mol

# Test on sample 0
print(f"Building molecule for sample {sample_idx}...")
mol = build_molecule(atoms, coords)

if mol:
    print(f"\nMolecule created!")
    print(f"  Atoms: {mol.GetNumAtoms()}")
    print(f"  Bonds: {mol.GetNumBonds()}")
    smi = Chem.MolToSmiles(mol)
    print(f"  SMILES: {smi[:80]}")
else:
    print("Failed to create molecule")

## 5. Visualize All Samples

In [None]:
mols = []
smiles_list = []
for i in range(X.shape[0]):
    print(f"\n{'='*60}")
    print(f"Sample {i} ({system_ids[i] if i < len(system_ids) else 'unknown'})")
    print(f"{'='*60}")
    
    coords = X[i].numpy()
    atoms = A[i].numpy()
    
    mol = build_molecule(atoms, coords)
    if mol:
        smi = Chem.MolToSmiles(mol)
        smiles_list.append(smi)
        mols.append(mol)
        print(f"✓ Successfully built molecule")
    else:
        print(f"✗ Failed to build molecule")

print(f"\n\n{'='*60}")
print(f"Summary: {len(mols)} / {X.shape[0]} molecules built successfully")
print(f"{'='*60}")

In [None]:
# Draw grid of molecules
if mols:
    img = Draw.MolsToGridImage(mols, molsPerRow=2, subImgSize=(400, 400), 
                               legends=smiles_list, returnPNG=False)
    img