In [3]:
# %% [markdown]
# Analysis of Sample Runs
# This notebook loads samples from four runs, computes energies, and generates Ramachandran plots and energy histograms.

# %%
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import mdtraj as md
import torch
from bgflow import XTBEnergy, XTBBridge
from bgflow.utils import as_numpy
from .dataset.ad2_dataset import get_alanine_atom_types, get_alanine_implicit_dataset




ImportError: attempted relative import with no known parent package

In [None]:
# -----------------------------------------------------------------------------
# Auxiliary functions from the main script
# -----------------------------------------------------------------------------

def get_energies(samples, energies_holdout=False):
    """
    Compute XTB energies for given samples.
    Args:
        samples: np.ndarray of shape (N, 66)
        energies_holdout: bool, if True also returns energies for MD holdout set
    Returns:
        energies_np: np.ndarray of shape (N,)
        energies_holdout_np (optional)
    """
    scaling = 10.0
    atom_types_xtb = get_alanine_atom_types()
    temperature = 300
    number_dict = {0: 1, 1: 6, 2: 7, 3: 8}
    numbers = np.array([number_dict[int(atom)] for atom in atom_types_xtb])
    target_xtb = XTBEnergy(
        XTBBridge(numbers=numbers, temperature=temperature, solvent="water"),
        two_event_dims=False
    )
    # compute and scale energies
    energies_np = as_numpy(target_xtb.energy(torch.from_numpy(samples) / scaling))
    energy_offset = 34600
    energies_np = energies_np + energy_offset
    if energies_holdout:
        holdout_samples = torch.from_numpy(
            np.load("../data/AD2_relaxed_holdout.npy")
        ).reshape(-1, samples.shape[1])
        energies_data_holdout = as_numpy(
            target_xtb.energy(holdout_samples / scaling)
        ) + energy_offset
        return energies_np, energies_data_holdout
    return energies_np


def plot_energy_histogram(energies, title="Energy distribution"):
    plt.figure(figsize=(8,6))
    plt.hist(energies, bins=100, density=True, alpha=0.7)
    plt.xlabel("Energy")
    plt.ylabel("Density")
    plt.title(title)
    plt.show()


def get_ramachandran_angles(samples):
    atom_types = get_alanine_atom_types().copy()
    atom_types[[4,6,8,14,16]] = np.arange(4,9)
    carbon_pos = np.where(atom_types==1)[0]
    coords = samples.reshape(-1,22,3)
    carbon_samples = coords[:, carbon_pos]
    carbon_distances = np.linalg.norm(coords[:, [8]] - carbon_samples, axis=-1)
    cb_idx = np.where(carbon_distances == carbon_distances.min(1, keepdims=True))

    def determine_chirality_batch(batch_coords):
        vectors = batch_coords - batch_coords[:, :1, :]
        normals = np.cross(vectors[:,1,:], vectors[:,2,:])
        dots = np.einsum('...i,...i->...', normals, vectors[:,3,:])
        return np.where(dots > 0, 'L', 'D')

    backbone = coords[:, [8,6,14]]
    cb = coords[cb_idx[0], carbon_pos[cb_idx[1]]][:, None, :]
    chirality = determine_chirality_batch(np.concatenate([backbone, cb], axis=1))
    mapped = coords.copy()
    mapped[chirality == 'D'] *= -1
    traj = md.Trajectory(mapped.reshape(-1,22,3), topology=get_alanine_implicit_dataset().system.mdtraj_topology)
    phi_idx, psi_idx = [4,6,8,14], [6,8,14,16]
    angles = md.compute_dihedrals(traj, [phi_idx, psi_idx])
    return angles


def plot_ramachandran(angles, title="Ramachandran plot"):
    fig, ax = plt.subplots(figsize=(6,6))
    h, x_bins, y_bins, im = ax.hist2d(
        angles[:,0], angles[:,1], bins=100,
        norm=LogNorm(), range=[[-np.pi,np.pi],[-np.pi,np.pi]]
    )
    ax.set_xlabel(r"$\varphi$")
    ax.set_ylabel(r"$\psi$")
    ax.set_title(title)
    plt.show()

# %%
# %% [markdown]
# Load samples and analyze
# Only sample files are provided; energies will be computed on the fly.


