In [None]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Model Garden - Recursion MAE Image Feature Extraction local inference


<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fvertex-ai-samples%2Fmain%2Fnotebooks%2Fcommunity%2Fmodel_garden%2Fmodel_garden_recursion_mae_local_inference.ipynb">
      <img alt="Google Cloud Colab Enterprise logo" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" width="32px"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_recursion_mae_local_inference.ipynb">
      <img alt="GitHub logo" src="https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png" width="32px"><br> View on GitHub
    </a>
  </td>
</tr></tbody></table>

## Overview

This notebook demonstrates how to install the necessary libraries and run local inference with the Recursion MAE model in a [Colab Enterprise Instance](https://cloud.google.com/colab/docs) for Image Feature Extraction.


### OpenPhenom Model Licensing

* OpenPhenom Model is available under a Non-Commercial End User License Agreement license. For full details, please refer to the [license documentation](https://huggingface.co/recursionpharma/OpenPhenom/blob/main/LICENSE) that governs the use of this model.


### Objective

* Run local inference with the Recursion MAE model for image feature extraction.


### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing) and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.

## Install dependencies

Before you begin, make sure you are connecting to a [Colab Enterprise runtime](https://cloud.google.com/colab/docs/connect-to-runtime) with CPU. GPU is not required for this model. If not, we recommend [creating a runtime template](https://cloud.google.com/colab/docs/create-runtime-template) with the `defalut` template (with a machine type of `e2-standard-4`).

In [None]:
! pip3 install --upgrade pip
! pip3 install huggingface-hub==0.25.2
! pip3 install timm==1.0.11
! pip3 install torch~=2.4.0
! pip3 install torchmetrics~=1.5.1
! pip3 install torchvision~=0.19.0
! pip3 install tqdm~=4.66.6
! pip3 install transformers~=4.46.1
! pip3 install pandas~=2.2.3
! pip3 install zarr~=2.18.3
! pip3 install hydra-core~=1.3.2
! pip3 install pytorch-lightning~=2.1
! pip3 install isort~=5.13.2
! pip3 install ruff~=0.7.2
! apt-get update

In [None]:
# @title Git clone the official Recursion MAE repo

! git clone https://github.com/recursionpharma/maes_microscopy.git
%cd /content/maes_microscopy
# Pin the repo to the commit on 2024-11-04
! git reset --hard 42cfc25290f0a09f6db2a27fdf238d064f5c0760

In [None]:
# @title Utility functions

from typing import Iterator, Tuple

import numpy as np
import pandas as pd
from google.cloud import storage


def download_gcs_dir_to_local(gcs_dir_path: str, local_dir_path: str):
    """Downloads files in a GCS directory to a local directory."""
    assert gcs_dir_path.startswith("gs://"), "gcs_dir_path must start with `gs://`."
    bucket_name = gcs_dir_path.split("/")[2]
    prefix = gcs_dir_path[len("gs://" + bucket_name) :].strip("/") + "/"
    client = storage.Client()
    blobs = client.list_blobs(bucket_name, prefix=prefix)
    for blob in blobs:
        if blob.name[-1] == "/":
            continue
        file_path = blob.name[len(prefix) :].strip("/")
        local_file_path = os.path.join(local_dir_path, file_path)
        os.makedirs(os.path.dirname(local_file_path), exist_ok=True)

        print(f"Downloading {file_path} to {local_file_path}")
        blob.download_to_filename(local_file_path)


def iter_border_patches(
    width: int, height: int, patch_size: int, border_trim_size: int = 0
) -> Iterator[Tuple[int, int]]:
    """Generates (x, y) coordinates of patches along the borders of an image.

    This function iterates over the patches along the outer edge of an image,
    excluding the corners. It's useful for prioritizing inference on the borders
    where objects are often partially occluded.

    Args:
      width: Width of the image in pixels.
      height: Height of the image in pixels.
      patch_size: Size of each patch in pixels (assumed to be square).
      border_trim_size: Number of pixels to trim from each border. This is used to
        skip patches at the very edge of the image, potentially improving
        inference time. Must be divisible by half the patch size.

    Yields:
      Tuples of (x, y) coordinates representing the top-left corner of each patch.
    """
    if border_trim_size % (patch_size / 2) != 0:
        raise ValueError("Border trim size has to be divisible by half the patch size")
    x_start, x_end, y_start, y_end = (
        border_trim_size,
        width - border_trim_size,
        border_trim_size,
        height - border_trim_size,
    )

    for x in range(x_start, x_end - patch_size + 1, patch_size):
        for y in range(y_start, y_end - patch_size + 1, patch_size):
            yield x, y


def patch_image(
    image_array: np.ndarray, patch_size: int = 256, border_trim_size: int = 0
) -> Tuple[np.ndarray, pd.DataFrame]:
    """Applied to each sample in the dataset where the image has been loaded."""
    width, height, _ = image_array.shape
    output_rows = []
    output_patches = []
    patch_count = 0
    for x, y in iter_border_patches(
        width, height, patch_size, border_trim_size=border_trim_size
    ):
        r = dict()
        patch = image_array[y : y + patch_size, x : x + patch_size, :].copy()
        patch_count += 1
        r["patch_top_left_y"] = y
        r["patch_top_left_x"] = x
        r["patch_width"] = patch_size
        r["patch_height"] = patch_size
        output_rows.append(r)
        output_patches.append(patch)

    output_patches = np.stack(output_patches, axis=0)
    output_rows = pd.DataFrame(output_rows)

    return output_patches, output_rows

# Image Embedding Generation with MAE Model

The code generates embeddings for multi-channel microscopy images. It uses a Masked Autoencoder (MAE) model developed by Recursion to process images, dividing them into 256x256 patches and generating an embedding vector for each patch.

**Requirements:**

* **Input:** Multiple image files (different channels) with dimensions divisible by 256.
* **GCS Paths:**  The demostration uses images from the `SAMPLE_IMAGES_GCS_PATH` GCS bucket.
    The model weights can also be pulled from a GCS bucket, or directly pull from Huggingface.

**Output:**

A Pandas DataFrame containing embeddings and metadata (e.g., patch coordinates) for each patch.


In [None]:
import glob
import os
import time

import torch
from huggingface_hub import snapshot_download
from huggingface_mae import MAEModel
from PIL import Image

MAX_IMAGE_CHANNELS = 11
LOCAL_MODEL_DIR = "/content/model"
LOCAL_IMAGES_DIR = "/content/images"

# Download the model weights to a local dir
MODEL_ID = "recursionpharma/OpenPhenom"
print(f"Downloading model from {MODEL_ID} to {LOCAL_MODEL_DIR}")

if MODEL_ID.startswith("gs://"):
    download_gcs_dir_to_local(MODEL_ID, LOCAL_MODEL_DIR)
    MODEL_ID = LOCAL_MODEL_DIR
else:
    snapshot_download(
        repo_id=MODEL_ID,
        local_dir=LOCAL_MODEL_DIR,
    )
    MODEL_ID = LOCAL_MODEL_DIR

model = MAEModel.from_pretrained(MODEL_ID)
model = model.eval().cpu()

SAMPLE_IMAGES_GCS_PATH = (
    "gs://cloud-samples-data/vertex-ai/model-garden/recursion-mae-images"
)
print(f"Downloading images from {SAMPLE_IMAGES_GCS_PATH} to {LOCAL_IMAGES_DIR}")
download_gcs_dir_to_local(SAMPLE_IMAGES_GCS_PATH, LOCAL_IMAGES_DIR)

# Load and Validate Images
image_paths = sorted(glob.glob(f"{LOCAL_IMAGES_DIR}/*"))
cur_images = []
for img_path in image_paths:
    img_channel = Image.open(img_path)

    # Check if image dimensions are divisible by 256 (MAE requirement)
    width, height = img_channel.size
    if width % 256 != 0 or height % 256 != 0:
        print(f"Image [{img_path}] dimensions are not divisible by 256.")
        continue
    cur_images.append(np.array(img_channel))

if not cur_images:
    raise ValueError("No valid images found in the input GCS folder.")
if len(cur_images) > MAX_IMAGE_CHANNELS:
    raise ValueError(
        f"Too many channel images found in the input GCS folder. Max allowed is {MAX_IMAGE_CHANNELS}."
    )


# Run Model Inference
embeddings = []
metadata_df = []
inference_start = time.time()

# Divide the input images into 256x256 patches (MAE requirement)
stacked_images = np.dstack(cur_images)
stacked_patches, metadata = patch_image(stacked_images)

# Convert to PyTorch tensor and permute dimensions
torch_im = torch.from_numpy(stacked_patches).permute(0, 3, 1, 2)

# Perform inference with autocast for mixed precision and no_grad to save memory
with torch.amp.autocast("cuda"), torch.no_grad():
    embedding = model.predict(torch_im)

embeddings.append(embedding)
metadata_df.append(metadata)

# Process Embeddings and Metadata
embedding_df = pd.DataFrame(np.concatenate(embeddings, axis=0))
embedding_df.columns = [f"feature_{i}" for i in range(embedding_df.shape[1])]
metadata_df = pd.concat(metadata_df).reset_index(drop=True)

# Sanity check to ensure embeddings and metadata have the same length
assert len(metadata_df) == len(embedding_df), (
    f"Embedding and metadata don't match; got {len(embedding_df)} embs and"
    f" {len(metadata_df)} metadata"
)

print("Prediction was made in %.2f seconds." % (time.time() - inference_start))

# Combine Metadata and Embeddings
final_df = pd.concat([metadata_df, embedding_df], axis=1)
print(final_df)