# Constrained Search tutorial

This is a complete tutorial going through the process of locating the 40S small ribosomal subunit using the constrained orientation search from identified 60S large ribosomal subunits.
We step through the process generating the reference template maps, running the template matching program, optimizing the template, and finally running the constrained orientation search.
Input data for this tutorial as well as the intermediary result files can be found at [this Zenodo dataset 10.5281/zenodo.15368246](https://zenodo.org/records/15368246).
We will be using the micrograph with filename `xenon_131_000_0.0_DWS.mrc`, but this process can be done for a whole cohort of images.

*Note: Some of the file links in the text are relative paths and may not work for the online documentation. Downloading the notebook and required data locally will fix this, or just inspect the files directly on Zenodo.*

### Tutorial Requirements

In terms of Python libraries, the following are required

* Leopard-EM v1.0 or above
* matplotlib
* TODO

You will also need a CUDA capable GPU for running some of the analyses.
However, you can alternately download the intermediary results from Zenodo instead of running the GPU programs.

In [None]:
# Run this code cell to install required packages
# !pip install leopard-em matplotlib
# TODO: test this and verify which packages are needed

In [None]:
import warnings

import matplotlib.pyplot as plt
import mmdf
import mrcfile
import numpy as np
import pandas as pd
import roma
import torch
from IPython.display import Markdown, display
from scipy.spatial.transform import Rotation
from ttsim3d.models import Simulator, SimulatorConfig

from leopard_em.analysis.zscore_metric import gaussian_noise_zscore_cutoff
from leopard_em.pydantic_models.managers import (
    ConstrainedSearchManager,
    MatchTemplateManager,
    OptimizeTemplateManager,
    RefineTemplateManager,
)

## 1. Download and pre-process required data

The following cells will go through, download and pre-process all the necessary data to process in this tutorial.
This will also create a directory structure to save the micrographs, models, maps, and configuration files.

We also include a few visualizations to see what data we are working with.

In [None]:
import os

import requests


def download_zenodo_file(url: str, out_dir: str) -> str:
    """Helper function to download a file hosted on Zenodo from a URL to given dir."""
    output_filename = url.split("/")[-1]

    response = requests.get(url, stream=True)
    response.raise_for_status()  # Check for request errors

    with open(f"{out_dir}/{output_filename}", "wb") as f:
        for chunk in response.iter_content(chunk_size=8192):
            f.write(chunk)

    return output_filename

In [None]:
# fmt: off
file_downloads = [
    ("mgraphs", "https://zenodo.org/records/15368246/files/xenon_131_000_0.0_DWS.mrc"),
    ("models",  "https://zenodo.org/records/15368246/files/60S_aligned.pdb"),
    ("models",  "https://zenodo.org/records/15368246/files/6q8y_aligned.pdb"),
    ("models",  "https://zenodo.org/records/15368246/files/6q8y_SSU_no_head_aligned.pdb"),
    ("models",  "https://zenodo.org/records/15368246/files/3j77_SSU_aligned_zero.pdb"),
    ("models",  "https://zenodo.org/records/15368246/files/3j78_SSU_aligned_zero.pdb"),
    ("configs", "https://zenodo.org/records/15368246/files/match_template_config_crop.yaml"),
    ("configs", "https://zenodo.org/records/15368246/files/optimize_template_config.yaml"),
    ("configs", "https://zenodo.org/records/15368246/files/match_template_config_60S.yaml"),
    ("configs", "https://zenodo.org/records/15368246/files/match_template_config_40S.yaml"),
    ("configs", "https://zenodo.org/records/15368246/files/refine_template_config_60S.yaml"),
    ("configs", "https://zenodo.org/records/15368246/files/constrained_config_step1.yaml"),
    ("configs", "https://zenodo.org/records/15368246/files/constrained_config_step2.yaml"),
    ("configs", "https://zenodo.org/records/15368246/files/constrained_config_step3.yaml"),
    ("configs", "https://zenodo.org/records/15368246/files/constrained_config_step4.yaml"),
]
# fmt: on

# Loop through files list and download each
for out_dir, file_url in file_downloads:
    # Create the directory if it doesn't exist
    os.makedirs(out_dir, exist_ok=True)

    # Skip if the file already exists
    fname = file_url.split("/")[-1]
    if os.path.exists(f"{out_dir}/{fname}"):
        print(f"Skipped {fname}, it already exists in {out_dir}.")
        continue

    # Download the file
    filename = download_zenodo_file(file_url, out_dir)
    print(f"Downloaded {filename} to {out_dir}")

### Plot of the micrograph

Here, we use the `mrcfile` package to read a .mrc file into a numpy array.
Then, we visualize the micrograph with `matplotlib`.

In [None]:
# Read image into numpy array called 'data'
data = mrcfile.open("mgraphs/xenon_131_000_0.0_DWS.mrc", mode="r").data.copy()

# Plot the greyscale image
plt.figure(figsize=(10, 10))
plt.imshow(data, cmap="gray")
plt.axis("off")
plt.show()

### Pre-process PDB files

We've downloaded a PDB model of the 80S ribosome in the non-rotated state `6q8y_aligned.pdb`.
In addition to this, we have two additional PDB files which correspond to the 40S (`6q8y_SSU_no_head_aligned.pdb`) and 60S (`60S_aligned.pdb`) ribosomal subunits; these were generated externally using the ChimeraX program from the full non-rotated ribosome model.
For the 40S subunit, the head domain has also been removed to leave only the body.
Both the 40S and 60S models have been pre-aligned with respect to the 80S model suing the matchmaker function in ChimeraX so the relative positions and orientations of the models match with each other.

#### Center PDB models

The 80S PDB file is shifted such that the average atomic position is located at $(0, 0, 0)$.
This same shift is applied to the 40S and 60S models so they remain aligned throughout, and all transformed PDB files are written back to disk with an `_aligned_zero` suffix.

In [None]:
def center_pdb_files(pdb_ref: str, pdb_A: str, pdb_B: str) -> None:
    """Transform reference PDB file to average atomic position of (0, 0, 0).

    The same transformation is applied to the other PDB files, A and B, and all files
    are saved with a new '_aligned_zero' suffix.

    Parameters
    ----------
    pdb_ref : str
        Path to reference PDB file to center.
    pdb_A : str
        Additional PDB file to also transform based on reference centering.
    pdb_B : str
        Additional PDB file to also transform based on reference centering.
    """
    # Load PDB models into DataFrame objects
    df_ref = mmdf.read(pdb_ref)
    df_A = mmdf.read(pdb_A)
    df_B = mmdf.read(pdb_B)

    # Extract atom coordinates from reference PDB. Shape of (n_atoms, 3)
    coords = df_ref[["x", "y", "z"]].values
    center = np.mean(coords, axis=0)

    print(f"Center of reference PDB: {center}")

    # Now apply the centering transformation to PDB files
    shift_vector = -center
    df_ref[["x", "y", "z"]] += shift_vector
    df_A[["x", "y", "z"]] += shift_vector
    df_B[["x", "y", "z"]] += shift_vector

    # Save the transformed PDB files with a new name
    mmdf.write(pdb_ref.replace(".pdb", "_aligned_zero.pdb"), df_ref)
    mmdf.write(pdb_A.replace(".pdb", "_aligned_zero.pdb"), df_A)
    mmdf.write(pdb_B.replace(".pdb", "_aligned_zero.pdb"), df_B)


# Center the PDB files
center_pdb_files(
    pdb_ref="models/60S_aligned.pdb",
    pdb_A="models/6q8y_aligned.pdb",
    pdb_B="models/6q8y_SSU_no_head_aligned.pdb",
)

## 2. Initial match template with 60S model

Since we want to constrain the search space for a 40S small subunit (SSU) using the 60S large subunit (LSU), we first need to run full-orientation match template on the LSU model.
We will go through the steps of configuring and running the match template program in Python.
Further details about the match template program in Leopard-EM are located [here on the documentation](TODO-link).

### Generating 3D maps from models

The template matching program requires simulated 3D maps to generate projections form, and below we use the [ttsim3d](https://github.com/teamtomo/ttsim3d) Python package to generate these maps.
For a different dataset/structure, these simulation configurations need to be changed.

In [None]:
# Making a directory to save 3D map files
os.makedirs("maps", exist_ok=True)

In [None]:
# Instantiate the simulation configuration object
sim_conf = SimulatorConfig(
    voltage=300.0,  # in keV
    apply_dose_weighting=True,
    dose_start=0.0,  # in e-/A^2
    dose_end=50.0,  # in e-/A^2
    dose_filter_modify_signal="rel_diff",
    upsampling=-1,  # auto
    mtf_reference="falcon4EC_300kv",
)

# Instantiate the simulator
sim = Simulator(
    pdb_filepath="models/60S_aligned_aligned_zero.pdb",
    pixel_spacing=0.95,  # Angstroms
    volume_shape=(512, 512, 512),
    center_atoms=False,
    remove_hydrogens=True,
    b_factor_scaling=0.5,  # Multiply model b-factors by 1/2
    additional_b_factor=0,
    simulator_config=sim_conf,
)

# Run the simulation and write the output to a file
# We will read this file into memory later
mrc_filepath = "maps/60S_map_px0.95_bscale0.5.mrc"
sim.export_to_mrc(mrc_filepath)

### Plotting a z-slice of the simulated map

Just for visualization purposes, we plot the central z-slice of the simulated 60S map.

In [None]:
volume = mrcfile.open(mrc_filepath, mode="r").data.copy()

plt.imshow(volume[256, :, :], cmap="gray")
plt.axis("off")
plt.show()

### Initial test template matching run

Below, we run the match template program on a small image patch which will give us a few peaks and optimize our template simulation before proceeding.
First, we crop our a central 1k by 1k patch from our 4k by 4k image and save it as a new mrc file.

In [None]:
data = mrcfile.open("mgraphs/xenon_131_000_0.0_DWS.mrc", mode="r").data.copy()

# Crop out a central (1024, 1024) region of the image
data_cropped = data[
    data.shape[0] // 2 - 512 : data.shape[0] // 2 + 512,
    data.shape[1] // 2 - 512 : data.shape[1] // 2 + 512,
]

# Save the cropped image to a new MRC file
# NOTE: This is not updating any of the header information
output_filename = "mgraphs/xenon_131_000_0.0_DWS_cropped_4.mrc"
with mrcfile.new(output_filename, overwrite=True) as mrc:
    mrc.set_data(data_cropped)

Below, we setup and run a full-orientation match template run based on the downloaded [configuration file](./config/match_template_config_crop.yaml).
The programs section of the Leopard-EM documentation contains detailed explanations for each of these fields, so we will continue by running match template.

**Note: This config assumes you have 4 GPUs on your system! You may need to change the `gpu_ids` field depending on your system!**

In [None]:
# Make directory to save program results
os.makedirs("results", exist_ok=True)

Now we run the match template program (this may take around 1 hour depending on GPU hardware).
Alternately, you can skip the next cell and uncomment the code which downloads the already processed results from Zenodo.

In [None]:
YAML_CONFIG_PATH = "configs/match_template_config_crop.yaml"
ORIENTATION_BATCH_SIZE = 16


def run_match_template_cropped_image():
    """Main function to run the match template program on the cropped image."""
    mt_manager = MatchTemplateManager.from_yaml(YAML_CONFIG_PATH)
    mt_manager.run_match_template(ORIENTATION_BATCH_SIZE)
    df = mt_manager.results_to_dataframe(locate_peaks_kwargs={"false_positives": 1.0})
    df.to_csv("results/results_match_template_crop.csv")


# NOTE: invoking from `if __name__ == "__main__"` is necessary
# for proper multiprocessing/GPU-distribution behavior
if __name__ == "__main__":
    run_match_template_cropped_image()

In [None]:
# # Uncomment this to download the cropped match template result files
# # fmt: off
# file_downloads = [
#     ("results", "https://zenodo.org/records/15368246/files/results_match_template_60S.csv"),
#     ("results", "https://zenodo.org/records/15368246/files/output_correlation_average.mrc"),
#     ("results", "https://zenodo.org/records/15368246/files/output_correlation_variance.mrc"),
#     ("results", "https://zenodo.org/records/15368246/files/output_mip.mrc"),
#     ("results", "https://zenodo.org/records/15368246/files/output_orientation_phi.mrc"),
#     ("results", "https://zenodo.org/records/15368246/files/output_orientation_psi.mrc"),
#     ("results", "https://zenodo.org/records/15368246/files/output_orientation_theta.mrc"),
#     ("results", "https://zenodo.org/records/15368246/files/output_relative_defocus.mrc"),
#     ("results", "https://zenodo.org/records/15368246/files/output_scaled_mip.mrc"),
# ]
# # fmt: on

# # Loop through files list and download each
# for out_dir, file_url in file_downloads:
#     # Create the directory if it doesn't exist
#     os.makedirs(out_dir, exist_ok=True)

#     # Skip if the file already exists
#     fname = file_url.split("/")[-1]
#     if os.path.exists(f"{out_dir}/{fname}"):
#         print(f"Skipped {fname}, it already exists in {out_dir}.")
#         continue

#     # Download the file
#     filename = download_zenodo_file(file_url, out_dir)
#     print(f"Downloaded {filename} to {out_dir}")

## 3. Template Optimization

As detailed elsewhere, 2DTM is extremely sensitive to pixel size when simulating a reference map.
Deposited pixel sizes can be a few percent incorrect, either because the pdb model was built into a map with the wrong pixel size and/or the microscope magnification was calibrated incorrectly.

Before proceeding and to maximize our sensitivity, we will run the optimize template program on the cropped image match template run.

### Running the optimize template program

Like the match template program, the optimize template program is configured using a YAML file.
We've already downloaded the necessary [yaml file for optimize template](configs/optimize_template_config.yaml); details of the optimize template program and its configuration are detailed elsewhere in the documentation.

Alternately, skip this and download the results files directly from Zenodo.

In [None]:
OPTIMIZE_YAML_PATH = "configs/optimize_template_config.yaml"


def run_optimize_template():
    """Main function to run the optimize template program."""
    ot_manager = OptimizeTemplateManager.from_yaml(OPTIMIZE_YAML_PATH)
    ot_manager.run_optimize_template(
        output_text_path="results/optimize_template_results.txt"
    )


# NOTE: invoking from `if __name__ == "__main__"` is necessary
# for proper multiprocessing/GPU-distribution behavior
if __name__ == "__main__":
    run_optimize_template()

In [None]:
# # Uncomment this to download the optimize template result files
# # fmt: off
# file_downloads = [
#     ("results", "https://zenodo.org/records/15368246/files/optimize_template_results.txt"),
#     ("results", "https://zenodo.org/records/15368246/files/optimize_template_results_all.csv"),
# ]

# # Loop through files list and download each
# for out_dir, file_url in file_downloads:
#     # Create the directory if it doesn't exist
#     os.makedirs(out_dir, exist_ok=True)

#     # Skip if the file already exists
#     fname = file_url.split("/")[-1]
#     if os.path.exists(f"{out_dir}/{fname}"):
#         print(f"Skipped {fname}, it already exists in {out_dir}.")
#         continue

#     # Download the file
#     filename = download_zenodo_file(file_url, out_dir)
#     print(f"Downloaded {filename} to {out_dir}")

By Inspecting the optimize template result file, we see this gave us an optimized pixel size of $0.936$ Angstroms (actual results possibly $\pm 0.002$ off).
We will proceed using this pixel size rather than the original $0.95$ Angstroms value.

### Other parameter considerations

The other most obvious and important parameter to optimize is the contrast transfer function B-factor.
Since we don't need the most accurate map and highly optimized results for this tutorial, we will skipp this for now and proceed with the default B-factor of $60.0$.

## 4. Re-simulating maps of reference templates

Now that we know the optimized pixel size of $0.936$ Angstroms, we will re-simulate the maps for both the 40S and 60S templates.
Note that the simulation conditions are the same, so we can re-use the simulation configuration object in both cases.

In [None]:
sim_conf = SimulatorConfig(
    voltage=300.0,  # in keV
    apply_dose_weighting=True,
    dose_start=0.0,  # in e-/A^2
    dose_end=50.0,  # in e-/A^2
    dose_filter_modify_signal="rel_diff",
    upsampling=-1,  # auto
    mtf_reference="falcon4EC_300kv",
)

In [None]:
# Instantiate the simulator for the centered 60S model
sim_60S = Simulator(
    pdb_filepath="models/60S_aligned_aligned_zero.pdb",
    pixel_spacing=0.936,  # Angstroms
    volume_shape=(512, 512, 512),
    center_atoms=False,
    remove_hydrogens=True,
    b_factor_scaling=0.5,
    additional_b_factor=0,
    simulator_config=sim_conf,
)

# Run the simulation and write the output to a file
mrc_filepath_60S = "maps/60S_map_px0.936_bscale0.5.mrc"
sim_60S.export_to_mrc(mrc_filepath_60S)

In [None]:
# Instantiate the simulator for the centered 40S model
sim_40S = Simulator(
    pdb_filepath="models/6q8y_SSU_no_head_aligned_aligned_zero.pdb",
    pixel_spacing=0.936,  # Angstroms
    volume_shape=(512, 512, 512),
    center_atoms=False,
    remove_hydrogens=True,
    b_factor_scaling=0.5,
    additional_b_factor=0,  # Add to all atoms
    simulator_config=sim_conf,
)

# Run the simulation and write the output to a file
mrc_filepath_40S = "maps/SSU-body_map_px0.936_bscale0.5.mrc"
sim_60S.export_to_mrc(mrc_filepath_40S)

## 5. Full-image match template

It's now time to run the match template program on the 60S ribosome map.
The YAML configuration file is similar to the one for the cropped image, but we've now updated the pixel size and referenced the full micrograph.

This next cell is a computationally expensive step and takes ~2.75 hours (wall time) on a machine equipped with 4xRTX A6000 ada GPUs.
Again, already processed results can be downloaded directly from Zenodo.

In [None]:
YAML_CONFIG_PATH = "configs/match_template_config_60S.yaml"
ORIENTATION_BATCH_SIZE = 8


def run_match_template_60S():
    """Main function to run the match template program."""
    mt_manager = MatchTemplateManager.from_yaml(YAML_CONFIG_PATH)
    mt_manager.run_match_template(ORIENTATION_BATCH_SIZE)
    df = mt_manager.results_to_dataframe(locate_peaks_kwargs={"false_positives": 1.0})
    df.to_csv("results/results_match_template_60S.csv")


# NOTE: invoking from `if __name__ == "__main__"` is necessary
# for proper multiprocessing/GPU-distribution behavior
if __name__ == "__main__":
    run_match_template_60S()

In [None]:
# Uncomment this to download the cropped match template result files
# fmt: off
file_downloads = [
    ("results", "https://zenodo.org/records/15368246/files/results_match_template_60S.csv"),
    ("results", "https://zenodo.org/records/15368246/files/output_correlation_average.mrc"),
    ("results", "https://zenodo.org/records/15368246/files/output_correlation_variance.mrc"),
    ("results", "https://zenodo.org/records/15368246/files/output_mip.mrc"),
    ("results", "https://zenodo.org/records/15368246/files/output_orientation_phi.mrc"),
    ("results", "https://zenodo.org/records/15368246/files/output_orientation_psi.mrc"),
    ("results", "https://zenodo.org/records/15368246/files/output_orientation_theta.mrc"),
    ("results", "https://zenodo.org/records/15368246/files/output_relative_defocus.mrc"),
    ("results", "https://zenodo.org/records/15368246/files/output_scaled_mip.mrc"),
]
# fmt: on

# Loop through files list and download each
for out_dir, file_url in file_downloads:
    # Create the directory if it doesn't exist
    os.makedirs(out_dir, exist_ok=True)

    # Skip if the file already exists
    fname = file_url.split("/")[-1]
    if os.path.exists(f"{out_dir}/{fname}"):
        print(f"Skipped {fname}, it already exists in {out_dir}.")
        continue

    # Download the file
    filename = download_zenodo_file(file_url, out_dir)
    print(f"Downloaded {filename} to {out_dir}")

### Looking at results from match template

The template matching process has found 408 peaks above the cutoff threshold.
However, this micrograph contains the edge of the lamella in the top-right corner, and the dark patch artificially inflates the z-scores in that region (low search variance).
We make a scatter plot of the peak locations superimposed on the original micrograph; points are colored by variance over search space.

In [None]:
# Read image into numpy array called 'data'
data = mrcfile.open("mgraphs/xenon_131_000_0.0_DWS.mrc", mode="r").data.copy()

# Get the x and y positions of particles from the results csv
df = pd.read_csv("results/results_match_template_60S.csv")
x = df["pos_x_img"].values
y = df["pos_y_img"].values
var = df["correlation_variance"].values

# Plot the greyscale image
plt.figure(figsize=(10, 8))
plt.imshow(data, cmap="gray")
plt.scatter(x, y, c=var, cmap="bwr", alpha=0.5)
plt.colorbar(label="Correlation Variance")
plt.xlabel("X Position (pixels)")
plt.ylabel("Y Position (pixels)")
plt.tight_layout()
plt.show()

Multiple data processing strategies can help account for these dark patches (or other imaging artifacts) some of which are:

* Replacing that region in the image with Gaussian noise with the same mean and standard deviation as the rest of the image,
* Excluding peaks picked in the artifact region, and
* Filtering based on both the MIP and z-score.

For the purposes of this tutorial we simply impose that both MIP and z-score (scaled MIP) are above the cutoff threshold.
This will remove these false positives at the cost of increasing the false negative rate.

In [None]:
# Calculate the cutoff value
# Using the total_correlations from the first row since it's the same for all rows
# Also, manually calculating the number of pixels in the result files
num_ccg = df["total_correlations"].iloc[0] * (4096 - 512 + 1) ** 2
cutoff = gaussian_noise_zscore_cutoff(num_ccg, false_positives=1.0)

print(f"Cutoff value: {cutoff:.4f}")

filtered_df = df[(df["mip"] > cutoff) & (df["scaled_mip"] > cutoff)]
filtered_df.to_csv("results/results_match_template_60S_edit.csv")

# Print the original and filtered number of peaks
print(f"Original number of peaks: {len(df)}")
print(f"Filtered number of peaks: {len(filtered_df)}")

## 6. Running refine template

We will now run the refine template program on the identified 60S particles.
Note that this will not find additional LSU particles, but it will improve our estimates fo the location and orientation for each already found particle.
The main difference between match template and refine template is match template runs on the entire micrograph over all orientations whereas refine template searches over a particle stack on already oriented particles around local orientations.

### Inspecting the refine template YAML file

We are again using a YAML file to configure our refine template run, and this time we will briefly inspect the YAML file.
The `particle_stack` field warrants some discussion.

Here, `particle_stack.df_path` is the path to the output csv file from the match template program; this csv file contains per-particle information on location and orientation, and we have already used this csv file to visualize results in the above plot.

The two other field, `particle_stack.original_template_size` and `particle_stack.extracted_box_size`, are used to crop out boxes around a particle in the original image.
Our map was simulated as a $(512, 512, 512)$ box so the generated projections and therefore original template size field are both $[512, 512]$.
The extracted box size here is $[518, 518]$, six pixels (must be even) wider/taller than the projections meaning the peak is constrained to a 6x6 region during the refinement.

Again, if your system as fewer/more GPUs, remember to adjust the `gpu_ids` field.

In [None]:
# Read the YAML file
with open("configs/refine_template_config_60S.yaml") as f:
    yaml_content = f.read()

# Display as markdown code block
display(Markdown(f"```yaml\n{yaml_content}\n```"))

### Running refine template

Template refinement is less computationally intensive than full match template, but the results can again be downloaded directly.

In [None]:
YAML_CONFIG_PATH = "configs/refine_template_config_60S.yaml"
DATAFRAME_OUTPUT_PATH = "results/results_refine_template_60S.csv"
PARTICLE_BATCH_SIZE = 64  # Adjust based on your GPU memory


def run_refine_template_60S():
    """Main function to run the refine template program."""
    rt_manager = RefineTemplateManager.from_yaml(
        "configs/refine_template_config_60S.yaml"
    )

    # Ignore UserWarning during refinement call
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", UserWarning)
        rt_manager.run_refine_template(DATAFRAME_OUTPUT_PATH, PARTICLE_BATCH_SIZE)


# NOTE: Invoking program under `if __name__ == "__main__"` necessary for multiprocessing
if __name__ == "__main__":
    run_refine_template_60S()

In [None]:
# # Uncomment this to download the refine template result files
# # fmt: off
# file_downloads = [
#     ("results", "https://zenodo.org/records/15368246/files/results_refine_template_60S.csv"),
# ]

# # Loop through files list and download each
# for out_dir, file_url in file_downloads:
#     # Create the directory if it doesn't exist
#     os.makedirs(out_dir, exist_ok=True)

#     # Skip if the file already exists
#     fname = file_url.split("/")[-1]
#     if os.path.exists(f"{out_dir}/{fname}"):
#         print(f"Skipped {fname}, it already exists in {out_dir}.")
#         continue

#     # Download the file
#     filename = download_zenodo_file(file_url, out_dir)
#     print(f"Downloaded {filename} to {out_dir}")

## 7. Initial 40S match template search

The 40S SSU is smaller and more conformationally flexible than the 60S LSU making it harder to identify using 2DTM.
Nonetheless, we will first do a full match template search for the 40S and look at the results.

Running an initial full match template search on the constrained particle is also necessary since we need the mean and variance of cross-correlation values over the search space for normalization.
Also, we are only searching for the SSU body since it is more rigid leading to higher quality 2DTM results.

In [None]:
YAML_CONFIG_PATH = "configs/match_template_config_40S.yaml"
ORIENTATION_BATCH_SIZE = 8


def run_match_template_40S():
    """Main function to run the match template program."""
    mt_manager = MatchTemplateManager.from_yaml(YAML_CONFIG_PATH)
    mt_manager.run_match_template(ORIENTATION_BATCH_SIZE)
    df = mt_manager.results_to_dataframe(locate_peaks_kwargs={"false_positives": 1.0})
    df.to_csv("results/results_match_template_40S.csv")


# NOTE: invoking from `if __name__ == "__main__"` is necessary
# for proper multiprocessing/GPU-distribution behavior
if __name__ == "__main__":
    run_match_template_40S()

In [None]:
# TODO: Download cell for pre-processed 40S results

In [None]:
# TODO: Scatter plot of the 40S results like for 60S

We've found 177 peaks, but many of these fall in the corner of the micrograph in the lamella edge as visualized below.
Applying the same filtering cutoff filtering process removes these false-positives.

In [None]:
# TODO: code for filtering the 40S particles and re-plotting their locations along with
# 60S particles

Data for both the SSU and LSU have now been processed, and we could attempt to combine these results to classify complete 80S ribosomes based on rotation angle.
However, many of the SSU particle are obsucred by noise and therefore irrecoverable.

A better strategy would be a constraind search for the 40S particle based on the positions and orientations of already identified 60S particles.
By restricting the 2DTM search space (both location and orientation), we lower our noise floor thus increasing our sensitivity.

## 8. Constrained Search

The 40S subunit has many degrees of freedom. It can rotate relative to the 60S and also within the 40S there is head and body domain that can 'swivel' with respect to one another.
For the purpose of this tutorial, we will ignore the latter and consider the 40S as one rigid body moving with respect to the 60S.

### ChimeraX processing

The first thing we need to do is find out the axis of rotation (between the LSU and SSU) using two PDB models, 3j77 and 3j78, that are in different rotational states.
These two models were first processed in ChimeraX using the matchmaker function to align them relative to the 60S model we template matched against.
After alignment, only the protein structures in the 40S subunit were selected and exported to the two files: `3j77_SSU_aligned_zero.pdb` and `3j78_SSU_aligned_zero.pdb` which were downloaded previously.

### Script for finding the rotation axis

We now run the following script, [`Leopard-EM/programs/constrained_search/utils/get_rot_axis.py`](https://github.com/Lucaslab-Berkeley/Leopard-EM/blob/main/programs/constrained_search/utils/get_rot_axis.py), which finds the relative rotation axis between two PDB models and calculates the rotation to align this rotation axis along the "Z" direction.
The script outputs the Euler angles necessary for alignment which is a necessary input for the constrained search program.

In practice, you would run `python get_rot_axis.py <pdb_file1> <pdb_file2> <output_file>`, but to keep the tutorial self-contained, we copy the script functions below.

In [None]:
def extract_rotation_axis_angle(
    rotmat: torch.Tensor | np.ndarray,
) -> tuple[np.ndarray, float]:
    """Extract rotation axis and angle from rotation matrix handling edge cases.

    Attributes
    ----------
    rotmat: torch.Tensor | np.ndarray
        The rotation matrix either as a torch tensor or numpy array.

    Returns
    -------
    tuple[np.ndarray, float]
        The rotation axis and angle with angle in units of radians.
    """
    rotmat = rotmat.numpy() if isinstance(rotmat, torch.Tensor) else rotmat

    rotation = Rotation.from_matrix(rotmat)
    rotvec = rotation.as_rotvec()

    angle = np.linalg.norm(rotvec)

    # Handle edge case for very small angles (near zero)
    if np.abs(angle) < 1e-6:
        return np.array([0.0, 0.0, 1.0]), angle

    # NOTE: Edge case for angles near 180 degrees handled by scipy internally
    axis = rotvec / angle

    return axis, angle


def calculate_axis_euler_angles(axis: torch.Tensor | np.ndarray) -> tuple[float, float]:
    """Calculate Euler angles (ZYZ) that for the rotation axis.

    Attributes
    ----------
    axis: torch.Tensor | np.ndarray
        The rotation axis.

    Returns
    -------
    tuple[float, float]
        The Euler angles in units of degrees
    """
    z_axis = np.array([0.0, 0.0, 1.0], dtype=np.float32)
    axis = axis.numpy() if isinstance(axis, torch.Tensor) else axis

    # Edge case for axis already aligned with z-axis
    if np.linalg.norm(axis - z_axis) < 1e-6:
        return 0.0, 0.0

    # Edge case for axis anti-aligned with z-axis
    if np.linalg.norm(axis + z_axis) < 1e-6:
        return 0.0, 180.0

    # Calculate theta - angle from z-axis (polar angle)
    cos_theta = np.dot(axis, z_axis)
    theta = np.acos(np.clip(cos_theta, -1.0, 1.0)) * 180 / np.pi

    # Calculate phi - angle in xy plane (azimuthal angle)
    phi = np.atan2(axis[1], axis[0]) * 180 / np.pi
    if phi < 0:
        phi += 360.0  # Convert to 0-360 range

    return phi, theta


def process_pdb_files(
    pdb_file1: str, pdb_file2: str
) -> tuple[np.ndarray, float, float, float]:
    """Helper function to calculate the rotation axis and angle for two PDB files.

    Parameters
    ----------
    pdb_file1: str
        Path to the first PDB file.
    pdb_file2: str
        Path to the second PDB file.

    Returns
    -------
    tuple[np.ndarray, float, float, float]
        The rotation axis, rotation angle in radians, and Euler angles (phi, theta).
    """
    # Read PDB files
    df1 = mmdf.read(pdb_file1)
    df2 = mmdf.read(pdb_file2)

    # Extract coordinates
    coords1 = torch.tensor(df1[["x", "y", "z"]].values, dtype=torch.float32)
    coords2 = torch.tensor(df2[["x", "y", "z"]].values, dtype=torch.float32)

    # Center coordinates
    centroid1 = coords1.mean(dim=0)
    centroid2 = coords2.mean(dim=0)
    coords1_centered = coords1 - centroid1
    coords2_centered = coords2 - centroid2

    # Calculate rotation matrix
    rotation_matrix, _ = roma.rigid_points_registration(
        coords1_centered, coords2_centered
    )

    # Extract rotation axis and angle plus Euler angles
    rotation_axis, rotation_angle = extract_rotation_axis_angle(rotation_matrix)
    phi, theta = calculate_axis_euler_angles(rotation_axis)

    # radians to degrees
    rotation_angle = np.rad2deg(rotation_angle)

    return rotation_axis, rotation_angle, phi, theta


def write_results(
    output_file: str,
    pdb_file1: str,
    pdb_file2: str,
    rotation_axis: np.ndarray,
    rotation_angle: float,
    phi: float,
    theta: float,
) -> None:
    """Helper function to write the script results to a file."""
    suggested_range = min(30.0, max(10.0, rotation_angle / 2))
    results_string = f"""# PDB Rotation Analysis Results\n
    Source PDB: {pdb_file1}
    Target PDB: {pdb_file2}

    ## Rotation Parameters
    Axis: {rotation_axis[0]:.6f} {rotation_axis[1]:.6f} {rotation_axis[2]:.6f}
    Angle: {rotation_angle:.6f} degrees\n

    ## Axis Orientation Angles (input for constrained search config)
    rotation_axis_euler_angles: [{phi:.2f}, {theta:.2f}, 0.0]\n

    ## Example constrained search config
    orientation_refinement_config:
      enabled: true
      out_of_plane_step: 1.0   # Step size around the rotation axis
      in_plane_step: 0.5       # Step size for fine adjustment angles
      rotation_axis_euler_angles: [{phi:.2f}, {theta:.2f}, 0.0]
      phi_min: -{suggested_range:.1f}  # Search range for around the axis
      phi_max: {suggested_range:.1f}
      theta_min: -2.0  # Small adjustments perpendicular to axis (optional)
      theta_max: 2.0
      psi_min: -2.0    # Small in-plane adjustments (optional)
      psi_max: 2.0
    """

    # Print the script results to the console
    print(results_string)

    # And also write them to a file
    with open(output_file, "w", encoding="utf-8") as f:
        f.write(results_string)

    print(f"Rotation analysis written to: {output_file}")

In [None]:
pdb_file1 = "models/3j77_SSU_aligned_zero.pdb"
pdb_file2 = "models/3j78_SSU_aligned_zero.pdb"
output_file = "results/rotation_axis.txt"

rotation_axis, rotation_angle, phi, theta = process_pdb_files(pdb_file1, pdb_file2)
write_results(
    output_file,
    pdb_file1,
    pdb_file2,
    rotation_axis,
    rotation_angle,
    phi,
    theta,
)

These Euler angles tell us the rotation matrix needed to align rotations along the Z-axis meaning we can perform a simple angular search over the phi/psi range.

### Determining relative defocus of constrained particle

As well as constraining orientations and $(x, y)$ positions, we can eliminate the defocus search.
The constrained search program uses the defocus and orientation of the 60S particle to work out the defocus of the 40S particle.
However, we need to provide the offset between the two subunits which can be calculated using this script: [`Leopard-EM/programs/constrained_search/utils/get_center_vector.py`](https://github.com/Lucaslab-Berkeley/Leopard-EM/blob/main/programs/constrained_search/utils/get_center_vector.py).
The contents of the script are copied below to keep the tutorial self-contained.

In [None]:
def calculate_mean_position(df: pd.DataFrame) -> torch.Tensor:
    """Calculate the mean position of a PDB structure loaded into a dataframe."""
    coords = torch.tensor(df[["x", "y", "z"]].values, dtype=torch.float32)
    mean_pos = coords.mean(dim=0)

    return mean_pos


def calculate_relative_vectors(pdb_file1: str, pdb_file2: str) -> dict:
    """
    Calculate the relative position and orientation vectors between two PDB structures.

    Parameters
    ----------
    pdb_file1 : str
        Path to the first PDB file
    pdb_file2 : str
        Path to the second PDB file

    Returns
    -------
    dict
        Dictionary containing relative vector data including:
        - df1, df2: DataFrames for both PDB files
        - vector: Vector from PDB1 to PDB2
        - euler_angles: Phi, Theta, Psi angles
        - z_diff: Z height difference
        - defocus_description: Human-readable defocus description
    """
    # Parse PDB files using mmdf
    df1 = mmdf.read(pdb_file1)
    df2 = mmdf.read(pdb_file2)

    print(f"File 1: {pdb_file1} - {len(df1)} atoms")
    print(f"File 2: {pdb_file2} - {len(df2)} atoms")

    # Calculate mean positions at default orientation (0, 0, 0)
    mean_pos1 = calculate_mean_position(df1)
    mean_pos2 = calculate_mean_position(df2)

    # Calculate vector from PDB1 to PDB2
    vector = mean_pos2 - mean_pos1

    # Convert vector to Euler angles
    phi, theta, psi = roma.rotvec_to_euler(
        convention="ZYZ", rotvec=vector, degrees=True, as_tuple=True
    )

    # Calculate Z-height difference (defocus)
    z_diff = vector[2].item()
    defocus_description = (
        f"{abs(z_diff):.2f} Angstroms {'below' if z_diff < 0 else 'above'}"
    )

    # Print initial results
    initial_results = f"""Initial Analysis:
    Vector from PDB1 to PDB2: [{vector[0]:.6f}, {vector[1]:.6f}, {vector[2]:.6f}]
    Vector Euler angles (ZYZ, deg): Phi={phi:.2f}, Theta={theta:.2f}, Psi={psi:.2f}
    Z-height difference (defocus): {defocus_description}
    """
    print(initial_results)

    return {
        "df1": df1,
        "df2": df2,
        "vector": vector,
        "euler_angles": (phi, theta, psi),
        "z_diff": z_diff,
        "defocus_description": defocus_description,
    }


def process_rotations(vector: torch.Tensor, num_rotations: int) -> list:
    """
    Process each rotation and calculate the resulting defocus.

    Parameters
    ----------
    vector : torch.Tensor
        The original vector between structures
    num_rotations : int
        Number of random rotations to test (in addition to default orientation)

    Returns
    -------
    list
        List of dictionaries with defocus results for each rotation
    """
    print("\nDefocus changes for different rotations:")
    defocus_results = []

    for i in range(num_rotations + 1):
        if i == 0:
            rand_rotmat = torch.eye(3)
        else:
            rand_rotmat = roma.random_rotmat()

        rand_euler = roma.rotmat_to_euler("ZYZ", rand_rotmat, degrees=True)
        rotated_vector = rand_rotmat @ vector

        # Extract new z-component (defocus)
        new_z_diff = rotated_vector[2].item()
        new_defocus = (
            f"{abs(new_z_diff):.2f} Angstroms {'below' if new_z_diff < 0 else 'above'}"
        )
        print(f"Rotation #{i+1} - {rand_euler}: Defocus = {new_defocus}")

        defocus_results.append(
            {
                "rotation": i + 1,
                "euler_angles": [angle.item() for angle in rand_euler],
                "defocus": new_z_diff,
                "description": new_defocus,
            }
        )

    return defocus_results


def write_results_to_file(
    output_file: str,
    pdb_file1: str,
    pdb_file2: str,
    vector_info: dict,
    defocus_results: list,
) -> None:
    """
    Write analysis results to output file.

    Parameters
    ----------
    output_file : str
        Path to output file
    pdb_file1 : str
        Path to first PDB file
    pdb_file2 : str
        Path to second PDB file
    vector_info : dict
        Dictionary with vector data from calculate_relative_vectors
    defocus_results : list
        List of defocus results from process_rotations
    """
    vector = vector_info["vector"]
    phi, theta, psi = vector_info["euler_angles"]
    defocus_description = vector_info["defocus_description"]

    result_string = f"""# PDB Vector and Defocus Analysis
    Source PDB 1: {pdb_file1}
    Source PDB 2: {pdb_file2}

    ## Initial Vector Analysis
    Vector PDB1-PDB2: [{vector[0]:.6f}, {vector[1]:.6f}, {vector[2]:.6f}]
    Vector Eulers (ZYZ, deg): Phi={phi:.2f}, Theta={theta:.2f}, Psi={psi:.2f}
    Z-height difference (defocus): {defocus_description}

    ## Defocus changes for different rotations
    """
    for result in defocus_results:
        euler = result["euler_angles"]
        result_string += (
            f"    Rotation #{result['rotation']} - "
            f"    Euler({euler[0]:.2f}, {euler[1]:.2f}, {euler[2]:.2f}): "
        )
        result_string += f"Defocus = {result['description']}\n"

    # Write results to file
    with open(output_file, "w") as f:
        f.write(result_string)

    print(f"\nAnalysis results written to {output_file}")

In [None]:
output_file = "results/rotation_defocus_analysis.txt"
pdb_file1 = "models/60S_aligned_aligned_zero.pdb"
pdb_file2 = "models/6q8y_SSU_no_head_aligned_aligned_zero.pdb"

vector_info = calculate_relative_vectors(pdb_file1, pdb_file2)
defocus_results = process_rotations(vector_info["vector"], 5)
write_results_to_file(output_file, pdb_file1, pdb_file2, vector_info, defocus_results)

The piece of information we need for the constrained search is the relative vector from the first PDB (reference 60S model) to the second PDB (40S without head).
In this case, the positional vector is `[88.023109, 52.080257, 45.528008]`.

### Configuring and running the constrained search

Now that the hard work of obtaining the necessary info for configuring the constrained search, we can move onto actually running the program.
The constrained search is a balance between limiting the number of cross-correlations to minimize noise and increasing the 2DTM SNR with progressively finer sampling.
We employ a multi-step approach in this tutorial to strike this balance.

Just like the other programs in Leopard-EM, we have a YAML configuration file for the constrained search program.
Our first search is around the Z-axis done by specifying a range for angle psi.

*Note: The Euler angles phi and psi, in this case, are degenerate, and we could equivalently proceed using phi in-place of psi.*

In [None]:
# Read the YAML file
with open("configs/constrained_config_step1.yaml") as file:
    yaml_content = file.read()

# Display as markdown code block
display(Markdown(f"```yaml\n{yaml_content}\n```"))

In [None]:
YAML_CONFIG_PATH = "configs/constrained_config_step1.yaml"
DATAFRAME_OUTPUT_PATH = "results/constrained_search_results_step1.csv"
PARTICLE_BATCH_SIZE = 64
FALSE_POSITIVES = 0.005  # False positives per particle


def run_constrained_search_step1():
    """Main function to run the constrained search program."""
    cs_manager = ConstrainedSearchManager.from_yaml(YAML_CONFIG_PATH)

    # Ignore UserWarning during refinement call
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", UserWarning)
        cs_manager.run_constrained_search(
            DATAFRAME_OUTPUT_PATH,
            locate_peaks_kwargs={"false_positives": FALSE_POSITIVES},
            particle_batch_size=PARTICLE_BATCH_SIZE,
        )
    cs_manager = ConstrainedSearchManager.from_yaml(YAML_CONFIG_PATH)
    cs_manager.run_constrained_search(
        output_dataframe_path=DATAFRAME_OUTPUT_PATH,
        false_positives=FALSE_POSITIVES,
        orientation_batch_size=PARTICLE_BATCH_SIZE,
    )


# NOTE: Invoking program under `if __name__ == "__main__"` necessary for multiprocessing
if __name__ == "__main__":
    run_constrained_search_step1()

We have increased the number of 40S picks by over an order of magnitude, from 16 to 262, using the constrained 2DTM search.

In [None]:
df_40S_full = pd.read_csv(
    "results/results_match_template_40S_filtered.csv"
)  # TODO: find appropriate name
df_40S_constrained_step1 = pd.read_csv("results/constrained_search_results_step1.csv")

print(f"Full match template 40S particles: {len(df_40S_full)}")
print(f"Constrained match template search: {len(df_40S_constrained_step1)}")

### Running the second constrained search

There could be a further benefit to searching over an axis orthogonal to the Z-axis.
We will do this in two steps, where a search over each rotation axis is considered independently, to minimize the total number of cross-correlations calculated and control the noise floor.

We can define the direction of this orthogonal axis using the `roll_axis` parameter (default is Y-axis with `[0, 1]`).
If the rotation in the orthogonal direction is significant, then we could find the best roll axis to search over by setting the `search_roll_axis` field to true.
However, we expect the rotation to be small, se we just use the default axis for now.

The file [`configs/constrained_config_step2.yaml`](./configs/constrained_config_step2.yaml) contains the updated orientation search parameters for the roll axis, and we perform a second constrained search using the previous results.

In [None]:
YAML_CONFIG_PATH = "configs/constrained_config_step2.yaml"
DATAFRAME_OUTPUT_PATH = "results/constrained_search_results_step2.csv"
PARTICLE_BATCH_SIZE = 64
FALSE_POSITIVES = 0.005  # False positives per particle


def run_constrained_search_step2():
    """Main function to run the constrained search program."""
    cs_manager = ConstrainedSearchManager.from_yaml(YAML_CONFIG_PATH)
    cs_manager.run_constrained_search(
        output_dataframe_path=DATAFRAME_OUTPUT_PATH,
        false_positives=FALSE_POSITIVES,
        orientation_batch_size=PARTICLE_BATCH_SIZE,
    )


# NOTE: Invoking program under `if __name__ == "__main__"` necessary for multiprocessing
if __name__ == "__main__":
    run_constrained_search_step2()

### Running the third & fourth constrained searches

The 40S parameters can be further refined, and we could potentially identify more particles.
But at each iteration, we accumulate more cross-correlations which in the 2DTM noise model will raise our noise floor.
Again, this is a balance of maximizing sensitivity while controlling for noise.

We will perform two more successive constrained searches, each with a finer angular sampling step.

In [None]:
YAML_CONFIG_PATH = "configs/constrained_config_step3.yaml"
DATAFRAME_OUTPUT_PATH = "results/constrained_search_results_step3.csv"
PARTICLE_BATCH_SIZE = 64
FALSE_POSITIVES = 0.005  # False positives per particle


def run_constrained_search_step3():
    """Main function to run the constrained search program."""
    cs_manager = ConstrainedSearchManager.from_yaml(YAML_CONFIG_PATH)
    cs_manager.run_constrained_search(
        output_dataframe_path=DATAFRAME_OUTPUT_PATH,
        false_positives=FALSE_POSITIVES,
        orientation_batch_size=PARTICLE_BATCH_SIZE,
    )


# NOTE: Invoking program under `if __name__ == "__main__"` necessary for multiprocessing
if __name__ == "__main__":
    run_constrained_search_step3()

During the fourth and final constrained search, we also include a defocus search (configured within the step 4 YAML file).

In [None]:
YAML_CONFIG_PATH = "configs/constrained_config_step4.yaml"
DATAFRAME_OUTPUT_PATH = "results/constrained_search_results_step4.csv"
PARTICLE_BATCH_SIZE = 64
FALSE_POSITIVES = 0.005  # False positives per particle


def run_constrained_search_step4():
    """Main function to run the constrained search program."""
    cs_manager = ConstrainedSearchManager.from_yaml(YAML_CONFIG_PATH)
    cs_manager.run_constrained_search(
        output_dataframe_path=DATAFRAME_OUTPUT_PATH,
        false_positives=FALSE_POSITIVES,
        orientation_batch_size=PARTICLE_BATCH_SIZE,
    )


# NOTE: Invoking program under `if __name__ == "__main__"` necessary for multiprocessing
if __name__ == "__main__":
    run_constrained_search_step4()

## 9. Bringing all the results together

We now need to combine the results from the four steps in our constrained search together.
There is a helper script [`Leopard-EM/programs/constrained_search/sequential_threshold_processing.py`](TODO-link) which does this automatically (assuming the file name formatting follows this tutorial).

In [None]:
# TODO: include code for the sequential thresholding