In [None]:
import dataclasses
import io
import math
import os
import pathlib
import random
import time
import traceback
from typing import Any

import jax
import jax.numpy as jnp
import jax.random as jrandom
import jaxtyping
import matplotlib.pyplot as plt
import numpy as np
import optax
import orbax.checkpoint as ocp
import pandas as pd
import py3Dmol
from flax import nnx
from tqdm.auto import tqdm

In [None]:
NpExampleType = dict[str, np.ndarray]
LossValuesType = list[tuple[int, float]]
BatchType = dict[str, jax.Array]

# Model choice

In [None]:
# @title Choose model type
# @markdown Default is a simple Point Cloud model.
# @markdown <font color='red'>**The simple EGNN model needs more work.**</font>
model_type = "Point Cloud Protein Diffusion Model"
# @param ["Point Cloud Protein Diffusion Model","EGNN Protein Diffusion Model"]
if model_type == "EGNN Protein Diffusion Model":
    # model = EGNNProteinDiffusionModel(config, rngs=rngs)
    raise NotImplementedError("EGNN model not working yet.")

# Configuration

**Note on data features**

*  'atom_positions':

    Shape = [batch_size, number of residues or amino acids, number of atom types, coordinates].

    There are 37 possible different atom types in proteins. We want to restrict
    ourselves to the backbone atoms ['N', 'CA', 'C', 'O'], which are at index 0, 1,
    2 and 4 of axis 2 of the 'atom_positions' array. The positions of all other atom types are set to 0. For each protein in the batch, the positions are centered on the center of mass. In cases where the protein chain is shorter than 'max_seq_length', the sequence dimension (= axis 1) is padded with zeros to 'max_seq_length' length. In cases where the protein chain is larger than 'max_seq_length', the protein chain is cut off beyond 'max_seq_length'.

* 'atom_mask':

    Denotes which atoms have a known position. Non backbone atoms are masked, i.e. set to 0 in 'atom_mask', see above.

    Shape = [batch_size, number of residues or amino acids, number of atom types].

* 'residue_index':

    Shape = [batch_size, number of residues or amino acids].

    Is literally just arange(seq_length) and denotes the linearly increasing position of individual residues, so amino acids, along the protein chain.

In [None]:
# @title Configuration and

# @markdown --- Configuration Todos--- \\
# @markdown - Use a better way to make configs. Potential needs
# @markdown   -   Allowing multiple models with model specific configs


@dataclasses.dataclass
class TrainingConfig:
    backbone_atom_indices: list[int] = dataclasses.field(
        default_factory=lambda: [0, 1, 2, 4]
    )  # N, CA, C, O indices in atom_types
    max_seq_length: int = 128  # Shorter sequence length for faster training initially
    batch_size: int = 128  # A good compromise for T4, Adjust based on GPU memory
    eval_batch_size: int = 32  #
    num_epochs: int = 100  # Reduced epochs for quick testing
    # Gradient accumulation
    gradient_accumulation_steps: int = 4  # Number of steps to accumulate gradients over
    # Whether to use effective batch size for scheduler calculations
    use_effective_batch_size: bool = True
    learning_rate: float = 1e-4  # Adjust later, smaller value for debugging loss
    warmup_steps: int = (
        500  # Number of warmup steps for learning rate, around 10-20% of total steps
    )
    # TODO: debug this, sometimes crashes in colab if the value os larger than one
    num_workers: int = 0  # For DataLoader,
    log_freq: int = 20  # Log training loss more frequently, adjust later
    eval_freq: int = 200  # Evaluate and sample more frequently for debugging
    save_freq: int = 100  # Save model checkpoint frequently
    output_dir: str = "protein_diffusion_jax_output"
    seed: int = 42  # for reproducibility

    # Diffusion specific
    timesteps: int = 200  # Number of diffusion steps (reduced for speed and debugging)
    beta_schedule: str = "linear"  # Noise schedule type ('linear' or 'cosine')

    # Model specific
    model_dim: int = 128  # Embedding dimension in the Transformer
    num_layers: int = 8  # Number of Transformer layers
    num_heads: int = 8  # Number of attention heads


# --- Data Processing and Loading ---
# Constants for protein structure representation
atom_types = [
    "N",
    "CA",
    "C",
    "CB",
    "O",
    "CG",
    "CG1",
    "CG2",
    "OG",
    "OG1",
    "SG",
    "CD",
    "CD1",
    "CD2",
    "ND1",
    "ND2",
    "OD1",
    "OD2",
    "SD",
    "CE",
    "CE1",
    "CE2",
    "CE3",
    "NE",
    "NE1",
    "NE2",
    "OE1",
    "OE2",
    "CH2",
    "NH1",
    "NH2",
    "OH",
    "CZ",
    "CZ2",
    "CZ3",
    "NZ",
    "OXT",
]


restypes = [
    "A",
    "R",
    "N",
    "D",
    "C",
    "Q",
    "E",
    "G",
    "H",
    "I",
    "L",
    "K",
    "M",
    "F",
    "P",
    "S",
    "T",
    "W",
    "Y",
    "V",
]

restype_order = {restype: i for i, restype in enumerate(restypes)}
restype_num = len(restypes)  # should be 20.

restype_1to3 = {
    "A": "ALA",
    "R": "ARG",
    "N": "ASN",
    "D": "ASP",
    "C": "CYS",
    "Q": "GLN",
    "E": "GLU",
    "G": "GLY",
    "H": "HIS",
    "I": "ILE",
    "L": "LEU",
    "K": "LYS",
    "M": "MET",
    "F": "PHE",
    "P": "PRO",
    "S": "SER",
    "T": "THR",
    "W": "TRP",
    "Y": "TYR",
    "V": "VAL",
}

bb_atom_types = ["N", "CA", "C", "O"]  # indices: 0, 1, 2, 4
bb_indices = [i for i, atom_type in enumerate(atom_types) if atom_type in bb_atom_types]


@dataclasses.dataclass(frozen=True)
class Protein:
    """Protein structure representation."""

    atom_positions: np.ndarray  # [num_res, num_atom_type, 3]
    aatype: np.ndarray  # [num_res]
    atom_mask: np.ndarray  # [num_res, num_atom_type]
    residue_index: np.ndarray  # [num_res]
    chain_index: np.ndarray  # [num_res]
    b_factors: np.ndarray  # [num_res, num_atom_type]

# Utilities

In [None]:
# @title Utilities


def calculate_radius_of_gyration(coords: jax.Array, mask: jax.Array) -> float | jax.Array:
    """Calculates Radius of Gyration (Rg). coords=[L, 4, 3], mask=[L, 4]"""
    # Use CA atoms for Rg calculation (index 1)
    ca_coords = coords[:, 1, :]  # [L, 3]
    ca_mask = mask[:, 1]  # [L]

    valid_ca_coords = ca_coords[ca_mask > 0.5]  # [N_valid, 3]
    if valid_ca_coords.shape[0] < 2:
        return 0.0  # Not enough atoms to calculate Rg

    center_of_mass = jnp.mean(valid_ca_coords, axis=0)  # [3]
    diff_sq = (valid_ca_coords - center_of_mass) ** 2  # [N_valid, 3]
    rg_sq = jnp.mean(jnp.sum(diff_sq, axis=1))  # Scalar
    return jnp.sqrt(rg_sq)


def linear_beta_schedule(timesteps: int) -> jax.Array:
    """Linear schedule, proposed in original DDPM paper."""
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return jnp.linspace(beta_start, beta_end, timesteps)


def cosine_beta_schedule(timesteps: int, s: float = 0.008) -> jax.Array:
    """Cosine schedule, proposed in Improved DDPM paper."""
    steps = timesteps + 1
    x = jnp.linspace(0, timesteps, steps)
    alphas_cumprod = jnp.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return jnp.clip(betas, 0, 0.999)


def get_diffusion_variables(beta_schedule_type: str, timesteps: int) -> dict[str, jax.Array]:
    """Create diffusion schedule variables."""
    if beta_schedule_type == "linear":
        betas = linear_beta_schedule(timesteps)
    elif beta_schedule_type == "cosine":
        betas = cosine_beta_schedule(timesteps)
    else:
        raise ValueError(f"Unknown beta schedule: {beta_schedule_type}")

    alphas = 1.0 - betas
    alphas_cumprod = jnp.cumprod(alphas)
    alphas_cumprod_prev = jnp.pad(alphas_cumprod[:-1], (1, 0), constant_values=1.0)

    # Required calculations for diffusion q(x_t | x_{t-1}) and others
    sqrt_alphas_cumprod = jnp.sqrt(alphas_cumprod)
    sqrt_one_minus_alphas_cumprod = jnp.sqrt(1.0 - alphas_cumprod)
    log_one_minus_alphas_cumprod = jnp.log(1.0 - alphas_cumprod)
    sqrt_recip_alphas_cumprod = jnp.sqrt(1.0 / alphas_cumprod)
    sqrt_recipm1_alphas_cumprod = jnp.sqrt(1.0 / alphas_cumprod - 1)

    sqrt_recip_alphas = jnp.sqrt(1.0 / alphas)

    # Calculations for posterior q(x_{t-1} | x_t, x_0)
    posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
    posterior_variance = jnp.clip(posterior_variance, 1e-20)  # Ensure non-negative variance
    posterior_log_variance_clipped = jnp.log(posterior_variance)
    posterior_mean_coef1 = betas * jnp.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
    posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * jnp.sqrt(alphas) / (1.0 - alphas_cumprod)

    return {
        "betas": betas,
        "alphas_cumprod": alphas_cumprod,
        "sqrt_alphas_cumprod": sqrt_alphas_cumprod,
        "sqrt_one_minus_alphas_cumprod": sqrt_one_minus_alphas_cumprod,
        "log_one_minus_alphas_cumprod": log_one_minus_alphas_cumprod,
        "sqrt_recip_alphas_cumprod": sqrt_recip_alphas_cumprod,
        "sqrt_recipm1_alphas_cumprod": sqrt_recipm1_alphas_cumprod,
        "sqrt_recip_alphas": sqrt_recip_alphas,
        "posterior_variance": posterior_variance,
        "posterior_log_variance_clipped": posterior_log_variance_clipped,
        "posterior_mean_coef1": posterior_mean_coef1,
        "posterior_mean_coef2": posterior_mean_coef2,
    }


# Extract backbone coordinates and mask
def select_backbone(batch: BatchType, backbone_atom_indices: list[int]):
    """Selects backbone atoms and updates mask."""
    # Shape: [B, L, 37, 3] -> [B, L, 4, 3]
    batch["positions"] = batch["positions"][:, :, backbone_atom_indices, :]
    # Shape: [B, L, 37] -> [B, L, 4]
    batch["mask"] = batch["mask"][:, :, backbone_atom_indices]
    return batch


def denormalize_coords(coords: jax.Array, scale: jax.Array) -> jax.Array:
    """Denormalizes coordinates by multiplying by scale."""
    return coords * scale


def calculate_stats(loader, backbone_atom_indices: list[int], max_batches: int = 100) -> jax.Array:
    """Calculates mean and std dev for normalization."""
    all_coords = []
    count = 0
    for i, batch in enumerate(tqdm(loader)):
        if i >= max_batches:
            break
        batch = select_backbone(batch, backbone_atom_indices)
        coords = batch["atom_positions"]  # B, L, 4, 3
        mask = batch["atom_mask"]  # B, L, 4

        # Mask out non-existent atoms
        masked_coords = coords[mask > 0.5]  # Select only valid coordinates
        if masked_coords.size > 0:
            all_coords.append(masked_coords.reshape(-1, 3))
        count += coords.shape[0]

    if not all_coords:
        raise ValueError("No valid coordinates found to calculate statistics.")

    all_coords_concat = jnp.concatenate(all_coords, axis=0)
    mean = jnp.mean(all_coords_concat, axis=0)  # Should be close to [0,0,0] due to centering
    print(f"Calculated mean: {mean}")
    std = jnp.std(all_coords_concat, axis=0)
    # Use a single std dev for all coordinates for simplicity
    coord_scale = jnp.sqrt(jnp.mean(std**2))
    print(f"Calculated coordinate scale (std dev): {coord_scale}")
    return coord_scale


def calculate_radius_of_gyration_distribution(
    batch: BatchType, coord_scale: jax.Array, config: TrainingConfig
) -> list[float]:
    rg_distribution = []
    batch = select_backbone(batch, config.backbone_atom_indices)
    coords_batch = batch["positions"]  # Normalized coords here
    mask_batch = batch["mask"]

    # Denormalize before Rg calculation
    coords_batch_denorm = denormalize_coords(coords_batch, coord_scale)

    for i in range(coords_batch_denorm.shape[0]):
        rg = calculate_radius_of_gyration(coords_batch_denorm[i], mask_batch[i])
        if rg > 0:  # Only store valid Rgs
            rg_distribution.append(rg)

    return rg_distribution

# Visualization

In [None]:
# @title Visualization


def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
    chain_end = "TER"
    return f"{chain_end:<6}{atom_index:>5}      {end_resname:>3} {chain_name:>1}{residue_index:>4}"


def to_pdb(prot: Protein) -> str:
    """Converts a `Protein` instance to a PDB string."""
    pdb_atom_types = ["N", "CA", "C", "O"]
    restypes_list = ["GLY"] * len(prot.aatype)  # Use GLY as placeholder
    res_1to3 = lambda r: restype_1to3.get(r, "UNK")

    # Rest of the PDB conversion code closely following the PyTorch implementation
    PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
    PDB_MAX_CHAINS = len(PDB_CHAIN_IDS)

    pdb_lines = []
    atom_mask = prot.atom_mask
    atom_positions = prot.atom_positions
    residue_index = prot.residue_index.astype(np.int32)
    chain_index = prot.chain_index.astype(np.int32)
    b_factors = prot.b_factors

    # Construct a mapping from chain indices to chain ID strings
    chain_ids = {}
    unique_chains = np.unique(chain_index) if len(chain_index) > 0 else []
    for i in unique_chains:
        if i >= PDB_MAX_CHAINS:
            chain_ids[i] = PDB_CHAIN_IDS[i % PDB_MAX_CHAINS]
        else:
            chain_ids[i] = PDB_CHAIN_IDS[i]

    pdb_lines.append("MODEL        1")
    atom_index = 1
    last_chain_index = chain_index[0] if len(chain_index) > 0 else 0

    num_residues = len(prot.aatype)
    if num_residues == 0:
        print("Warning: No residues to write to PDB.")
        pdb_lines.append("ENDMDL")
        pdb_lines.append("END")
        return "\n".join(pdb_lines) + "\n"

    for i in range(num_residues):
        current_chain_idx = chain_index[i]
        # Close previous chain if necessary
        if i > 0 and last_chain_index != current_chain_idx:
            if (i - 1) < len(restypes_list) and (i - 1) < len(residue_index):
                pdb_lines.append(
                    _chain_end(
                        atom_index,
                        res_1to3(restypes_list[i - 1]),
                        chain_ids.get(last_chain_index, "A"),
                        residue_index[i - 1],
                    )
                )
                atom_index += 1

        last_chain_index = current_chain_idx
        res_name_3 = res_1to3(restypes_list[i])

        # Iterate through backbone atom types
        for atom_idx_bb, (atom_name, pos, mask, b_factor) in enumerate(
            zip(pdb_atom_types, atom_positions[i], atom_mask[i], b_factors[i])
        ):
            if mask < 0.5:
                continue

            record_type = "ATOM"
            name = atom_name.ljust(4) if len(atom_name) >= 4 else f" {atom_name}".ljust(4)

            alt_loc = ""
            insertion_code = ""
            occupancy = 1.00
            element = atom_name.strip()[0] if atom_name.strip() else " "
            charge = ""

            atom_line = (
                f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
                f"{res_name_3:>3} {chain_ids.get(current_chain_idx, 'A'):>1}"
                f"{residue_index[i]:>4}{insertion_code:>1}   "
                f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
                f"{occupancy:>6.2f}{b_factor:>6.2f}          "
                f"{element:>2}{charge:>2}"
            )
            pdb_lines.append(atom_line)
            atom_index += 1

    # Close the final chain
    if num_residues > 0:
        if (num_residues - 1) < len(restypes_list) and (num_residues - 1) < len(residue_index):
            pdb_lines.append(
                _chain_end(
                    atom_index,
                    res_1to3(restypes_list[-1]),
                    chain_ids.get(last_chain_index, "A"),
                    residue_index[-1],
                )
            )

    pdb_lines.append("ENDMDL")
    pdb_lines.append("END")

    # Pad all lines to 80 characters
    pdb_lines = [line.ljust(80) for line in pdb_lines]
    return "\n".join(pdb_lines) + "\n"


def visualize_pdb(pdb_string: str, width=600, height=400):
    """Visualize PDB string using py3Dmol."""
    if not pdb_string:
        print("Cannot visualize empty PDB string.")
        return
    view = py3Dmol.view(width=width, height=height)
    view.addModel(pdb_string, "pdb")
    view.setStyle({"cartoon": {"color": "spectrum"}})
    view.zoomTo()
    view.show()


def flat_coords_to_protein(
    flat_coords: np.ndarray,
    flat_mask: np.ndarray,
    target_seq_len: int,
    num_bb_atoms: int = 4,
) -> Protein:
    """Converts flattened coordinates back to Protein object format for PDB saving."""
    # Convert JAX arrays to numpy if needed
    if hasattr(flat_coords, "device_buffer"):
        flat_coords = np.array(flat_coords)
    if hasattr(flat_mask, "device_buffer"):
        flat_mask = np.array(flat_mask)

    # Calculate expected number of residues
    expected_num_res = target_seq_len
    expected_total_points = expected_num_res * num_bb_atoms

    # Handle potential size mismatch
    num_total_points = flat_coords.shape[0]
    if num_total_points != expected_total_points:
        print(
            f"Warning: flat_coords/flat_mask length ({num_total_points}) doesn't match"
            f" target_seq_len * num_bb_atoms ({expected_total_points}). Truncating or padding."
        )

        # Truncate if too long
        if num_total_points > expected_total_points:
            flat_coords = flat_coords[:expected_total_points]
            flat_mask = flat_mask[:expected_total_points]
        # Pad if too short (less likely scenario)
        elif num_total_points < expected_total_points:
            pad_len = expected_total_points - num_total_points
            flat_coords = np.pad(flat_coords, ((0, pad_len), (0, 0)), mode="constant")
            flat_mask = np.pad(flat_mask, ((0, pad_len),), mode="constant")

    # Reshape back to [N_res, num_bb_atoms, 3]
    atom_positions_bb = flat_coords.reshape(expected_num_res, num_bb_atoms, 3)
    atom_mask_bb = flat_mask.reshape(expected_num_res, num_bb_atoms)

    # Determine actual length based on CA mask (index 1 within bb atoms)
    ca_mask = atom_mask_bb[:, 1]
    actual_len = int(np.sum(ca_mask > 0.5))

    if actual_len == 0:
        print(
            "Warning: Generated mask suggests zero length within target_seq_len."
            " Creating PDB with length 1."
        )
        actual_len = 1
        atom_positions_bb = atom_positions_bb[:1]
        atom_mask_bb = atom_mask_bb[:1]
        atom_mask_bb[:, :] = 1.0
    else:
        # Truncate to the actual determined length
        atom_positions_bb = atom_positions_bb[:actual_len]
        atom_mask_bb = atom_mask_bb[:actual_len]

    # Create dummy data needed for Protein object
    aatype = np.zeros(actual_len, dtype=np.int32)  # Dummy aatype
    residue_index = np.arange(1, actual_len + 1)  # PDB standard 1-based indexing
    chain_index = np.zeros(actual_len, dtype=np.int32)  # Single chain 'A'
    b_factors = np.ones((actual_len, num_bb_atoms)) * 50.0  # Dummy B-factors

    # Create the Protein object
    protein = Protein(
        atom_positions=atom_positions_bb,
        aatype=aatype,
        atom_mask=atom_mask_bb,
        residue_index=residue_index,
        chain_index=chain_index,
        b_factors=b_factors,
    )
    return protein


def plot_loss_curves(
    train_losses: LossValuesType,
    val_losses: LossValuesType,
    plot_filename: str | None = None,
    current_step: int | None = None,
):
    """
    Plots training and validation loss curves.

    Args:
        train_losses: list of tuples (step, loss).
        val_losses: list of tuples (step, loss).
        current_step: The current global step, used for title.
    """
    plt.figure(figsize=(6, 3))

    # Plot training loss
    if train_losses:
        train_steps, train_vals = zip(*train_losses)
        if not current_step:
            current_step = train_steps[-1]
        plt.plot(train_steps, train_vals, label="Training Loss", alpha=0.8)

    # Plot validation loss
    if val_losses:
        val_steps, val_vals = zip(*val_losses)
        # Use markers for validation points as they are less frequent
        plt.plot(val_steps, val_vals, label="Validation Loss", marker="o", linestyle="--")

    plt.xlabel("Global Step")
    plt.ylabel("Loss")
    plt.title(f"Loss Curves up to Step {current_step}")
    plt.legend()
    plt.grid(True, which="both", linestyle="--", linewidth=0.5)
    plt.yscale("log")  # Log scale often helpful for losses
    plt.tight_layout()

    # Show plot if in interactive environment
    try:
        plt.show()
    except Exception as e:
        print(f"Error showing plot: {e}")
        pass

    if not plot_filename:
        try:
            plt.savefig(plot_filename)
            print(f"Saved loss plot to {plot_filename}")
        except Exception as e:
            print(f"Error saving loss plot: {e}")


def plot_radius_of_gyration(rg_train_hist: list[float], rg_samples_hist: list[float]):
    fig, axs = plt.subplots(1, 1, figsize=(6, 3))

    # Rg Histogram plot
    if rg_train_hist and rg_samples_hist:
        bins = np.linspace(
            min(rg_train_hist + rg_samples_hist),
            max(rg_train_hist + rg_samples_hist),
            30,
        )
        axs[0].hist(rg_train_hist, bins=bins, alpha=0.6, label="Train Set Rg", density=True)
        axs[0].hist(
            rg_samples_hist,
            bins=bins,
            alpha=0.6,
            label="Generated Samples Rg",
            density=True,
        )
        axs[0].set_xlabel("Radius of Gyration (Angstrom)")
        axs[0].set_ylabel("Density")
        axs[0].set_title("Rg Distribution")
        axs[0].legend()
        axs[0].grid(True)
    else:
        axs[0].text(
            0.5,
            0.5,
            "Rg data not available yet",
            horizontalalignment="center",
            verticalalignment="center",
        )
        axs[0].set_title("Rg Distribution")

    plt.tight_layout()

    # Save plot to buffer (optional, useful if running non-interactively)
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    buf.seek(0)
    plt.show()  # Display the plot
    return buf

# Data

In [None]:
# @title Data
# @markdown TODOs: \\
# @markdown - Migrate to [grain](https://github.com/google/grain) for a full JAX/Flax implementation
# @markdown -


def make_np_example(coords_dict: dict[str, np.ndarray]) -> NpExampleType:
    """Make a dictionary of non-batched numpy protein features."""
    # Handle potential missing keys gracefully
    if "N" not in coords_dict or coords_dict["N"] is None:
        # Return None for skipping this item in the dataset
        return None

    num_res = np.array(coords_dict["N"]).shape[0]
    if num_res == 0:
        return None

    atom_positions = np.zeros([num_res, len(atom_types), 3], dtype=np.float32)

    for i, atom_type in enumerate(atom_types):
        if atom_type in bb_atom_types:
            # Check if the key exists and is not None before converting to array
            if atom_type in coords_dict and coords_dict[atom_type] is not None:
                try:
                    pos_array = np.array(coords_dict[atom_type])
                    # Ensure the array has the expected number of residues
                    if pos_array.shape[0] == num_res:
                        atom_positions[:, i, :] = pos_array
                except Exception:
                    pass  # Set to zero implicitly

    # Mask nan / None coordinates.
    nan_pos = np.isnan(atom_positions)[..., 0]
    atom_positions[nan_pos] = 0.0
    atom_mask = np.zeros([num_res, len(atom_types)], dtype=np.float32)
    atom_mask[..., bb_indices] = 1
    atom_mask[nan_pos] = 0  # Ensure NaNs are masked

    batch = {
        "atom_positions": atom_positions,
        "atom_mask": atom_mask,
        "residue_index": np.arange(num_res, dtype=np.int32),
    }
    return batch


def make_fixed_size(np_example: NpExampleType, max_seq_length: int = 500) -> NpExampleType:
    """Pad features to fixed sequence length."""
    # Check if np_example is None before proceeding
    if np_example is None:
        return None

    for k, v in np_example.items():
        # Check if v is None or empty before accessing shape
        if v is None or v.shape[0] == 0:
            return None  # Or handle appropriately

        pad = max_seq_length - v.shape[0]
        if pad > 0:
            pad_shape = [(0, pad)] + [(0, 0)] * (len(v.shape) - 1)
            v = np.pad(v, pad_shape, mode="constant")
        elif pad < 0:
            v = v[:max_seq_length]
        np_example[k] = v

    return np_example


def center_positions(np_example: NpExampleType) -> NpExampleType:
    """Center 'atom_positions' on CA center of mass."""
    # Check if np_example is None before proceeding
    if np_example is None:
        return None

    atom_positions = np_example["atom_positions"]
    atom_mask = np_example["atom_mask"]

    # Check if arrays are valid before processing
    if (
        atom_positions is None
        or atom_mask is None
        or atom_positions.shape[0] == 0
        or atom_mask.shape[0] == 0
    ):
        return np_example  # Return unmodified example

    ca_positions = atom_positions[:, 1, :]  # CA is at index 1
    ca_mask = atom_mask[:, 1]

    # Ensure ca_mask sum is not zero before division
    ca_mask_sum = np.sum(ca_mask)
    if ca_mask_sum < 1e-9:
        return np_example

    ca_center = np.sum(ca_mask[..., None] * ca_positions, axis=0) / ca_mask_sum

    # Center only valid atoms
    atom_positions = (atom_positions - ca_center[None, None, :]) * atom_mask[..., None]
    np_example["atom_positions"] = atom_positions

    return np_example


class DatasetFromDataframe:
    """Dataset class for protein coordinates data from a DataFrame."""

    def __init__(self, data_frame: pd.DataFrame, max_seq_length: int = 512):
        self.data = data_frame
        self.max_seq_length = max_seq_length
        self.valid_indices = self._preprocess_data()
        self._cache = {}  # Cache for processed examples

    def _preprocess_data(self):
        """Pre-filter data to find valid indices."""
        valid_indices = []
        print("Preprocessing dataset to check for valid entries...")
        num_skipped = 0

        for idx in tqdm(range(len(self.data)), desc="Preprocessing"):
            coords_dict = self.data.iloc[idx].coords
            # Basic check: ensure 'coords' is a dict and 'N' exists and has > 0 residues
            if (
                isinstance(coords_dict, dict)
                and "N" in coords_dict
                and coords_dict["N"] is not None
                and len(coords_dict["N"]) > 0
            ):
                # Further check with make_np_example to catch more issues early
                temp_example = make_np_example(coords_dict)
                if temp_example is not None:
                    valid_indices.append(idx)
                else:
                    num_skipped += 1
            else:
                num_skipped += 1

        print(
            f"Preprocessing complete. Found {len(valid_indices)} valid entries."
            f" Skipped {num_skipped} invalid/problematic entries."
        )
        return valid_indices

    def __len__(self):
        # Return the number of valid preprocessed entries
        return len(self.valid_indices)

    def __getitem__(self, idx):
        # Check if example is already in cache
        if idx in self._cache:
            return self._cache[idx]

        # Map the requested index to the valid index list
        actual_idx = self.valid_indices[idx]
        coords_dict = self.data.iloc[actual_idx].coords

        # Process the item
        np_example = make_np_example(coords_dict)
        if np_example is None:
            raise RuntimeError(f"make_np_example failed for pre-validated index {actual_idx}")

        np_example = make_fixed_size(np_example, self.max_seq_length)
        if np_example is None:
            raise RuntimeError(f"make_fixed_size failed for pre-validated index {actual_idx}")

        np_example = center_positions(np_example)
        if np_example is None:
            raise RuntimeError(f"center_positions failed for pre-validated index {actual_idx}")

        # Select only backbone atoms for positions and mask
        positions_bb = np_example["atom_positions"][:, bb_indices, :]
        mask_bb = np_example["atom_mask"][:, bb_indices]  # Now [N_res, 4]
        res_mask_ca = np_example["atom_mask"][:, 1]  # CA mask [N_res]
        residue_idx_orig = np_example["residue_index"]

        # Reshape for point cloud: [N_res * 4, 3] for positions, [N_res * 4] for mask
        num_res = positions_bb.shape[0]
        num_bb_atoms = len(bb_indices)  # Should be 4

        positions_flat = positions_bb.reshape(num_res * num_bb_atoms, 3)
        mask_flat = mask_bb.reshape(num_res * num_bb_atoms)

        # Create residue index mapping for the flattened structure
        residue_index_flat = np.repeat(residue_idx_orig, num_bb_atoms)

        # Create a JAX-compatible example (convert to jnp arrays)
        example = {
            "positions": jnp.array(positions_flat, dtype=jnp.float32),
            "mask": jnp.array(mask_flat, dtype=jnp.float32),  # Mask for points
            "residue_index": jnp.array(residue_index_flat, dtype=jnp.int32),
            "res_mask": jnp.array(res_mask_ca, dtype=jnp.float32),  # Mask for residues
        }

        # Cache the processed example
        self._cache[idx] = example

        return example

    def get_batch(self, indices: list[int]) -> BatchType:
        """Get a batch of examples."""
        examples = [self[i] for i in indices]

        # Stack the examples into a batch
        batch = {
            "positions": jnp.stack([ex["positions"] for ex in examples]),
            "mask": jnp.stack([ex["mask"] for ex in examples]),
            "residue_index": jnp.stack([ex["residue_index"] for ex in examples]),
            "res_mask": jnp.stack([ex["res_mask"] for ex in examples]),
        }

        return batch


class JAXDataLoader:
    """JAX equivalent of PyTorch's DataLoader."""

    def __init__(
        self,
        dataset: DatasetFromDataframe,
        batch_size: int,
        shuffle: bool = False,
        drop_last: bool = False,
        rng_key=None,
    ):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.rng_key = rng_key if rng_key is not None else jrandom.PRNGKey(0)

    def __iter__(self):
        # Get indices for all data points
        indices = list(range(len(self.dataset)))

        # Shuffle if needed
        if self.shuffle:
            self.rng_key, subkey = jrandom.split(self.rng_key)
            indices = jrandom.permutation(subkey, jnp.array(indices)).tolist()

        # Create batches
        batch_indices = []
        for i in range(0, len(indices), self.batch_size):
            if i + self.batch_size <= len(indices) or not self.drop_last:
                batch_idx = indices[i : i + self.batch_size]
                batch_indices.append(batch_idx)

        # Return an iterator over batches
        for batch_idx in batch_indices:
            yield self.dataset.get_batch(batch_idx)

    def __len__(self):
        if self.drop_last:
            return len(self.dataset) // self.batch_size
        else:
            return (len(self.dataset) + self.batch_size - 1) // self.batch_size

# Diffusion Models

In [None]:
# @title Abstract Protein Diffusion Model
class ProteinDiffusionModel(nnx.Module):
    pass

In [None]:
# @title Time step embedding


# --- Timestep Embedding ---
# Source: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/nn.py
# Modified slightly for Flax NNX
class TimestepEmbedding(nnx.Module):
    """
    Sinusoidal timestep embeddings, adapted from the original DDPM implementation.
    Converted from PyTorch to Flax NNX.
    """

    def __init__(self, dim, max_period=10000, *, rngs=None):
        # rngs is included for API compatibility but not used for this module
        # since we don't have parameters that need initialization
        self.dim = dim
        self.max_period = max_period

    def __call__(self, t) -> jax.Array:
        """
        Create sinusoidal timestep embeddings.

        Args:
            t: 1D tensor of timesteps with shape [batch]

        Returns:
            Timestep embeddings with shape [batch, dim]
        """
        half = self.dim // 2
        freqs = jnp.exp(-math.log(self.max_period) * jnp.arange(0, half) / half)
        args = t[:, None] * freqs[None, :]
        embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)

        # Handle odd dimensions
        if self.dim % 2:
            embedding = jnp.concatenate([embedding, jnp.zeros_like(embedding[:, :1])], axis=-1)

        return embedding

In [None]:
# @title Point Cloud Protein Diffusion Model

# --- Model Architecture (using Flax NNX) ---


class PointCloudTransformerBlock(ProteinDiffusionModel):
    """Transformer block adapted for point clouds."""

    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.1, *, rngs: nnx.Rngs):
        # Multi-head attention with all required parameters
        self.attention = nnx.MultiHeadAttention(
            num_heads=num_heads,
            in_features=embed_dim,
            qkv_features=embed_dim,
            out_features=embed_dim,
            dropout_rate=dropout,
            rngs=rngs,
        )

        # Layer normalization
        self.norm1 = nnx.LayerNorm(embed_dim, rngs=rngs)
        self.norm2 = nnx.LayerNorm(embed_dim, rngs=rngs)

        # Dropout
        self.dropout = nnx.Dropout(dropout, rngs=rngs)

        # FFN layers - 4x expansion in hidden layer as in original transformer
        self.ffn_linear1 = nnx.Linear(embed_dim, embed_dim * 4, rngs=rngs)
        self.ffn_linear2 = nnx.Linear(embed_dim * 4, embed_dim, rngs=rngs)

    def __call__(
        self,
        x: jax.Array,
        time_emb: jax.Array,
        mask: jax.Array | None = None,
        *,
        rngs: nnx.Rngs | None = None,
        deterministic: bool = False,
    ) -> jax.Array:
        # Add time embedding to each point embedding
        x_with_time = x + time_emb[:, None, :]  # [batch, num_points, embed_dim]

        # Self-attention with residual connection
        attn_mask = None
        if mask is not None:
            # Create attention mask for MultiHeadAttention
            # Following Flax NNX API - we need to create it with proper broadcasting
            attn_mask = nnx.make_attention_mask(
                mask >= 0.5,  # query_input (True for valid positions)
                mask >= 0.5,  # key_input (True for valid positions)
                dtype=jnp.float32,
            )

        # Self-attention - use query only form as we're doing self-attention
        # Need to provide query, key, value for self-attention
        attn_output = self.attention(
            x_with_time,
            x_with_time,
            x_with_time,
            deterministic=deterministic,
            mask=attn_mask,
            decode=False,
            rngs=rngs,
        )
        attn_output = self.dropout(attn_output, rngs=rngs, deterministic=deterministic)
        x = x + attn_output
        x = self.norm1(x)

        # FFN with residual connection
        x_res = x  # Store for residual connection
        ffn_output = self.ffn_linear1(x)
        ffn_output = jax.nn.gelu(ffn_output)
        ffn_output = self.dropout(ffn_output, rngs=rngs, deterministic=deterministic)
        ffn_output = self.ffn_linear2(ffn_output)
        ffn_output = self.dropout(ffn_output, rngs=rngs, deterministic=deterministic)
        x = x_res + ffn_output  # Add residual connection
        x = self.norm2(x)  # Apply second layer normalization
        x = x + ffn_output
        x = self.norm2(x)

        return x


class PointCloudProteinDiffusionModel(nnx.Module):
    """Protein structure diffusion model."""

    def __init__(self, config: TrainingConfig, *, rngs: nnx.Rngs):
        """
        Initialize the protein diffusion model.

        Args:
            config: Configuration object with model parameters
            rngs: Random number generator key wrapper
        """
        self.config = config

        # Input projection: map 3D coords to model_dim
        self.coord_proj = nnx.Linear(3, config.model_dim, rngs=rngs)

        # Timestep embedding
        self.time_embed = TimestepEmbedding(config.model_dim, rngs=rngs)
        self.time_embed_mlp_1 = nnx.Linear(config.model_dim, config.model_dim * 4, rngs=rngs)
        self.time_embed_mlp_2 = nnx.Linear(config.model_dim * 4, config.model_dim, rngs=rngs)

        # In Flax NNX, modules need to be stored as direct attributes
        # Using setattr to dynamically create attributes
        for i in range(config.num_layers):
            setattr(
                self,
                f"block_{i}",
                PointCloudTransformerBlock(
                    config.model_dim, config.num_heads, dropout=0.1, rngs=rngs
                ),
            )

        # Output projection
        self.final_norm = nnx.LayerNorm(config.model_dim, rngs=rngs)
        self.final_proj = nnx.Linear(config.model_dim, 3, rngs=rngs)

    def __call__(
        self,
        x_t: jax.Array,
        t: jax.Array,
        point_mask: jax.Array,
        *,
        rngs=None,
        deterministic=None,
    ) -> jax.Array:
        """
        Forward pass of the model, exactly matching the PyTorch implementation.

        Args:
            x_t: [batch, num_points, 3] - noisy coordinates
            t: [batch] - timesteps
            point_mask: [batch, num_points] - 1.0 for valid points, 0.0 for padding
            rngs: Optional dictionary of PRNG keys for stochastic operations
            deterministic: If True, disables stochastic operations (like dropout)
        """
        # Set training mode based on deterministic flag
        _training = not deterministic if deterministic is not None else True
        _batch_size, num_points, _ = x_t.shape

        # Project coordinates to embedding dimension
        x_emb = self.coord_proj(x_t)

        # Process timestep embedding
        time_emb = self.time_embed(t)
        time_emb = self.time_embed_mlp_1(time_emb)
        time_emb = jax.nn.gelu(time_emb)
        time_emb = self.time_embed_mlp_2(time_emb)

        # Apply transformer blocks sequentially
        h = x_emb
        for i in range(self.config.num_layers):
            block = getattr(self, f"block_{i}")
            h = block(h, time_emb, point_mask, rngs=rngs, deterministic=deterministic)

        # Final normalization and projection to 3D coordinates
        h = self.final_norm(h)
        predicted_noise = self.final_proj(h)

        # Apply mask to output (zero out predictions for masked/padded points)
        predicted_noise = predicted_noise * point_mask[:, :, None]

        return predicted_noise

In [None]:
# @title EGNN Protein Diffusion Model

# @markdown <font color='red'>**NOT WORKING YET**</font>


class EGNNLayer(nnx.Module):
    """
    A conceptual and simplified E(n) Equivariant Graph Neural Network Layer.
    Operates on node features (h) and node coordinates (x).
    Assumes a fully connected graph for simplicity in this example.
    """

    def __init__(self, feature_dim: int, hidden_dim: int, *, rngs: nnx.Rngs):
        """
        Args:
            feature_dim: Dimensionality of input/output node features (h).
            hidden_dim: Internal dimensionality for MLPs.
            rngs: NNX random number generators.
        """
        self.feature_dim = feature_dim

        # MLP for edge features/messages (phi_e)
        # Input: h_i, h_j, ||x_i - x_j||^2
        self.phi_e = nnx.Sequential(
            nnx.Linear(feature_dim * 2 + 1, hidden_dim, rngs=rngs),
            nnx.silu,  # Swish activation
            nnx.Linear(hidden_dim, hidden_dim, rngs=rngs),
            nnx.silu,
        )

        # MLP for node feature update (phi_h)
        # Input: h_i, aggregated_message_m_i
        self.phi_h = nnx.Sequential(
            nnx.Linear(feature_dim + hidden_dim, hidden_dim, rngs=rngs),
            nnx.silu,
            nnx.Linear(hidden_dim, feature_dim, rngs=rngs),
        )

        # MLP for coordinate update weights (phi_x)
        # Input: message_m_ij (output of phi_e)
        self.phi_x = nnx.Sequential(
            nnx.Linear(hidden_dim, hidden_dim, rngs=rngs),
            nnx.silu,
            # Output a single scalar weight per message
            nnx.Linear(hidden_dim, 1, use_bias=False, rngs=rngs),  # Often bias is omitted here
        )

        # Optional: Layer Normalization
        self.norm_h = nnx.LayerNorm(feature_dim, rngs=rngs)
        # Note: Normalizing coordinates directly can break equivariance.
        # Coordinate normalization/stabilization often happens outside the layer
        # or via techniques like normalizing the coordinate update vectors.

    def __call__(
        self,
        h: jax.Array,
        x: jax.Array,
        mask: jax.Array | None = None,
        *,
        rngs: nnx.Rngs | None = None,
    ) -> tuple[jax.Array, jax.Array]:
        """
        Forward pass of the EGNN layer.

        Args:
            h: Node features [batch, num_nodes, feature_dim]
            x: Node coordinates [batch, num_nodes, 3]
            mask: Optional node mask [batch, num_nodes] (1 for real nodes, 0 for padding)

        Returns:
            tuple containing updated node features and coordinates.
        """
        batch_size, num_nodes, _ = h.shape

        # --- 1. Calculate Pairwise Differences and Distances ---
        # Expand dims for broadcasting: x_i [B, N, 1, 3], x_j [B, 1, N, 3]
        delta_x = x[:, :, None, :] - x[:, None, :, :]  # [B, N, N, 3] (vector x_i - x_j)
        # Calculate squared distances ||x_i - x_j||^2
        sq_dist = jnp.sum(delta_x**2, axis=-1, keepdims=True)  # [B, N, N, 1]

        # --- 2. Prepare Edge Features ---
        # Expand dims for broadcasting: h_i [B, N, 1, F], h_j [B, 1, N, F]
        h_i = h[:, :, None, :].repeat(num_nodes, axis=2)
        h_j = h[:, None, :, :].repeat(num_nodes, axis=1)

        # Concatenate features for the edge MLP input: [h_i, h_j, ||x_i - x_j||^2]
        edge_mlp_input = jnp.concatenate([h_i, h_j, sq_dist], axis=-1)  # [B, N, N, 2*F + 1]

        # --- 3. Calculate Edge Messages (m_ij) ---
        # Apply the edge MLP phi_e
        m_ij = self.phi_e(edge_mlp_input)  # [B, N, N, hidden_dim]

        # --- Masking (Important!) ---
        node_mask = None
        if mask is not None:
            # Ensure messages involving padded nodes are zeroed out
            # Also apply to delta_x and sq_dist for safety, although sq_dist is invariant
            node_mask = mask[:, :, None, None] * mask[:, None, :, None]  # [B, N, N, 1]
            m_ij = m_ij * node_mask
            delta_x = delta_x * node_mask
            # Prevent self-loops from contributing to coordinate updates later
            self_mask = (1.0 - jnp.eye(num_nodes, dtype=m_ij.dtype))[
                None, :, :, None
            ]  # [1, N, N, 1]
            delta_x = delta_x * self_mask  # Zero out diagonal delta_x
            m_ij = (
                m_ij * self_mask
            )  # Zero out diagonal messages if desired (can affect feature updates)

        else:
            # Prevent self-loops if no mask provided
            self_mask = (1.0 - jnp.eye(num_nodes, dtype=m_ij.dtype))[None, :, :, None]
            delta_x = delta_x * self_mask
            m_ij = m_ij * self_mask

        # --- 4. Aggregate Messages for Feature Update ---
        # Sum messages arriving at each node i: m_i = sum_j(m_ij)
        m_i = jnp.sum(m_ij, axis=2)  # [B, N, hidden_dim]

        # --- 5. Update Node Features (h') ---
        # Input to node MLP phi_h: [h_i, m_i]
        h_mlp_input = jnp.concatenate([h, m_i], axis=-1)  # [B, N, F + hidden_dim]
        # Residual connection: h' = h + phi_h([h, m_i])
        h_update = self.phi_h(h_mlp_input)  # [B, N, F]
        h_new = self.norm_h(h + h_update)  # Apply norm after residual

        if mask is not None:
            h_new = h_new * mask[:, :, None]  # Apply node mask

        # --- 6. Calculate Coordinate Update Weights ---
        # Apply coordinate MLP phi_x to messages
        coord_weights = self.phi_x(m_ij)  # [B, N, N, 1]

        # Stabilize weights (optional but common)
        # Avoid division by zero if sq_dist is zero (i.e. i=j)
        coord_weights = coord_weights / jnp.sqrt(sq_dist + 1e-8)  # Normalize by distance
        coord_weights = jax.nn.tanh(coord_weights)  # Apply activation (e.g., tanh)

        if mask is not None:
            # Ensure weights involving masked nodes or self-loops are zero
            coord_weights = coord_weights * node_mask * self_mask

        # --- 7. Update Coordinates (x') ---
        # Equivariant update: x'_i = x_i + C * sum_{j!=i} (x_i - x_j) * phi_x(m_ij)
        # Note: delta_x is (x_i - x_j)
        coord_update = jnp.sum(delta_x * coord_weights, axis=2)  # [B, N, 3]

        # Constant C (often 1/(num_nodes-1) or learned) - simple average here
        # Calculate num_real_nodes per batch item for potentially better normalization
        if mask is not None:
            num_real_nodes = jnp.sum(mask, axis=1, keepdims=True)  # [B, 1]
            C = 1.0 / jnp.maximum(num_real_nodes - 1, 1)  # [B, 1], avoid div by zero
            C = C[:, :, None]  # Broadcast to [B, 1, 1]
        else:
            C = 1.0 / (num_nodes - 1 + 1e-8)  # Avoid division by zero if N=1

        x_new = x + C * coord_update  # [B, N, 3]

        if mask is not None:
            x_new = x_new * mask[:, :, None]  # Apply node mask

        return h_new, x_new


class EGNNProteinDiffusionModel(ProteinDiffusionModel):
    """Protein structure diffusion model using EGNN layers."""

    def __init__(self, config: TrainingConfig, *, rngs: nnx.Rngs):
        """
        Initialize the EGNN-based protein diffusion model.

        Args:
            config: Configuration object with model parameters.
            rngs: NNX random number generator key wrapper.
        """
        self.config = config

        # Input projection: map 3D coords to feature_dim (h_0)
        # This creates initial node features from the input coordinates
        self.coord_embed = nnx.Linear(3, config.model_dim, rngs=rngs)

        # Timestep embedding (remains the same as in the Transformer version)
        self.time_embed = TimestepEmbedding(config.model_dim, rngs=rngs)
        # Use MLPs to process time embedding before adding to node features
        self.time_embed_mlp_1 = nnx.Linear(config.model_dim, config.model_dim * 4, rngs=rngs)
        self.time_embed_mlp_2 = nnx.Linear(config.model_dim * 4, config.model_dim, rngs=rngs)

        # EGNN Layers instead of Transformer Blocks
        # Store layers in a list or Modulelist if NNX provides one
        self.egnn_layers = [
            EGNNLayer(
                feature_dim=config.model_dim,
                hidden_dim=config.model_dim // 2,  # Example hidden dim, can be tuned
                rngs=rngs,
            )
            for _ in range(config.num_layers)
        ]

        # Output projection (operates on final node features `h`)
        self.final_norm = nnx.LayerNorm(config.model_dim, rngs=rngs)
        # Predicts the noise added to the original coordinates based on final features
        self.final_proj = nnx.Linear(config.model_dim, 3, rngs=rngs)

    def __call__(self, x_t, t, point_mask, *, rngs=None, deterministic=None):
        """
        Forward pass of the EGNN-based model.

        Args:
            x_t: [batch, num_points, 3] - noisy coordinates at timestep t.
            t: [batch] - timesteps.
            point_mask: [batch, num_points] - 1.0 for valid points, 0.0 for padding.
            rngs: Optional dictionary of PRNG keys (EGNNLayer doesn't use dropout here).
            deterministic: If True, disables stochastic ops (not relevant for this EGNNLayer).

        Returns:
            Predicted noise [batch, num_points, 3].
        """
        batch_size, num_points, _ = x_t.shape

        # 1. Initial Feature Embedding
        # Create initial node features h_0 from input coordinates x_t
        h = self.coord_embed(x_t)  # [B, N, F]

        # 2. Process Timestep Embedding
        time_emb = self.time_embed(t)  # [B, F]
        # Pass through MLPs
        time_emb = self.time_embed_mlp_1(time_emb)
        time_emb = jax.nn.silu(time_emb)  # Use SiLU activation consistent with EGNN
        time_emb = self.time_embed_mlp_2(time_emb)  # [B, F]

        # Add time embedding to initial node features (broadcast across nodes)
        h = h + time_emb[:, None, :]  # [B, N, F]

        # Apply mask after adding time embedding
        if point_mask is not None:
            h = h * point_mask[:, :, None]

        # 3. Apply EGNN Layers Sequentially
        # The EGNN layer updates both features (h) and coordinates (x)
        current_x = x_t  # Start with the input noisy coordinates
        current_h = h  # Start with the embedded features + time
        for layer in self.egnn_layers:
            # Pass current features (h) and coordinates (x), and the mask
            current_h, current_x = layer(current_h, current_x, mask=point_mask)
            # The layer outputs updated h' and x'
            # Note: current_x is updated equivariantly inside the layer

        # 4. Final Projection
        # Use the final node *features* `current_h` to predict the noise
        h_final = self.final_norm(current_h)
        predicted_noise = self.final_proj(h_final)  # [B, N, 3]

        # Apply mask to the final output prediction
        if point_mask is not None:
            predicted_noise = predicted_noise * point_mask[:, :, None]

        # Return the predicted noise. The DiffusionProcess loss function
        # will compare this prediction to the actual noise added to the original x_0.
        return predicted_noise

# Diffusion process

In [None]:
# @title Diffusion Process


# --- Diffusion Process ---
class DiffusionProcess:
    """Implementation of diffusion processes for protein structures."""

    def __init__(self, config: TrainingConfig):
        self.config = config
        self.timesteps = config.timesteps
        self.vars = get_diffusion_variables(config.beta_schedule, config.timesteps)

    def q_sample(
        self,
        x_start: jax.Array,
        t: jax.Array,
        noise: jax.Array | None = None,
        key: jax.random.PRNGKey | None = None,
    ):
        """Forward diffusion process: q(x_t | x_0)."""
        if noise is None:
            if key is None:
                key = jrandom.PRNGKey(0)
            noise = jrandom.normal(key, shape=x_start.shape)

        sqrt_alphas_cumprod_t = self._extract(self.vars["sqrt_alphas_cumprod"], t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = self._extract(
            self.vars["sqrt_one_minus_alphas_cumprod"], t, x_start.shape
        )

        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    def _extract(self, a: jax.Array, t: jax.Array, x_shape: tuple[int]) -> jax.Array:
        # Get value at index t for a batch of shape x_shape
        batch_size = t.shape[0]

        # The axis parameter specifies which axis to gather from (in this case, the last axis, -1)
        out = jnp.take_along_axis(a, t, axis=-1)

        # Reshape to match the desired output shape
        return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))

    def p_losses(
        self,
        model_fn,
        x_start: jax.Array,
        t: jax.Array,
        point_mask: jax.Array,
        deterministic: bool = False,
        noise: jax.Array | None = None,
        key: jax.random.PRNGKey | None = None,
    ) -> jax.Array:
        """Calculate diffusion loss."""
        if noise is None:
            if key is None:
                key = jrandom.PRNGKey(0)
            noise = jrandom.normal(key, shape=x_start.shape)

        # Apply forward process to get noisy input x_t
        x_noisy = self.q_sample(x_start, t, noise, key)

        # Predict noise using the model
        predicted_noise = model_fn(x_noisy, t, point_mask, deterministic=deterministic)

        # Calculate MSE loss between predicted and actual noise
        loss = (noise - predicted_noise) ** 2
        loss = loss * point_mask[:, :, None]  # Apply mask to only compute loss on valid points

        # Compute the SNR-weighted loss
        snr = self._extract(self.vars["sqrt_alphas_cumprod"], t, x_start.shape) / self._extract(
            self.vars["sqrt_one_minus_alphas_cumprod"], t, x_start.shape
        )
        loss_weights = jnp.clip(snr / (1.0 + snr), 0.5, 5.0)
        loss = loss * loss_weights

        # Average loss per valid point across the batch
        per_sample_loss = jnp.sum(loss, axis=(1, 2)) / (jnp.sum(point_mask, axis=1) * 3 + 1e-8)
        return jnp.mean(per_sample_loss)

    def p_sample(
        self,
        model_fn,
        x_t: jax.Array,
        t: jax.Array,
        point_mask: jax.Array,
        key=None,
        dropout_key=None,
    ):
        """
        Reverse process step: Sample x_{t-1} from p(x_{t-1} | x_t).
        """
        if key is None:
            key = jrandom.PRNGKey(0)

        # Create dropout RNG if not provided
        if dropout_key is None:
            key, dropout_key = jrandom.split(key)

        # Extract required coefficients for timestep t
        betas_t = self._extract(self.vars["betas"], t, x_t.shape)
        sqrt_one_minus_alphas_cumprod_t = self._extract(
            self.vars["sqrt_one_minus_alphas_cumprod"], t, x_t.shape
        )
        sqrt_recip_alphas_t = self._extract(self.vars["sqrt_recip_alphas"], t, x_t.shape)

        # Create a wrapper function that includes the dropout RNG
        # Use try/except to handle whether the function accepts deterministic
        try:
            # First try to call with dropout RNGs but without deterministic param
            fn_rngs = nnx.Rngs(dropout=dropout_key)
            predicted_noise = model_fn(x_t, t, point_mask, rngs=fn_rngs)
        except TypeError:
            # If that doesn't work, try with both parameters
            try:
                fn_rngs = nnx.Rngs(dropout=dropout_key)
                predicted_noise = model_fn(x_t, t, point_mask, deterministic=True, rngs=fn_rngs)
            except TypeError:
                # If both fail, try without any extra params
                predicted_noise = model_fn(x_t, t, point_mask)

        # Rest of the function remains the same
        model_mean = sqrt_recip_alphas_t * (
            x_t - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t
        )

        if jnp.all(t == 0):
            return model_mean
        else:
            posterior_variance_t = self._extract(self.vars["posterior_variance"], t, x_t.shape)
            noise = jrandom.normal(key, shape=x_t.shape)
            noise = noise * point_mask[:, :, None]
            return model_mean + jnp.sqrt(posterior_variance_t) * noise

    def sample(
        self, model_fn, shape: tuple[int], point_mask: jax.Array, *, rng_key=None
    ) -> list[jax.Array]:
        """
        Generate samples starting from pure noise (x_T).

        Args:
            model_fn: Function that takes (x_t, t, point_mask) and returns predicted noise
            shape: Shape of the output tensor [batch, num_points, 3]
            point_mask: Mask for valid points [batch, num_points]
            rng_key: JAX PRNG key

        Returns:
            list of generated samples at different denoising steps
        """
        if rng_key is None:
            rng_key = jrandom.PRNGKey(0)

        batch_size = shape[0]
        # Start from pure Gaussian noise ~ N(0, I)
        rng_key, noise_key = jrandom.split(rng_key)
        denoised_p = jrandom.normal(noise_key, shape=shape)

        # lists to store generated samples
        denoised_ps = []

        # Loop from T-1 down to 0 (reverse diffusion process)
        for i in tqdm(
            reversed(range(0, self.timesteps)),
            desc="Sampling",
            total=self.timesteps,
            leave=False,
        ):
            # Create PRNG keys for this step - one for sampling noise, one for dropout
            rng_key, step_key, dropout_key = jrandom.split(rng_key, 3)

            # Create a tensor of the current timestep for the batch
            t = jnp.full((batch_size,), i, dtype=jnp.int32)

            # Perform one denoising step with both keys
            denoised_p = self.p_sample(
                model_fn,
                denoised_p,
                t,
                point_mask,
                key=step_key,
                dropout_key=dropout_key,
            )

            # Save intermediate steps (at regular intervals)
            if i % 50 == 0:
                denoised_ps.append(denoised_p)

        # Add final sample if not already added
        if self.timesteps % 50 != 0:
            denoised_ps.append(denoised_p)

        return denoised_ps

# Checkpointing utilities

In [None]:
# @title Checkpointing utilities


def setup_checkpoint_manager(base_dir):
    """Setup Orbax checkpoint manager."""
    try:
        # Ensure the base directory path is absolute
        base_dir_abs = os.path.abspath(base_dir)
        print(f"Setting up checkpoint manager in absolute path: {base_dir_abs}...")

        # Ensure the directory exists before creating the manager
        # Use pathlib for robust directory creation
        pathlib.Path(base_dir_abs).mkdir(parents=True, exist_ok=True)

        options = ocp.CheckpointManagerOptions(
            max_to_keep=5,
            create=True,  # create=True is often default, but explicit is fine
        )

        checkpoint_manager = ocp.CheckpointManager(
            directory=base_dir_abs,  # Pass the absolute path string directly
            options=options,
        )

        print(f"Successfully created checkpoint manager in {base_dir_abs}")
        # Return the absolute path for consistency
        return checkpoint_manager, base_dir_abs
    except Exception as e:
        print(f"Error setting up checkpoint manager: {e}")
        traceback.print_exc()
        raise


def save_checkpoint(checkpoint_manager, model, step):
    """Save model checkpoint using Orbax."""
    print(f"Attempting to save checkpoint for step {step}...")

    # Get the model state using nnx.state
    try:
        # Ensure the model is an NNX Module or GraphDef before getting state
        if not isinstance(model, (nnx.Module, nnx.GraphDef)):
            raise TypeError(f"Expected model to be nnx.Module or nnx.GraphDef, got {type(model)}")

        model_state = nnx.state(model)
        print("Successfully extracted model state.")
    except Exception as e:
        print(f"Error getting model state: {e}")
        traceback.print_exc()  # traceback is now defined globally
        raise

    # Define the arguments for saving using Orbax argument classes
    try:
        # Create save args. StandardSave works well for standard PyTrees like NNX state.
        save_args = ocp.args.Composite(
            model=ocp.args.StandardSave(model_state)
            # You can add other things to save here, e.g., optimizer state:
            # optimizer=ocp.args.StandardSave(optimizer_state)
        )

        checkpoint_manager.save(
            step,
            args=save_args,
            # force=False # Set to True to overwrite if step exists (use with caution)
        )

        # Wait for checkpointing to finish before proceeding (important for async saves)
        checkpoint_manager.wait_until_finished()

        # Use the manager's directory property which holds the correct Path object
        print(f"Successfully saved checkpoint for step {step} to {checkpoint_manager.directory}")

    except Exception as e:
        print(f"Error during checkpoint_manager.save or wait: {e}")
        traceback.print_exc()  # traceback is now defined globally
        raise

    return checkpoint_manager


def load_checkpoint(checkpoint_manager, target_model_template=None, step=None):
    """
    Load model checkpoint using Orbax.

    Args:
        checkpoint_manager: The Orbax CheckpointManager instance.
        target_model_template: An optional NNX Module instance or GraphDef
                               with the same structure as the saved model.
                               If provided, the loaded state will be applied
                               to this template. If None, the raw state dict
                               is returned.
        step: The specific step to restore. If None, restores the latest step.

    Returns:
        A tuple containing:
        - The restored model state (either applied to target_model_template
          or as a raw dictionary).
        - The step number restored from.
        Returns (None, None) if no checkpoint is found.
    """
    try:
        if step is None:
            step = checkpoint_manager.latest_step()
            if step is None:
                print(f"No checkpoints found in {checkpoint_manager.directory} to restore.")
                return None, None

        print(
            "Attempting to restore checkpoint from step"
            f" {step} in {checkpoint_manager.directory}..."
        )

        # Define the args for restoring.
        # If a template is provided, Orbax can restore directly into it.
        # Otherwise, StandardRestore() loads the raw data.
        if target_model_template:
            if not isinstance(target_model_template, (nnx.Module, nnx.GraphDef)):
                raise TypeError(
                    "Expected target_model_template to be nnx.Module"
                    f" or nnx.GraphDef, got {type(target_model_template)}"
                )
            # Use the template's state as the target for restoration
            target_state = nnx.state(target_model_template)
            restore_args = ocp.args.Composite(
                model=ocp.args.StandardRestore(target_state)
                # Add args for other items if they were saved, e.g.:
                # optimizer=ocp.args.StandardRestore(optimizer_template.state)
            )
            print("Restoring checkpoint into provided model template.")
            # Restore directly into the target state
            restored_data = checkpoint_manager.restore(
                step,
                args=restore_args,  # Pass the args with the target state
            )
            # Orbax modifies the target state in-place.
            # We need to update the original model template with the modified state.
            nnx.update(
                target_model_template, restored_data["model"]
            )  # Apply the restored state back

            print(f"Successfully restored checkpoint from step {step} into the template.")
            return target_model_template, step  # Return the modified template
        else:
            # Restore raw state dictionary if no template is given
            restore_args = ocp.args.Composite(
                model=ocp.args.StandardRestore()
                # Add args for other items if they were saved
                # optimizer=ocp.args.StandardRestore()
            )
            print("Restoring checkpoint as raw state dictionary.")
            restored_data = checkpoint_manager.restore(step, args=restore_args)
            print(f"Successfully restored raw checkpoint data from step {step}")
            # Return the specific part of the restored data corresponding to the model
            if "model" not in restored_data:
                print("Warning: 'model' key not found in restored data dictionary.")
                return None, step  # Or handle as appropriate
            return restored_data["model"], step

    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        traceback.print_exc()  # traceback is now defined globally
        raise

# Optimization utilities

In [None]:
# @title Optimization utilities


class GradAccumulationState(nnx.Module):
    """Module to track gradient accumulation state."""

    grads: Any | None = None
    count: int = 0

    def update(self, new_grads: jax.Array) -> None:
        """Update accumulated gradients."""
        if self.grads is None:
            self.grads = jax.tree.map(lambda g: g.copy(), new_grads)
        else:
            self.grads = jax.tree.map(lambda g1, g2: g1 + g2, self.grads, new_grads)
        self.count += 1

    def get_averaged_grads(self) -> jax.Array:
        """Get averaged gradients and reset state."""
        averaged_grads = jax.tree.map(lambda g: g / self.count, self.grads)
        return averaged_grads

    def reset(self) -> None:
        """Reset accumulation state."""
        self.grads = None
        self.count = 0


def accumulate_gradients(
    grad_state: GradAccumulationState, grads: jax.Array
) -> GradAccumulationState:
    """Update gradient accumulation state with new gradients."""
    grad_state.update(grads)
    return grad_state


def apply_accumulated_gradients(
    optimizer: optax.Optimizer, grad_state: GradAccumulationState
) -> optax.Optimizer:
    """Apply accumulated gradients and reset state."""
    if grad_state.count > 0:
        averaged_grads = grad_state.get_averaged_grads()
        optimizer.update(averaged_grads)
        grad_state.reset()
    return optimizer

# Training functions

In [None]:
# @title Training functions

# @markdown ToDo: \\
# @markdown   - Start from a checkpoint for training


def create_linear_warmup_cosine_decay_schedule(
    init_value, peak_value, warmup_steps, decay_steps, end_value
):
    """Create a learning rate schedule with linear warmup and cosine decay."""

    # Create linear warmup schedule
    warmup_schedule = optax.linear_schedule(
        init_value=init_value, end_value=peak_value, transition_steps=warmup_steps
    )

    # Create cosine decay schedule
    decay_schedule = optax.cosine_decay_schedule(
        init_value=peak_value,
        decay_steps=decay_steps,
        alpha=end_value / peak_value,  # alpha = final_value / initial_value
    )

    # Join the two schedules
    return optax.join_schedules(
        schedules=[warmup_schedule, decay_schedule], boundaries=[warmup_steps]
    )


def calculate_batch_size_info(config: TrainingConfig, dataset_size: int):
    """Calculate and print batch size information."""
    effective_batch_size = config.batch_size * config.gradient_accumulation_steps

    # Calculate steps with either nominal or effective batch size
    divisor = effective_batch_size if config.use_effective_batch_size else config.batch_size
    total_steps = (dataset_size * config.num_epochs) // divisor

    print("Batch size configuration:")
    print(f"  • Nominal batch size: {config.batch_size}")
    print(f"  • Gradient accumulation steps: {config.gradient_accumulation_steps}")
    print(f"  • Effective batch size: {effective_batch_size}")
    print(f"  • Total training update steps: {total_steps}")

    return effective_batch_size, total_steps


@nnx.jit(static_argnames=["diffusion"])
def compute_grads_step(
    model: ProteinDiffusionModel, diffusion: DiffusionProcess, batch: BatchType, rng_key
):
    """Compute gradients without applying them."""

    x_start = batch["positions"]
    point_mask = batch["mask"]

    # Split RNG key for timestep sampling, noise, and dropout
    rng_key, timestep_key, noise_key, dropout_key = jrandom.split(rng_key, 4)

    # Sample random timesteps (uniform from 0 to T-1)
    batch_size = x_start.shape[0]
    t = jrandom.randint(timestep_key, (batch_size,), 0, diffusion.timesteps, dtype=jnp.int32)

    # Create noise
    noise = jrandom.normal(noise_key, shape=x_start.shape)

    def loss_fn(model):
        # Create a forward function for diffusion process
        def forward_fn(x_t, t, mask, deterministic=False):
            fn_rngs = nnx.Rngs(dropout=dropout_key)
            return model(x_t, t, mask, deterministic=deterministic, rngs=fn_rngs)

        # Apply diffusion loss calculation
        return diffusion.p_losses(
            forward_fn, x_start, t, point_mask, deterministic=False, noise=noise
        )

    # Calculate gradients using nnx value_and_grad
    loss, grads = nnx.value_and_grad(loss_fn)(model)

    return loss, grads, model, rng_key


# --- Training Functions ---
@nnx.jit(static_argnames=["diffusion"])
def train_step(
    model: ProteinDiffusionModel,
    optimizer: nnx.Optimizer,
    diffusion: DiffusionProcess,
    batch: BatchType,
    rng_key,
):
    """
    Single training step, JIT-compiled for speed with nnx.jit.
    Matches the PyTorch implementation's training step.

    Args:
        model: ProteinDiffusionModel
        optimizer: Flax NNX optimizer wrapping the model
        diffusion: DiffusionProcess object
        batch: dictionary containing 'positions' and 'mask'
        rng_key: JAX PRNG key

    Returns:
        loss: Training loss for this batch
        model: Updated model
        optimizer: Updated optimizer
        rng_key: Updated PRNG key
    """

    x_start = batch["positions"]
    point_mask = batch["mask"]

    # Split RNG key for timestep sampling, noise, and dropout
    rng_key, timestep_key, noise_key, dropout_key = jrandom.split(rng_key, 4)

    # Sample random timesteps (uniform from 0 to T-1)
    batch_size = x_start.shape[0]
    t = jrandom.randint(timestep_key, (batch_size,), 0, diffusion.timesteps, dtype=jnp.int32)

    # Create noise
    noise = jrandom.normal(noise_key, shape=x_start.shape)

    # In Flax NNX, state is kept inside an nnx.Module and is mutable
    def loss_fn(model):
        # Create a forward function for diffusion process (training mode)
        def forward_fn(x_t, t, mask, deterministic=False):
            # Create a dropout RNG for this call
            fn_rngs = nnx.Rngs(dropout=dropout_key)
            return model(x_t, t, mask, deterministic=deterministic, rngs=fn_rngs)

        # Apply diffusion loss calculation
        return diffusion.p_losses(
            forward_fn, x_start, t, point_mask, deterministic=False, noise=noise
        )

    # Calculate gradients using nnx value_and_grad
    loss, grads = nnx.value_and_grad(loss_fn)(model)

    # Update the optimizer and model in-place (NNX supports in-place mutation)
    optimizer.update(grads)

    return loss, model, optimizer, rng_key


@nnx.jit(static_argnames=["diffusion"])
def eval_step(
    model: ProteinDiffusionModel,
    diffusion: DiffusionProcess,
    batch: BatchType,
    *,
    rng_key,
):
    """
    Evaluation step, JIT-compiled for speed using nnx.jit.
    Matches the PyTorch implementation's evaluation logic.

    Args:
        model: The ProteinDiffusionModel
        diffusion: DiffusionProcess object
        batch: dictionary containing 'positions' and 'mask'
        rng_key: JAX PRNG key

    Returns:
        loss: Validation loss for this batch
        rng_key: Updated PRNG key
    """

    x_start = batch["positions"]
    point_mask = batch["mask"]

    # Split RNG key for timestep sampling, noise, and dropout
    rng_key, timestep_key, noise_key, dropout_key = jrandom.split(rng_key, 4)

    # Sample random timesteps
    batch_size = x_start.shape[0]
    t = jrandom.randint(timestep_key, (batch_size,), 0, diffusion.timesteps, dtype=jnp.int32)

    # Create noise (same as in train_step)
    noise = jrandom.normal(noise_key, shape=x_start.shape)

    # Calculate loss without gradients
    def loss_fn(model):
        # Forward function for diffusion process (evaluation mode)
        def forward_fn(x_t, t, mask, deterministic=True):
            # Create a dropout RNG for this call
            fn_rngs = nnx.Rngs(dropout=dropout_key)
            return model(x_t, t, mask, deterministic=deterministic, rngs=fn_rngs)

        # Apply diffusion loss calculation
        return diffusion.p_losses(
            forward_fn, x_start, t, point_mask, deterministic=True, noise=noise
        )

    # Calculate loss (no gradients needed for evaluation)
    loss = loss_fn(model)

    return loss, rng_key


def train(
    config: TrainingConfig,
    model: ProteinDiffusionModel,
    diffusion: DiffusionProcess,
    train_loader: JAXDataLoader,
    val_loader: JAXDataLoader,
    *,
    rng_key,
    plot_rgs=False,
):
    """Training loop for the protein diffusion model."""

    # Create output directory
    os.makedirs(config.output_dir, exist_ok=True)
    checkpoint_manager, checkpoint_dir = setup_checkpoint_manager(config.output_dir)

    effective_batch_size, total_steps = calculate_batch_size_info(config, len(train_loader.dataset))
    decay_steps = total_steps - config.warmup_steps

    # Create learning rate schedule based on effective or nominal batch size
    schedule_fn = create_linear_warmup_cosine_decay_schedule(
        init_value=config.learning_rate / 3,
        peak_value=config.learning_rate,
        warmup_steps=config.warmup_steps,
        decay_steps=decay_steps,
        end_value=config.learning_rate / 1e3,
    )

    # Initialize optimizer with Flax NNX Optimizer
    # This creates an optimizer that holds a reference to the model
    optimizer = nnx.Optimizer(
        model,
        optax.chain(
            optax.clip_by_global_norm(1.0),  # Gradient clipping.
            optax.adamw(
                learning_rate=schedule_fn,
                weight_decay=2e-5,
                b1=0.9,
                b2=0.99,
                eps=1e-8,
            ),
        ),
    )

    metrics = nnx.MultiMetric(
        loss=nnx.metrics.Average("loss"),
    )

    # Training metrics
    train_losses = []
    val_losses = []
    # rg_distribution_train = []  # Store Rg values from training set once
    # rg_distribution_samples = []  # Store Rg values from generated samples per evaluation

    if plot_rgs:
        raise NotImplementedError("Plotting Rg distributions not implemented yet.")

    print(f"Starting training for {config.num_epochs} epochs...")

    # Initialize state tracking
    global_step = 0  # Total batches processed
    update_step = 0  # Parameter updates performed
    start_time = time.time()

    # Initialize gradient accumulation state
    grad_state = GradAccumulationState()

    for epoch in range(config.num_epochs):
        # Training loop
        epoch_loss = 0.0
        train_batch_count = 0

        # Get dataloader length if available
        try:
            num_batches = len(train_loader)
            has_known_length = True
        except (TypeError, AttributeError):
            has_known_length = False
            num_batches = "unknown"

        print(f"Epoch {epoch + 1}/{config.num_epochs}: {num_batches} batches")
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{config.num_epochs}")

        # Track current batch in epoch
        batch_idx = 0

        for batch in progress_bar:
            batch_idx += 1

            # Generate a new RNG key for this step
            rng_key, step_key = jrandom.split(rng_key)

            # Compute gradients without applying them
            loss, grads, model, rng_key = compute_grads_step(model, diffusion, batch, step_key)

            # Accumulate gradients
            grad_state = accumulate_gradients(grad_state, grads)

            # Update metrics
            metrics.update(loss=loss)
            epoch_loss += loss
            train_batch_count += 1
            global_step += 1

            # Update progress bar
            progress_bar.set_postfix(
                loss=f"{loss:.4f}",
                acc=f"{grad_state.count}/{config.gradient_accumulation_steps}",
            )

            # Log training progress
            if global_step % config.log_freq == 0:
                print(f"Step: {global_step}, Train Loss: {loss:.4f}")
                train_losses.append((global_step, loss))

            # Check if it's time to apply gradients
            # 1. We've reached the specified number of accumulation steps, OR
            # 2. We're at the end of an epoch (if we can determine this)
            is_last_batch = has_known_length and batch_idx == num_batches
            should_apply_grads = (grad_state.count >= config.gradient_accumulation_steps) or (
                is_last_batch and grad_state.count > 0
            )
            if should_apply_grads:
                # Apply accumulated gradients
                optimizer = apply_accumulated_gradients(optimizer, grad_state)
                update_step += 1

                # Reset accumulation state
                grad_state.reset()

            # Evaluation and sampling
            if global_step > 0 and global_step % config.eval_freq == 0 and val_loader:
                print(
                    f"\n--- Evaluating at update step {update_step} (batch step {global_step}) ---"
                )

                # Validation
                total_val_loss = 0.0
                val_batch_count = 0

                for val_batch in tqdm(val_loader, desc="Validation", leave=False):
                    # Generate a new RNG key for evaluation
                    rng_key, eval_key = jrandom.split(rng_key)

                    # Perform evaluation step
                    val_loss, rng_key = eval_step(model, diffusion, val_batch, rng_key=eval_key)

                    # Update validation metrics
                    total_val_loss += val_loss
                    val_batch_count += 1

                # Calculate average validation loss
                avg_val_loss = (
                    total_val_loss / val_batch_count if val_batch_count > 0 else float("inf")
                )
                val_losses.append((global_step, avg_val_loss))
                print(f"Step: {global_step}, Avg Validation Loss: {avg_val_loss:.4f}")

                # Generate a sample structure
                print("Generating a sample structure...")

                # Define sample parameters (e.g., protein of length 100 residues)
                sample_len_res = config.max_seq_length
                num_sample_points = sample_len_res * 4
                sample_shape = 1, num_sample_points, 3

                # Create a mask for this length (all valid points)
                sample_point_mask = jnp.ones(sample_shape[:-1])

                # Split RNG key for sampling
                rng_key, sample_key = jrandom.split(rng_key)

                # Generate sample with deterministic=True for inference
                def sample_forward_fn(x_t, t, mask, deterministic=True, rngs=None):
                    # Use the rngs if provided, otherwise function without it
                    if rngs is not None:
                        return model(x_t, t, mask, deterministic=deterministic, rngs=rngs)
                    else:
                        return model(x_t, t, mask, deterministic=deterministic)

                # Generate samples
                generated_samples = diffusion.sample(
                    sample_forward_fn,
                    sample_shape,
                    sample_point_mask,
                    rng_key=sample_key,
                )

                # Extract the final denoised sample (at step t=0)
                final_sample_flat = np.array(generated_samples[-1][0])  # Convert to numpy array
                final_mask_flat = np.array(sample_point_mask[0])

                try:
                    # Convert flat coordinates back to Protein object
                    protein_sample = flat_coords_to_protein(
                        final_sample_flat, final_mask_flat, sample_len_res
                    )

                    # Convert to PDB string format
                    pdb_string = to_pdb(protein_sample)

                    if pdb_string:  # Check if PDB string was generated successfully
                        print("--- Generated Sample PDB (first 15 lines) ---")
                        print("\n".join(pdb_string.splitlines()[:15]))
                        print("---------------------------------------------")

                        # Save PDB file
                        pdb_filename = os.path.join(
                            config.output_dir, f"sample_step_{global_step}.pdb"
                        )

                        with open(pdb_filename, "w") as f:
                            f.write(pdb_string)
                        print(f"Saved sample PDB to {pdb_filename}")

                        # Try to visualize if in interactive environment
                        try:
                            if COLAB_ENV:
                                print("Plotting loss curves...")
                                plot_loss_curves(train_losses, val_losses, current_step=global_step)
                                visualize_pdb(pdb_string)

                        except (ImportError, ModuleNotFoundError):
                            print("py3Dmol not available for visualization")
                    else:
                        print("PDB string generation failed for the sample.")
                except Exception as e:
                    print(f"Error generating, saving, or visualizing sample: {e}")
                    import traceback

                    traceback.print_exc()

            # Save model checkpoint
            if update_step > 0 and update_step % config.save_freq == 0:
                # Save using Orbax checkpoint manager
                save_checkpoint(checkpoint_manager, model, update_step)
                print(f"Saved checkpoint at update step {update_step}")

        # End of epoch logging
        avg_epoch_loss = epoch_loss / train_batch_count if train_batch_count > 0 else float("inf")
        print(f"--- Epoch {epoch + 1} Finished ---")
        print(f"Average Training Loss for Epoch: {avg_epoch_loss:.4f}")
        print("-----------------------------")

    # Calculate total training time
    end_time = time.time()
    print(f"Total training time: {(end_time - start_time) / 60:.2f} minutes")

    # Final sample generation (similar to the evaluation sample generation)
    print("\n--- Generating final sample ---")
    sample_len_res = 150  # Generate a longer protein for final sample
    num_sample_points = sample_len_res * 4
    sample_shape = (1, num_sample_points, 3)
    sample_point_mask = jnp.ones(sample_shape[:-1])

    rng_key, final_sample_key = jrandom.split(rng_key)

    def final_sample_fn(x_t, t, mask):
        return model(x_t, t, mask, deterministic=True)

    generated_samples = diffusion.sample(
        final_sample_fn, sample_shape, sample_point_mask, rng_key=final_sample_key
    )

    final_sample_flat = np.array(generated_samples[-1][0])
    final_mask_flat = np.array(sample_point_mask[0])

    try:
        protein_sample = flat_coords_to_protein(final_sample_flat, final_mask_flat, sample_len_res)

        pdb_string = to_pdb(protein_sample)

        if pdb_string:
            pdb_filename = os.path.join(config.output_dir, f"final_sample_{sample_len_res}res.pdb")

            with open(pdb_filename, "w") as f:
                f.write(pdb_string)
            print(f"Saved final sample PDB ({sample_len_res} residues) to {pdb_filename}")

            visualize_pdb(pdb_string)
        else:
            print("Final PDB string generation failed.")
    except Exception as e:
        print(f"Error generating final sample: {e}")
        traceback.print_exc()

    return model, train_losses, val_losses

# Load CATH data

In [None]:
# @title Load CATH data

# Check if running in Colab to determine data loading approach
try:
    from google.colab import drive

    COLAB_ENV = True
except ImportError:
    COLAB_ENV = False
    drive = None


def read_cath_data() -> tuple[pd.DataFrame, pd.DataFrame]:
    df, cath_splits = pd.DataFrame(), pd.DataFrame()
    try:
        drive.mount("/content/drive", force_remount=True)
        data_path_prefix = "/content/drive/MyDrive/CATH_ml_takehome/"

        # Check if path exists
        if not os.path.exists(data_path_prefix):
            raise FileNotFoundError(f"Data path not found: {data_path_prefix}")
        print("Google Drive mounted successfully.")

        print("Reading chain_set.jsonl...")
        df = pd.read_json(os.path.join(data_path_prefix, "chain_set.jsonl"), lines=True)
        cath_splits = pd.read_json(
            os.path.join(data_path_prefix, "chain_set_splits.json"), lines=True
        )

        print("Read data.")
    except FileNotFoundError as e:
        print(f"Error: {e}")
        print("Please ensure the CATH_ml_takehome folder exists in 'My Drive'.")
    except Exception as e:
        print(f"An error occurred during data loading: {e}")
        import traceback

        traceback.print_exc()

    return df, cath_splits

# Initialize model and data

In [None]:
# @title Initialize model and data


def initialize_model(
    config: TrainingConfig,
) -> tuple[ProteinDiffusionModel, DiffusionProcess, jaxtyping.PRNGKeyArray]:
    seed = config.seed
    rng_key = jrandom.PRNGKey(seed)
    rng_key, params_key, dropout_key = jrandom.split(rng_key, 3)
    rngs = nnx.Rngs(params=params_key, dropout=dropout_key)

    # Initialize model and diffusion process
    rng_key, model_key = jrandom.split(rng_key)
    rngs = nnx.Rngs(params=model_key)

    # Initialize model
    if model_type == "Point Cloud Protein Diffusion Model":
        model = PointCloudProteinDiffusionModel(config, rngs=rngs)
    elif model_type == "EGNN Protein Diffusion Model":
        # model = EGNNProteinDiffusionModel(config, rngs=rngs)
        raise NotImplementedError("EGNN model not implemented yet.")
    else:
        raise ValueError(f"Unknown model type: {model_type}")

    # Initialize diffusion process
    diffusion = DiffusionProcess(config)

    return model, diffusion, rng_key


def read_data_from_colab() -> pd.DataFrame:
    df, cath_splits = read_cath_data()

    # Function to determine split for each PDB
    def get_split(pdb_name):
        if (
            "train" in cath_splits
            and cath_splits.train.iloc[0]
            and pdb_name in cath_splits.train.iloc[0]
        ):
            return "train"

        if (
            "validation" in cath_splits
            and cath_splits.validation.iloc[0]
            and pdb_name in cath_splits.validation.iloc[0]
        ):
            return "validation"

        if (
            "test" in cath_splits
            and cath_splits.test.iloc[0]
            and pdb_name in cath_splits.test.iloc[0]
        ):
            return "test"

        return "None"

    # Add split and sequence length information
    df["split"] = df.name.apply(get_split)
    df["seq_len"] = df.seq.apply(lambda x: len(x) if isinstance(x, str) else 0)

    return df

# Main

In [None]:
# @title Main Execution


# --- Main Execution ---
def main(config):
    # Set random seed for reproducibility
    random.seed(config.seed)
    np.random.seed(config.seed)

    # Create output directory
    os.makedirs(config.output_dir, exist_ok=True)

    model, diffusion, rng_key = initialize_model(config)

    # Count model parameters
    params_state = nnx.state(model, nnx.Param)
    param_count = sum(p.size for p in jax.tree_util.tree_leaves(params_state))
    print(f"Total trainable parameters: {param_count:,}")

    train_loader, val_loader = None, None  # Initialize to None

    # Load data if in Colab environment
    if COLAB_ENV and drive:
        df = read_data_from_colab()
        # Filter based on split
        df_train = df[df.split == "train"].reset_index(drop=True)
        df_val = df[df.split == "validation"].reset_index(drop=True)

        print(f"Total train PDBs found: {len(df_train)}")
        print(f"Total validation PDBs found: {len(df_val)}")

        # Create datasets using our custom DatasetFromDataframe class
        train_set = DatasetFromDataframe(df_train, max_seq_length=config.max_seq_length)
        val_set = DatasetFromDataframe(df_val, max_seq_length=config.max_seq_length)

        # Get new RNG keys for data loaders
        rng_key, train_key, val_key = jrandom.split(rng_key, 3)

        # Create JAX DataLoaders only if datasets are not empty
        if len(train_set) > 0:
            train_loader = JAXDataLoader(
                train_set,
                batch_size=config.batch_size,
                shuffle=True,
                drop_last=False,
                rng_key=train_key,
            )
            print(f"Train DataLoader created with {len(train_set)} valid samples.")

            # Check a sample batch to validate the data pipeline
            try:
                sample_batch = next(iter(train_loader))
                print("Sample train batch keys:", list(sample_batch.keys()))
                print("Sample positions shape:", sample_batch["positions"].shape)  # [B, N_res*4, 3]
                print("Sample mask shape:", sample_batch["mask"].shape)  # [B, N_res*4]
                print(
                    "Sample residue index shape:", sample_batch["residue_index"].shape
                )  # [B, N_res*4]
                print("Sample residue mask shape:", sample_batch["res_mask"].shape)  # [B, N_res]
            except StopIteration:
                print("Warning: Train DataLoader is empty after filtering.")
                train_loader = None
            except Exception as e:
                print(f"Error fetching sample batch from train_loader: {e}")
                train_loader = None  # Disable loader if fetching fails
        else:
            print("Warning: Training dataset is empty after preprocessing.")
            train_loader = None

        if len(val_set) > 0:
            val_loader = JAXDataLoader(
                val_set,
                batch_size=config.eval_batch_size,
                shuffle=False,
                drop_last=False,
                rng_key=val_key,
            )
            print(f"Validation DataLoader created with {len(val_set)} valid samples.")
        else:
            print("Warning: Validation dataset is empty after preprocessing.")
            val_loader = None

    else:
        print("Not running in Colab or Google Drive not available. Data loading skipped.")
        print("Please adapt data loading if running locally.")

        # For local environment, you could implement alternative data loading logic here
        # Example:
        # local_data_path = "path/to/local/data"
        # if os.path.exists(local_data_path):
        #     # Load and process data similarly to the Colab approach

        raise NotImplementedError("Data loading not implemented for local environment.")

    # Perform training if datasets are available
    if train_loader and val_loader:
        model, train_losses, val_losses = train(
            config, model, diffusion, train_loader, val_loader, rng_key=rng_key
        )

        # Plot training and validation losses if data exists
        if train_losses:
            plot_filename = os.path.join(config.output_dir, "loss_plot.png")
            plot_loss_curves(train_losses, val_losses, plot_filename=plot_filename)

    else:
        print(
            "Could not start training because the training data loader or validation data loader"
            "is empty or failed to initialize."
        )

    return model, diffusion


if __name__ == "__main__":
    config = TrainingConfig()
    model, diffusion = main(config)