### 🌐 Colab

This notebook is designed to be run on Google Colab (it runs fastes on GPU).

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/tutorials/hyrax_demo.ipynb)

# Hyrax Demo Notebook

This notebook demonstrates the basic functionalities of the Hyrax package for scalable ML and unsupervised discovery.

**Note:** For optimal performance, especially during model training and inference, it is highly recommended to change your Colab runtime to use a GPU. You can do this by going to `Runtime > Change runtime type` and selecting 'GPU' as the hardware accelerator.

## Installation & Imports

In [None]:
# If NOT RUNNING on Colab; create a Conda environment first
# conda create -n hyrax python=3.12 (or your preferred Python 3.9+ version)
# conda activate hyrax
# uncomment and run the following code line for first time hyrax run on colab

import subprocess
import sys

def is_package_installed(package_name):
    try:
        subprocess.run([sys.executable, "-m", "pip", "show", package_name], capture_output=True, check=True)
        return True
    except subprocess.CalledProcessError:
        return False

if not is_package_installed("hyrax"):
    print("Installing hyrax...")
    !pip install hyrax
else:
    print("hyrax is already installed.")

In [1]:
# This step is for NGrok, which helps expose local services (like MLflow or TensorBoard) to a public URL.
# It's especially useful in environments like Google Colab where direct localhost access is tricky for embedded UIs.
# If running locally and not needing public URLs, you might not need this.

import subprocess
import sys

def is_package_installed(package_name):
    try:
        subprocess.run([sys.executable, "-m", "pip", "show", package_name], capture_output=True, check=True)
        return True
    except subprocess.CalledProcessError:
        return False

if not is_package_installed("pyngrok"):
    print("Installing pyngrok...")
    !pip install pyngrok # Not needed if not doing this in Google Colab
else:
    print("pyngrok is already installed.")

pyngrok is already installed.


In [2]:
# Replace with your actual NGROK authentication token if you plan to use ngrok
# You can get one from dashboard.ngrok.com
!ngrok authtoken 2xTNZ4xDL9yJiaT604sqUiYDoNh_7ucmysrPxLbUH1QM2kFqi

Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml


In [3]:
import hyrax
import pooch # We'll use this to retrieve some example Hyper Suprime Cam data from Zenodo
import numpy as np
import time
import matplotlib.pyplot as plt
import subprocess
from pathlib import Path
from IPython.display import IFrame, display

# Colab Interactive Output for HoloViews/Bokeh and MLflow via ngrok
try:
    from google.colab import output
    output.enable_custom_widget_manager()
    COLAB_ENV = True
except ImportError:
    COLAB_ENV = False

if COLAB_ENV:
    from pyngrok import ngrok

## Downloading a Small HSC Dataset

For this dataset, we will be using a small dataset of ~1000 cutouts from HSC. These are all extended sources in the redshift range 0.25 < z < 0.50 and mag < 23.

Note that we are not using hyrax's `download` module here just to make this live demo more reproducible.

In [4]:
# DOI for Example HSC dataset
file_path = pooch.retrieve(
    url="doi:10.5281/zenodo.14498536/hsc_demo_data.zip",
    known_hash="md5:1be05a6b49505054de441a7262a09671", # Updated hash
    fname="example_hsc_new.zip",
    path="./data", # Adjusted path for general use, will create ./data/ if it doesn't exist
    processor=pooch.Unzip(extract_dir=".") # Extracts to ./data/example_hsc_new (relative to path)
)
print(f"Dataset downloaded and extracted to: {Path('./data/example_hsc_new').resolve()}")

Dataset downloaded and extracted to: /content/data/example_hsc_new


In [20]:
# prompt: lets generate a code cell to verify and describe "./data/example_hsc_new"

import os

# Use the absolute path where the data was extracted
file_path = "/content/data/hsc_8asec_1000"

if os.path.exists(file_path):
  print(f"The directory '{file_path}' exists.")
  # Further checks if it's a directory
  if os.path.isdir(file_path):
    print(f"'{file_path}' is a directory.")
    # Count .fits files in the directory
    fits_file_count = sum(1 for item in os.listdir(file_path) if item.endswith('.fits'))
    print(f"Number of .fits files: {fits_file_count}")
  else:
    print(f"'{file_path}' exists but is not a directory.")
else:
  print(f"The directory '{file_path}' does not exist.")

The directory '/content/data/hsc_8asec_1000' exists.
'/content/data/hsc_8asec_1000' is a directory.
Number of .fits files: 2980


## Step 0 - Create a Hyrax Object

The first step of working with hyrax is to create a Hyrax object. In 99% of cases, this is the first step you will need to do when working with Hyrax.

In [21]:
h = hyrax.Hyrax()

[2025-05-23 01:24:14,664 hyrax:INFO] Runtime Config read from: /usr/local/lib/python3.11/dist-packages/hyrax/hyrax_default_config.toml
INFO:hyrax:Runtime Config read from: /usr/local/lib/python3.11/dist-packages/hyrax/hyrax_default_config.toml


Now after creating the object, lets configure a hyrax object to set up our training. The alternative to doing this interactively is to use a configuration file. If you don't specify any of the configuration parameters, hyrax will default settings for everything.

In [23]:
# Specify the location of the data to use for training
# Adjust this path if your data is located elsewhere
h.config["general"]["data_dir"] = "/content/data/hsc_8asec_1000"

# Specify the dataset class that represents the data
h.config["data_set"]["name"] = "HSCDataSet"
h.config["data_set"]["train_size"] = 0.8
h.config["data_set"]["validate_size"] = 0.2
h.config["data_set"]["test_size"] = 0.0

# Select the model to use for training
h.config["model"]["name"] = "HyraxAutoencoder"

# Set the number of epochs and batch size for training.
h.config["train"]["epochs"] = 20
h.config["data_loader"]["batch_size"] = 32

## Step 1 - Training Your Model

In [None]:
%%time #with GPU selected in colab this takes ~30 seconds
h.train()

### 1.1 (If You Care) Let's Check Some System Metrics

Hyrax uses MLflow in the backend to log experiments. You can check the GPU/CPU utilization and other metrics in the MLflow UI.

In [None]:
%load_ext tensorboard
%tensorboard --logdir {h.config['general']['results_dir']}

### 1.2 - Let's Define A Slightly Different Model & Train It

This demonstrates how to define your own custom PyTorch model and integrate it with Hyrax. The key is the `@hyrax_model` decorator and defining `forward` and `train_step` methods.

In [28]:
import torch.nn as nn
from hyrax.models.model_registry import hyrax_model

@hyrax_model # This decorator registers the model with the FIBAD framework
class TrialAutoencoder(nn.Module):
    def __init__(self, config, shape):
        super().__init__()
        self.config = config

        #Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1), # -> [16, 48, 48]
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # -> [32, 24, 24]
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # -> [64, 12, 12]
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # -> [128, 6, 6]
            nn.ReLU(),
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # -> [64, 12, 12]
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # -> [32, 24, 24]
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # -> [16, 48, 48]
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1), # -> [3, 96, 96]
            nn.Sigmoid(), # Normalize output to [0, 1]
        )

    def eval_encoder(self, x):
        return self.encoder(x)

    def eval_decoder(self, x):
        return self.decoder(x)

    def forward(self, x):
        return self.eval_decoder(self.eval_encoder(x))

    def train_step(self, x):
        z = self.eval_encoder(x)
        x_hat = self.eval_decoder(z)

        loss = self.criterion(x, x_hat)
        # Modified loss calculation to use .mean()
        loss = loss.mean()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return {"loss": loss.item()}

In addition to the @fibad_model, other decorators to extend extensibility and reduce boilerplate code have been developed including; @fibad_dataset for rapid development of new data set interfaces \n
@fbad_verb - for new core actions ie f.custom_train(...) f.bespoke_predict(...) etc.

In [None]:
%%time
# Specify that we now want to train the model defined in this notebook
h.config["model"]["name"] = "TrialAutoencoder"

# Define loss and optimizer functions for easy experimentation
h.config["criterion"]["name"] = "torch.nn.MSELoss"
h.config["criterion"]["reduction"] = "none"

h.config["optimizer"]["name"] = "torch.optim.Adam"
h.config["optimizer"]["lr"] = 1e-3

# train the new model
h.train()

### 1.3 Compare Models (using MLflow)

MLflow is automatically used by Hyrax to log experiments. We can launch its UI to compare different runs, parameters, and metrics.

In [None]:
# The backend_store_uri and mlflow_cmd were set up in a previous cell for TensorBoard/MLflow.
# We'll reuse that setup or ensure it's correctly pointing to MLflow's data.

backend_store_uri = f"file://{Path(h.config['general']['results_dir']).resolve()}/mlflow"
mlflow_ui_cmd = f"mlflow ui --backend-store-uri {backend_store_uri} --host 0.0.0.0 --port 5000"

# Check if MLflow is already running, if not, start it.
# This is a simplified check; a more robust check might involve checking the process or port.
try:
    # Try to kill any existing MLflow on port 5000 to avoid conflicts, especially in Colab
    subprocess.run("!kill $(lsof -t -i:5000)", shell=True, check=False, capture_output=True)
    time.sleep(2) # Give it a moment to shut down
except Exception:
    pass # lsof might not be available or no process running

subprocess.Popen(f"nohup {mlflow_ui_cmd} > mlflow_ui.log 2>&1 &", shell=True,  close_fds=True)
time.sleep(5) # Give MLflow UI a moment to start

public_url_mlflow = "http://localhost:5000"
if COLAB_ENV:
    try:
        # Disconnect previous ngrok tunnel for port 5000 if it exists to avoid errors
        tunnels = ngrok.get_tunnels()
        for tunnel in tunnels:
            if 'localhost:5000' in tunnel.config['addr']:
                ngrok.disconnect(tunnel.public_url)
                ngrok.kill() # ensure ngrok process is stopped
                time.sleep(1)
                break
        public_url_mlflow = ngrok.connect(addr="5000", proto="http").public_url
        print(f"MLflow App is live at: {public_url_mlflow}")
        display(IFrame(src=public_url_mlflow, width="100%", height=800))
    except Exception as e:
        print(f"Could not (re)start ngrok for MLflow: {e}. You might need to set an NGROK authtoken or check ngrok processes.")
        print(f"You can still try accessing MLflow locally at http://localhost:5000 if running locally.")
else:
    print(f"MLflow App should be live at: {public_url_mlflow}")
    display(IFrame(src=public_url_mlflow, width="100%", height=800))

## Step 2 - Running Inference

Let's now perform inference on the holdout dataset using one of our trained models.

In [None]:
%%time
h.config["general"]["dev_mode"] = True # To ensure it uses the latest model from the demo
h.config["data_set"]["test_size"] = 1.0
h.config["data_set"]["train_size"] = 0.0
h.config["data_set"]["validate_size"] = 0.0
h.config["data_loader"]["batch_size"] = 128

h.infer()

## Step 3 - Visualization

First, we reduce the dimensionality of the latent space (embeddings) using UMAP.

In [None]:
%%time
h.umap() #(src: Any, width: Any, height: Any, extras: Iterable[str] = None, **kwargs: Any) -> None

Now, we can visualize the UMAP-projected latent space.

**Note for Colab Users:** The interactive HoloViews/Bokeh plot might not render correctly directly in the Colab output cell due to iframe sandboxing or JS conflicts. The 3D visualizer (Tensorleap's `leap_ec`) will launch in a separate browser tab/window (or via an ngrok tunnel if set up).

In [16]:
import holoviews as hv
import panel as pn
pn.extension(comms='colab' if COLAB_ENV else None) # Specify HoloViews this is Colab or local
hv.extension('bokeh')

# Configure visualization parameters
h.config["visualize"]["display_images"] = True
h.config["visualize"]["fields"] = ["ra", "dec"]
h.config["visualize"]["three_d"] = False # For 2D visualization

viz = h.visualize(width=700, height=700)
display(viz)

# To get the selected data points (if you made a selection in the plot):
# selected_df = viz.get_selected_df()
# if selected_df is not None:
#     print(selected_df.head())

RuntimeError: Could not find a results directory. Run infer or use [results] inference_dir config to specify a directory.

#### 3D Visualization (Launches in Browser)
This will launch an interactive 3D scatter plot in your browser. You can explore it, select points, and color by different features.

In [None]:
# Configure for 3D visualization
h.config["visualize"]["three_d"] = True
h.config["visualize"]["color_features"] = ["x", "y", "z", "ra", "dec", "g_cmodel_mag", "r_cmodel_mag", "i_cmodel_mag", "photz_best"]
h.config["visualize"]["default_color_feature"] = "g_cmodel_mag"
h.config["visualize"]["image_path_column"] = "filename" # if you have image paths
h.config["visualize"]["local_data_path"] = Path(h.config['general']['data_dir']).parent # Parent of hsc_basec_1000

if COLAB_ENV:
    # For Colab, we need to use ngrok to expose the port
    try:
        public_url_3d = ngrok.connect(addr="8080", proto="http").public_url
        print(f"Hyrax 3D Explorer (Colab) will be available at: {public_url_3d}")
        h.config["visualize"]["public_url"] = public_url_3d
    except Exception as e:
        print(f"Could not start ngrok for 3D visualizer: {e}")
else:
     h.config["visualize"]["public_url"] = "http://localhost:8080"

h.visualize() # This will print instructions to open in browser

## Creating a Vector Database

By calling the `index` verb, we can populate a vector database with the results of inference. This vector database can be used for efficient similarity or nearest neighbor searches.

In [None]:
h.index()

## Performing a Similarity Search

To search for objects that are similar to a given object, we can exploit the efficiency of the vector database that was just created. Here, we create an instance of a ChromaDB object that connects to the database that was just created.

In [None]:
from hyrax.vector_dbs import ChromaDB
from hyrax.config_utils import find_most_recent_results_dir

context = {"results_directory": find_most_recent_results_dir(h.config, "index")}
#context = find_most_recent_results_dir(h.config, "index")
db = ChromaDB(context=context)

With an object id, we can search for k nearest neighbors. In this case, we randomly select an object id, and request the 5 nearest neighbors.

Note: Because the randomly selected object id is in the database the closest "neighbor" is the object itself. Thus the nearest neighbor that is not the original object, is the second element in the returned list.

Note: The ids of the neighbors are returned in order of increasing distance.

In [None]:
#({'ids': [['36412406317975358', '38562363067159808', '39913401664672536', '37489369367447921', '36416396342612022']], 'distances': [[0.0, 100.23, 105.34, 106.78, 108.91]], 'metadatas': None, 'embeddings': None, 'documents': None, 'uris': None, 'data': None},)

In [None]:
# Example object ID from the HSC dataset (you can pick any ID from your inference results)
search_object_id = "36412406317975358"
search_results = db.search_by_id(search_object_id, k=5)
print(search_results)

## Display the Objects

Let's check that the nearest neighbor seems reasonable. We'll plot the original object, and then the nearest neighbors for visual comparison.

To plot the images we'll need to get some information from our dataset.

Using the `prepare` verb to return an instance of the dataset, we can get the list of object ids.

In [None]:
hsc_dataset = h.prepare()
all_ids = list(hsc_dataset.ids())

We can use the ids from the hsc_dataset object to get the dataset index of the object ids returned from the vector database query, and plot those.

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

# search_results was defined in a previous cell as:
# search_object_id = "36412406317975358" # Or any other valid ID from your dataset
# search_results = db.search_by_id(search_object_id, k=5)

# hsc_dataset and all_ids were also defined previously:
# hsc_dataset = h.prepare()
# all_ids = list(hsc_dataset.ids())

fig, axes = plt.subplots(3, 5, figsize=(25, 5*3)) # 3 rows for bands, 5 columns for original + 4 neighbors

for ni, r_id in enumerate(search_results[0]['ids'][0]): # search_results[0]['ids'][0] contains the list of object IDs
    try:
        indx = all_ids.index(r_id) # Get the index of the object in the hsc_dataset
    except ValueError:
        # Handle case where object ID is not found in the dataset
        print(f"Warning: Object ID {r_id} not found in hsc_dataset.ids(). Skipping.")
        # Fill the plot with a placeholder or leave blank if an ID is not found
        for i in range(3):
            axes[i, ni].text(0.5, 0.5, 'ID Not Found', horizontalalignment='center', verticalalignment='center', transform=axes[i, ni].transAxes)
            axes[i, ni].set_xticks([])
            axes[i, ni].set_yticks([])
        if ni == 0:
            axes[0, ni].set_title(f"Original search object\\nID: {search_object_id} (Not Found)")
        else:
            axes[0, ni].set_title(f"Neighbor {ni}\\nID: {r_id} (Not Found)")
        continue

    data = hsc_dataset[indx].numpy() # Retrieve the data for the object - a [3, 96, 96] numpy array

    # Normalize the data for display (per image, not per band for this visualization)
    # This might need adjustment based on how you want to display the image.
    # The demo appeared to normalize per image to make features visible across bands.
    min_val = np.min(data)
    max_val = np.max(data)
    if max_val > min_val: # Avoid division by zero for blank images
        data = (data - min_val) / (max_val - min_val)
    else:
        data = np.zeros_like(data) # Or handle as appropriate

    for i in range(3): # For each band (assuming 3 bands: G, R, I)
        ax = axes[i, ni]
        im = ax.imshow(data[i], origin='lower', norm=LogNorm(), cmap='Greys')
        ax.set_xticks([])
        ax.set_yticks([])

        if ni == 0: # First column is the original search object
            ax.set_title(f"Original search object\\nID: {search_object_id}, Indx: {indx}")
        else:
            ax.set_title(f"Neighbor {ni}\\nID: {r_id}, Indx: {indx}")

        if ni == 0: # Add Y-label for bands only to the first column
            ax.set_ylabel(f"Band {i+1}") # Assuming bands are 0, 1, 2 -> G, R, I

plt.tight_layout()
plt.show()

Explanation of the code and assumptions based on the demo:
search_results: This variable is assumed to have been populated in the preceding cell from db.search_by_id(search_object_id, k=5). The structure search_results[0]['ids'][0] is used to access the list of neighbor IDs.
hsc_dataset and all_ids: These are also assumed to be from previous cells where h.prepare() was called and all_ids was created.
plt.subplots(3, 5, figsize=(25, 5*3)): This creates a figure with 3 rows (one for each band G, R, I) and 5 columns (one for the original object and 4 for its nearest neighbors). The figsize was mentioned as (25,5) initially in the demo then corrected to something implying 3 rows, so 5*3 for height makes sense.
Looping through Neighbors: The code iterates through the IDs returned by the similarity search.
all_ids.index(r_id): This finds the index of the current neighbor ID within the hsc_dataset. A try-except block is added for robustness in case an ID from search_results isn't in all_ids (though this shouldn't happen if search_results comes from the same indexed data).
hsc_dataset[indx].numpy(): Retrieves the image data (likely a 3-channel NumPy array of shape (3, height, width)) for the object at the found index.
Normalization: The image data is normalized to the range [0, 1] for better visualization. The demo seemed to normalize the entire 3-channel image together rather than per-band, which is reflected here.
imshow: Each band of the image is displayed using imshow with LogNorm for scaling and a 'Greys' colormap, as seen in the demo.
Titles and Labels: Titles are set for each subplot to indicate if it's the original object or a neighbor, along with its ID and dataset index. Y-axis labels indicate the band number.
plt.tight_layout() and plt.show(): These are standard matplotlib commands for tidying up the plot and displaying it.

In [9]:
import os

# List the contents of the directory where the zip file was downloaded and extracted
data_dir_path = "./data"
print(f"Contents of {data_dir_path}:")
if os.path.exists(data_dir_path):
    for item in os.listdir(data_dir_path):
        print(f"- {item}")
else:
    print(f"The directory {data_dir_path} does not exist.")

# Also check the content of the expected extraction directory
extracted_dir_path = "./data/example_hsc_new"
print(f"\nContents of {extracted_dir_path}:")
if os.path.exists(extracted_dir_path):
    for item in os.listdir(extracted_dir_path):
        print(f"- {item}")
else:
    print(f"The directory {extracted_dir_path} does not exist.")

Contents of ./data:
- example_hsc_new.zip
- hsc_8asec_1000

Contents of ./data/example_hsc_new:
The directory ./data/example_hsc_new does not exist.
