In [1]:
import timm
import torch
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from huggingface_hub import login

In [2]:
login()  # login with your User Access Token, found at https://huggingface.co/settings/tokens

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
# pretrained=True needed to load UNI2-h weights (and download weights for the first time)
timm_kwargs = {
            'img_size': 224, 
            'patch_size': 14, 
            'depth': 24,
            'num_heads': 24,
            'init_values': 1e-5, 
            'embed_dim': 1536,
            'mlp_ratio': 2.66667*2,
            'num_classes': 0, 
            'no_embed_class': True,
            'mlp_layer': timm.layers.SwiGLUPacked, 
            'act_layer': torch.nn.SiLU, 
            'reg_tokens': 8, 
            'dynamic_img_size': True
        }
model = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs)
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
model.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1536, kernel_size=(14, 14), stride=(14, 14))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1536, out_features=4608, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (proj): Linear(in_features=1536, out_features=1536, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)
      (mlp): GluMlp(
        (fc1): Linear(in_features=1536, out_features=8192, bias=True)
        (act): SiLU()
        (drop1): Dropout(p=0.0, inplace=False)
    

In [4]:
# create a dummy input tensor of size (1, 3, 224, 224)
dummy_input = torch.randn(1, 3, 224, 224)

# perform a forward pass
with torch.no_grad():
    output = model(dummy_input)

print("Output shape:", output.shape)


Output shape: torch.Size([1, 1536])


In [5]:
import os
from pathlib import Path
import ipywidgets as widgets
from IPython.display import display
import tifffile
import numpy as np

# Set your source directory path
source_dir = Path(r"K:\499-ProjectData\2025\P25-0048_Thyroid_Recurrence\04-Processed_Datasets\split_cleared_dataset")

def get_tiff_files(directory):
    """Get all TIFF files excluding those ending with _tissue"""
    if not directory.exists():
        print(f"Directory does not exist: {directory}")
        return []
    
    tiff_files = []
    for file in sorted(directory.glob("*.tif*")):
        if not file.stem.endswith("_tissue"):
            tiff_files.append(file.name)
    return tiff_files

def load_volume_and_mask(filename):
    """Load the selected volume and its corresponding tissue mask"""
    if not filename:
        print("No file selected")
        return None, None
    
    volume_path = source_dir / filename
    base_name = Path(filename).stem
    tissue_filename = f"{base_name}_tissue.tif"
    tissue_path = source_dir / tissue_filename
    
    print(f"Loading volume: {volume_path}")
    volume = tifffile.imread(volume_path)
    print(f"Volume shape: {volume.shape}, dtype: {volume.dtype}")
    
    if tissue_path.exists():
        print(f"Loading tissue mask: {tissue_path}")
        tissue_mask = tifffile.imread(tissue_path)
        print(f"Tissue mask shape: {tissue_mask.shape}, dtype: {tissue_mask.dtype}")
    else:
        print(f"Warning: Tissue mask not found at {tissue_path}")
        tissue_mask = None
    
    return volume, tissue_mask

def on_file_selected(change):
    """Callback when file is selected"""
    selected_file = change['new']
    if selected_file:
        volume, tissue_mask = load_volume_and_mask(selected_file)
        global current_volume, current_tissue_mask
        current_volume = volume
        current_tissue_mask = tissue_mask

# Get available TIFF files
tiff_files = get_tiff_files(source_dir)
print(f"Found {len(tiff_files)} TIFF files")

# Create and display dropdown
if tiff_files:
    file_dropdown = widgets.Dropdown(
        options=['Select a file...'] + tiff_files,
        value='Select a file...',
        description='TIFF File:',
        disabled=False,
        style={'description_width': '100px'},
        layout=widgets.Layout(width='500px')
    )
    file_dropdown.observe(on_file_selected, names='value')
    display(file_dropdown)
else:
    print(f"No TIFF files found in: {source_dir}")


Found 212 TIFF files


Dropdown(description='TIFF File:', layout=Layout(width='500px'), options=('Select a file...', '002_B05.20964B.…

In [6]:
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider

def plot_slice(slice_idx):
    """Plot a single slice from the volume and tissue mask"""
    if 'current_volume' not in globals() or current_volume is None:
        print("No volume loaded. Please select a file from the dropdown above.")
        return
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Plot volume slice
    axes[0].imshow(current_volume[slice_idx], cmap='gray')
    axes[0].set_title(f'Volume - Slice {slice_idx}/{current_volume.shape[0]-1}')
    axes[0].axis('off')
    
    # Plot tissue mask slice if available
    if current_tissue_mask is not None:
        axes[1].imshow(current_tissue_mask[slice_idx], cmap='gray')
        axes[1].set_title(f'Tissue Mask - Slice {slice_idx}/{current_tissue_mask.shape[0]-1}')
    else:
        axes[1].text(0.5, 0.5, 'No tissue mask available', 
                    ha='center', va='center', transform=axes[1].transAxes)
        axes[1].set_title('Tissue Mask - Not Available')
    axes[1].axis('off')
    
    plt.tight_layout()
    plt.show()

# Create interactive slider for scrolling through slices
if 'current_volume' in globals() and current_volume is not None:
    depth = current_volume.shape[0]
    interact(plot_slice, 
             slice_idx=IntSlider(min=0, max=depth-1, step=1, value=depth//2, 
                                description='Slice:'))
else:
    print("No volume loaded yet. Please select a file from the dropdown above first.")

interactive(children=(IntSlider(value=52, description='Slice:', max=103), Output()), _dom_classes=('widget-int…