In [3]:
import numpy as np
import os
import json
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Conv3D, MaxPooling3D, Flatten, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from sklearn.metrics import classification_report
import copick
import zarr
from tqdm import tqdm  # For progress bars

In [4]:
def load_picks(picks_folder, voxel_spacing):
    """
    Loads particle picks from JSON files.

    Args:
        picks_folder (str): Path to the Picks folder.
        voxel_spacing: Voxel spacing object from copick.

    Returns:
        dict: Dictionary with particle names as keys and numpy arrays of locations as values.
    """
    picks = {}
    for json_file in os.listdir(picks_folder):
        if json_file.endswith('.json'):
            json_path = os.path.join(picks_folder, json_file)
            with open(json_path, 'r') as file:
                pick_data = json.load(file)
            picks[json_file[:-5]] = np.array([
                [
                    point['location']['x'] / voxel_spacing.voxel_size,
                    point['location']['y'] / voxel_spacing.voxel_size,
                    point['location']['z'] / voxel_spacing.voxel_size
                ]
                for point in pick_data['points']
            ])
    return picks


def extract_patches(data, picks, patch_size=16):
    """
    Extracts cubic patches around particle locations.

    Args:
        data (numpy array): 3D tomogram volume.
        picks (dict): Dictionary with particle types as keys and 3D coordinates as values.
        patch_size (int): Size of the cubic patch.

    Returns:
        tuple: (patches, labels)
            patches (numpy array): Extracted patches.
            labels (list): Corresponding labels for the patches.
    """
    patches = []
    labels = []
    half_size = patch_size // 2

    for particle, locations in picks.items():
        for loc in locations:
            x, y, z = loc.astype(int)
            # Ensure patch is within bounds
            if (x - half_size >= 0 and x + half_size < data.shape[2] and
                    y - half_size >= 0 and y + half_size < data.shape[1] and
                    z - half_size >= 0 and z + half_size < data.shape[0]):
                patch = data[z - half_size:z + half_size,
                        y - half_size:y + half_size,
                        x - half_size:x + half_size]
                patches.append(patch)
                labels.append(particle)
    return np.array(patches), np.array(labels)


def process_dataset(config_path, dataset_type='train', tomogram_keys=None):
    """
    Processes the dataset to extract patches and labels from specified tomograms.

    Args:
        config_path (str): Path to the copick configuration file.
        dataset_type (str): Dataset type ('train' or 'test').
        tomogram_keys (list or None): List of tomogram keys to process. If None, process all.

    Returns:
        tuple: (patches, labels)
    """
    copick_root = copick.from_file(config_path)
    runs = copick_root.runs
    run = copick_root.get_run(runs[0].name)

    voxel_spacing = run.get_voxel_spacing(10.000)

    # Access the specific tomogram(s)
    tomogram = voxel_spacing.get_tomogram("denoised")

    # Access the Zarr data
    zarr_store = tomogram.zarr()
    zarr_group = zarr.open(zarr_store)

    patches = []
    labels = []

    # Path to the Picks folder
    picks_folder_base = os.path.join(
        '/Users/jake.brannigan/Documents/Kaggle/CryoET/Data/czii-cryo-et-object-identification',
        dataset_type,
        'overlay',
        'ExperimentRuns',
        'TS_5_4',
        'Picks'
    )
    
    picks = load_picks(picks_folder_base, voxel_spacing)
    
    # Determine tomogram keys
    if tomogram_keys is None:
        tomogram_keys = list(picks.keys())

    for tomogram_key in tomogram_keys:
        # Load the tomogram data
        tomogram_vals = zarr_group[tomogram_key]

        # Define picks folder for this tomogram (assuming separate Picks per tomogram)
        picks_folder = os.path.join(picks_folder_base, tomogram_key)
        if not os.path.exists(picks_folder):
            print(f"Picks folder {picks_folder} does not exist. Skipping tomogram {tomogram_key}.")
            continue

        picks = load_picks(picks_folder, voxel_spacing)
        tomogram_patches, tomogram_labels = extract_patches(tomogram_vals, picks, patch_size=16)

        patches.append(tomogram_patches)
        labels.append(tomogram_labels)

    if patches:
        patches = np.concatenate(patches, axis=0)
        labels = np.concatenate(labels, axis=0)
    else:
        patches = np.array([])
        labels = np.array([])

    return patches, labels


def prepare_label_maps(labels):
    """
    Creates label maps for encoding and decoding labels.

    Args:
        labels (numpy array): Array of label names.

    Returns:
        tuple: (label_map, inverse_label_map)
            label_map (dict): Mapping from label name to index.
            inverse_label_map (dict): Mapping from index to label name.
    """
    unique_labels = sorted(set(labels))
    label_map = {name: idx for idx, name in enumerate(unique_labels)}
    inverse_label_map = {idx: name for name, idx in label_map.items()}
    return label_map, inverse_label_map


def normalize_data(data, max_val):
    """
    Normalizes the data by dividing by the maximum value.

    Args:
        data (numpy array): Data to normalize.
        max_val (float): Maximum value for normalization.

    Returns:
        numpy array: Normalized data.
    """
    return data / max_val if max_val != 0 else data

In [5]:
# Path to your copick configuration
config_path = '../../copick_config.json'

In [6]:
# =======================
# Training Phase
# =======================

# Process training data
print("Processing training data...")
train_patches, train_labels = process_dataset(config_path, dataset_type='train')

if train_patches.size == 0:
    raise ValueError("No training patches were extracted. Check your Picks folder and data paths.")

print(f"Extracted {train_patches.shape[0]} training patches.")

# Normalize patches based on global max from training data
global_max = train_patches.max()
train_patches = normalize_data(train_patches, global_max)

# Encode labels
label_map, inverse_label_map = prepare_label_maps(train_labels)
encoded_train_labels = np.array([label_map[label] for label in train_labels])

print(f"Label mapping: {label_map}")

# Split into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(
    train_patches, encoded_train_labels, test_size=0.2, random_state=42
)

# Add a channel dimension for CNN
X_train = X_train[..., np.newaxis]
X_val = X_val[..., np.newaxis]

print(f"Training set: {X_train.shape}, Validation set: {X_val.shape}")

Processing training data...
Picks folder /Users/jake.brannigan/Documents/Kaggle/CryoET/Data/czii-cryo-et-object-identification/train/overlay/ExperimentRuns/TS_5_4/Picks/0 does not exist. Skipping tomogram 0.
Picks folder /Users/jake.brannigan/Documents/Kaggle/CryoET/Data/czii-cryo-et-object-identification/train/overlay/ExperimentRuns/TS_5_4/Picks/1 does not exist. Skipping tomogram 1.
Picks folder /Users/jake.brannigan/Documents/Kaggle/CryoET/Data/czii-cryo-et-object-identification/train/overlay/ExperimentRuns/TS_5_4/Picks/2 does not exist. Skipping tomogram 2.


ValueError: No training patches were extracted. Check your Picks folder and data paths.

In [None]:
# =======================
# Model Definition and Training
# =======================

# Define the 3D CNN model
model = Sequential([
    Conv3D(32, kernel_size=(3, 3, 3), activation='relu', input_shape=(16, 16, 16, 1)),
    MaxPooling3D(pool_size=(2, 2, 2)),
    Conv3D(64, kernel_size=(3, 3, 3), activation='relu'),
    MaxPooling3D(pool_size=(2, 2, 2)),
    Flatten(),
    Dense(128, activation='relu'),
    Dropout(0.5),
    Dense(len(label_map), activation='softmax')  # Output layer for classification
])

# Compile the model
model.compile(
    optimizer=Adam(learning_rate=0.001),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Train the model
print("Training the model...")
history = model.fit(
    X_train, y_train,
    epochs=20,
    batch_size=32,
    validation_data=(X_val, y_val),
    verbose=1
)

# Save the trained model
model_save_path = 'particle_detection_model.h5'
# model.save(model_save_path)
print(f"Model saved to {model_save_path}")

In [None]:
# =======================
# Prediction Phase
# =======================

def sliding_window_predict(model, tomogram, patch_size=16, step=8, threshold=0.5, label_map=None, global_max=1.0):
    """
    Performs sliding window predictions on the test tomogram.

    Args:
        model (Keras model): Trained CNN model.
        tomogram (numpy array): 3D tomogram volume.
        patch_size (int): Size of cubic patches.
        step (int): Step size for sliding window.
        threshold (float): Confidence threshold for predictions.
        label_map (dict): Map of label indices to particle names.
        global_max (float): Global maximum value for normalization.

    Returns:
        list: Predicted particles with their positions and labels.
    """
    half_size = patch_size // 2
    z_max, y_max, x_max = tomogram.shape

    predictions = []
    for z in tqdm(range(half_size, z_max - half_size, step), desc="Z-axis"):
        for y in range(half_size, y_max - half_size, step):
            for x in range(half_size, x_max - half_size, step):
                patch = tomogram[z - half_size:z + half_size,
                        y - half_size:y + half_size,
                        x - half_size:x + half_size]
                patch = patch[np.newaxis, ..., np.newaxis]  # Add batch and channel dimensions
                patch = normalize_data(patch, global_max)  # Normalize using training global max

                pred = model.predict(patch, verbose=0)
                max_prob = np.max(pred)
                label_idx = np.argmax(pred)

                if max_prob > threshold:
                    predictions.append({
                        "position": (int(x), int(y), int(z)),
                        "label": label_map[label_idx],
                        "confidence": float(max_prob)
                    })

    return predictions


def prepare_test_data_multiple(config_path, dataset_type='test', num_tomograms=5):
    """
    Prepares multiple test tomograms for predictions.

    Args:
        config_path (str): Path to the copick configuration file.
        dataset_type (str): Dataset type ('test').
        num_tomograms (int): Number of tomograms to process.

    Returns:
        list: List of tomogram data arrays.
        list: List of tomogram keys.
    """
    copick_root = copick.from_file(config_path)
    runs = copick_root.runs
    run = copick_root.get_run(runs[0].name)

    voxel_spacing = run.get_voxel_spacing(10.000)

    # Access the specific tomogram
    tomogram = voxel_spacing.get_tomogram("denoised")

    # Access the Zarr data
    zarr_store = tomogram.zarr()
    zarr_group = zarr.open(zarr_store)

    # Get tomogram keys and limit to first `num_tomograms`
    all_keys = list(zarr_group.keys())
    selected_keys = all_keys[:num_tomograms]

    tomogram_data = []
    for key in selected_keys:
        tomogram_vals = zarr_group[key][:]
        tomogram_data.append(tomogram_vals)

    return tomogram_data, selected_keys


def save_predictions(predictions, output_path):
    """
    Saves predictions in overlay format.

    Args:
        predictions (list): List of predicted particles with positions and labels.
        output_path (str): Path to save the JSON file.
    """
    overlay_data = {
        "points": [
            {
                "location": {"x": pos[0], "y": pos[1], "z": pos[2]},
                "label": label,
                "confidence": confidence
            }
            for pred in predictions
            for pos, label, confidence in [(pred["position"], pred["label"], pred["confidence"])]
        ]
    }

    with open(output_path, 'w') as json_file:
        json.dump(overlay_data, json_file, indent=4)


# Function to predict on multiple tomograms
def predict_on_multiple_tomograms(model, config_path, num_tomograms=5, patch_size=16, step=8, threshold=0.5,
                                  output_dir='predictions'):
    """
    Predicts particle locations on multiple tomograms and saves the results.

    Args:
        model (Keras model): Trained CNN model.
        config_path (str): Path to the copick configuration file.
        num_tomograms (int): Number of tomograms to process.
        patch_size (int): Size of cubic patches.
        step (int): Step size for sliding window.
        threshold (float): Confidence threshold for predictions.
        output_dir (str): Directory to save prediction JSON files.
    """
    os.makedirs(output_dir, exist_ok=True)

    # Prepare test data
    tomograms, tomogram_keys = prepare_test_data_multiple(config_path, dataset_type='test', num_tomograms=num_tomograms)

    if not tomograms:
        print("No tomograms found for prediction.")
        return

    for idx, (tomogram_vals, tomogram_key) in enumerate(zip(tomograms, tomogram_keys)):
        print(f"\nPredicting on tomogram {idx + 1}/{num_tomograms}: {tomogram_key}")

        # Perform prediction
        predictions = sliding_window_predict(
            model=model,
            tomogram=tomogram_vals,
            patch_size=patch_size,
            step=step,
            threshold=threshold,
            label_map=inverse_label_map,
            global_max=global_max
        )

        # Save predictions
        output_path = os.path.join(output_dir, f'predictions_{tomogram_key}.json')
        save_predictions(predictions, output_path)
        print(f"Saved predictions to {output_path}. Detected {len(predictions)} particles.")


In [None]:
# =======================
# Execute Prediction on First 5 Tomograms
# =======================

# Define parameters
NUM_TOMOGRAMS = 5  # Number of tomograms to predict on
PATCH_SIZE = 16
STEP = 8
THRESHOLD = 0.5
OUTPUT_DIR = 'predictions'

# Load the trained model (optional, if not already in memory)
# model = load_model('particle_detection_model.h5')

# Perform predictions
print("Starting predictions on test tomograms...")
predict_on_multiple_tomograms(
    model=model,
    config_path=config_path,
    num_tomograms=NUM_TOMOGRAMS,
    patch_size=PATCH_SIZE,
    step=STEP,
    threshold=THRESHOLD,
    output_dir=OUTPUT_DIR
)
print("Prediction completed.")
