<a href="https://colab.research.google.com/github/alim98/MPI/blob/main/MPI_Vit_f.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Essential downloads

In [1]:
!wget -O downloaded_file.zip "https://drive.usercontent.google.com/download?id=1iHPBdBOPEagvPTHZmrN__LD49emXwReY&export=download&authuser=0&confirm=t&uuid=631d60dd-569c-4bb1-a9e8-d681f0ed3d43&at=APvzH3r4me8x_LwP3n8O7lgPo8oK%3A1733988188000"

!unzip -q downloaded_file.zip

--2025-01-04 09:56:24--  https://drive.usercontent.google.com/download?id=1iHPBdBOPEagvPTHZmrN__LD49emXwReY&export=download&authuser=0&confirm=t&uuid=631d60dd-569c-4bb1-a9e8-d681f0ed3d43&at=APvzH3r4me8x_LwP3n8O7lgPo8oK%3A1733988188000
Resolving drive.usercontent.google.com (drive.usercontent.google.com)... 142.251.12.132, 2404:6800:4003:c00::84
Connecting to drive.usercontent.google.com (drive.usercontent.google.com)|142.251.12.132|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1264688649 (1.2G) [application/octet-stream]
Saving to: ‘downloaded_file.zip’


2025-01-04 09:56:53 (46.3 MB/s) - ‘downloaded_file.zip’ saved [1264688649/1264688649]



In [2]:

!pip install transformers scikit-learn matplotlib seaborn torch torchvision umap-learn git+https://github.com/funkelab/funlib.learn.torch.git
!pip install openpyxl


Collecting git+https://github.com/funkelab/funlib.learn.torch.git
  Cloning https://github.com/funkelab/funlib.learn.torch.git to /tmp/pip-req-build-1cjuoxa2
  Running command git clone --filter=blob:none --quiet https://github.com/funkelab/funlib.learn.torch.git /tmp/pip-req-build-1cjuoxa2
  Resolved https://github.com/funkelab/funlib.learn.torch.git to commit 049729151c7a2c0320a446dc9d3244ac830f7ea8
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting umap-learn
  Downloading umap_learn-0.5.7-py3-none-any.whl.metadata (21 kB)
Collecting pynndescent>=0.5 (from umap-learn)
  Downloading pynndescent-0.5.13-py3-none-any.whl.metadata (6.8 kB)
Downloading umap_learn-0.5.7-py3-none-any.whl (88 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.8/88.8 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pynndescent-0.5.13-py3-none-any.w

In [3]:

import os
import glob
import imageio.v2 as iio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch, Rectangle
from torch.utils.data import Dataset, DataLoader
from transformers import ViTImageProcessor, ViTModel
import torch
import torch.nn as nn
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import seaborn as sns
from umap import UMAP
import torch.nn.functional as F

# Vit video

In [27]:

raw_base_dir = '/content/raw'
seg_base_dir = '/content/seg'
bbox_names = [f'bbox{i}' for i in range(1,8)]
os.makedirs('csv_outputs', exist_ok=True)

def load_bbox_data(bbox_name, max_slices=None):
    """
    Load raw and segmentation volumes for a bounding box name.
    Returns raw_vol, seg_vol of shape (Z, Y, X).
    """
    raw_dir = os.path.join(raw_base_dir, bbox_name)
    seg_dir = os.path.join(seg_base_dir, bbox_name)

    raw_tif_files = sorted(glob.glob(os.path.join(raw_dir, 'slice_*.tif')))
    seg_tif_files = sorted(glob.glob(os.path.join(seg_dir, 'slice_*.tif')))

    if max_slices is not None:
        raw_tif_files = raw_tif_files[:max_slices]
        seg_tif_files = seg_tif_files[:max_slices]

    assert len(raw_tif_files) == len(seg_tif_files), f"Raw/Seg mismatch in {bbox_name}"

    raw_slices = [iio.imread(f) for f in raw_tif_files]
    seg_slices = [iio.imread(f).astype(np.uint32) for f in seg_tif_files]

    raw_vol = np.stack(raw_slices, axis=0)  # shape: (Z, Y, X)
    seg_vol = np.stack(seg_slices, axis=0)  # shape: (Z, Y, X)
    return raw_vol, seg_vol

def create_segment_masks(seg_vol, side1_coord, side2_coord):
    """
    Creates boolean masks for side_1 and side_2 coords in the segmentation volume.
    """
    x1, y1, z1 = [int(c) for c in side1_coord]
    x2, y2, z2 = [int(c) for c in side2_coord]

    seg_id_1 = seg_vol[z1, y1, x1]
    seg_id_2 = seg_vol[z2, y2, x2]

    mask_1 = (seg_vol == seg_id_1) if seg_id_1 != 0 else np.zeros_like(seg_vol, dtype=bool)
    mask_2 = (seg_vol == seg_id_2) if seg_id_2 != 0 else np.zeros_like(seg_vol, dtype=bool)
    return mask_1, mask_2

class SynapseDataset(Dataset):
    """
    Loads entries from a DataFrame that has columns:
      - bbox_index: Which bounding box volume to use
      - central_coord_(1,2,3)
      - side_1_coord_(1,2,3), side_2_coord_(1,2,3)
    Extracts an 80^3 sub-volume around 'central_coord' in raw_vol.
    Then, for each z-slice in that sub-volume, we create 3 channels:
      Channel 0: side1 mask
      Channel 1: raw intensity
      Channel 2: side2 mask
    """

    def __init__(self, vol_data_list, synapse_df, subvol_size=80):
        self.vol_data_list = vol_data_list
        self.synapse_df = synapse_df.reset_index(drop=True)
        self.subvol_size = subvol_size
        self.half_size = subvol_size // 2

    def __len__(self):
        return len(self.synapse_df)

    def __getitem__(self, idx):
        syn_info = self.synapse_df.iloc[idx]
        bbox_index = syn_info['bbox_index']  # which volume to use
        raw_vol, seg_vol = self.vol_data_list[bbox_index]

        # Coordinates
        central_coord = (
            int(syn_info['central_coord_1']),
            int(syn_info['central_coord_2']),
            int(syn_info['central_coord_3'])
        )
        side1_coord = (
            int(syn_info['side_1_coord_1']),
            int(syn_info['side_1_coord_2']),
            int(syn_info['side_1_coord_3'])
        )
        side2_coord = (
            int(syn_info['side_2_coord_1']),
            int(syn_info['side_2_coord_2']),
            int(syn_info['side_2_coord_3'])
        )

        # Create the side-1 and side-2 masks
        mask_1_full, mask_2_full = create_segment_masks(seg_vol, side1_coord, side2_coord)
        cx, cy, cz = central_coord
        x_start = max(cx - self.half_size, 0)
        x_end   = min(cx + self.half_size, raw_vol.shape[2])
        y_start = max(cy - self.half_size, 0)
        y_end   = min(cy + self.half_size, raw_vol.shape[1])
        z_start = max(cz - self.half_size, 0)
        z_end   = min(cz + self.half_size, raw_vol.shape[0])

        sub_raw = raw_vol[z_start:z_end, y_start:y_end, x_start:x_end]
        sub_mask_1 = mask_1_full[z_start:z_end, y_start:y_end, x_start:x_end]
        sub_mask_2 = mask_2_full[z_start:z_end, y_start:y_end, x_start:x_end]

        desired_shape = (self.subvol_size, self.subvol_size, self.subvol_size)
        dz, dy, dx = sub_raw.shape

        padded_sub_raw = np.zeros(desired_shape, dtype=sub_raw.dtype)
        padded_sub_mask1 = np.zeros(desired_shape, dtype=np.uint8)  # or bool
        padded_sub_mask2 = np.zeros(desired_shape, dtype=np.uint8)  # or bool

        padded_sub_raw[:dz, :dy, :dx] = sub_raw
        padded_sub_mask1[:dz, :dy, :dx] = sub_mask_1
        padded_sub_mask2[:dz, :dy, :dx] = sub_mask_2

        # Now each is shape (80, 80, 80)
        # We want to form (Z, 3, Y, X), i.e. for each z-slice we have 3 channels:
        #   channel 0 => side1 mask
        #   channel 1 => raw intensities
        #   channel 2 => side2 mask

        # Stack them along axis=1 (the "channel" axis).
        # Result: shape (80, 3, 80, 80)
        sub_3d = np.stack([padded_sub_mask1, padded_sub_raw, padded_sub_mask2], axis=1)

        # Convert to Torch tensor
        sub_3d_tensor = torch.from_numpy(sub_3d).float()

        # Convert syn_info row to dict
        syn_info_dict = syn_info.to_dict()

        return sub_3d_tensor, syn_info_dict


In [35]:
import torch
import numpy as np
from transformers import VivitModel, VivitImageProcessor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_vivit = VivitModel.from_pretrained(
    "google/vivit-b-16x2-kinetics400",
    attn_implementation="sdpa",
    # torch_dtype=torch.float16
).to(device)

processor_vivit = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")

model_vivit.eval()
print("ViViT model loaded and in eval mode.")


Some weights of VivitModel were not initialized from the model checkpoint at google/vivit-b-16x2-kinetics400 and are newly initialized: ['vivit.pooler.dense.bias', 'vivit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViViT model loaded and in eval mode.


In [4]:
import os
import glob
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import imageio.v2 as iio
from transformers import VivitModel, VivitImageProcessor
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

raw_base_dir = '/content/raw'
seg_base_dir = '/content/seg'
bbox_names = [f'bbox{i}' for i in range(1,8)]  # e.g., "bbox1", "bbox2", etc.

os.makedirs('csv_outputs', exist_ok=True)

def load_bbox_data(bbox_name, max_slices=None):
    """
    Load raw and segmentation volumes for a bounding box.
    Returns (raw_vol, seg_vol) each shape (Z, Y, X).
    """
    raw_dir = os.path.join(raw_base_dir, bbox_name)
    seg_dir = os.path.join(seg_base_dir, bbox_name)

    raw_tif_files = sorted(glob.glob(os.path.join(raw_dir, 'slice_*.tif')))
    seg_tif_files = sorted(glob.glob(os.path.join(seg_dir, 'slice_*.tif')))

    if max_slices is not None:
        raw_tif_files = raw_tif_files[:max_slices]
        seg_tif_files = seg_tif_files[:max_slices]

    assert len(raw_tif_files) == len(seg_tif_files), f"Raw/Seg mismatch in {bbox_name}"

    raw_slices = [iio.imread(f) for f in raw_tif_files]
    seg_slices = [iio.imread(f).astype(np.uint32) for f in seg_tif_files]

    raw_vol = np.stack(raw_slices, axis=0)  # shape (Z, Y, X)
    seg_vol = np.stack(seg_slices, axis=0)  # shape (Z, Y, X)
    return raw_vol, seg_vol

def create_segment_masks(seg_vol, side1_coord, side2_coord):
    """
    Returns boolean masks for side1 and side2 in seg_vol (Z, Y, X).
    """
    x1, y1, z1 = [int(c) for c in side1_coord]
    x2, y2, z2 = [int(c) for c in side2_coord]

    seg_id_1 = seg_vol[z1, y1, x1]
    seg_id_2 = seg_vol[z2, y2, x2]

    mask_1 = (seg_vol == seg_id_1) if seg_id_1 != 0 else np.zeros_like(seg_vol, dtype=bool)
    mask_2 = (seg_vol == seg_id_2) if seg_id_2 != 0 else np.zeros_like(seg_vol, dtype=bool)
    return mask_1, mask_2

class SynapseDataset(Dataset):
    """
    Loads each synapse from synapse_df, extracts an 80^3 subvolume around central_coord,
    merges raw intensity + side1/side2 masks into a 3-channel volume ( shape => [80, 3, 80, 80] ).
    Then we sample exactly 32 frames out of 80 so that we match the
    'google/vivit-b-16x2-kinetics400' model's 32-frame requirement.
    """

    def __init__(self, vol_data_list, synapse_df, subvol_size=80):
        self.vol_data_list = vol_data_list
        self.synapse_df = synapse_df.reset_index(drop=True)
        self.subvol_size = subvol_size
        self.half_size = subvol_size // 2

    def __len__(self):
        return len(self.synapse_df)

    def __getitem__(self, idx):
        syn_info = self.synapse_df.iloc[idx]
        bbox_index = syn_info['bbox_index']  # which volume to use
        raw_vol, seg_vol = self.vol_data_list[bbox_index]

        # Coordinates
        central_coord = (
            int(syn_info['central_coord_1']),
            int(syn_info['central_coord_2']),
            int(syn_info['central_coord_3'])
        )
        side1_coord = (
            int(syn_info['side_1_coord_1']),
            int(syn_info['side_1_coord_2']),
            int(syn_info['side_1_coord_3'])
        )
        side2_coord = (
            int(syn_info['side_2_coord_1']),
            int(syn_info['side_2_coord_2']),
            int(syn_info['side_2_coord_3'])
        )

        # Masks for side 1 and side 2
        mask_1_full, mask_2_full = create_segment_masks(seg_vol, side1_coord, side2_coord)

        # Determine subvolume bounds
        cx, cy, cz = central_coord
        x_start = max(cx - self.half_size, 0)
        x_end   = min(cx + self.half_size, raw_vol.shape[2])
        y_start = max(cy - self.half_size, 0)
        y_end   = min(cy + self.half_size, raw_vol.shape[1])
        z_start = max(cz - self.half_size, 0)
        z_end   = min(cz + self.half_size, raw_vol.shape[0])

        sub_raw    = raw_vol[z_start:z_end, y_start:y_end, x_start:x_end]
        sub_mask_1 = mask_1_full[z_start:z_end, y_start:y_end, x_start:x_end]
        sub_mask_2 = mask_2_full[z_start:z_end, y_start:y_end, x_start:x_end]

        # Pad each to (80,80,80) if near edges
        desired_shape = (self.subvol_size, self.subvol_size, self.subvol_size)
        dz, dy, dx = sub_raw.shape

        padded_sub_raw    = np.zeros(desired_shape, dtype=sub_raw.dtype)
        padded_sub_mask1  = np.zeros(desired_shape, dtype=np.uint8)
        padded_sub_mask2  = np.zeros(desired_shape, dtype=np.uint8)

        padded_sub_raw[:dz, :dy, :dx] = sub_raw
        padded_sub_mask1[:dz, :dy, :dx] = sub_mask_1
        padded_sub_mask2[:dz, :dy, :dx] = sub_mask_2

        # Now shape => (80,80,80) each
        # We'll build a 3-channel subvolume for each z-slice:
        #   Channel 0 = side1 mask
        #   Channel 1 = raw intensities
        #   Channel 2 = side2 mask
        # => sub_3d shape: (80, 3, 80, 80)
        sub_3d = np.stack([padded_sub_mask1, padded_sub_raw, padded_sub_mask2], axis=1)

        # We want to sample exactly 32 frames out of 80 to match the pretrained model
        z_indices = np.linspace(0, 79, 32, dtype=int)  # pick 32 evenly spaced slices
        sub_3d_32 = sub_3d[z_indices]  # shape => (32, 3, 80, 80)

        # Convert to float tensor
        sub_3d_tensor = torch.from_numpy(sub_3d_32).float()

        # Convert syn_info row to dict for convenience
        syn_info_dict = syn_info.to_dict()

        return sub_3d_tensor, syn_info_dict

model_name = "google/vivit-b-16x2-kinetics400"

print("Loading Vivit Model...")
model_vivit = VivitModel.from_pretrained(
    model_name,
).to(device)
model_vivit.eval()

processor_vivit = VivitImageProcessor.from_pretrained(model_name)
print("ViViT model & processor loaded.\n")

def get_vivit_features(video_batch):
    """
    video_batch: tensor of shape [B, 32, 3, H, W].
    We'll pass each subvolume (video) individually to the image processor
    and the model.
    Returns a (B, hidden_size) np.array of pooler embeddings.
    """
    features_list = []
    for i in range(video_batch.size(0)):
        # video_batch[i]: shape [32, 3, H, W]
        frames_3d = video_batch[i].cpu().numpy()  # => [32, 3, H, W]

        # The VivitImageProcessor expects a list of frames, each [3, H, W].
        frame_list = [frames_3d[z] for z in range(frames_3d.shape[0])]

        # Preprocess (resizes to 224x224, normalizes, etc.)
        processed = processor_vivit(
            frame_list,
            return_tensors="pt"
        )
        pixel_values = processed["pixel_values"].to(device)
        # => shape [1, 32, 3, 224, 224]

        with torch.no_grad():
            # We set interpolate_pos_encoding=True to allow minor spatial interpolation
            # but now the temporal dimension is EXACTLY 32 frames => matches pretrained
            outputs = model_vivit(pixel_values=pixel_values, interpolate_pos_encoding=True)
            pooler_output = outputs.pooler_output  # shape [1, hidden_size]

        features_list.append(pooler_output.cpu().numpy())

    # Combine into [B, hidden_size]
    features_array = np.concatenate(features_list, axis=0)
    return features_array

all_csv_paths = []

for bbox_name in bbox_names:
    print(f"Processing {bbox_name}...")
    raw_vol, seg_vol = load_bbox_data(bbox_name)
    excel_file = f'/content/{bbox_name}.xlsx'
    syn_df = pd.read_excel(excel_file)

    syn_df['bbox_index'] = 0
    syn_df['bbox_name']  = bbox_name

    vol_data_list = [(raw_vol, seg_vol)]
    dataset_bbox = SynapseDataset(vol_data_list, syn_df, subvol_size=80)
    dataloader_bbox = DataLoader(dataset_bbox, batch_size=2, shuffle=False, num_workers=2)

    bbox_features = []
    bbox_syn_info = []

    for batch_idx, (video_batch, syn_infos) in enumerate(dataloader_bbox):
        # video_batch: [B, 32, 3, 80, 80]
        feats = get_vivit_features(video_batch)  # => shape (B, hidden_size)
        bbox_features.append(feats)

        syn_infos_df = pd.DataFrame(syn_infos)  # convert list-of-dict to DataFrame
        bbox_syn_info.append(syn_infos_df)

    # Concatenate
    bbox_features = np.concatenate(bbox_features, axis=0)  # shape [N, hidden_size]
    bbox_syn_info = pd.concat(bbox_syn_info, axis=0).reset_index(drop=True)

    feature_cols = [f'feat_{j}' for j in range(bbox_features.shape[1])]
    features_df = pd.DataFrame(bbox_features, columns=feature_cols)

    output_df = pd.concat([bbox_syn_info, features_df], axis=1)

    # Write CSV
    output_csv_name = f'csv_outputs/{bbox_name}_features.csv'
    output_df.to_csv(output_csv_name, index=False)
    all_csv_paths.append(output_csv_name)
    print(f"Saved features for {bbox_name} -> {output_csv_name}")

merged_df = pd.concat([pd.read_csv(p) for p in all_csv_paths], ignore_index=True)
print(f"\nMerged {len(all_csv_paths)} CSVs into one DataFrame with {len(merged_df)} rows.")

merged_csv = 'csv_outputs/all_features_merged.csv'
merged_df.to_csv(merged_csv, index=False)
print(f"Final merged CSV: {merged_csv}")


Loading Vivit Model...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/18.6k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/356M [00:00<?, ?B/s]

Some weights of VivitModel were not initialized from the model checkpoint at google/vivit-b-16x2-kinetics400 and are newly initialized: ['vivit.pooler.dense.bias', 'vivit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


preprocessor_config.json:   0%|          | 0.00/401 [00:00<?, ?B/s]

ViViT model & processor loaded.

Processing bbox1...
Saved features for bbox1 -> csv_outputs/bbox1_features.csv
Processing bbox2...
Saved features for bbox2 -> csv_outputs/bbox2_features.csv
Processing bbox3...
Saved features for bbox3 -> csv_outputs/bbox3_features.csv
Processing bbox4...
Saved features for bbox4 -> csv_outputs/bbox4_features.csv
Processing bbox5...
Saved features for bbox5 -> csv_outputs/bbox5_features.csv
Processing bbox6...
Saved features for bbox6 -> csv_outputs/bbox6_features.csv
Processing bbox7...
Saved features for bbox7 -> csv_outputs/bbox7_features.csv

Merged 7 CSVs into one DataFrame with 509 rows.
Final merged CSV: csv_outputs/all_features_merged.csv


# Dim Red and Visualize

In [52]:
!dir

CHANGELOG.md  examples	   MinkowskiEngine   setup.py	      vit_umap_x_vs_y.html
csv_outputs   LICENSE	   pybind	     src	      vit_umap_x_vs_z.html
docker	      Makefile	   README.md	     tests	      vit_umap_y_vs_z.html
docs	      MANIFEST.in  requirements.txt  vit_umap3d.html


In [5]:
import pandas as pd
import numpy as np

from sklearn.decomposition import PCA
import umap.umap_ as umap
import plotly.express as px

merged_csv = 'csv_outputs/all_features_merged.csv'
df = pd.read_csv(merged_csv)

feat_cols = [c for c in df.columns if c.startswith('feat_')]
X = df[feat_cols].values  # shape: [N, hidden_size] (e.g. [N, 768])

pca = PCA(n_components=50, random_state=42)
X_pca = pca.fit_transform(X)  # shape => [N, 50]

# 4) UMAP from 50 -> 3 dims (for 3D visualization)
umap_3d = umap.UMAP(
    n_components=3,
    n_neighbors=15,     # can tune
    min_dist=0.1,       # can tune
    random_state=42
)
X_umap3 = umap_3d.fit_transform(X_pca)  # shape => [N, 3]

# 5) Add UMAP coordinates back to the DataFrame
df['umap_x'] = X_umap3[:,0]
df['umap_y'] = X_umap3[:,1]
df['umap_z'] = X_umap3[:,2]

fig = px.scatter_3d(
    df,
    x='umap_x',
    y='umap_y',
    z='umap_z',
    color='bbox_name',
    hover_data=['central_coord_1', 'central_coord_2', 'central_coord_3']
)
fig.update_traces(marker=dict(size=3))
fig.update_layout(width=800, height=600)
fig.write_html("vit_umap3d.html")

fig.show()


  warn(


In [6]:
import plotly.express as px

# 1) UMAP (x vs. y)
fig_xy = px.scatter(
    df,
    x="umap_x",
    y="umap_y",
    color="bbox_name",  # color by bbox_name => discrete legend
    title="UMAP (x vs y)",
    hover_data=["umap_x", "umap_y", "bbox_name", "Var1"]
)
fig_xy.write_html("vit_umap_x_vs_y.html")

fig_xy.show()

# 2) UMAP (x vs. z)
fig_xz = px.scatter(
    df,
    x="umap_x",
    y="umap_z",
    color="bbox_name",
    title="UMAP (x vs z)",
    hover_data=["umap_x", "umap_z", "bbox_name", "Var1"]
)
fig_xz.write_html("vit_umap_x_vs_z.html")

fig_xz.show()

# 3) UMAP (y vs. z)
fig_yz = px.scatter(
    df,
    x="umap_y",
    y="umap_z",
    color="bbox_name",
    title="UMAP (y vs z)",
    hover_data=["umap_y", "umap_z", "bbox_name", "Var1"]
)
fig_xz.write_html("vit_umap_y_vs_z.html")

fig_yz.show()


In [53]:
from google.colab import files

# files.download("csv_outputs/all_features_merged.csv")
files.download("vit_umap_x_vs_y.html")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [54]:

files.download("vit_umap_x_vs_z.html")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [55]:

files.download("vit_umap_y_vs_z.html")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [56]:

files.download("vit_umap3d.html")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [41]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

fig = make_subplots(
    rows=1, cols=3,
    subplot_titles=[
        "UMAP (x vs y)",
        "UMAP (x vs z)",
        "UMAP (y vs z)"
    ]
)


cat_codes = df["bbox_name"].astype("category").cat.codes
trace_xy = go.Scatter(
    x=df["umap_x"],
    y=df["umap_y"],
    mode="markers",
    name="(x vs y)",
    marker=dict(
        color=cat_codes,
        colorscale="Viridis",
        showscale=True,
        size=5
    ),
    text=df["bbox_name"],    # hover text
    hovertemplate="<b>bbox_name:%{text}</b><br>umap_x=%{x}<br>umap_y=%{y}<extra></extra>"
)
fig.add_trace(trace_xy, row=1, col=1)

trace_xz = go.Scatter(
    x=df["umap_x"],
    y=df["umap_z"],
    mode="markers",
    name="(x vs z)",
    marker=dict(
        color=cat_codes,
        colorscale="Viridis",
        showscale=False,  # we already have a colorbar in the first subplot
        size=5
    ),
    text=df["bbox_name"],
    hovertemplate="<b>bbox_name:%{text}</b><br>umap_x=%{x}<br>umap_z=%{y}<extra></extra>"
)
fig.add_trace(trace_xz, row=1, col=2)

# 3) UMAP_y vs UMAP_z
trace_yz = go.Scatter(
    x=df["umap_y"],
    y=df["umap_z"],
    mode="markers",
    name="(y vs z)",
    marker=dict(
        color=cat_codes,
        colorscale="Viridis",
        showscale=False,
        size=5
    ),
    text=df["bbox_name"],
    hovertemplate="<b>bbox_name:%{text}</b><br>umap_y=%{x}<br>umap_z=%{y}<extra></extra>"
)
fig.add_trace(trace_yz, row=1, col=3)

# Adjust layout
fig.update_layout(
    title="2D UMAP Projections (All Pairwise Components)",
    width=1800,   # wide figure
    height=600,
    showlegend=False
)

fig.show()
