In [2]:
# Library imports
import random
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import sys
import zarr
import json
import os
import napari
from rich import print as rprint  # Import rich's print function
import copy
from dataclasses import dataclass
from monai.transforms import Compose, RandFlipd, RandRotated, RandGaussianNoised, RandAdjustContrastd
# set torch and cuda seed for reproducibility
torch.manual_seed(37)
torch.cuda.manual_seed(37)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# -------------LOADING TOMOGRAM DATA AND SEGMENTATION MASKS-----------------#

experiment_runs = ["TS_" + str(i) for i in range(27)]
particle_types = {"virus-like-particle": 1, "apo-ferritin": 2, "beta-amylase": 3, 
                  "beta-galactosidase": 4, "ribosome": 5, "thyroglobulin": 6}
type_mapping = {
    "virus-like-particle": ["pp7_vlp", 6], "apo-ferritin": ["ferritin_complex", 1], 
    "beta-amylase": ["beta_amylase", 2], "beta-galactosidase": ["beta_galactosidase", 3], 
    "ribosome": ["cytosolic_ribosome", 4], "thyroglobulin": ["thyroglobulin", 5]
}

# Initialize lists to store tomograms and labels
combined_tomogram_data = []
combined_label_data = []

for experiment_run in experiment_runs:
    zarr_file_path = os.path.join("aws_data", "10441", experiment_run, "Reconstructions", 
                                  "VoxelSpacing10.000", "Tomograms", "100", experiment_run)

    try:
        # Load tomogram
        tomogram = zarr.open(zarr_file_path + ".zarr", mode="r")
        tomogram_data = tomogram["0"][:]
        # tomogram_data = (tomogram_data - tomogram_data.mean()) / tomogram_data.std()
        
        # Initialize label cube for this tomogram
        label_cube = np.zeros_like(tomogram_data, dtype=np.uint8)

        # Process segmentation masks
        for particle_type, particle_id in particle_types.items():
            name, number = type_mapping[particle_type]
            segmentation_mask_path = os.path.join("aws_data", "10441", experiment_run, "Reconstructions", 
                                                  "VoxelSpacing10.000", "Annotations", 
                                                  "10" + str(number), name + "-1.0_segmentationmask.zarr")
            try:
                segmentation_mask = zarr.open(segmentation_mask_path, mode="r")["0"][:]
                label_cube[segmentation_mask > 0] = particle_id  # Assign particle ID where mask exists
            except Exception as e:
                print(f"Error loading segmentation mask for {particle_type} in {experiment_run}: {e}")

        # Append processed data
        combined_tomogram_data.append(tomogram_data)
        combined_label_data.append(label_cube)

    except Exception as e:
        print(f"Error loading Zarr file for {experiment_run}: {e}")

# Concatenate all tomograms and labels along the Z-dimension
combined_tomogram_data = np.concatenate(combined_tomogram_data, axis=0)
combined_label_data = np.concatenate(combined_label_data, axis=0)

print("Final combined tomogram shape:", combined_tomogram_data.shape)
print("Final combined label shape:", combined_label_data.shape)

Final combined tomogram shape: (5400, 630, 630)
Final combined label shape: (5400, 630, 630)


In [3]:
# -------------VISUALIZATION WITH NAPARI-----------------#

def visualize_tomogram_and_labels(tomogram_data, label_data):
    viewer = napari.Viewer()
    viewer.add_image(
        tomogram_data,
        name="Tomogram Data",
        # contrast_limits=[np.min(tomogram_data), np.max(tomogram_data)],
        colormap="gray",
    )
    viewer.add_labels(
        label_data,
        name="Label Data",
        opacity=0.5,
    )
    napari.run()

# Visualize the dataset
visualize_tomogram_and_labels(combined_tomogram_data, combined_label_data)



In [4]:

# -------------Combine tomograms and sample cubes with particles in it-----------------#
label_cube = combined_label_data
# Dimensions of the combined tomogram data
data_shape = combined_tomogram_data.shape
cube_size = (96, 96, 96)
background_id = 0

# Calculate the number of cubes in each dimension
num_cubes_z = data_shape[0] // cube_size[0]
num_cubes_y = data_shape[1] // cube_size[1]
num_cubes_x = data_shape[2] // cube_size[2]

# Create a list of all possible cube indices
cubes = []
particle_cubes = []
non_particle_cubes = []

for z in range(num_cubes_z):
    for y in range(num_cubes_y):
        for x in range(num_cubes_x):
            cubes.append((z, y, x))

# Separate cubes into particle-containing and non-particle cubes
def contains_particle(cube_start, label_cube):
    z_start, y_start, x_start = cube_start
    z_end, y_end, x_end = z_start + cube_size[0], y_start + cube_size[1], x_start + cube_size[2]
    return np.any(label_cube[z_start:z_end, y_start:y_end, x_start:x_end] > 0)

for cz, cy, cx in cubes:
    cube_start = (cz * cube_size[0], cy * cube_size[1], cx * cube_size[2])
    if contains_particle(cube_start, label_cube):
        particle_cubes.append((cz, cy, cx))
    else:
        non_particle_cubes.append((cz, cy, cx))

# Limit non-particle cubes to 10% of the dataset
num_non_particle_cubes = min(len(non_particle_cubes), int(len(particle_cubes) * 0.1))

selected_non_particle_cubes = random.sample(non_particle_cubes, num_non_particle_cubes)
selected_cubes = particle_cubes + selected_non_particle_cubes
print(f"Selected {len(selected_cubes)} cubes for the dataset. Where {len(particle_cubes)} contain particles and {len(selected_non_particle_cubes)} do not.")


Selected 2016 cubes for the dataset. Where 1965 contain particles and 51 do not.


In [5]:
class DataCreator():
    def __init__(self, tomogram_data, label_cube, selected_cubes):
        self.tomogram_data = tomogram_data
        self.label_cube = label_cube
        self.selected_cubes = selected_cubes
        self.cube_size = (96, 96, 96)
        self.subcube_size = (6, 6, 6)
        self.background_id = 0

    def __len__(self):
        return len(self.selected_cubes)

    def getitem(self, idx):
        cz, cy, cx = self.selected_cubes[idx]
        z_start, z_end = cz * self.cube_size[0], (cz + 1) * self.cube_size[0]
        y_start, y_end = cy * self.cube_size[1], (cy + 1) * self.cube_size[1]
        x_start, x_end = cx * self.cube_size[2], (cx + 1) * self.cube_size[2]

        cube_data = self.tomogram_data[z_start:z_end, y_start:y_end, x_start:x_end]
        labels = self.generate_labels(z_start, y_start, x_start)

        cube_data = np.expand_dims(cube_data, axis=0)  # Add channel dimension
        return (
            torch.tensor(cube_data, dtype=torch.float32),
            torch.tensor(labels, dtype=torch.int64),
        )

    def generate_labels(self, z_start, y_start, x_start):
        mini_cube_labels = []
        
        for z in range(0, self.cube_size[2], self.subcube_size[2]):  # Width first
            for y in range(0, self.cube_size[1], self.subcube_size[1]):  # Height second
                for x in range(0, self.cube_size[0], self.subcube_size[0]):  # Depth last
                    mini_cube = self.label_cube[
                        z_start + z:z_start + z + self.subcube_size[0],
                        y_start + y:y_start + y + self.subcube_size[1],
                        x_start + x:x_start + x + self.subcube_size[2],
                    ]
                    unique, counts = np.unique(mini_cube, return_counts=True)
                    label_coverage = dict(zip(unique, counts))
                    total_voxels = np.prod(self.subcube_size)

                    dominant_label = self.background_id
                    max_coverage = 0

                    for label, coverage in label_coverage.items():
                        if label != self.background_id and coverage / total_voxels >= 0.37 and coverage > max_coverage:
                            dominant_label = label
                            max_coverage = coverage

                    mini_cube_labels.append(dominant_label)

        return np.array(mini_cube_labels)

    def generate_data(self):
        tomogram_data = torch.zeros((len(self.selected_cubes), 1, *self.cube_size))
        segmentation_labels = torch.zeros((len(self.selected_cubes), int(self.cube_size[0]**3/self.subcube_size[0]**3)), dtype=torch.int64)

        for idx in range(len(self.selected_cubes)):
            data_tensor, label_tensor= self.getitem(idx)
            tomogram_data[idx] = data_tensor
            segmentation_labels[idx] = label_tensor
        return tomogram_data, segmentation_labels

# Data Creator
data_creator = DataCreator(
    combined_tomogram_data, label_cube, selected_cubes
)

input_data, segmentation_labels = data_creator.generate_data()

print(input_data.shape, segmentation_labels.shape)

torch.Size([2016, 1, 96, 96, 96]) torch.Size([2016, 4096])


In [8]:
del combined_tomogram_data
del label_cube
del combined_label_data

NameError: name 'combined_tomogram_data' is not defined

In [3]:
# -------------DATASET IMPLEMENTATION-----------------#
class TomogramDatasetMiniCubes(Dataset):
    def __init__(self, tomogram_data, segmentation_labels):
        assert tomogram_data.size(0) == segmentation_labels.size(0), f"{tomogram_data.size(0)}, {segmentation_labels.size(0)}"
        self.tomogram_data = tomogram_data
        self.segmentation_labels = segmentation_labels

    def __len__(self):
        return self.tomogram_data.size(0)

    def __getitem__(self, idx):
        return (
            self.tomogram_data[idx],
            self.segmentation_labels[idx],
        )


particle_types = {"virus-like-particle": 1, "apo-ferritin": 2, "beta-amylase": 3, 
                  "beta-galactosidase": 4, "ribosome": 5, "thyroglobulin": 6}
input_data, segmentation_labels = torch.load("segmentation_input_data.pt"), torch.load("segmentation_labels.pt")
# Create the dataset
particle_dataset_mini_cubes = TomogramDatasetMiniCubes(input_data, segmentation_labels)

# Test the dataset
cube_data, segmentation_labels = particle_dataset_mini_cubes[200]
print(f"len(particle_dataset_mini_cubes): {len(particle_dataset_mini_cubes)}")
print("Cube Data Shape:", cube_data.shape)  # Should be (1, 96, 96, 96)
print("Labels Shape:", segmentation_labels.shape)        # Should be (4096,)

len(particle_dataset_mini_cubes): 2016
Cube Data Shape: torch.Size([1, 96, 96, 96])
Labels Shape: torch.Size([4096])


In [4]:
# Get the first cube and its labels
cube_data, labels = particle_dataset_mini_cubes[550]

# Convert cube data to numpy (removing the channel dimension)
sample_data = cube_data.squeeze().numpy()

# Reshape labels from (16, 16, 16) to match subcube structure
labels = labels.numpy().reshape((16, 16, 16))

# Create an empty array for the upscaled labels
scaled_labels = np.zeros_like(sample_data, dtype=np.int64)

# Iterate over the (16, 16, 16) labels and assign them to (6, 6, 6) regions
for z in range(16):
    for y in range(16):
        for x in range(16):
            scaled_labels[z*6:(z+1)*6, y*6:(y+1)*6, x*6:(x+1)*6] = labels[z, y, x]

# Create a napari viewer
viewer = napari.Viewer()

# Add the cube data as a 3D volume
viewer.add_image(
    sample_data, name='Tomogram Cube', colormap="gray",
    contrast_limits=[np.min(sample_data), np.max(sample_data)]
)

# Add the upsampled labels
viewer.add_labels(scaled_labels, name='Scaled Mini-Cube Labels')

# Start the napari event loop
napari.run()

In [7]:
class SelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        # e.g. in GPT-2 (124M), n_head=11, hs=64, so nh*hs=C=768 channels in the Transformer
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=False) # flash attention
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y)
        return y

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

class TransformerBlock(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = SelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

class Transformer(nn.Module):
    def __init__(self, num_layers, decoder_layer):
        super().__init__()
        self.layers = clones(decoder_layer, num_layers)
        self.num_layers = num_layers

    def forward(self, tgt):
        output = tgt

        for mod in self.layers:
            output = mod(output)
        return output
    
class LinearTokenizer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.tokenizer = nn.Conv3d(
            in_channels=1,  # Input channels (C=1)
            out_channels=config.n_embd,
            kernel_size=config.token_width,
            stride=config.token_width,
            padding=0,
            bias=True  # Optional
        )

        
    def forward(self, x):
        # x shape: (B, 1, D=96, H=96, W=96)
        x = self.tokenizer(x)  # Output shape: (B, n_embd, D_blocks=8, H_blocks=8, W_blocks=8)
        
        # Efficient reshaping and transposing
        x = x.reshape(x.size(0), x.size(1), -1).transpose(-2, -1)  # (B, 512, n_embd)
        
        return x

@dataclass
class LinearConfig:
    block_size: int = int((96**3)/(6**3)) # max sequence length
    token_width: int = 6 # width of the cube
    n_layer: int = 4 # number of layers
    n_head: int = 16 # number of heads
    n_embd: int = 144 # embedding dimension


class BaseModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.tokenizer = LinearTokenizer(config)
        self.positional_embedding = nn.Parameter(torch.zeros(config.block_size, config.n_embd))
        self.transformer = Transformer(config.n_layer, TransformerBlock(config))
        self.decoder = nn.Linear(config.n_embd, len(particle_types) + 1)  # Output layer

    def ff(self, x):
        x = self.tokenizer(x) # (N, 1, 96, 96, 96) -> (N, 8 * 8 * 8, n_embd)
        x = x + self.positional_embedding # (N, n_embd, 8, 8, 8) -> (N, n_embd, 512) -> (N, 512, n_embd)
        x = self.transformer(x) # (N, 512, n_embd) -> (N, 512, n_embd)
        x = self.decoder(x) # (N, 512, n_embd) -> (N, 512, 7)        
        return x
    
    def forward(self, x):   
        return self.ff(x)
    

class ContrastiveModel(BaseModel):
    def __init__(self, config):
        super().__init__(config)
        # Summarizer token
        self.summarizer = nn.Parameter(torch.zeros(1, 1, config.n_embd))
        self.projection_head = nn.Sequential(
            nn.Linear(config.n_embd, config.n_embd),
            nn.ReLU(),
            nn.Linear(config.n_embd, 128)  # Project to lower-dimensional space
        )

    def forward(self, x):
        x = self.tokenizer(x) # (N, 1, 96, 96, 96) -> (N, 8 * 8 * 8, n_embd)
        x = x + self.positional_embedding # (N, n_embd, 8, 8, 8) -> (N, n_embd, n_tokens) -> (N, n_tokens, n_embd)
        
        # Append Summarizer token
        x = torch.cat([self.summarizer.repeat(x.size(0), 1, 1), x], dim=1)
        
        # Build representation using transformer
        x = self.transformer(x) # (N, n_tokens + 1, n_embd) -> (N, n_tokens + 1, n_embd)
        
        # Extract the summarizer token
        x = x[:, 0, :]  
        # Project to contrastive space
        return self.projection_head(x)
    
class SegmentationModel(ContrastiveModel):
    def __init__(self, config):
        super().__init__(config)
    
    def ff_without_summarizer(self, x):
        # ContrastiveModel inherits the ff (feedforward without summarizer) method from the BaseModel
        return self.ff(x)
    
    def ff_with_summarizer(self, x):
        x = self.tokenizer(x) # (N, 1, 96, 96, 96) -> (N, 8 * 8 * 8, n_embd)
        x = x + self.positional_embedding # (N, n_embd, 8, 8, 8) -> (N, n_embd, n_tokens) -> (N, n_tokens, n_embd)
        
        # Append Summarizer token
        x = torch.cat([self.summarizer.repeat(x.size(0), 1, 1), x], dim=1)
        
        # Build representation using transformer
        x = self.transformer(x) # (N, n_tokens + 1, n_embd) -> (N, n_tokens + 1, n_embd)     

        # Classify
        x = self.decoder(x[:, 1:, :])
        return x       

    def forward(self, x):
        return self.ff_without_summarizer(x)

In [8]:
import time

config = LinearConfig()
device = "cuda" if torch.cuda.is_available() else "cpu"
segmentation_model = SegmentationModel(config).to(device)

In [12]:


try:
    # Load the trained model
    segmentation_model.load_state_dict(torch.load("segmentation_model_mini_cubes.pth"))
    print("Loaded pretrained model")
except:
    pass
# print the # of parameters
print(f"Number of parameters: {sum(p.numel() for p in segmentation_model.parameters() if p.requires_grad):.2e}")
print(segmentation_model)
# Write train and validation dataloader for segmentation using particle_dataset_mini_cubes
train_size = int(0.8 * len(particle_dataset_mini_cubes))
val_size = len(particle_dataset_mini_cubes) - train_size
train_dataset, val_dataset = random_split(particle_dataset_mini_cubes, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

# Define optimizer, loss, and scheduler
optimizer = optim.AdamW(segmentation_model.parameters(), lr=1e-3)
# weights = torch.tensor([0.5] + [1 for _ in range(1, len(particle_types) + 1)]).to(device)  # Lower weight for background

"""
particle_types = {"virus-like-particle": 1, "apo-ferritin": 2, "beta-amylase": 3, 
                  "beta-galactosidase": 4, "ribosome": 5, "thyroglobulin": 6}
"""


# background, virus-like particle, "apo-ferritin", "beta-amylase","beta-galactosidase","ribosome","thyroglobulin"
weights = torch.tensor([0.08, 0.16, 0.16, 0.16, 0.66, 0.16, 0.66]).to(device)

criterion = nn.CrossEntropyLoss(weight=weights)

# Training loop
epochs = 5
best_val_loss = float('inf')

for epoch in range(epochs):
    segmentation_model.train()
    train_loss = 0.0

    for batch_idx, (inputs, labels) in enumerate(train_loader):
        # Measure time to load batch
        # st = time.time()
        # torch.cuda.synchronize()
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = segmentation_model(inputs)
        loss = criterion(outputs.view(-1, len(particle_types) + 1), labels.view(-1))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        # torch.cuda.synchronize()
        # et = time.time()
        # print(f"{(inputs.shape[0] * labels.shape[-1]) * 1000 / (et - st):.2f} tokens/ms")


    
    val_loss = 0.0
    segmentation_model.eval()
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = segmentation_model(inputs)
            loss = criterion(outputs.view(-1, len(particle_types) + 1), labels.view(-1))
            val_loss += loss.item()
    
    # Log epoch performance
    print(f"Epoch [{epoch + 1}/{epochs}] Train Loss: {train_loss / len(train_loader):.8f}, Validation Loss: {val_loss / len(val_loader):.8f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(segmentation_model.state_dict(), "segmentation_model_mini_cubes.pth")
        print(f"...saved model at Epoch {epoch + 1}\n")


Number of parameters: 1.66e+06
SegmentationModel(
  (tokenizer): LinearTokenizer(
    (tokenizer): Conv3d(1, 144, kernel_size=(6, 6, 6), stride=(6, 6, 6))
  )
  (transformer): Transformer(
    (layers): ModuleList(
      (0-3): 4 x TransformerBlock(
        (ln_1): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
        (attn): SelfAttention(
          (c_attn): Linear(in_features=144, out_features=432, bias=True)
          (c_proj): Linear(in_features=144, out_features=144, bias=True)
        )
        (ln_2): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Linear(in_features=144, out_features=576, bias=True)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=576, out_features=144, bias=True)
        )
      )
    )
  )
  (decoder): Linear(in_features=144, out_features=7, bias=True)
  (projection_head): Sequential(
    (0): Linear(in_features=144, out_features=144, bias=True)
    (1): ReLU()
    (2): Line

KeyboardInterrupt: 

In [54]:
def visualize_inference(inputs, labels, predictions, sample_idx):
    """
    Visualize the input data, ground truth, and predictions using napari.
    """
    # Create a napari viewer
    viewer = napari.Viewer()

    # Add input tomogram
    viewer.add_image(inputs.cpu().numpy(), name=f"Sample {sample_idx} - Input", colormap="gray")

    # Add ground truth labels
    viewer.add_labels(labels.cpu().numpy(), name=f"Sample {sample_idx} - Ground Truth")

    # Add predictions
    viewer.add_labels(predictions.cpu().numpy(), name=f"Sample {sample_idx} - Predictions")

    # Start napari viewer
    napari.run()
print(segmentation_model)
# Load the trained model
segmentation_model.load_state_dict(torch.load("segmentation_model_mini_cubes.pth"))
# Select a few samples from training and validation datasets
segmentation_model.eval()
for i, (inputs, labels) in enumerate(val_loader):
    if i >= 5:  # Visualize only 5 samples
        break

    inputs, labels = inputs[0].to(device), labels[0].to(device)
    # print(inputs.shape, labels.shape)
    labels = labels.reshape(16, 16, 16)  # Reshape to (16, 16, 16) for visualization
    true_labels = torch.zeros_like(inputs[0]).to(torch.int64) # shape (96, 96, 96)
    # Iterate over the (16, 16, 16) labels and assign them to (6, 6, 6) regions
    for z in range(16):
        for y in range(16):
            for x in range(16):
                true_labels[z*6:(z+1)*6, y*6:(y+1)*6, x*6:(x+1)*6] = labels[z, y, x]
    # visualize_inference(inputs[0], true_labels, true_labels, sample_idx=i)

    inputs = inputs.unsqueeze(1)  # Add channel dimension for model input

    with torch.no_grad():
        outputs = segmentation_model(inputs)
        predictions = torch.argmax(outputs, dim=-1)  # Get predicted class labels
        predictions = predictions.reshape(16, 16, 16)  # Reshape to (16, 16, 16) for visualization
        true_predictions = torch.zeros_like(inputs[0, 0]).to(torch.int64) # shape (96, 96, 96)
        for z in range(16):
            for y in range(16):
                for x in range(16):
                    true_predictions[z*6:(z+1)*6, y*6:(y+1)*6, x*6:(x+1)*6] = predictions[z, y, x]

    # Visualize the first sample in the batch
    visualize_inference(inputs[0, 0], true_labels, true_predictions, sample_idx=i)

SegmentationModel(
  (tokenizer): LinearTokenizer(
    (tokenizer): Conv3d(1, 128, kernel_size=(6, 6, 6), stride=(6, 6, 6))
  )
  (transformer): Transformer(
    (layers): ModuleList(
      (0-9): 10 x TransformerBlock(
        (ln_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): SelfAttention(
          (c_attn): Linear(in_features=128, out_features=384, bias=True)
          (c_proj): Linear(in_features=128, out_features=128, bias=True)
        )
        (ln_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Linear(in_features=128, out_features=512, bias=True)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=512, out_features=128, bias=True)
        )
      )
    )
  )
  (decoder): Linear(in_features=128, out_features=7, bias=True)
  (projection_head): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_featur

RuntimeError: Error(s) in loading state_dict for SegmentationModel:
	Missing key(s) in state_dict: "positional_embedding", "summarizer", "tokenizer.tokenizer.weight", "tokenizer.tokenizer.bias", "transformer.layers.0.ln_1.weight", "transformer.layers.0.ln_1.bias", "transformer.layers.0.attn.c_attn.weight", "transformer.layers.0.attn.c_attn.bias", "transformer.layers.0.attn.c_proj.weight", "transformer.layers.0.attn.c_proj.bias", "transformer.layers.0.ln_2.weight", "transformer.layers.0.ln_2.bias", "transformer.layers.0.mlp.c_fc.weight", "transformer.layers.0.mlp.c_fc.bias", "transformer.layers.0.mlp.c_proj.weight", "transformer.layers.0.mlp.c_proj.bias", "transformer.layers.1.ln_1.weight", "transformer.layers.1.ln_1.bias", "transformer.layers.1.attn.c_attn.weight", "transformer.layers.1.attn.c_attn.bias", "transformer.layers.1.attn.c_proj.weight", "transformer.layers.1.attn.c_proj.bias", "transformer.layers.1.ln_2.weight", "transformer.layers.1.ln_2.bias", "transformer.layers.1.mlp.c_fc.weight", "transformer.layers.1.mlp.c_fc.bias", "transformer.layers.1.mlp.c_proj.weight", "transformer.layers.1.mlp.c_proj.bias", "transformer.layers.2.ln_1.weight", "transformer.layers.2.ln_1.bias", "transformer.layers.2.attn.c_attn.weight", "transformer.layers.2.attn.c_attn.bias", "transformer.layers.2.attn.c_proj.weight", "transformer.layers.2.attn.c_proj.bias", "transformer.layers.2.ln_2.weight", "transformer.layers.2.ln_2.bias", "transformer.layers.2.mlp.c_fc.weight", "transformer.layers.2.mlp.c_fc.bias", "transformer.layers.2.mlp.c_proj.weight", "transformer.layers.2.mlp.c_proj.bias", "transformer.layers.3.ln_1.weight", "transformer.layers.3.ln_1.bias", "transformer.layers.3.attn.c_attn.weight", "transformer.layers.3.attn.c_attn.bias", "transformer.layers.3.attn.c_proj.weight", "transformer.layers.3.attn.c_proj.bias", "transformer.layers.3.ln_2.weight", "transformer.layers.3.ln_2.bias", "transformer.layers.3.mlp.c_fc.weight", "transformer.layers.3.mlp.c_fc.bias", "transformer.layers.3.mlp.c_proj.weight", "transformer.layers.3.mlp.c_proj.bias", "transformer.layers.4.ln_1.weight", "transformer.layers.4.ln_1.bias", "transformer.layers.4.attn.c_attn.weight", "transformer.layers.4.attn.c_attn.bias", "transformer.layers.4.attn.c_proj.weight", "transformer.layers.4.attn.c_proj.bias", "transformer.layers.4.ln_2.weight", "transformer.layers.4.ln_2.bias", "transformer.layers.4.mlp.c_fc.weight", "transformer.layers.4.mlp.c_fc.bias", "transformer.layers.4.mlp.c_proj.weight", "transformer.layers.4.mlp.c_proj.bias", "transformer.layers.5.ln_1.weight", "transformer.layers.5.ln_1.bias", "transformer.layers.5.attn.c_attn.weight", "transformer.layers.5.attn.c_attn.bias", "transformer.layers.5.attn.c_proj.weight", "transformer.layers.5.attn.c_proj.bias", "transformer.layers.5.ln_2.weight", "transformer.layers.5.ln_2.bias", "transformer.layers.5.mlp.c_fc.weight", "transformer.layers.5.mlp.c_fc.bias", "transformer.layers.5.mlp.c_proj.weight", "transformer.layers.5.mlp.c_proj.bias", "transformer.layers.6.ln_1.weight", "transformer.layers.6.ln_1.bias", "transformer.layers.6.attn.c_attn.weight", "transformer.layers.6.attn.c_attn.bias", "transformer.layers.6.attn.c_proj.weight", "transformer.layers.6.attn.c_proj.bias", "transformer.layers.6.ln_2.weight", "transformer.layers.6.ln_2.bias", "transformer.layers.6.mlp.c_fc.weight", "transformer.layers.6.mlp.c_fc.bias", "transformer.layers.6.mlp.c_proj.weight", "transformer.layers.6.mlp.c_proj.bias", "transformer.layers.7.ln_1.weight", "transformer.layers.7.ln_1.bias", "transformer.layers.7.attn.c_attn.weight", "transformer.layers.7.attn.c_attn.bias", "transformer.layers.7.attn.c_proj.weight", "transformer.layers.7.attn.c_proj.bias", "transformer.layers.7.ln_2.weight", "transformer.layers.7.ln_2.bias", "transformer.layers.7.mlp.c_fc.weight", "transformer.layers.7.mlp.c_fc.bias", "transformer.layers.7.mlp.c_proj.weight", "transformer.layers.7.mlp.c_proj.bias", "transformer.layers.8.ln_1.weight", "transformer.layers.8.ln_1.bias", "transformer.layers.8.attn.c_attn.weight", "transformer.layers.8.attn.c_attn.bias", "transformer.layers.8.attn.c_proj.weight", "transformer.layers.8.attn.c_proj.bias", "transformer.layers.8.ln_2.weight", "transformer.layers.8.ln_2.bias", "transformer.layers.8.mlp.c_fc.weight", "transformer.layers.8.mlp.c_fc.bias", "transformer.layers.8.mlp.c_proj.weight", "transformer.layers.8.mlp.c_proj.bias", "transformer.layers.9.ln_1.weight", "transformer.layers.9.ln_1.bias", "transformer.layers.9.attn.c_attn.weight", "transformer.layers.9.attn.c_attn.bias", "transformer.layers.9.attn.c_proj.weight", "transformer.layers.9.attn.c_proj.bias", "transformer.layers.9.ln_2.weight", "transformer.layers.9.ln_2.bias", "transformer.layers.9.mlp.c_fc.weight", "transformer.layers.9.mlp.c_fc.bias", "transformer.layers.9.mlp.c_proj.weight", "transformer.layers.9.mlp.c_proj.bias", "decoder.weight", "decoder.bias", "projection_head.0.weight", "projection_head.0.bias", "projection_head.2.weight", "projection_head.2.bias". 
	Unexpected key(s) in state_dict: "_orig_mod.positional_embedding", "_orig_mod.summarizer", "_orig_mod.tokenizer.tokenizer.weight", "_orig_mod.tokenizer.tokenizer.bias", "_orig_mod.transformer.layers.0.ln_1.weight", "_orig_mod.transformer.layers.0.ln_1.bias", "_orig_mod.transformer.layers.0.attn.c_attn.weight", "_orig_mod.transformer.layers.0.attn.c_attn.bias", "_orig_mod.transformer.layers.0.attn.c_proj.weight", "_orig_mod.transformer.layers.0.attn.c_proj.bias", "_orig_mod.transformer.layers.0.ln_2.weight", "_orig_mod.transformer.layers.0.ln_2.bias", "_orig_mod.transformer.layers.0.mlp.c_fc.weight", "_orig_mod.transformer.layers.0.mlp.c_fc.bias", "_orig_mod.transformer.layers.0.mlp.c_proj.weight", "_orig_mod.transformer.layers.0.mlp.c_proj.bias", "_orig_mod.transformer.layers.1.ln_1.weight", "_orig_mod.transformer.layers.1.ln_1.bias", "_orig_mod.transformer.layers.1.attn.c_attn.weight", "_orig_mod.transformer.layers.1.attn.c_attn.bias", "_orig_mod.transformer.layers.1.attn.c_proj.weight", "_orig_mod.transformer.layers.1.attn.c_proj.bias", "_orig_mod.transformer.layers.1.ln_2.weight", "_orig_mod.transformer.layers.1.ln_2.bias", "_orig_mod.transformer.layers.1.mlp.c_fc.weight", "_orig_mod.transformer.layers.1.mlp.c_fc.bias", "_orig_mod.transformer.layers.1.mlp.c_proj.weight", "_orig_mod.transformer.layers.1.mlp.c_proj.bias", "_orig_mod.transformer.layers.2.ln_1.weight", "_orig_mod.transformer.layers.2.ln_1.bias", "_orig_mod.transformer.layers.2.attn.c_attn.weight", "_orig_mod.transformer.layers.2.attn.c_attn.bias", "_orig_mod.transformer.layers.2.attn.c_proj.weight", "_orig_mod.transformer.layers.2.attn.c_proj.bias", "_orig_mod.transformer.layers.2.ln_2.weight", "_orig_mod.transformer.layers.2.ln_2.bias", "_orig_mod.transformer.layers.2.mlp.c_fc.weight", "_orig_mod.transformer.layers.2.mlp.c_fc.bias", "_orig_mod.transformer.layers.2.mlp.c_proj.weight", "_orig_mod.transformer.layers.2.mlp.c_proj.bias", "_orig_mod.transformer.layers.3.ln_1.weight", "_orig_mod.transformer.layers.3.ln_1.bias", "_orig_mod.transformer.layers.3.attn.c_attn.weight", "_orig_mod.transformer.layers.3.attn.c_attn.bias", "_orig_mod.transformer.layers.3.attn.c_proj.weight", "_orig_mod.transformer.layers.3.attn.c_proj.bias", "_orig_mod.transformer.layers.3.ln_2.weight", "_orig_mod.transformer.layers.3.ln_2.bias", "_orig_mod.transformer.layers.3.mlp.c_fc.weight", "_orig_mod.transformer.layers.3.mlp.c_fc.bias", "_orig_mod.transformer.layers.3.mlp.c_proj.weight", "_orig_mod.transformer.layers.3.mlp.c_proj.bias", "_orig_mod.transformer.layers.4.ln_1.weight", "_orig_mod.transformer.layers.4.ln_1.bias", "_orig_mod.transformer.layers.4.attn.c_attn.weight", "_orig_mod.transformer.layers.4.attn.c_attn.bias", "_orig_mod.transformer.layers.4.attn.c_proj.weight", "_orig_mod.transformer.layers.4.attn.c_proj.bias", "_orig_mod.transformer.layers.4.ln_2.weight", "_orig_mod.transformer.layers.4.ln_2.bias", "_orig_mod.transformer.layers.4.mlp.c_fc.weight", "_orig_mod.transformer.layers.4.mlp.c_fc.bias", "_orig_mod.transformer.layers.4.mlp.c_proj.weight", "_orig_mod.transformer.layers.4.mlp.c_proj.bias", "_orig_mod.transformer.layers.5.ln_1.weight", "_orig_mod.transformer.layers.5.ln_1.bias", "_orig_mod.transformer.layers.5.attn.c_attn.weight", "_orig_mod.transformer.layers.5.attn.c_attn.bias", "_orig_mod.transformer.layers.5.attn.c_proj.weight", "_orig_mod.transformer.layers.5.attn.c_proj.bias", "_orig_mod.transformer.layers.5.ln_2.weight", "_orig_mod.transformer.layers.5.ln_2.bias", "_orig_mod.transformer.layers.5.mlp.c_fc.weight", "_orig_mod.transformer.layers.5.mlp.c_fc.bias", "_orig_mod.transformer.layers.5.mlp.c_proj.weight", "_orig_mod.transformer.layers.5.mlp.c_proj.bias", "_orig_mod.transformer.layers.6.ln_1.weight", "_orig_mod.transformer.layers.6.ln_1.bias", "_orig_mod.transformer.layers.6.attn.c_attn.weight", "_orig_mod.transformer.layers.6.attn.c_attn.bias", "_orig_mod.transformer.layers.6.attn.c_proj.weight", "_orig_mod.transformer.layers.6.attn.c_proj.bias", "_orig_mod.transformer.layers.6.ln_2.weight", "_orig_mod.transformer.layers.6.ln_2.bias", "_orig_mod.transformer.layers.6.mlp.c_fc.weight", "_orig_mod.transformer.layers.6.mlp.c_fc.bias", "_orig_mod.transformer.layers.6.mlp.c_proj.weight", "_orig_mod.transformer.layers.6.mlp.c_proj.bias", "_orig_mod.transformer.layers.7.ln_1.weight", "_orig_mod.transformer.layers.7.ln_1.bias", "_orig_mod.transformer.layers.7.attn.c_attn.weight", "_orig_mod.transformer.layers.7.attn.c_attn.bias", "_orig_mod.transformer.layers.7.attn.c_proj.weight", "_orig_mod.transformer.layers.7.attn.c_proj.bias", "_orig_mod.transformer.layers.7.ln_2.weight", "_orig_mod.transformer.layers.7.ln_2.bias", "_orig_mod.transformer.layers.7.mlp.c_fc.weight", "_orig_mod.transformer.layers.7.mlp.c_fc.bias", "_orig_mod.transformer.layers.7.mlp.c_proj.weight", "_orig_mod.transformer.layers.7.mlp.c_proj.bias", "_orig_mod.transformer.layers.8.ln_1.weight", "_orig_mod.transformer.layers.8.ln_1.bias", "_orig_mod.transformer.layers.8.attn.c_attn.weight", "_orig_mod.transformer.layers.8.attn.c_attn.bias", "_orig_mod.transformer.layers.8.attn.c_proj.weight", "_orig_mod.transformer.layers.8.attn.c_proj.bias", "_orig_mod.transformer.layers.8.ln_2.weight", "_orig_mod.transformer.layers.8.ln_2.bias", "_orig_mod.transformer.layers.8.mlp.c_fc.weight", "_orig_mod.transformer.layers.8.mlp.c_fc.bias", "_orig_mod.transformer.layers.8.mlp.c_proj.weight", "_orig_mod.transformer.layers.8.mlp.c_proj.bias", "_orig_mod.transformer.layers.9.ln_1.weight", "_orig_mod.transformer.layers.9.ln_1.bias", "_orig_mod.transformer.layers.9.attn.c_attn.weight", "_orig_mod.transformer.layers.9.attn.c_attn.bias", "_orig_mod.transformer.layers.9.attn.c_proj.weight", "_orig_mod.transformer.layers.9.attn.c_proj.bias", "_orig_mod.transformer.layers.9.ln_2.weight", "_orig_mod.transformer.layers.9.ln_2.bias", "_orig_mod.transformer.layers.9.mlp.c_fc.weight", "_orig_mod.transformer.layers.9.mlp.c_fc.bias", "_orig_mod.transformer.layers.9.mlp.c_proj.weight", "_orig_mod.transformer.layers.9.mlp.c_proj.bias", "_orig_mod.decoder.weight", "_orig_mod.decoder.bias", "_orig_mod.projection_head.0.weight", "_orig_mod.projection_head.0.bias", "_orig_mod.projection_head.2.weight", "_orig_mod.projection_head.2.bias". 