This notebook aims to show how to visualize rotation invariance for the rotation invariant point cloud models. We illustrate this with the cellPACK synthetic dataset. First, lets import some necessary libraries

In [None]:
%load_ext autoreload
%autoreload 2
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = INSERT_YOUR_DEVICE
import matplotlib.pyplot as plt
import numpy as np
import torch
import yaml
from hydra.utils import instantiate
from br.features.rotation_invariance import rotation_image_batch_z, rotation_pc_batch_z
from br.models.load_models import load_model_from_path

device = "cuda:0"

Set the CYTODL_CONFIG_PATH environment variable as described in the documentation associated with this repo

In [None]:
# Set paths
CYTODL_CONFIG_PATH = insert_your_path
os.chdir(CYTODL_CONFIG_PATH)
save_path = insert_save_path
results_path = CYTODL_CONFIG_PATH + '/results/'

Load a batch of the cellPACK point cloud dataset

In [None]:
# Load data yaml and test batch
cellpack_data = "./configs/data/cellpack/pc.yaml"
with open(cellpack_data) as stream:
    cellpack_data = yaml.safe_load(stream)
data = instantiate(cellpack_data)
batch = next(iter(data.test_dataloader()))

Save examples of raw data for each of the 6 packing rules

In [None]:
from pathlib import Path

this_save_path = Path(save_path) / Path("panel_a")
this_save_path.mkdir(parents=True, exist_ok=True)

all_arr = []
for i in range(6):
    np_arr = batch["pcloud"][i].numpy()
    new_array = np.zeros(np_arr.shape)
    z = np_arr[:, 0]
    new_array[:, 0] = np_arr[:, 2]
    new_array[:, 1] = z
    new_array[:, 2] = np_arr[:, 1]
    all_arr.append(new_array)
    np.save(this_save_path / Path(f"{i}.npy"), new_array)

Make utility functions for computing and saving reconstructions across four 90 degree rotations 

In [None]:
# utility function for plotting
def plot_pc(this_p, axes, max_size, color="gray", x_ind=2, y_ind=1):
    axes.scatter(this_p[:, x_ind], this_p[:, y_ind], c=color, s=1)
    axes.spines["top"].set_visible(False)
    axes.spines["right"].set_visible(False)
    axes.spines["bottom"].set_visible(False)
    axes.spines["left"].set_visible(False)
    axes.set_aspect("equal", adjustable="box")
    axes.set_ylim([-max_size, max_size])
    axes.set_xlim([-max_size, max_size])
    axes.set_yticks([])
    axes.set_xticks([])

def save_recons_across_rotations(save_path, this_name, this_key, model, batch):
    this_save_path = Path(save_path) / Path(f"Recons_{this_name}")
    this_save_path.mkdir(parents=True, exist_ok=True)

    max_z = {0: 20, 1: 20, 2: 20, 3: 1, 4: 20, 5: 20}
    max_size = 10

    all_thetas = [
        0,
        1 * 90,
        2 * 90,
        3 * 90,
    ]


    all_xhat = []
    all_canon = []
    all_input = []
    with torch.no_grad():
        for jl, theta in enumerate(all_thetas):
            this_input_rot = rotation_pc_batch_z(
                batch,
                theta,
            )
            batch_input = {this_key: torch.tensor(this_input_rot).to(device).float()}
            z, z_params = model.get_embeddings(batch_input, inference=True)
            xhat = model.decode_embeddings(z_params, batch_input, decode=True, return_canonical=True)
            all_input.append(this_input_rot)
            for ind in range(6):
                this_p = this_input_rot[ind]
                this_max_z = max_z[ind]
                this_p = this_p[np.where(this_p[:, 0] < this_max_z)[0]]
                this_p = this_p[np.where(this_p[:, 0] > -this_max_z)[0]]
                fig, ax = plt.subplots(1, 1, figsize=(5, 5))
                plot_pc(this_p, ax, max_size, "black")
                fig.savefig(this_save_path / f"input_{ind}_theta_{theta}.png")

            if ("canonical" in xhat.keys()) and ("Rotation" in this_name):
                this_canon = xhat["canonical"].detach().cpu().numpy()
                all_canon.append(this_canon)
                for ind in range(6):
                    this_p = this_canon[ind]
                    this_max_z = max_z[ind]
                    this_p = this_p[np.where(this_p[:, 1] < this_max_z)[0]]
                    this_p = this_p[np.where(this_p[:, 1] > -this_max_z)[0]]
                    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
                    plot_pc(this_p, ax, max_size, "black", x_ind=0, y_ind=2)
                    fig.savefig(this_save_path / f"canon_{ind}_theta_{theta}.png")
            this_recon = xhat[this_key].detach().cpu().numpy()
            all_xhat.append(this_recon)
            for ind in range(6):
                this_p = this_recon[ind]
                this_max_z = max_z[ind]
                this_p = this_p[np.where(this_p[:, 0] < this_max_z)[0]]
                this_p = this_p[np.where(this_p[:, 0] > -this_max_z)[0]]
                fig, ax = plt.subplots(1, 1, figsize=(5, 5))
                plot_pc(this_p, ax, max_size, "black")
                fig.savefig(this_save_path / f"recon_{ind}_theta_{theta}.png")

Load the trained models

In [None]:
models, names, sizes, model_manifest, x_labels, latent_dims = load_model_from_path("cellpack", results_path)

Move batch elements to gpu

In [None]:
for key in batch.keys():
    if key not in [
        "split",
        "bf_meta_dict",
        "egfp_meta_dict",
        "filenames",
        "image_meta_dict",
        "cell_id",
    ]:
        if not isinstance(batch[key], list):
            batch[key] = batch[key].to(device)

Save reconstructions across rotations for the classical point cloud model

In [None]:
# Lets load the classical point cloud model
ind = names.index('Classical_pointcloud')
model = models[ind]
this_name = names[ind]
this_key = x_labels[ind]
save_recons_across_rotations(save_path, this_name, this_key, model, batch)

Save reconstructions across rotations for the rotation invariant point cloud model

In [None]:
# Lets load the rotation invariant point cloud model
ind = names.index('Rotation_invariant_pointcloud')
model = models[ind]
this_name = names[ind]
this_key = x_labels[ind]
save_recons_across_rotations(save_path, this_name, this_key, model, batch)