# Lightweight Alaninesys Inference

This notebook provides a minimal implementation for running inference on alaninesys models.
It loads configs, models, and generates 500 samples without plotting or energy calculations.

In [1]:
!nvidia-smi


Wed Jul 30 21:56:16 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.77                 Driver Version: 565.77         CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L40                     Off |   00000000:45:00.0 Off |                    0 |
| N/A   39C    P8             36W /  300W |       1MiB /  46068MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
import numpy as np
import torch
import sys
import yaml
import tqdm

sys.path.append("./BoltzNCE/")
from BoltzNCE.utils.utils import load_models
from BoltzNCE.dataset.alsys_dataloader import alaninesys_featurizer
from BoltzNCE.models.interpolant import Interpolant


****** PyMBAR will use 64-bit JAX! *******
* JAX is currently set to 32-bit bitsize *
* which is its default.                  *
*                                        *
* PyMBAR requires 64-bit mode and WILL   *
* enable JAX's 64-bit mode when called.  *
*                                        *
* This MAY cause problems with other     *
* Uses of JAX in the same code.          *
******************************************


Due to the on going maintenance burden of keeping command line application
wrappers up to date, we have decided to deprecate and eventually remove these
modules.

We instead now recommend building your command line and invoking it directly
with the subprocess module.


## Load Configuration

In [3]:
# Set config path - modify this to your config file
config_path = "/net/galaxy/home/koes/jmc530/koes_lab/BoltzNCE/BoltzNCE/saved_models/trained_vector_9_layer_al6_mc.yaml"  # UPDATE THIS PATH

# Load config
with open(config_path, 'r') as f:
    args = yaml.safe_load(f)

print(f"Loaded config from {config_path}")

Loaded config from /net/galaxy/home/koes/jmc530/koes_lab/BoltzNCE/BoltzNCE/saved_models/trained_vector_9_layer_al6_mc.yaml


## Setup Parameters and Topology

In [4]:
# Extract data parameters
scaling = 30.0
data_path = args['dataloader']['data_path']
split = args['dataloader']['split']

# Load topology and initial features
topology, h_initial = alaninesys_featurizer(data_path, split=split)

# Setup adjacency list and atom types
adj_list = torch.from_numpy(np.array([(b.atom1.index, b.atom2.index) for b in topology.bonds], dtype=np.int32))
atom_dict = {"C": 0, "H": 1, "N": 2, "O": 3, "S": 4}
atom_types = []
for atom_name in topology.atoms:
    atom_types.append(atom_name.name[0])
atom_types = torch.from_numpy(np.array([atom_dict[atom_type] for atom_type in atom_types]))

# Setup dimensions
dim = h_initial.shape[0] * 3
n_particles = h_initial.shape[0]
args['dim'] = dim

print(f"Topology loaded with {n_particles} particles, dimension: {dim}")

Topology loaded with 63 particles, dimension: 189


## Update Interpolant Arguments

In [5]:
def update_interpolant_args(args):
    """Update interpolant arguments with integration parameters"""
    args['interpolant']['rtol'] = args['rtol']
    args['interpolant']['atol'] = args['atol'] 
    args['interpolant']['tmin'] = args['tmin']
    args['interpolant']['dim'] = args['dim']
    args['interpolant']['num_particles'] = args['dim'] // 3
    return args

# Set default values if not in config
args['rtol'] = 1e-3
args['atol'] = 1e-3
args['tmin'] = args.get('tmin', 0.0)

# Update interpolant arguments
args = update_interpolant_args(args)

print("Interpolant arguments updated")

Interpolant arguments updated


## Load Models

In [6]:
# Load models
potential_model, vector_field, interpolant_obj = load_models(
    args, 
    h_initial=h_initial, 
    potential=args['model_type']=='potential'
)

# Set models to evaluation mode
if potential_model is not None:
    potential_model.eval()
vector_field.eval()

print("Models loaded and set to evaluation mode")

Total number of parameters in vector field model: 1209431
Loaded vector field model from /net/galaxy/home/koes/rishal/nce/BoltzNCE/saved_models/trained_vector_9_layer_al6_mc.pt
Models loaded and set to evaluation mode


## Generate Samples

In [9]:
 %%time
def gen_samples(n_samples, n_sample_batches, interpolant_obj, integral_type='ode'):
    """
    Generate samples using the interpolant object.
    
    Args:
        n_samples: Number of samples per batch
        n_sample_batches: Number of batches
        interpolant_obj: Interpolant object for sampling
        integral_type: Type of integration ('ode' or 'ode_divergence')
    
    Returns:
        samples_np: Generated samples as numpy array
        dlogp_all: Log probabilities (divergence info)
    """
    samples_np = np.empty(shape=(0))
    dlogp_all = []
    
    # Store original interpolant type and switch to integration type
    interpolant_placeholder = interpolant_obj.interpolant_type
    interpolant_obj.interpolant_type = interpolant_obj.integration_interpolant
    
    for i in tqdm.tqdm(range(n_sample_batches), desc="Generating samples"):
        if integral_type == 'ode':
            samples = interpolant_obj.ode_integral(n_samples)
        elif integral_type == 'ode_divergence':
            samples, logp_samples = interpolant_obj.ode_divergence_integral(n_samples)
            dlogp_all.append(logp_samples.cpu().detach().numpy())
        else:
            raise ValueError("integral_type not recognized")
        
        samples_np = np.append(samples_np, samples.detach().cpu().numpy())
    
    # Process divergence info
    if len(dlogp_all) > 0:
        dlogp_all = np.concatenate(dlogp_all, axis=0)
    else:
        # Uniform weights when no divergence
        dlogp_all = np.zeros((samples_np.shape[0], 1), dtype=np.float32)
    
    samples_np = samples_np.reshape(-1, interpolant_obj.dim)
    
    # Restore original interpolant type
    interpolant_obj.interpolant_type = interpolant_placeholder
    
    return samples_np, dlogp_all

CPU times: user 4 μs, sys: 3 μs, total: 7 μs
Wall time: 12.6 μs


In [8]:
 %%time
# Set inference parameters (matching original script)
n_samples = 100  # Total samples to generate
n_sample_batches = 1  # Number of batches (as in original script)
divergence = True  # Set to True to evaluate divergence (as in original script)

# Determine integral type (matching original script logic)
integral_type = 'ode'
if divergence == True:
    integral_type = 'ode_divergence'

print(f"########## generating initial samples")
print(f"Generating {n_samples} samples in {n_sample_batches} batches using {integral_type}...")

# Generate samples (matching original script call)
samples_np, dlogf_np = gen_samples(
    n_samples=n_samples,
    n_sample_batches=n_sample_batches,
    interpolant_obj=interpolant_obj,
    integral_type=integral_type
)

print(f"Generated samples shape: {samples_np.shape}")
print(f"Divergence info shape: {dlogf_np.shape}")
print("Inference completed successfully!")

########## generating initial samples
Generating 100 samples in 1 batches using ode_divergence...


Generating samples:   0%|          | 0/1 [00:10<?, ?it/s]


KeyboardInterrupt: 

## Save Results (Optional)

In [None]:
# Optional: Save generated samples
save_results = False  # Set to True if you want to save

if save_results:
    save_prefix = './generated/'
    np.save(f'{save_prefix}samples.npy', samples_np)
    np.save(f'{save_prefix}dlogp.npy', dlogp_np)
    print(f"Results saved to {save_prefix}")
else:
    print("Results not saved (set save_results=True to save)")