This notebook aims to show how to visualize rotation invariance for the rotation invariant image models. We illustrate this with the cellPACK synthetic dataset. First, lets import some necessary libraries

In [None]:
%load_ext autoreload
%autoreload 2
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
# os.environ["CUDA_VISIBLE_DEVICES"] = INSERT_YOUR_DEVICE
os.environ["CUDA_VISIBLE_DEVICES"] = "MIG-5dae7a03-dc20-5083-8df1-b583c1a5e8b8"
import matplotlib.pyplot as plt
import numpy as np
import torch
import yaml
from hydra.utils import instantiate
from br.features.rotation_invariance import rotation_image_batch_z, rotation_pc_batch_z
from br.models.load_models import load_model_from_path

device = "cuda:0"

Set the CYTODL_CONFIG_PATH environment variable as described in the documentation associated with this repo

In [None]:
# Set paths
# CYTODL_CONFIG_PATH = insert_your_path
CYTODL_CONFIG_PATH = "/allen/aics/modeling/ritvik/projects/latest_clones/benchmarking_representations/configs/"
os.chdir(CYTODL_CONFIG_PATH)
os.chdir("..")
# save_path = insert_save_path
save_path = './test_cellpack_recons/'
results_path = CYTODL_CONFIG_PATH + '/results/'

Load a batch of the cellPACK image dataset

In [None]:
# Load data yaml and test batch
cellpack_data = CYTODL_CONFIG_PATH + "data/cellpack/image.yaml"
with open(cellpack_data) as stream:
    cellpack_data = yaml.safe_load(stream)
data = instantiate(cellpack_data)

batches = []
for batch in data.train_dataloader():
    if batch['orig_CellId'][0] == '7abfecf1-44db-468a-b799-4959a23cfb0d':
        batches.append(batch)
batches = batches[:3]

Save examples of raw data for each of the 6 packing rules

In [None]:
from pathlib import Path

this_save_path = Path(save_path) / Path("panel_a")
this_save_path.mkdir(parents=True, exist_ok=True)

all_arr = []
for batch in batches:
    for i in range(2):
        np_arr = batch["pcloud"][i].numpy().squeeze()
        np.save(this_save_path / Path(f"image_{i}.npy"), np_arr)

Make utility functions for computing and saving reconstructions across four 90 degree rotations 

In [None]:
# utility function for plotting
def move(batch):
    for key in batch.keys():
        if key not in [
            "split",
            "bf_meta_dict",
            "egfp_meta_dict",
            "filenames",
            "image_meta_dict",
            "cell_id",
        ]:
            if not isinstance(batch[key], list):
                batch[key] = batch[key].to(device)
    return batch

def plot_image(this_image, axes, z_ind=2, transpose=False):
    max_size = 118
    slices = np.s_[:, :, 59-20:59+20]
    slices = np.s_[:, :, :]
    if transpose:
        axes.imshow(this_image[slices].max(z_ind).T, cmap='gray_r')
    else:
        axes.imshow(this_image[slices].max(z_ind), cmap='gray_r')
    axes.spines['top'].set_visible(False)
    axes.spines['right'].set_visible(False)
    axes.spines['bottom'].set_visible(False)
    axes.spines['left'].set_visible(False)
    axes.set_aspect('equal', adjustable='box')
    axes.set_ylim([0,max_size])
    axes.set_xlim([0,max_size])
    axes.set_yticks([])
    axes.set_xticks([])

def save_recons_across_rotations(save_path, this_name, this_key, model, batch, batch_ind):
    this_save_path = Path(save_path) / Path(f"Recons_{this_name}")
    this_save_path.mkdir(parents=True, exist_ok=True)

    all_thetas = [
        0,
        1 * 90,
        2 * 90,
        3 * 90,
    ]


    all_xhat = []
    all_canon = []
    all_input = []
    with torch.no_grad():
        for jl, theta in enumerate(all_thetas):
            this_input_rot = rotation_image_batch_z(
                batch, theta, False
            )
            batch_input = {this_key: torch.tensor(this_input_rot).to(device).float()}
            z_params = model.encode(batch_input)
            z = model.sample_z(z_params, inference=True)
            xhat = model.decode(z, return_canonical=True)

            from cyto_dl.image.transforms import RotationMask
            mask = RotationMask(
                                'so3',
                                3,
                                batch[this_key].shape[-1],
                                background=0,
                            )
            xhat['canonical'] = mask(xhat['canonical'])

            all_input.append(this_input_rot)
            for ind in range(2):
                rule = 2*batch_ind + ind
                this_image = this_input_rot[ind].squeeze()
                fig, ax = plt.subplots(1, 1, figsize=(5, 5))
                plot_image(this_image, ax)
                ax.set_title(f'Input rule {rule} theta {theta}')
                fig.savefig(this_save_path / f"input_{rule}_theta_{theta}.png")

            if ("canonical" in xhat.keys()) and ("Rotation" in this_name):
                this_canon = xhat["canonical"].detach().cpu().numpy()
                all_canon.append(this_canon)
                for ind in range(2):
                    rule = 2*batch_ind + ind
                    this_image = this_canon[ind].squeeze()
                    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
                    plot_image(this_image, ax)
                    ax.set_title(f'Canonical rule {rule} theta {theta}')
                    fig.savefig(this_save_path / f"canon_{rule}_theta_{theta}.png")
            if isinstance(xhat[this_key], torch.Tensor):
                this_recon = xhat[this_key].detach().cpu().numpy()
            else:
                this_recon = xhat[this_key]
            all_xhat.append(this_recon)
            for ind in range(2):
                rule = 2*batch_ind + ind
                this_image = this_recon[ind].squeeze()
                fig, ax = plt.subplots(1, 1, figsize=(5, 5))
                plot_image(this_image, ax, 2, False)
                ax.set_title(f'Recon. rule {rule} theta {theta}')
                fig.savefig(this_save_path / f"recon_{rule}_theta_{theta}.png")

Load the trained models

In [None]:
models, names, sizes, model_manifest, x_labels, latent_dims = load_model_from_path("cellpack", results_path)

Save reconstructions across rotations for the classical image model

In [None]:
# Lets load the classical point cloud model
ind = names.index('Classical_image')
model = models[ind]
this_name = names[ind]
this_key = x_labels[ind]
for ind, batch in enumerate(batches):
    save_recons_across_rotations(save_path, this_name, this_key, model, batch, ind)

Save reconstructions across rotations for the rotation invariant image model

In [None]:
# Lets load the classical point cloud model
ind = names.index('Rotation_invariant_image')
model = models[ind]
this_name = names[ind]
this_key = x_labels[ind]
for ind, batch in enumerate(batches):
    save_recons_across_rotations(save_path, this_name, this_key, model, batch, ind)