In [1]:
# Library imports
import random
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import sys
import zarr
import json
import os
import napari
from rich import print as rprint  # Import rich's print function
import copy
from dataclasses import dataclass
from monai.transforms import Compose, RandFlipd, RandRotated, RandGaussianNoised, RandAdjustContrastd
# set torch and cuda seed for reproducibility
torch.manual_seed(37)
torch.cuda.manual_seed(37)


  from .autonotebook import tqdm as notebook_tqdm


In [3]:

# -------------LOADING TOMOGRAM DATA AND PARTICLE COORDINATES-----------------#

# Define the experiment runs to load
experiment_runs = ["TS_5_4", "TS_69_2", "TS_6_4", "TS_6_6", "TS_73_6", "TS_86_3", "TS_99_9"]
particle_types = {"virus-like-particle": 1, "apo-ferritin": 2, "beta-amylase": 3, "beta-galactosidase": 4, "ribosome": 5, "thyroglobulin": 6}
particle_radii = {"virus-like-particle": 135, "apo-ferritin": 60, "beta-amylase": 65, "beta-galactosidase": 90, "ribosome": 150, "thyroglobulin": 130}
voxel_spacing = [10.0, 10.0, 10.0]  # 10 angstroms per voxel

# Initialize lists to store combined data
combined_tomogram_data = []
combined_particle_coords = {pt: [] for pt in particle_types}

# Track the cumulative z-depth for coordinate translation
cumulative_z_depth = 0

# Load and combine data from all experiment runs
for experiment_run in experiment_runs:
    zarr_file_path = os.path.join("train", "static", "ExperimentRuns", experiment_run, "VoxelSpacing10.000", "denoised.zarr")
    json_base_path = os.path.join("train", "overlay", "ExperimentRuns", experiment_run, "Picks")

    # Load the Zarr file
    try:
        tomogram = zarr.open(zarr_file_path, mode="r")
        tomogram_data = tomogram["0"][:]  # Load into memory as a NumPy array
        print(f"Tomogram shape for {experiment_run} (z, y, x):", tomogram_data.shape)
        tomogram_data = (tomogram_data - tomogram_data.mean()) / tomogram_data.std()
        combined_tomogram_data.append(tomogram_data)
    except Exception as e:
        print(f"Error loading Zarr file for {experiment_run}: {e}")
        continue

    # Load and transform particle coordinates for all types
    for particle_type in particle_types:
        json_file_path = os.path.join(json_base_path, f"{particle_type}.json")
        try:
            with open(json_file_path, "r") as file:
                data = json.load(file)
            points = data["points"]

            # Convert from real-world coordinates (angstroms) to voxel indices and reorder to (z, y, x)
            coords = np.round([
                [
                    (p["location"]["z"] / voxel_spacing[0]) + cumulative_z_depth,
                    p["location"]["y"] / voxel_spacing[1],
                    p["location"]["x"] / voxel_spacing[2],
                ]
                for p in points
            ]).astype(int)
            combined_particle_coords[particle_type].extend(coords)
            print(f"Loaded {len(coords)} points for {particle_type} in {experiment_run}.")
        except Exception as e:
            print(f"Error loading JSON file for {particle_type} in {experiment_run}: {e}")

    # Update cumulative_z_depth for the next tomogram
    cumulative_z_depth += tomogram_data.shape[0]

# Combine all tomogram data into a single array
combined_tomogram_data = np.concatenate(combined_tomogram_data, axis=0)
print("Combined tomogram shape (z, y, x):", combined_tomogram_data.shape)

# Print total number of particles
total_particles = sum(len(coords) for coords in combined_particle_coords.values())
print(f"Total number of particles: {total_particles}")

Tomogram shape for TS_5_4 (z, y, x): (184, 630, 630)
Loaded 11 points for virus-like-particle in TS_5_4.
Loaded 46 points for apo-ferritin in TS_5_4.
Loaded 10 points for beta-amylase in TS_5_4.
Loaded 12 points for beta-galactosidase in TS_5_4.
Loaded 31 points for ribosome in TS_5_4.
Loaded 30 points for thyroglobulin in TS_5_4.
Tomogram shape for TS_69_2 (z, y, x): (184, 630, 630)
Loaded 9 points for virus-like-particle in TS_69_2.
Loaded 35 points for apo-ferritin in TS_69_2.
Loaded 12 points for beta-amylase in TS_69_2.
Loaded 16 points for beta-galactosidase in TS_69_2.
Loaded 37 points for ribosome in TS_69_2.
Loaded 34 points for thyroglobulin in TS_69_2.
Tomogram shape for TS_6_4 (z, y, x): (184, 630, 630)
Loaded 10 points for virus-like-particle in TS_6_4.
Loaded 58 points for apo-ferritin in TS_6_4.
Loaded 9 points for beta-amylase in TS_6_4.
Loaded 12 points for beta-galactosidase in TS_6_4.
Loaded 74 points for ribosome in TS_6_4.
Loaded 30 points for thyroglobulin in TS_6

In [4]:
# -------------PRECOMPUTE LABEL CUBE-----------------#
label_cube = np.zeros(combined_tomogram_data.shape, dtype=int)

for particle_type, coords in combined_particle_coords.items():
    particle_id = particle_types[particle_type]
    radius = int(particle_radii[particle_type] / voxel_spacing[0])  # Convert radius to voxel units
    radius = 3
    for coord in coords:
        z, y, x = coord.astype(int)

        # Define the bounding box for the particle
        z_min, z_max = max(0, z - radius), min(label_cube.shape[0], z + radius + 1)
        y_min, y_max = max(0, y - radius), min(label_cube.shape[1], y + radius + 1)
        x_min, x_max = max(0, x - radius), min(label_cube.shape[2], x + radius + 1)

        # Mark the region with the particle ID
        label_cube[z_min:z_max, y_min:y_max, x_min:x_max] = particle_id

print("Label cube precomputed.")

Label cube precomputed.


In [5]:
# Function to visualize the tomogram and label cube
def visualize_tomogram_and_labels(tomogram_data, label_cube):
    """
    Visualize the combined tomogram data and the label cube using napari.

    Parameters:
        tomogram_data (numpy.ndarray): The combined tomogram data array.
        label_cube (numpy.ndarray): The label cube corresponding to the tomogram data.
    """
    # Initialize the napari viewer
    viewer = napari.Viewer()

    # Add the tomogram data as the first layer
    viewer.add_image(
        tomogram_data,
        name="Tomogram Data",
        contrast_limits=[np.min(tomogram_data), np.max(tomogram_data)],  # Adjust contrast
        colormap="gray",
    )

    # Add the label cube as the second layer
    viewer.add_labels(
        label_cube,
        name="Label Cube",
        opacity=0.5,  # Make it slightly transparent to see the tomogram data underneath
    )

    # Start the napari event loop
    napari.run()


# Call the function to visualize
visualize_tomogram_and_labels(combined_tomogram_data, label_cube)

In [6]:

# -------------Combine tomograms and sample cubes with particles in it-----------------#

# Dimensions of the combined tomogram data
data_shape = combined_tomogram_data.shape
cube_size = (96, 96, 96)
background_id = 0

# Calculate the number of cubes in each dimension
num_cubes_z = data_shape[0] // cube_size[0]
num_cubes_y = data_shape[1] // cube_size[1]
num_cubes_x = data_shape[2] // cube_size[2]

# Create a list of all possible cube indices
cubes = []
particle_cubes = []
non_particle_cubes = []

for z in range(num_cubes_z):
    for y in range(num_cubes_y):
        for x in range(num_cubes_x):
            cubes.append((z, y, x))

# Separate cubes into particle-containing and non-particle cubes
def contains_particle(cube_start, label_cube):
    z_start, y_start, x_start = cube_start
    z_end, y_end, x_end = z_start + cube_size[0], y_start + cube_size[1], x_start + cube_size[2]
    return np.any(label_cube[z_start:z_end, y_start:y_end, x_start:x_end] > 0)

for cz, cy, cx in cubes:
    cube_start = (cz * cube_size[0], cy * cube_size[1], cx * cube_size[2])
    if contains_particle(cube_start, label_cube):
        particle_cubes.append((cz, cy, cx))
    else:
        non_particle_cubes.append((cz, cy, cx))

# Limit non-particle cubes to 10% of the dataset
num_non_particle_cubes = int(len(particle_cubes) * 0.1)
selected_non_particle_cubes = random.sample(non_particle_cubes, num_non_particle_cubes)
selected_cubes = particle_cubes + selected_non_particle_cubes
print(f"Selected {len(selected_cubes)} cubes for the dataset. Where {len(particle_cubes)} contain particles and {len(selected_non_particle_cubes)} do not.")


Selected 398 cubes for the dataset. Where 362 contain particles and 36 do not.


In [8]:
# ------------------- VISUALIZE Combined Tomogram Data ----------------------------#

# Define a color map for label IDs
label_colors = {
    1: "red",        # virus-like-particle
    2: "green",      # apo-ferritin
    3: "blue",       # beta-amylase
    4: "yellow",     # beta-galactosidase
    5: "magenta",    # ribosome
    6: "cyan",       # thyroglobulin
}

# Function to visualize the combined tomogram with particles in 3D using napari
def visualize_combined_tomogram(tomogram_data, particle_coords):
    # Create a napari viewer
    viewer = napari.Viewer()

    # Add the combined tomogram data as a 3D volume
    viewer.add_image(tomogram_data, name="Combined Tomogram")

    # Collect all particle coordinates and their label IDs
    all_particles = []
    all_labels = []
    for particle_type, coords in particle_coords.items():
        label_id = particle_types[particle_type]
        all_particles.extend(coords)
        all_labels.extend([label_id] * len(coords))

    # Convert to numpy arrays
    all_particles = np.array(all_particles)
    all_labels = np.array(all_labels)

    # Assign colors to each particle based on its label ID
    colors = [label_colors[label] for label in all_labels]

    # Add the particles as a 3D points layer with different colors
    if all_particles.size > 0:
        viewer.add_points(
            all_particles,
            name="Particles",
            face_color=colors,
            size=5,
            opacity=0.8,
        )

    # Start the napari event loop
    napari.run()

# Visualize the combined tomogram with particles
print("Visualizing the combined tomogram with particles...")
visualize_combined_tomogram(combined_tomogram_data, combined_particle_coords)
# ---------------------------------------------------------------------------------#

Visualizing the combined tomogram with particles...


In [9]:
def visualize_selected_cubes(tomogram_data, particle_coords, selected_cubes, cube_size):
    # Create a napari viewer
    viewer = napari.Viewer()

    # Iterate through the 10 selected cubes and visualize them
    for idx, (cz, cy, cx) in enumerate(selected_cubes[:10]):  # Limit to 10 cubes
        # Define cube boundaries
        z_start, y_start, x_start = cz * cube_size[0], cy * cube_size[1], cx * cube_size[2]
        z_end, y_end, x_end = z_start + cube_size[0], y_start + cube_size[1], x_start + cube_size[2]

        # Extract cube data from the tomogram
        cube_data = tomogram_data[z_start:z_end, y_start:y_end, x_start:x_end]

        # Collect particle coordinates and labels within the cube
        cube_particles = []
        cube_labels = []
        for particle_type, coords in particle_coords.items():
            label_id = particle_types[particle_type]
            for coord in coords:
                z, y, x = coord.astype(int)
                if z_start <= z < z_end and y_start <= y < y_end and x_start <= x < x_end:
                    # Adjust coordinates to cube-local space
                    cube_particles.append([z - z_start, y - y_start, x - x_start])
                    cube_labels.append(label_id)

        # Convert to numpy arrays
        cube_particles = np.array(cube_particles)
        cube_labels = np.array(cube_labels)

        # Assign colors to each particle based on its label ID
        colors = [label_colors[label] for label in cube_labels]

        # Add the cube data as a volume
        viewer.add_image(cube_data, name=f"Cube {idx + 1}", colormap="gray")

        # Add the particles as a points layer
        if cube_particles.size > 0:
            viewer.add_points(
                cube_particles,
                name=f"Particles in Cube {idx + 1}",
                face_color=colors,
                size=5,
                opacity=0.8,
            )

    # Start the napari event loop
    napari.run()

# Visualize the first 10 selected cubes
print("Visualizing 10 selected cubes with particles...")
visualize_selected_cubes(combined_tomogram_data, combined_particle_coords, selected_cubes, cube_size)


Visualizing 10 selected cubes with particles...


In [11]:
COUNTER = 0

class DataCreator():
    def __init__(self, tomogram_data, label_cube, selected_cubes):
        self.tomogram_data = tomogram_data
        self.label_cube = label_cube
        self.selected_cubes = selected_cubes
        self.cube_size = (96, 96, 96)
        self.subcube_size = (6, 6, 6)
        self.background_id = 0

    def __len__(self):
        return len(self.selected_cubes)

    def getitem(self, idx):
        cz, cy, cx = self.selected_cubes[idx]
        z_start, z_end = cz * self.cube_size[0], (cz + 1) * self.cube_size[0]
        y_start, y_end = cy * self.cube_size[1], (cy + 1) * self.cube_size[1]
        x_start, x_end = cx * self.cube_size[2], (cx + 1) * self.cube_size[2]

        cube_data = self.tomogram_data[z_start:z_end, y_start:y_end, x_start:x_end]
        labels = self.generate_labels(z_start, y_start, x_start)

        cube_data = np.expand_dims(cube_data, axis=0)  # Add channel dimension
        return (
            torch.tensor(cube_data, dtype=torch.float32),
            torch.tensor(labels, dtype=torch.int64),
        )

    def generate_labels(self, z_start, y_start, x_start):
        global COUNTER
        mini_cube_labels = []
        
        for z in range(0, self.cube_size[2], self.subcube_size[2]):  # Width first
            for y in range(0, self.cube_size[1], self.subcube_size[1]):  # Height second
                for x in range(0, self.cube_size[0], self.subcube_size[0]):  # Depth last
                    mini_cube = self.label_cube[
                        z_start + z:z_start + z + self.subcube_size[0],
                        y_start + y:y_start + y + self.subcube_size[1],
                        x_start + x:x_start + x + self.subcube_size[2],
                    ]
                    unique, counts = np.unique(mini_cube, return_counts=True)
                    label_coverage = dict(zip(unique, counts))
                    total_voxels = np.prod(self.subcube_size)

                    dominant_label = self.background_id
                    max_coverage = 0

                    for label, coverage in label_coverage.items():
                        if label != self.background_id and coverage / total_voxels >= 0.3 and coverage > max_coverage:
                            dominant_label = label
                            max_coverage = coverage
                    if dominant_label !=0 : COUNTER += 1
                    mini_cube_labels.append(dominant_label)

        return np.array(mini_cube_labels)

    def generate_data(self):
        tomogram_data = torch.zeros((len(self.selected_cubes), 1, *self.cube_size))
        segmentation_labels = torch.zeros((len(self.selected_cubes), int(self.cube_size[0]**3/self.subcube_size[0]**3)), dtype=torch.int64)

        for idx in range(len(self.selected_cubes)):
            data_tensor, label_tensor= self.getitem(idx)
            tomogram_data[idx] = data_tensor
            segmentation_labels[idx] = label_tensor
        return tomogram_data, segmentation_labels

# Data Creator
data_creator = DataCreator(
    combined_tomogram_data, label_cube, selected_cubes
)

input_data, segmentation_labels = data_creator.generate_data()


In [17]:


def upscale_labels(label_tensor, original_shape=(90, 90, 90), subcube_shape=(16, 16, 16)):
    """
    Upscale the labels from (4096,) to (16, 16, 16) and then to (90, 90, 90).
    """
    # Step 1: Reshape the flat label tensor to (16, 16, 16)
    reshaped_labels = label_tensor.view(subcube_shape)

    # Step 2: Upscale to (90, 90, 90) by assigning subcube regions
    upscaled_labels = np.zeros(original_shape, dtype=int)
    subcube_size = original_shape[0] // subcube_shape[0]  # Typically 6 if (90, 90, 90) and (16, 16, 16)

    for z in range(subcube_shape[0]):
        for y in range(subcube_shape[1]):
            for x in range(subcube_shape[2]):
                label = reshaped_labels[z, y, x]
                upscaled_labels[
                    z * subcube_size : (z + 1) * subcube_size,
                    y * subcube_size : (y + 1) * subcube_size,
                    x * subcube_size : (x + 1) * subcube_size
                ] = label

    return upscaled_labels

def visualize_with_napari(inputs, labels, sample_idx=0):
    """
    Visualize the input tomogram and the corresponding labels using napari.
    
    Parameters:
        inputs (torch.Tensor): Input tomogram data of shape (N, 1, 90, 90, 90).
        labels (torch.Tensor): Segmentation labels of shape (N, 4096).
        sample_idx (int): Index of the sample to visualize.
    """
    # Get the input data and corresponding label
    input_sample = inputs[sample_idx].squeeze(0).cpu().numpy()  # Shape: (90, 90, 90)
    label_sample = labels[sample_idx].cpu()

    # Upscale the labels from (4096,) -> (16, 16, 16) -> (90, 90, 90)
    upscaled_labels = upscale_labels(label_sample)


    viewer = napari.Viewer()

    # Add the input tomogram data as grayscale image
    viewer.add_image(input_sample, name=f"Sample {sample_idx} - Input", colormap="gray")

    # Add the upscaled labels as an overlay
    viewer.add_labels(upscaled_labels, name=f"Sample {sample_idx} - Labels")

    # Optionally, you can adjust opacity or blending mode for better visualization:
    viewer.layers[f"Sample {sample_idx} - Labels"].opacity = 0.5  # Semi-transparent overlay

# Example usage:
# Visualize the first sample from the generated data
visualize_with_napari(input_data, segmentation_labels, sample_idx=5)

In [4]:
class SelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        # e.g. in GPT-2 (124M), n_head=11, hs=64, so nh*hs=C=768 channels in the Transformer
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=False) # flash attention
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y)
        return y

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

class TransformerBlock(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = SelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

class Transformer(nn.Module):
    def __init__(self, num_layers, decoder_layer):
        super().__init__()
        self.layers = clones(decoder_layer, num_layers)
        self.num_layers = num_layers

    def forward(self, tgt):
        output = tgt

        for mod in self.layers:
            output = mod(output)
        return output

class CnnTokenizer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.conv_1 = nn.Conv3d(1, config.n_embd //  2, kernel_size=3, stride=1, padding=1)
        self.bn_1 = nn.BatchNorm3d(config.n_embd // 2)
        self.dropout_1 = nn.Dropout(0.2)
        self.conv_2 = nn.Conv3d(config.n_embd // 2, config.n_embd // 2, kernel_size=3, stride=1, padding=1)
        self.bn_2 = nn.BatchNorm3d(config.n_embd // 2)
        self.dropout_2 = nn.Dropout(0.2)
        self.conv_3 = nn.Conv3d(config.n_embd // 2, config.n_embd, kernel_size=3, stride=1, padding=1)
        self.bn_3 = nn.BatchNorm3d(config.n_embd)
        self.dropout_3 = nn.Dropout(0.2)
        self.downsize = nn.Conv3d(config.n_embd, config.n_embd, kernel_size=3, stride=2, padding=1)
        self.slice = nn.Conv3d(config.n_embd, config.n_embd, kernel_size=config.token_width, stride=config.token_width, padding=0) 

    def forward(self, x):
        x = self.conv_1(x)
        # print(f"shape after conv_1: {x.shape}")
        x = self.bn_1(x)
        x = F.gelu(x)
        # x = self.dropout_1(x)

        x = self.conv_2(x)
        # print(f"shape after conv_2: {x.shape}")
        x = self.bn_2(x)
        x = F.gelu(x)
        # x = self.dropout_2(x)

        x = self.conv_3(x)
        # print(f"shape after conv_3: {x.shape}")
        x = self.bn_3(x)
        x = F.gelu(x)
        # x = self.dropout_3(x)
        
        x = self.downsize(x)
        # print(f"shape after downsize: {x.shape}")
        x = self.slice(x)
        # print(f"shape after slice: {x.shape}")
        x = x.reshape(x.size(0), self.config.n_embd, -1).transpose(1, 2)
        return x
    
    
class LinearTokenizer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.tokenizer = nn.Conv3d(
            in_channels=1,  # Input channels (C=1)
            out_channels=config.n_embd,
            kernel_size=config.token_width,
            stride=config.token_width,
            padding=0,
            bias=True  # Optional
        )
    
    def forward(self, x):
        # x shape: (B, 1, D=96, H=96, W=96)
        x = self.tokenizer(x)  # Output shape: (B, n_embd, D_blocks=8, H_blocks=8, W_blocks=8)
        x = x.permute(0, 2, 3, 4, 1)  # (B, D_blocks, H_blocks, W_blocks, n_embd)
        x = x.view(x.size(0), -1, x.size(-1))  # (B, D_blocks*H_blocks*W_blocks=512, n_embd)
        # print(f"Tokenizer output shape: {x.shape}")
        return x
@dataclass
class CNNConfig:
    block_size: int = 8**3 # max sequence length
    token_width: int = 6 # width of the cube
    n_layer: int = 4 # number of layers
    n_head: int = 16 # number of heads
    n_embd: int = 256 # embedding dimension

@dataclass
class LinearConfig:
    block_size: int = int((96**3)/(6**3)) # max sequence length
    token_width: int = 6 # width of the cube
    n_layer: int = 8 # number of layers
    n_head: int = 16 # number of heads
    n_embd: int = 128 # embedding dimension


class BaseModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.tokenizer = LinearTokenizer(config)
        self.positional_embedding = nn.Parameter(torch.zeros(config.block_size, config.n_embd))
        self.transformer = Transformer(config.n_layer, TransformerBlock(config))
        self.decoder = nn.Linear(config.n_embd, len(particle_types) + 1)  # Output layer

    def ff(self, x):
        x = self.tokenizer(x) # (N, 1, 96, 96, 96) -> (N, 8 * 8 * 8, n_embd)
        x = x + self.positional_embedding # (N, n_embd, 8, 8, 8) -> (N, n_embd, 512) -> (N, 512, n_embd)
        x = self.transformer(x) # (N, 512, n_embd) -> (N, 512, n_embd)
        x = self.decoder(x) # (N, 512, n_embd) -> (N, 512, 7)        
        return x
    
    def forward(self, x):   
        return self.ff(x)

In [None]:
# ------------------- Contrastive Learning  ------------------------#

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class ContrastiveTomogramDataset(Dataset):
    def __init__(self, tomogram_data, selected_cubes, transform=None):
        self.tomogram_data = tomogram_data
        self.selected_cubes = selected_cubes
        self.cube_size = (96, 96, 96)
        self.transform = transform

    def __len__(self):
        return len(self.selected_cubes)

    def __getitem__(self, idx):
        cz, cy, cx = self.selected_cubes[idx]
        z_start = cz * self.cube_size[0]
        y_start = cy * self.cube_size[1]
        x_start = cx * self.cube_size[2]

        # Extract original cube
        cube = self.tomogram_data[
            z_start:z_start+self.cube_size[0],
            y_start:y_start+self.cube_size[1],
            x_start:x_start+self.cube_size[2]
        ]
        
        # Create two augmented views
        return {
            'view1': self._apply_transforms(cube),
            'view2': self._apply_transforms(cube)
        }

    def _apply_transforms(self, cube):
        cube = np.expand_dims(cube, axis=0)  # Add channel dim
        cube = torch.tensor(cube, dtype=torch.float32)
        
        if self.transform:
            data = {'image': cube}
            data = self.transform(data)
            return data['image']
        return cube

# Define stronger augmentations for contrastive learning
contrastive_transforms = Compose([
    RandFlipd(keys=['image'], prob=0.5, spatial_axis=0),
    RandFlipd(keys=['image'], prob=0.5, spatial_axis=1),
    RandFlipd(keys=['image'], prob=0.5, spatial_axis=2),
    RandRotated(
        keys=['image'],
        prob=0.8,
        range_x=15.0,
        range_y=15.0,
        range_z=15.0,
        mode='bilinear',
        padding_mode='zeros'
    ),
    RandAdjustContrastd(keys=['image'], prob=0.5, gamma=(0.7, 1.3)),
    RandGaussianNoised(keys=['image'], prob=0.5, std=0.1),
])


class ContrastiveModel(BaseModel):
    def __init__(self, config):
        super().__init__(config)
        # Summarizer token
        self.summarizer = nn.Parameter(torch.zeros(1, 1, config.n_embd))
        self.projection_head = nn.Sequential(
            nn.Linear(config.n_embd, config.n_embd),
            nn.ReLU(),
            nn.Linear(config.n_embd, 128)  # Project to lower-dimensional space
        )

    def forward(self, x):
        x = self.tokenizer(x) # (N, 1, 96, 96, 96) -> (N, 8 * 8 * 8, n_embd)
        x = x + self.positional_embedding # (N, n_embd, 8, 8, 8) -> (N, n_embd, n_tokens) -> (N, n_tokens, n_embd)
        
        # Append Summarizer token
        x = torch.cat([self.summarizer.repeat(x.size(0), 1, 1), x], dim=1)
        
        # Build representation using transformer
        x = self.transformer(x) # (N, n_tokens + 1, n_embd) -> (N, n_tokens + 1, n_embd)
        
        # Extract the summarizer token
        x = x[:, 0, :]  
        # Project to contrastive space
        return self.projection_head(x)

# NT-Xent Loss (Normalized Temperature-scaled Cross Entropy Loss)
class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super().__init__()
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, z1, z2):
        assert z1.size() == z2.size()
        # Concatenate both views
        z = torch.cat([z1, z2], dim=0)
        z = F.normalize(z, p=2, dim=1)
        size = z.size(0)
        
        # Compute similarity matrix
        sim = torch.mm(z, z.T) / self.temperature
        logits = torch.exp(sim)
        logits = torch.masked_fill(logits, torch.eye(size, dtype=torch.bool, device=device), 0)
        logits = logits / torch.sum(logits, dim=1, keepdim=True)

        labels = torch.remainder(torch.arange(size) + size//2, size).to(device)
        
        return self.criterion(logits, labels)

config = LinearConfig()

# Create dataset and loaders
contrastive_dataset = ContrastiveTomogramDataset(
    combined_tomogram_data,
    selected_cubes,
    transform=contrastive_transforms
)

train_loader = DataLoader(contrastive_dataset, batch_size=6, shuffle=True)

# Initialize model and loss
model = ContrastiveModel(config).to(device)
# print the # of parameters
print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):.2e}")
try:
    model.load_state_dict(torch.load("contrastive_model.pth"))
    print("Loaded pre-trained model.")
except:
    pass
criterion = NTXentLoss(temperature=0.1).to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-5)

# Training Loop
for epoch in range(50):
    model.train()
    total_loss = 0.0
    for batch in train_loader:
        optimizer.zero_grad()
        # Get both views
        x1 = batch['view1'].to(device)
        x2 = batch['view2'].to(device)
        # Forward passes
        z1 = model(x1)
        z2 = model(x2)
        # Compute loss
        loss = criterion(z1, z2)
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")

torch.save(model.state_dict(), f"contrastive_model.pth")


Using device: cuda
Number of parameters: 2.17e+06
Epoch 1, Loss: 2.4415
Epoch 2, Loss: 2.4301
Epoch 3, Loss: 2.4030
Epoch 4, Loss: 2.3751
Epoch 5, Loss: 2.3804
Epoch 6, Loss: 2.3978
Epoch 7, Loss: 2.4243
Epoch 8, Loss: 2.3816
Epoch 9, Loss: 2.3488
Epoch 10, Loss: 2.3569
Epoch 11, Loss: 2.3713
Epoch 12, Loss: 2.3695
Epoch 13, Loss: 2.3981
Epoch 14, Loss: 2.3652


KeyboardInterrupt: 

In [71]:
torch.save(model.state_dict(), f"contrastive_model.pth")

In [None]:
class SegmentationModel(ContrastiveModel):
    def __init__(self, config):
        super().__init__(config)
    
    def ff_without_summarizer(self, x):
        # ContrastiveModel inherits the ff (feedforward without summarizer) method from the BaseModel
        return self.ff(x)
    
    def ff_with_summarizer(self, x):
        x = self.tokenizer(x) # (N, 1, 96, 96, 96) -> (N, 8 * 8 * 8, n_embd)
        x = x + self.positional_embedding # (N, n_embd, 8, 8, 8) -> (N, n_embd, n_tokens) -> (N, n_tokens, n_embd)
        
        # Append Summarizer token
        x = torch.cat([self.summarizer.repeat(x.size(0), 1, 1), x], dim=1)
        
        # Build representation using transformer
        x = self.transformer(x) # (N, n_tokens + 1, n_embd) -> (N, n_tokens + 1, n_embd)     

        # Classify
        x = self.decoder(x[:, 1:, :])
        return x       

    def forward(self, x):
        return self.ff_with_summarizer(x)

In [6]:
# -------------DATASET IMPLEMENTATION-----------------#
class TomogramDatasetMiniCubes(Dataset):
    def __init__(self, tomogram_data, label_cube, selected_cubes):
        self.tomogram_data = tomogram_data
        self.label_cube = label_cube
        self.selected_cubes = selected_cubes
        self.cube_size = (96, 96, 96)
        self.subcube_size = (6, 6, 6)
        self.background_id = 0

    def __len__(self):
        return len(self.selected_cubes)

    def __getitem__(self, idx):
        cz, cy, cx = self.selected_cubes[idx]
        z_start, z_end = cz * self.cube_size[0], (cz + 1) * self.cube_size[0]
        y_start, y_end = cy * self.cube_size[1], (cy + 1) * self.cube_size[1]
        x_start, x_end = cx * self.cube_size[2], (cx + 1) * self.cube_size[2]

        cube_data = self.tomogram_data[z_start:z_end, y_start:y_end, x_start:x_end]
        labels = self.generate_labels(z_start, y_start, x_start)

        cube_data = np.expand_dims(cube_data, axis=0)  # Add channel dimension
        return (
            torch.tensor(cube_data, dtype=torch.float32),
            torch.tensor(labels, dtype=torch.int64),
        )

    def generate_labels(self, z_start, y_start, x_start):
        mini_cube_labels = []
        for z in range(0, self.cube_size[0], self.subcube_size[0]):
            for y in range(0, self.cube_size[1], self.subcube_size[1]):
                for x in range(0, self.cube_size[2], self.subcube_size[2]):
                    mini_cube = self.label_cube[
                        z_start + z:z_start + z + self.subcube_size[0],
                        y_start + y:y_start + y + self.subcube_size[1],
                        x_start + x:x_start + x + self.subcube_size[2],
                    ]
                    unique, counts = np.unique(mini_cube, return_counts=True)
                    label_coverage = dict(zip(unique, counts))
                    total_voxels = np.prod(self.subcube_size)

                    dominant_label = self.background_id
                    max_coverage = 0

                    for label, coverage in label_coverage.items():
                        if label != self.background_id and coverage / total_voxels >= 0.2 and coverage > max_coverage:
                            dominant_label = label
                            max_coverage = coverage

                    mini_cube_labels.append(dominant_label)

        return np.array(mini_cube_labels)


# Create the dataset
particle_dataset_mini_cubes = TomogramDatasetMiniCubes(
    combined_tomogram_data, label_cube, selected_cubes
)

# Test the dataset
cube_data, labels = particle_dataset_mini_cubes[200]
print(f"len(particle_dataset_mini_cubes): {len(particle_dataset_mini_cubes)}")
print("Cube Data Shape:", cube_data.shape)  # Should be (1, 96, 96, 96)
print("Labels Shape:", labels.shape)        # Should be (4096,)

len(particle_dataset_mini_cubes): 459
Cube Data Shape: torch.Size([1, 96, 96, 96])
Labels Shape: torch.Size([4096])


In [20]:
# Get the first cube and its labels
cube_data, labels = particle_dataset_mini_cubes[15]
# print(labels)
sample_data = cube_data.squeeze().numpy()  # Remove the channel dimension and convert to numpy array

# Reshape the labels to match the mini-cube structure
labels = labels.numpy().reshape((16, 16, 16)) 

# Create a napari viewer
viewer = napari.Viewer()

# Add the cube data as a 3D volume
viewer.add_image(sample_data, name='Tomogram Cube')
# viewer.add_image(np.ones_like(labels), name='Tomogram Cube')
# Add the labels as a 3D labels layer
# viewer.add_labels(labels, name='Mini-Cube Labels')

# Start the napari event loop
napari.run()

In [None]:
segmentation_model = SegmentationModel(config).to(device)
segmentation_model.load_state_dict(torch.load("contrastive_model.pth"))

# Write train and validation dataloader for segmentation using particle_dataset_mini_cubes
train_size = int(0.8 * len(particle_dataset_mini_cubes))
val_size = len(particle_dataset_mini_cubes) - train_size
train_dataset, val_dataset = random_split(particle_dataset_mini_cubes, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

# Define optimizer, loss, and scheduler
optimizer = optim.AdamW(segmentation_model.parameters(), lr=1e-5)
weights = torch.tensor([0.001] + [1 for _ in range(1, len(particle_types) + 1)]).to(device)  # Lower weight for background
criterion = nn.CrossEntropyLoss(weight=weights)

# Training loop
epochs = 200
best_val_loss = float('inf')

for epoch in range(epochs):
    segmentation_model.train()
    train_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = segmentation_model(inputs)
        loss = criterion(outputs.view(-1, len(particle_types) + 1), labels.view(-1))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    val_loss = 0.0
    segmentation_model.eval()
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = segmentation_model(inputs)
            loss = criterion(outputs.view(-1, len(particle_types) + 1), labels.view(-1))
            val_loss += loss.item()
    
    print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss / len(train_loader):.8f}, Validation Loss: {val_loss / len(val_loader):.8f}")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(segmentation_model.state_dict(), "segmentation_model_mini_cubes.pth")



Epoch 1/200, Train Loss: 1.57288820, Validation Loss: 1.28516688
Epoch 2/200, Train Loss: 1.22700006, Validation Loss: 1.13906887
Epoch 3/200, Train Loss: 1.10426477, Validation Loss: 1.05914401
Epoch 4/200, Train Loss: 1.01188412, Validation Loss: 1.00391569
Epoch 5/200, Train Loss: 0.96388233, Validation Loss: 0.96390377
Epoch 6/200, Train Loss: 0.91687117, Validation Loss: 0.93749539
Epoch 7/200, Train Loss: 0.89988101, Validation Loss: 0.91471113
Epoch 8/200, Train Loss: 0.85953799, Validation Loss: 0.89410546
Epoch 9/200, Train Loss: 0.84770854, Validation Loss: 0.88336953
Epoch 10/200, Train Loss: 0.83550239, Validation Loss: 0.87119049
Epoch 11/200, Train Loss: 0.81651862, Validation Loss: 0.86471188
Epoch 12/200, Train Loss: 0.80328981, Validation Loss: 0.85287727
Epoch 13/200, Train Loss: 0.79548920, Validation Loss: 0.85040967
Epoch 14/200, Train Loss: 0.78357425, Validation Loss: 0.84379584
Epoch 15/200, Train Loss: 0.77454838, Validation Loss: 0.84086553
Epoch 16/200, Train

KeyboardInterrupt: 

In [81]:
def visualize_inference(inputs, labels, predictions, sample_idx):
    """
    Visualize the input data, ground truth, and predictions using napari.
    """
    # Create a napari viewer
    viewer = napari.Viewer()

    # Add input tomogram
    viewer.add_image(inputs.cpu().numpy(), name=f"Sample {sample_idx} - Input", colormap="gray")

    # Add ground truth labels
    viewer.add_labels(labels.cpu().numpy(), name=f"Sample {sample_idx} - Ground Truth")

    # Add predictions
    viewer.add_labels(predictions.cpu().numpy(), name=f"Sample {sample_idx} - Predictions")

    # Start napari viewer
    napari.run()

# Load the trained model
segmentation_model.load_state_dict(torch.load("segmentation_model_mini_cubes.pth"))
# Select a few samples from training and validation datasets
segmentation_model.eval()
for i, (inputs, labels) in enumerate(train_loader):
    if i >= 5:  # Visualize only 5 samples
        break

    inputs, labels = inputs[0].to(device), labels[0].to(device)
    # print(inputs.shape, labels.shape)
    labels = labels.reshape(16, 16, 16)  # Reshape to (16, 16, 16) for visualization
    true_labels = torch.zeros_like(inputs[0]).to(torch.int64) # shape (96, 96, 96)
    for z in range(0, 96, 6):
        for y in range(0, 96, 6):
            for x in range(0, 96, 6):
                lz, ly, lx = z//6, y//6, x//6
                true_labels[z:z+6, y:y+6, x:x+6] = labels[lz, ly, lx]
    # visualize_inference(inputs[0], true_labels, true_labels, sample_idx=i)

    inputs = inputs.unsqueeze(1)  # Add channel dimension for model input

    with torch.no_grad():
        outputs = segmentation_model(inputs)
        predictions = torch.argmax(outputs, dim=-1)  # Get predicted class labels
        predictions = predictions.reshape(16, 16, 16)  # Reshape to (16, 16, 16) for visualization
        true_predictions = torch.zeros_like(inputs[0, 0]).to(torch.int64) # shape (96, 96, 96)
        for z in range(0, 96, 6):
            for y in range(0, 96, 6):
                for x in range(0, 96, 6):
                    lz, ly, lx = z//6, y//6, x//6
                    true_predictions[z:z+6, y:y+6, x:x+6] = predictions[lz, ly, lx]

    # Visualize the first sample in the batch
    visualize_inference(inputs[0, 0], true_labels, true_predictions, sample_idx=i)