In [43]:
import numpy as np
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv3D, MaxPooling3D, Flatten, Dense, Dropout
from tensorflow.keras.optimizers import Adam
import copick
import zarr
import os
import json

In [None]:
# Load the configuration
copickRoot = copick.from_file('../../copick_config.json')
runs = copickRoot.runs
run = copickRoot.get_run(runs[0].name)

voxel_spacing = run.get_voxel_spacing(10.000)

# Access the specific tomogram
tomogram = voxel_spacing.get_tomogram("denoised")

# Access the Zarr data
zarr_store = tomogram.zarr()
zarr_group = zarr.open(zarr_store)
# Load the tomogram data
tomogram_vals = zarr_group['0']  # Adjust the key if needed
num_slices = tomogram_vals.shape[0]  # Number of slices in the z-dimension



In [18]:
# Load the tomogram data
tomogram_vals = zarr_group['0']  # Adjust the key if needed
num_slices = tomogram_vals.shape[0]  # Number of slices in the z-dimension

# Path to the Picks folder
picks_folder = '/Users/jake.brannigan/Documents/Kaggle/CryoET/Data/czii-cryo-et-object-identification/train/overlay/ExperimentRuns/TS_5_4/Picks'

def load_picks(picks_folder, voxel_spacing):
    picks = {}
    for json_file in os.listdir(picks_folder):
        if json_file.endswith('.json'):
            json_path = os.path.join(picks_folder, json_file)
            with open(json_path, 'r') as file:
                pick_data = json.load(file)
            picks[json_file[:-5]] = np.array([
                [point['location']['x'] / voxel_spacing.voxel_size,
                 point['location']['y'] / voxel_spacing.voxel_size,
                 point['location']['z'] / voxel_spacing.voxel_size]
                for point in pick_data['points']
            ])
    return picks

def map_picks_to_slices(picks, num_slices, copick_config):
    """
    Maps particle picks to slices based on their z-coordinates and particle radius.

    Args:
        picks (dict): Dictionary where keys are particle names and values are their 3D locations (numpy arrays).
        num_slices (int): Number of slices to consider along the z-axis.
        copick_config: The copick configuration object containing `pickable_objects`.

    Returns:
        dict: Dictionary mapping slice indices to filtered particle locations.
    """
    # Create a mapping of particle names to their radii from the copick configuration
    particle_radii = {
    obj.name: (obj.radius / 10 if obj.radius is not None else obj.radius)
    for obj in copick_config.pickable_objects
    }

    # Initialize a dictionary for storing particle locations for each slice
    plot_particles = {}
    for slice_idx in range(num_slices):
        plot_particles[slice_idx] = {
            particle: locations[np.abs(locations[:, 2] - slice_idx) < particle_radii[particle]].tolist()
            for particle, locations in picks.items()
        }

    return plot_particles

# Usage
picks = load_picks(picks_folder, voxel_spacing)

In [44]:
# Example: Extract patches around particles for training
def extract_patches(data, picks, patch_size=16):
    """
    Extracts cubic patches around particle locations.

    Args:
        data (numpy array): 3D tomogram volume.
        picks (dict): Dictionary with particle types as keys and 3D coordinates as values.
        patch_size (int): Size of the cubic patch.

    Returns:
        patches (numpy array): Extracted patches.
        labels (list): Corresponding labels for the patches.
    """
    patches = []
    labels = []
    half_size = patch_size // 2

    for particle, locations in picks.items():
        for loc in locations:
            x, y, z = map(int, loc)
            # Ensure patch is within bounds
            if (x - half_size >= 0 and x + half_size < data.shape[2] and
                y - half_size >= 0 and y + half_size < data.shape[1] and
                z - half_size >= 0 and z + half_size < data.shape[0]):
                patch = data[z-half_size:z+half_size,
                             y-half_size:y+half_size,
                             x-half_size:x+half_size]
                patches.append(patch)
                labels.append(particle)

    return np.array(patches), np.array(labels)

patches, labels = extract_patches(tomogram_vals, picks, patch_size=16)

# Normalize patches
patches = patches / np.max(patches)

# Encode labels to integers
label_map = {name: idx for idx, name in enumerate(set(labels))}
encoded_labels = np.array([label_map[label] for label in labels])

# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(patches, encoded_labels, test_size=0.2, random_state=42)

# Add a channel dimension for CNN
X_train = X_train[..., np.newaxis]
X_test = X_test[..., np.newaxis]

# Define a 3D CNN
model = Sequential([
    Conv3D(32, kernel_size=(3, 3, 3), activation='relu', input_shape=(16, 16, 16, 1)),
    MaxPooling3D(pool_size=(2, 2, 2)),
    Conv3D(64, kernel_size=(3, 3, 3), activation='relu'),
    MaxPooling3D(pool_size=(2, 2, 2)),
    Flatten(),
    Dense(128, activation='relu'),
    Dropout(0.5),
    Dense(len(label_map), activation='softmax')  # Output layer for classification
])

# Compile the model
model.compile(optimizer=Adam(learning_rate=0.001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train the model
history = model.fit(X_train, y_train, epochs=200, batch_size=32, validation_data=(X_test, y_test), verbose=1)

# Evaluate the model
test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=0)
print(f"Test Accuracy: {test_accuracy:.2f}")

# Save the model
# model.save('particle_detection_model.h5')

Epoch 1/200
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 96ms/step - accuracy: 0.2128 - loss: 1.7651 - val_accuracy: 0.4286 - val_loss: 1.4845
Epoch 2/200
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 78ms/step - accuracy: 0.4043 - loss: 1.5129 - val_accuracy: 0.5714 - val_loss: 1.3438
Epoch 3/200
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 78ms/step - accuracy: 0.5192 - loss: 1.4204 - val_accuracy: 0.7143 - val_loss: 1.2446
Epoch 4/200
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 79ms/step - accuracy: 0.5952 - loss: 1.2679 - val_accuracy: 0.7500 - val_loss: 1.1034
Epoch 5/200
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 78ms/step - accuracy: 0.6056 - loss: 1.0668 - val_accuracy: 0.6429 - val_loss: 0.9634
Epoch 6/200
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 78ms/step - accuracy: 0.7336 - loss: 0.9246 - val_accuracy: 0.7143 - val_loss: 0.8205
Epoch 7/200
[1m4/4[0m [32m━━━━━━━━━━━