In [2]:
# Library imports
import random
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import sys
import zarr
import json
import os
import napari
from rich import print as rprint  # Import rich's print function
import copy
from dataclasses import dataclass
from monai.networks.nets import UNet
# set torch and cuda seed for reproducibility
torch.manual_seed(37)
torch.cuda.manual_seed(37)

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
# Create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=6+1,
    channels=(48, 64, 80, 80),
    strides=(2, 2, 1),
    num_res_units=1,
).to(device)

# print the number of parameters in the model
print(f"Number of parameters in the model: {sum(p.numel() for p in model.parameters()):.2e}")

Number of parameters in the model: 1.12e+06


In [1]:
!pip install cryoet_data_portal



In [4]:
import os
os.getcwd()
os.chdir("./synthetic_data")

In [5]:
os.getcwd()

'c:\\Users\\ariel\\Downloads\\czii-cryo-et-object-identification\\synthetic_data'

In [8]:
from cryoet_data_portal import Client, Dataset

client = Client()

dataset = Dataset.get_by_id(client, 10441)
dataset.download_everything("./")

  2%|▏         | 567M/36.7G [00:54<58:10, 10.3MiB/s]   


FileNotFoundError: [Errno 2] No such file or directory: '1:0.5c8FD169'

In [None]:
#-------------LOADING TOMOGRAM DATA AND PARTICLE COORDINATES-----------------#

# Define the experiment runs to load
experiment_runs = ["TS_5_4", "TS_69_2", "TS_6_4", "TS_6_6", "TS_73_6", "TS_86_3", "TS_99_9"]
particle_types = {"virus-like-particle":1, "apo-ferritin":2, "beta-amylase":3, "beta-galactosidase":4, "ribosome":5, "thyroglobulin":6}
voxel_spacing = [10.0, 10.0, 10.0]  # 10 angstroms per voxel

# Initialize lists to store combined data
combined_tomogram_data = []
combined_particle_coords = {pt: [] for pt in particle_types}

# Track the cumulative z-depth for coordinate translation
cumulative_z_depth = 0

# Load and combine data from all experiment runs
for experiment_run in experiment_runs:
    zarr_file_path = os.path.join("train", "static", "ExperimentRuns", experiment_run, "VoxelSpacing10.000", "denoised.zarr")
    json_base_path = os.path.join("train", "overlay", "ExperimentRuns", experiment_run, "Picks")

    # Load the Zarr file
    try:
        tomogram = zarr.open(zarr_file_path, mode="r")
        tomogram_data = tomogram["0"][:]  # Load into memory as a NumPy array
        print(f"Tomogram shape for {experiment_run} (z, y, x):", tomogram_data.shape)
        tomogram_data = (tomogram_data - tomogram_data.mean()) / tomogram_data.std()
        combined_tomogram_data.append(tomogram_data)
    except Exception as e:
        print(f"Error loading Zarr file for {experiment_run}: {e}")
        continue

    # Load and transform particle coordinates for all types
    for particle_type in particle_types:
        json_file_path = os.path.join(json_base_path, f"{particle_type}.json")
        try:
            with open(json_file_path, "r") as file:
                data = json.load(file)
            points = data["points"]

            # Convert from real-world coordinates (angstroms) to voxel indices and reorder to (z, y, x)
            coords = np.array([
                [
                    (p["location"]["z"] / voxel_spacing[0]) + cumulative_z_depth,  # Translate z-coordinate
                     p["location"]["y"] / voxel_spacing[1],  # y-coordinate
                     p["location"]["x"] / voxel_spacing[2],  # x-coordinate
                ]
                for p in points
            ])
            combined_particle_coords[particle_type].extend(coords)
            print(f"Loaded {len(coords)} points for {particle_type} in {experiment_run}.")
        except Exception as e:
            print(f"Error loading JSON file for {particle_type} in {experiment_run}: {e}")

    # Update cumulative_z_depth for the next tomogram
    cumulative_z_depth += tomogram_data.shape[0]

# Combine all tomogram data into a single array
combined_tomogram_data = np.concatenate(combined_tomogram_data, axis=0)
print("Combined tomogram shape (z, y, x):", combined_tomogram_data.shape)

# Print total number of particles
total_particles = sum(len(coords) for coords in combined_particle_coords.values())
print(f"Total number of particles: {total_particles}")

# --------------------------------------------------------------------------------------------#


In [None]:
#-------------Combine tomograms and sample cubes with particles in it-----------------#

# Dimensions of the combined tomogram data
data_shape = combined_tomogram_data.shape
cube_size = (96, 96, 96)
particle_label_size = (8, 8, 8)
background_id = 0

# Calculate the number of cubes in each dimension
num_cubes_z = data_shape[0] // cube_size[0]
num_cubes_y = data_shape[1] // cube_size[1]
num_cubes_x = data_shape[2] // cube_size[2]

# Create a list of all possible cube indices
cubes = []
particle_cubes = []
non_particle_cubes = []

for z in range(num_cubes_z):
    for y in range(num_cubes_y):
        for x in range(num_cubes_x):
            cubes.append((z, y, x))

# Separate cubes into particle-containing and non-particle cubes
def contains_particle(cube_start, particle_coords):
    for coords in particle_coords.values():
        for coord in coords:
            z, y, x = coord.astype(int)
            if (
                cube_start[0] <= z < cube_start[0] + cube_size[0] and
                cube_start[1] <= y < cube_start[1] + cube_size[1] and
                cube_start[2] <= x < cube_start[2] + cube_size[2]
            ):
                return True
    return False

for cz, cy, cx in cubes:
    cube_start = (cz * cube_size[0], cy * cube_size[1], cx * cube_size[2])
    if contains_particle(cube_start, combined_particle_coords):
        particle_cubes.append((cz, cy, cx))
    else:
        non_particle_cubes.append((cz, cy, cx))

# Limit non-particle cubes to 20% of the dataset
num_non_particle_cubes = int(len(particle_cubes) * 0.1)
selected_non_particle_cubes = random.sample(non_particle_cubes, num_non_particle_cubes)
selected_cubes = particle_cubes + selected_non_particle_cubes
print(f"Selected {len(selected_cubes)} cubes for the dataset. Where {len(particle_cubes)} contain particles and {len(selected_non_particle_cubes)} do not.")


In [None]:
# Function to create labels for a cube
def create_labels(cube_start, cube_data, particle_coords):
    labels = np.full(cube_data.shape, background_id, dtype=int)
    
    for particle_type, coords in particle_coords.items():
        for coord in coords:
            z, y, x = coord.astype(int)
            z_rel, y_rel, x_rel = z - cube_start[0], y - cube_start[1], x - cube_start[2]
            
            if (
                0 <= z_rel < cube_data.shape[0] and
                0 <= y_rel < cube_data.shape[1] and
                0 <= x_rel < cube_data.shape[2]
            ):
                z_start = max(0, z_rel - particle_label_size[0] // 2)
                z_end = min(cube_data.shape[0], z_rel + particle_label_size[0] // 2 + 1)
                y_start = max(0, y_rel - particle_label_size[1] // 2)
                y_end = min(cube_data.shape[1], y_rel + particle_label_size[1] // 2 + 1)
                x_start = max(0, x_rel - particle_label_size[2] // 2)
                x_end = min(cube_data.shape[2], x_rel + particle_label_size[2] // 2 + 1)

                labels[z_start:z_end, y_start:y_end, x_start:x_end] = particle_types[particle_type]  # Unique ID for particle type
    
    return labels

# Define a PyTorch dataset class
class TomogramDataset(Dataset):
    def __init__(self, tomogram_data, selected_cubes, particle_coords):
        self.tomogram_data = tomogram_data
        self.selected_cubes = selected_cubes
        self.particle_coords = particle_coords

    def __len__(self):
        return len(self.selected_cubes)

    def __getitem__(self, idx):
        cz, cy, cx = self.selected_cubes[idx]
        z_start, z_end = cz * cube_size[0], (cz + 1) * cube_size[0]
        y_start, y_end = cy * cube_size[1], (cy + 1) * cube_size[1]
        x_start, x_end = cx * cube_size[2], (cx + 1) * cube_size[2]

        cube_data = self.tomogram_data[z_start:z_end, y_start:y_end, x_start:x_end]
        cube_start = (z_start, y_start, x_start)
        labels = create_labels(cube_start, cube_data, self.particle_coords)

        return torch.tensor(cube_data, dtype=torch.float32), torch.tensor(labels, dtype=torch.int64)

# Create the dataset
particle_dataset = TomogramDataset(combined_tomogram_data, selected_cubes, combined_particle_coords)

# ---------------------------------------------------------------------------------#

In [None]:
# ------------------- VISUALIZE Combined Tomogram Data ----------------------------#

# Define a color map for label IDs
label_colors = {
    1: "red",        # virus-like-particle
    2: "green",      # apo-ferritin
    3: "blue",       # beta-amylase
    4: "yellow",     # beta-galactosidase
    5: "magenta",    # ribosome
    6: "cyan",       # thyroglobulin
}

# Function to visualize the combined tomogram with particles in 3D using napari
def visualize_combined_tomogram(tomogram_data, particle_coords):
    # Create a napari viewer
    viewer = napari.Viewer()

    # Add the combined tomogram data as a 3D volume
    viewer.add_image(tomogram_data, name="Combined Tomogram")

    # Collect all particle coordinates and their label IDs
    all_particles = []
    all_labels = []
    for particle_type, coords in particle_coords.items():
        label_id = particle_types[particle_type]
        all_particles.extend(coords)
        all_labels.extend([label_id] * len(coords))

    # Convert to numpy arrays
    all_particles = np.array(all_particles)
    all_labels = np.array(all_labels)

    # Assign colors to each particle based on its label ID
    colors = [label_colors[label] for label in all_labels]

    # Add the particles as a 3D points layer with different colors
    if all_particles.size > 0:
        viewer.add_points(
            all_particles,
            name="Particles",
            face_color=colors,
            size=5,
            opacity=0.8,
        )

    # Start the napari event loop
    napari.run()

# Visualize the combined tomogram with particles
print("Visualizing the combined tomogram with particles...")
visualize_combined_tomogram(combined_tomogram_data, combined_particle_coords)
# ---------------------------------------------------------------------------------#

In [None]:
def visualize_selected_cubes(tomogram_data, particle_coords, selected_cubes, cube_size):
    # Create a napari viewer
    viewer = napari.Viewer()

    # Iterate through the 10 selected cubes and visualize them
    for idx, (cz, cy, cx) in enumerate(selected_cubes[:10]):  # Limit to 10 cubes
        # Define cube boundaries
        z_start, y_start, x_start = cz * cube_size[0], cy * cube_size[1], cx * cube_size[2]
        z_end, y_end, x_end = z_start + cube_size[0], y_start + cube_size[1], x_start + cube_size[2]

        # Extract cube data from the tomogram
        cube_data = tomogram_data[z_start:z_end, y_start:y_end, x_start:x_end]

        # Collect particle coordinates and labels within the cube
        cube_particles = []
        cube_labels = []
        for particle_type, coords in particle_coords.items():
            label_id = particle_types[particle_type]
            for coord in coords:
                z, y, x = coord.astype(int)
                if z_start <= z < z_end and y_start <= y < y_end and x_start <= x < x_end:
                    # Adjust coordinates to cube-local space
                    cube_particles.append([z - z_start, y - y_start, x - x_start])
                    cube_labels.append(label_id)

        # Convert to numpy arrays
        cube_particles = np.array(cube_particles)
        cube_labels = np.array(cube_labels)

        # Assign colors to each particle based on its label ID
        colors = [label_colors[label] for label in cube_labels]

        # Add the cube data as a volume
        viewer.add_image(cube_data, name=f"Cube {idx + 1}", colormap="gray")

        # Add the particles as a points layer
        if cube_particles.size > 0:
            viewer.add_points(
                cube_particles,
                name=f"Particles in Cube {idx + 1}",
                face_color=colors,
                size=5,
                opacity=0.8,
            )

    # Start the napari event loop
    napari.run()

# Visualize the first 10 selected cubes
print("Visualizing 10 selected cubes with particles...")
visualize_selected_cubes(combined_tomogram_data, combined_particle_coords, selected_cubes, cube_size)
