## Installing offline deps

As this is a code comp, there is no internet. 
So we have to do some silly things to get dependencies in here. 
Why is asciitree such a PITA? 

In [1]:
deps_path = './'

In [2]:
import random
import numpy as np
import torch

def set_random_seed(seed=42):
    """Set random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed) # 设置所有GPU的随机数生成器的种子
        torch.backends.cudnn.deterministic = True # 设置为确定性模式，确保结果可复现
        torch.backends.cudnn.benchmark = False # 禁用benchmark模式，不然会自动优化但导致结果不可复现

# 设置随机种子
# 43 79
set_random_seed(666677)

In [3]:
from typing import List, Tuple, Union
import numpy as np
import torch
from monai.data import DataLoader, Dataset, CacheDataset, decollate_batch
from monai.transforms import (  # 导入MONAI的变换模块，用于数据预处理和增强
    Compose,                   # 组合多个变换，方便链式调用
    EnsureChannelFirstd,        # 确保数据的通道维度位于第一位
    Orientationd,              # 调整数据的空间方向
    AsDiscrete,                # 将连续型数据离散化
    RandFlipd,                 # 随机水平或垂直翻转
    RandRotate90d,             # 随机旋转90度
    NormalizeIntensityd,        # 对图像强度进行标准化
    RandCropByLabelClassesd     # 依据标签类别进行随机采样
)

from czii_helper import *
from dataset import *
from scipy.optimize import linear_sum_assignment
import matplotlib.pyplot as plt

## Define some helper functions


### Patching helper functions

These are mostly used to split large volumes into smaller ones and stitch them back together. 

In [4]:
def calculate_patch_starts(dimension_size: int, patch_size: int) -> List[int]:
    """
    Calculate the starting positions of patches along a single dimension
    with minimal overlap to cover the entire dimension.
    
    Parameters:
    -----------
    dimension_size : int
        Size of the dimension
    patch_size : int
        Size of the patch in this dimension
        
    Returns:
    --------
    List[int]
        List of starting positions for patches
    """
    if dimension_size <= patch_size:
        return [0]
        
    # Calculate number of patches needed
    n_patches = np.ceil(dimension_size / patch_size)
    
    if n_patches == 1:
        return [0]
    
    # Calculate overlap
    total_overlap = (n_patches * patch_size - dimension_size) / (n_patches - 1)
    
    # Generate starting positions
    positions = []
    for i in range(int(n_patches)):
        pos = int(i * (patch_size - total_overlap))
        if pos + patch_size > dimension_size:
            pos = dimension_size - patch_size
        if pos not in positions:  # Avoid duplicates
            positions.append(pos)
    
    return positions

def extract_3d_patches_minimal_overlap(arrays: List[np.ndarray], patch_size: int) -> Tuple[List[np.ndarray], List[Tuple[int, int, int]]]:
    """
    Extract 3D patches from multiple arrays with minimal overlap to cover the entire array.
    
    Parameters:
    -----------
    arrays : List[np.ndarray]
        List of input arrays, each with shape (m, n, l)
    patch_size : int
        Size of cubic patches (a x a x a)
        
    Returns:
    --------
    patches : List[np.ndarray]
        List of all patches from all input arrays
    coordinates : List[Tuple[int, int, int]]
        List of starting coordinates (x, y, z) for each patch
    """
    if not arrays or not isinstance(arrays, list):
        raise ValueError("Input must be a non-empty list of arrays")
    
    # Verify all arrays have the same shape
    shape = arrays[0].shape
    if not all(arr.shape == shape for arr in arrays):
        raise ValueError("All input arrays must have the same shape")
    
    if patch_size > min(shape):
        raise ValueError(f"patch_size ({patch_size}) must be smaller than smallest dimension {min(shape)}")
    #184, 320, 320
    m, n, l = shape
    patches = []
    coordinates = []
    
    # Calculate starting positions for each dimension
    # x_starts = calculate_patch_starts(m, 184)
    # y_starts = calculate_patch_starts(n, 320)
    # z_starts = calculate_patch_starts(l, 320)
    x_starts = calculate_patch_starts(m, 96)
    y_starts = calculate_patch_starts(n, 96)
    z_starts = calculate_patch_starts(l, 96)
    
    # Extract patches from each array
    for arr in arrays:
        for x in x_starts:
            for y in y_starts:
                for z in z_starts:
                    patch = arr[
                        # x:x + 184,
                        # y:y + 320,
                        # z:z + 320
                        x:x + 96,
                        y:y + 96,
                        z:z + 96
                    ]
                    patches.append(patch)
                    coordinates.append((x, y, z))
    
    return patches, coordinates

# Note: I should probably averge the overlapping areas, 
# but here they are just overwritten by the most recent one. 

def reconstruct_array(patches: List[np.ndarray], 
                     coordinates: List[Tuple[int, int, int]], 
                     original_shape: Tuple[int, int, int]) -> np.ndarray:
    """
    Reconstruct array from patches.
    
    Parameters:
    -----------
    patches : List[np.ndarray]
        List of patches to reconstruct from
    coordinates : List[Tuple[int, int, int]]
        Starting coordinates for each patch
    original_shape : Tuple[int, int, int]
        Shape of the original array
        
    Returns:
    --------
    np.ndarray
        Reconstructed array
    """
    reconstructed = np.zeros(original_shape, dtype=np.int64)  # To track overlapping regions
    
    patch_size = patches[0].shape[0]
    
    for patch, (x, y, z) in zip(patches, coordinates):
        reconstructed[
            # x:x + 184,
            # y:y + 320,
            # z:z + 320
            x:x + 96,
            y:y + 96,
            z:z + 96
        ] = patch
        
    
    return reconstructed

## Submission helper functions

These help with getting the submission in the correct format

In [5]:
import pandas as pd

def dict_to_df(coord_dict, experiment_name):
    """
    Convert dictionary of coordinates to pandas DataFrame.
    
    Parameters:
    -----------
    coord_dict : dict
        Dictionary where keys are labels and values are Nx3 coordinate arrays
        
    Returns:
    --------
    pd.DataFrame
        DataFrame with columns ['x', 'y', 'z', 'label']
    """
    # Create lists to store data
    all_coords = []
    all_labels = []
    
    # Process each label and its coordinates
    for label, coords in coord_dict.items():
        all_coords.append(coords)
        all_labels.extend([label] * len(coords))
    
    # Concatenate all coordinates
    all_coords = np.vstack(all_coords)
    
    df = pd.DataFrame({
        'experiment': experiment_name,
        'particle_type': all_labels,
        'x': all_coords[:, 0],
        'y': all_coords[:, 1],
        'z': all_coords[:, 2]
    })

    
    return df

## Reading in the data

In [6]:
TRAIN_DATA_DIR = "./"
TEST_DATA_DIR = "./"

In [7]:
# train_names = ['TS_17','TS_6_6','TS_69_2',
#  'TS_18',
#  'TS_19',
#  'TS_2',
#  'TS_5_4',
#  'TS_73_6',
#  'TS_86_3',
#  'TS_99_9',
#  'TS_0',
#  'TS_1',
#  'TS_10',
#  'TS_11',
#  'TS_12',
#  'TS_13',
#  'TS_14',
#  'TS_15',
#  'TS_16',
#  'TS_20',
#  'TS_21',
#  'TS_22',
#  'TS_23',
#  'TS_24',
#  'TS_25',
#  'TS_26',
#  'TS_3',
#  'TS_4',
#  'TS_5',
#  'TS_6',
#  'TS_7',
#  'TS_8',
#  'TS_9']
train_names = ['TS_5_4', 'TS_69_2', 'TS_6_6', 'TS_73_6', 'TS_86_3', 'TS_99_9', 'TS_6_4']
valid_names = ['TS_6_4']#,'TS_6_6','TS_69_2']

train_files = []
valid_files = []

for name in train_names:
    image = np.load(f"{TRAIN_DATA_DIR}/train_image_{name}.npy")
    label = np.load(f"{TRAIN_DATA_DIR}/train_label_{name}.npy")

    train_files.append({"image": image, "label": label})
    

for name in valid_names:
    image = np.load(f"{TRAIN_DATA_DIR}/train_image_{name}.npy")
    label = np.load(f"{TRAIN_DATA_DIR}/train_label_{name}.npy")

    valid_files.append({"image": image, "label": label})
    


### Create the training dataloader

I should probably find a way to create a dataloader that takes more batches. 

In [8]:
# Non-random transforms to be cached
non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS")
])

raw_train_ds = CacheDataset(data=train_files, transform=non_random_transforms, cache_rate=1.0)

Loading dataset: 100%|██████████| 7/7 [00:00<00:00, 18.27it/s]


In [9]:
my_num_samples = 80
train_batch_size = 1

# Random transforms to be applied during training
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[96, 96, 96],
        num_classes=7,
        num_samples=my_num_samples
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),    
])

train_ds = Dataset(data=raw_train_ds, transform=random_transforms)


# DataLoader remains the same
train_loader = DataLoader(
    train_ds,
    batch_size=train_batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=torch.cuda.is_available()
)

### Create the validation dataloader

Here I deviate a little from the source notebooks. 

In the source, the validation dataloader also used the random transformations. This is bad practice and will result in noisy validation. 

Here I split the validation dataset in (slightly) overlapping blocks of `(96, 96 , 96)` so that we can have a consistent validation set that uses all the validation data. 


In [10]:
val_images,val_labels = [dcts['image'] for dcts in valid_files],[dcts['label'] for dcts in valid_files]

val_image_patches, _ = extract_3d_patches_minimal_overlap(val_images, 96)
val_label_patches, _ = extract_3d_patches_minimal_overlap(val_labels, 96)

val_patched_data = [{"image": img, "label": lbl} for img, lbl in zip(val_image_patches, val_label_patches)]

valid_ds = CacheDataset(data=val_patched_data, transform=non_random_transforms, cache_rate=1.0)


valid_batch_size = 2*49
# DataLoader remains the same
valid_loader = DataLoader(
    valid_ds,
    batch_size=valid_batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=torch.cuda.is_available()
)

Loading dataset: 100%|██████████| 98/98 [00:00<00:00, 404.96it/s]


## Initialize the model

This model is pretty much directly copied from [3D U-Net PyTorch Lightning distributed training](https://www.kaggle.com/code/zhuowenzhao11/3d-u-net-pytorch-lightning-distributed-training)

In [11]:
inference_transforms = Compose([
    EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image"], axcodes="RAS")
])

import cc3d

id_to_name = {1: "apo-ferritin", 
              2: "beta-amylase",
              3: "beta-galactosidase", 
              4: "ribosome", 
              5: "thyroglobulin", 
              6: "virus-like-particle"}

def do_one_eval(truth, predict, threshold):
    P=len(predict)
    T=len(truth)

    if P==0:
        hit=[[],[]]
        miss=np.arange(T).tolist()
        fp=[]
        metric = [P,T,len(hit[0]),len(miss),len(fp)]
        return hit, fp, miss, metric

    if T==0:
        hit=[[],[]]
        fp=np.arange(P).tolist()
        miss=[]
        metric = [P,T,len(hit[0]),len(miss),len(fp)]
        return hit, fp, miss, metric

    #---
    distance = predict.reshape(P,1,3)-truth.reshape(1,T,3)
    distance = distance**2
    distance = distance.sum(axis=2)
    distance = np.sqrt(distance)
    p_index, t_index = linear_sum_assignment(distance)

    valid = distance[p_index, t_index] <= threshold
    p_index = p_index[valid]
    t_index = t_index[valid]
    hit = [p_index.tolist(), t_index.tolist()]
    miss = np.arange(T)
    miss = miss[~np.isin(miss,t_index)].tolist()
    fp = np.arange(P)
    fp = fp[~np.isin(fp,p_index)].tolist()

    metric = [P,T,len(hit[0]),len(miss),len(fp)] #for lb metric F-beta copmutation
    return hit, fp, miss, metric
    
def compute_lb(submit_df, overlay_dir):
    eval_df = []
    for id in submit_df['experiment'].unique():
        truth = read_one_truth(id, overlay_dir)
        id_df = submit_df[submit_df['experiment'] == id]
        for p in PARTICLE:
            p = dotdict(p)
            xyz_truth = truth[p.name]
            xyz_predict = id_df[id_df['particle_type'] == p.name][['x', 'y', 'z']].values
            hit, fp, miss, metric = do_one_eval(xyz_truth, xyz_predict, p.radius * 0.5)
            eval_df.append({
                'particle_type': p.name,
                'P': metric[0], 'T': metric[1], 'hit': metric[2], 'miss': metric[3], 'fp': metric[4],
            })

    eval_df = pd.DataFrame(eval_df)
    gb = eval_df.groupby('particle_type').agg('sum')
    gb.loc[:, 'precision'] = gb['hit'] / gb['P']
    gb.loc[:, 'recall'] = gb['hit'] / gb['T']
    gb.loc[:, 'f-beta4'] = 17 * gb['precision'] * gb['recall'] / (16 * gb['precision'] + gb['recall'])
    gb.loc[:, 'weight'] = [1, 0, 2, 1, 2, 1]

    lb_score = (gb['f-beta4'].fillna(0) * gb['weight']).sum() / gb['weight'].sum()
    return lb_score

In [12]:
import lightning.pytorch as pl

from monai.networks.nets import UNet
from monai.losses import TverskyLoss,DiceLoss
from monai.metrics import DiceMetric

class Model(pl.LightningModule):
    def __init__(
        self, 
        spatial_dims: int = 3,
        in_channels: int = 1,
        out_channels: int = 7,
        channels: Union[Tuple[int, ...], List[int]] = (48, 64, 80, 80),
        strides: Union[Tuple[int, ...], List[int]] = (2, 2, 1),
        num_res_units: int = 1,
        lr: float=1e-3):
    
        super().__init__()
        self.save_hyperparameters()
        self.model = UNet(
            spatial_dims=self.hparams.spatial_dims,
            in_channels=self.hparams.in_channels,
            out_channels=self.hparams.out_channels,
            channels=self.hparams.channels,
            strides=self.hparams.strides,
            num_res_units=self.hparams.num_res_units,
        )
        self.loss_fn = TverskyLoss(alpha=0.4, beta=0.6, include_background=True, to_onehot_y=True, softmax=True)
        #self.loss_fn = TverskyLoss(include_background=True, to_onehot_y=True, softmax=True)
        #self.loss_fn = DiceLoss(include_background=True, to_onehot_y=True, softmax=True)
        self.metric_fn = DiceMetric(include_background=True, reduction="mean", ignore_empty=True)

        self.train_loss = 0
        self.val_metric = 0
        self.num_train_batch = 0
        self.num_val_batch = 0

        tomo = np.load(f"train_image_TS_6_4.npy")
        self.tomo = tomo
        tomo_patches, self.coordinates = extract_3d_patches_minimal_overlap([tomo], 96)
        tomo_patched_data = [{"image": img} for img in tomo_patches]
        self.tomo_ds = CacheDataset(data=tomo_patched_data, transform=inference_transforms, cache_rate=1.0)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch['image'], batch['label']
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.train_loss += loss
        self.num_train_batch += 1
        torch.cuda.empty_cache()
        return loss

    def on_train_epoch_end(self):
        loss_per_epoch = self.train_loss/self.num_train_batch
        self.log('train_loss', loss_per_epoch, prog_bar=True)
        self.train_loss = 0
        self.num_train_batch = 0
    
    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            if self.current_epoch > 170:
                lb_score = self.process_tomography_and_compute_lb(['TS_6_4'], 50, 0.1)
            else:
                lb_score = 0.0
            self.val_metric = lb_score
        torch.cuda.empty_cache()
        return {'val_metric': lb_score}

    def on_validation_epoch_end(self):
        metric_per_epoch = self.val_metric#/self.num_val_batch
        print(f"Epoch {self.current_epoch} - Average Val Metric: {metric_per_epoch:.4f}")
        self.log('val_metric', metric_per_epoch, prog_bar=True, sync_dist=False) # sync_dist=True for distributed training
        self.val_metric = 0
        self.num_val_batch = 0
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=1e-5)

    def process_tomography_and_compute_lb(self, valid_id, BLOB_THRESHOLD, CERTAINTY_THRESHOLD):
        location_df = []

        for run in valid_id:
            pred_masks = []
            tomo_ds = self.tomo_ds 
            for i in range(len(tomo_ds)):
                input_tensor = tomo_ds[i]['image'].unsqueeze(0).to("cuda")
                augmented_tensors = [
                    input_tensor,  
                    torch.flip(input_tensor.clone(), dims=[2]),  
                    torch.flip(input_tensor.clone(), dims=[3]),  
                    torch.flip(input_tensor.clone(), dims=[4])  
                ]
                batch = torch.cat(augmented_tensors, dim=0)
                model_output = self(batch)
                model_output[1] = torch.flip(model_output[1], dims=[1])
                model_output[2] = torch.flip(model_output[2], dims=[2])
                model_output[3] = torch.flip(model_output[3], dims=[3])

                probs = torch.softmax(model_output, dim=1)
                avg_probs = torch.mean(probs, dim=0)
                thresh_probs = avg_probs > 0.1
                _, max_classes = thresh_probs.max(dim=0)

                pred_masks.append(max_classes.cpu().numpy())

            reconstructed_mask = reconstruct_array(pred_masks, self.coordinates, self.tomo.shape)

            location = {}
            classes = [1, 2, 3, 4, 5, 6]

            for c in classes:
                cc = cc3d.connected_components(reconstructed_mask == c)
                stats = cc3d.statistics(cc)
                zyx = stats['centroids'][1:] * 10.012444
                zyx_large = zyx[stats['voxel_counts'][1:] > BLOB_THRESHOLD]
                xyz = np.ascontiguousarray(zyx_large[:, ::-1])

                location[id_to_name[c]] = xyz

            df = dict_to_df(location, run)
            location_df.append(df)

        location_df = pd.concat(location_df)

        lb_score = compute_lb(location_df, f'./overlay/ExperimentRuns')
        return lb_score

In [13]:
channels = (48, 64, 80, 80)
strides_pattern = (2, 2, 1)       
num_res_units = 1
learning_rate = 1e-3
num_epochs = 5000

model = Model(channels=channels, strides=strides_pattern, num_res_units=num_res_units, lr=learning_rate)

Loading dataset: 100%|██████████| 98/98 [00:00<00:00, 540.15it/s]


## Train the model



In [14]:
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint

torch.set_float32_matmul_precision('medium')

# Check if CUDA is available and then count the GPUs
if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs available: {num_gpus}")
else:
    print("No GPU available. Running on CPU.")
devices = list(range(num_gpus))
print(devices)

early_stopping_callback = EarlyStopping(
    monitor='val_metric',  
    patience=10000,         
    mode='max',       
    #verbose=True          
)

checkpoint_callback = ModelCheckpoint(
    dirpath='checkpoints/',  # 保存模型的目录
    # 自定义文件名模板，这里使用{epoch}作为占位符，但实际上我们稍后会在on_validation_end中修改它
    filename='checkpoint_{epoch}-{val_metric:.2f}',  
    monitor='val_metric',  # 监控的指标名称
    save_top_k=1,  # 只保存最好的k个模型
    mode='max',  # 指标越大越好
    verbose=True  # 打印保存模型的提示信息
)

trainer = pl.Trainer(
    max_epochs=num_epochs,
    #strategy="ddp_notebook", 
    accelerator="gpu",
    devices=[0],# devices
    num_nodes=1,
    log_every_n_steps=10,
    enable_progress_bar=True,
    callbacks=[early_stopping_callback, checkpoint_callback],
    check_val_every_n_epoch=3,  # 每10个训练步骤进行一次验证
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Number of GPUs available: 1
[0]


Let there be gradients!

Locally this config seems to train for about 1000 steps before the model starts overfitting. 

In [None]:
trainer.fit(model, train_loader, valid_loader)

/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /root/autodl-tmp/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type        | Params | Mode 
------------------------------------------------
0 | model   | UNet        | 1.1 M  | train
1 | loss_fn | TverskyLoss | 0      | train
------------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.480     Total estimated model params size (MB)
88        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Epoch 0 - Average Val Metric: 0.0000


/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (7) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 21: 'val_metric' reached 0.00000 (best 0.00000), saving model to '/root/autodl-tmp/checkpoints/checkpoint_epoch=2-val_metric=0.00-v6.ckpt' as top 1


Epoch 2 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 42: 'val_metric' was not in top 1


Epoch 5 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 63: 'val_metric' was not in top 1


Epoch 8 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 11, global step 84: 'val_metric' was not in top 1


Epoch 11 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 14, global step 105: 'val_metric' was not in top 1


Epoch 14 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 17, global step 126: 'val_metric' was not in top 1


Epoch 17 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 20, global step 147: 'val_metric' was not in top 1


Epoch 20 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 23, global step 168: 'val_metric' was not in top 1


Epoch 23 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 26, global step 189: 'val_metric' was not in top 1


Epoch 26 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 29, global step 210: 'val_metric' was not in top 1


Epoch 29 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 32, global step 231: 'val_metric' was not in top 1


Epoch 32 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 35, global step 252: 'val_metric' was not in top 1


Epoch 35 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 38, global step 273: 'val_metric' was not in top 1


Epoch 38 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 41, global step 294: 'val_metric' was not in top 1


Epoch 41 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 44, global step 315: 'val_metric' was not in top 1


Epoch 44 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 47, global step 336: 'val_metric' was not in top 1


Epoch 47 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 50, global step 357: 'val_metric' was not in top 1


Epoch 50 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 53, global step 378: 'val_metric' was not in top 1


Epoch 53 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 56, global step 399: 'val_metric' was not in top 1


Epoch 56 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 59, global step 420: 'val_metric' was not in top 1


Epoch 59 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 62, global step 441: 'val_metric' was not in top 1


Epoch 62 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 65, global step 462: 'val_metric' was not in top 1


Epoch 65 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 68, global step 483: 'val_metric' was not in top 1


Epoch 68 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 71, global step 504: 'val_metric' was not in top 1


Epoch 71 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 74, global step 525: 'val_metric' was not in top 1


Epoch 74 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 77, global step 546: 'val_metric' was not in top 1


Epoch 77 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 80, global step 567: 'val_metric' was not in top 1


Epoch 80 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 83, global step 588: 'val_metric' was not in top 1


Epoch 83 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 86, global step 609: 'val_metric' was not in top 1


Epoch 86 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 89, global step 630: 'val_metric' was not in top 1


Epoch 89 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 92, global step 651: 'val_metric' was not in top 1


Epoch 92 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 95, global step 672: 'val_metric' was not in top 1


Epoch 95 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 98, global step 693: 'val_metric' was not in top 1


Epoch 98 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 101, global step 714: 'val_metric' was not in top 1


Epoch 101 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 104, global step 735: 'val_metric' was not in top 1


Epoch 104 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 107, global step 756: 'val_metric' was not in top 1


Epoch 107 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 110, global step 777: 'val_metric' was not in top 1


Epoch 110 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 113, global step 798: 'val_metric' was not in top 1


Epoch 113 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 116, global step 819: 'val_metric' was not in top 1


Epoch 116 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 119, global step 840: 'val_metric' was not in top 1


Epoch 119 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 122, global step 861: 'val_metric' was not in top 1


Epoch 122 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 125, global step 882: 'val_metric' was not in top 1


Epoch 125 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 128, global step 903: 'val_metric' was not in top 1


Epoch 128 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 131, global step 924: 'val_metric' was not in top 1


Epoch 131 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 134, global step 945: 'val_metric' was not in top 1


Epoch 134 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 137, global step 966: 'val_metric' was not in top 1


Epoch 137 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 140, global step 987: 'val_metric' was not in top 1


Epoch 140 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 143, global step 1008: 'val_metric' was not in top 1


Epoch 143 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 146, global step 1029: 'val_metric' was not in top 1


Epoch 146 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 149, global step 1050: 'val_metric' was not in top 1


Epoch 149 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 152, global step 1071: 'val_metric' was not in top 1


Epoch 152 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 155, global step 1092: 'val_metric' was not in top 1


Epoch 155 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 158, global step 1113: 'val_metric' was not in top 1


Epoch 158 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 161, global step 1134: 'val_metric' was not in top 1


Epoch 161 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 164, global step 1155: 'val_metric' was not in top 1


Epoch 164 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 167, global step 1176: 'val_metric' was not in top 1


Epoch 167 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 170, global step 1197: 'val_metric' was not in top 1


Epoch 170 - Average Val Metric: 0.0000


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 173, global step 1218: 'val_metric' reached 0.74400 (best 0.74400), saving model to '/root/autodl-tmp/checkpoints/checkpoint_epoch=173-val_metric=0.74.ckpt' as top 1


Epoch 173 - Average Val Metric: 0.7440


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 176, global step 1239: 'val_metric' was not in top 1


Epoch 176 - Average Val Metric: 0.7256


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 179, global step 1260: 'val_metric' reached 0.76166 (best 0.76166), saving model to '/root/autodl-tmp/checkpoints/checkpoint_epoch=179-val_metric=0.76.ckpt' as top 1


Epoch 179 - Average Val Metric: 0.7617


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 182, global step 1281: 'val_metric' was not in top 1


Epoch 182 - Average Val Metric: 0.7396


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 185, global step 1302: 'val_metric' was not in top 1


Epoch 185 - Average Val Metric: 0.7300


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 188, global step 1323: 'val_metric' was not in top 1


Epoch 188 - Average Val Metric: 0.7123


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 191, global step 1344: 'val_metric' was not in top 1


Epoch 191 - Average Val Metric: 0.7317


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 194, global step 1365: 'val_metric' was not in top 1


Epoch 194 - Average Val Metric: 0.7174


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 197, global step 1386: 'val_metric' reached 0.77452 (best 0.77452), saving model to '/root/autodl-tmp/checkpoints/checkpoint_epoch=197-val_metric=0.77.ckpt' as top 1


Epoch 197 - Average Val Metric: 0.7745


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 200, global step 1407: 'val_metric' was not in top 1


Epoch 200 - Average Val Metric: 0.7619


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 203, global step 1428: 'val_metric' reached 0.78442 (best 0.78442), saving model to '/root/autodl-tmp/checkpoints/checkpoint_epoch=203-val_metric=0.78.ckpt' as top 1


Epoch 203 - Average Val Metric: 0.7844


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 206, global step 1449: 'val_metric' was not in top 1


Epoch 206 - Average Val Metric: 0.7232


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 209, global step 1470: 'val_metric' was not in top 1


Epoch 209 - Average Val Metric: 0.7287


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 212, global step 1491: 'val_metric' was not in top 1


Epoch 212 - Average Val Metric: 0.7007


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 215, global step 1512: 'val_metric' was not in top 1


Epoch 215 - Average Val Metric: 0.7471


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 218, global step 1533: 'val_metric' was not in top 1


Epoch 218 - Average Val Metric: 0.7488


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 221, global step 1554: 'val_metric' was not in top 1


Epoch 221 - Average Val Metric: 0.7482


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 224, global step 1575: 'val_metric' was not in top 1


Epoch 224 - Average Val Metric: 0.7451


## Predict on the test set



In [None]:
weights =torch.load("checkpoint_epoch=188-val_metric=0.77.ckpt")['state_dict']
model.load_state_dict(weights)
model.eval();
model.to("cuda");

In [None]:
# Non-random transforms to be cached
inference_transforms = Compose([
    EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image"], axcodes="RAS")
])

In [None]:
import cc3d

id_to_name = {1: "apo-ferritin", 
              2: "beta-amylase",
              3: "beta-galactosidase", 
              4: "ribosome", 
              5: "thyroglobulin", 
              6: "virus-like-particle"}

In [None]:
def do_one_eval(truth, predict, threshold):
    P=len(predict)
    T=len(truth)

    if P==0:
        hit=[[],[]]
        miss=np.arange(T).tolist()
        fp=[]
        metric = [P,T,len(hit[0]),len(miss),len(fp)]
        return hit, fp, miss, metric

    if T==0:
        hit=[[],[]]
        fp=np.arange(P).tolist()
        miss=[]
        metric = [P,T,len(hit[0]),len(miss),len(fp)]
        return hit, fp, miss, metric

    #---
    distance = predict.reshape(P,1,3)-truth.reshape(1,T,3)
    distance = distance**2
    distance = distance.sum(axis=2)
    distance = np.sqrt(distance)
    p_index, t_index = linear_sum_assignment(distance)

    valid = distance[p_index, t_index] <= threshold
    p_index = p_index[valid]
    t_index = t_index[valid]
    hit = [p_index.tolist(), t_index.tolist()]
    miss = np.arange(T)
    miss = miss[~np.isin(miss,t_index)].tolist()
    fp = np.arange(P)
    fp = fp[~np.isin(fp,p_index)].tolist()

    metric = [P,T,len(hit[0]),len(miss),len(fp)] #for lb metric F-beta copmutation
    return hit, fp, miss, metric

In [None]:
id_to_thread = {1: 235, 
              2: 260,
              3: 355, 
              4: 600, 
              5: 510, 
              6: 530}

In [None]:
def process_tomography_and_compute_lb(valid_id, model, BLOB_THRESHOLD=50, CERTAINTY_THRESHOLD=0.1):
    location_df = []

    for run in valid_id:
        tomo = np.load(f"train_image_{run}.npy")
        tomo_patches, coordinates = extract_3d_patches_minimal_overlap([tomo], 96)
        tomo_patched_data = [{"image": img} for img in tomo_patches]
        tomo_ds = CacheDataset(data=tomo_patched_data, transform=inference_transforms, cache_rate=1.0)

        pred_masks = []

        for i in range(len(tomo_ds)):
            input_tensor = tomo_ds[i]['image'].unsqueeze(0).to("cuda")
            model_output = model(input_tensor)

            probs = torch.softmax(model_output[0], dim=0)
            thresh_probs = probs > 0.01
            _, max_classes = thresh_probs.max(dim=0)

            pred_masks.append(max_classes.cpu().numpy())

        reconstructed_mask = reconstruct_array(pred_masks, coordinates, tomo.shape)

        location = {}
        classes = [1, 2, 3, 4, 5, 6]

        for c in classes:
            cc = cc3d.connected_components(reconstructed_mask == c)
            stats = cc3d.statistics(cc)
            zyx = stats['centroids'][1:] * 10.012444
            zyx_large = zyx[stats['voxel_counts'][1:] >id_to_thread[c]]
            xyz = np.ascontiguousarray(zyx_large[:, ::-1])

            location[id_to_name[c]] = xyz

        df = dict_to_df(location, run)
        location_df.append(df)

    location_df = pd.concat(location_df)

    def compute_lb(submit_df, overlay_dir):
        eval_df = []
        for id in list(submit_df['experiment'].unique()):
            truth = read_one_truth(id, overlay_dir)
            id_df = submit_df[submit_df['experiment'] == id]
            for p in PARTICLE:
                p = dotdict(p)
                print('\r', id, p.name, end='', flush=True)
                xyz_truth = truth[p.name]
                xyz_predict = id_df[id_df['particle_type'] == p.name][['x', 'y', 'z']].values
                hit, fp, miss, metric = do_one_eval(xyz_truth, xyz_predict, p.radius * 0.5)
                eval_df.append(dotdict(
                    id=id, particle_type=p.name,
                    P=metric[0], T=metric[1], hit=metric[2], miss=metric[3], fp=metric[4],
                ))
        eval_df = pd.DataFrame(eval_df)
        gb = eval_df.groupby('particle_type').agg('sum').drop(columns=['id'])
        gb.loc[:, 'precision'] = gb['hit'] / gb['P']
        gb.loc[:, 'precision'] = gb['precision'].fillna(0)
        gb.loc[:, 'recall'] = gb['hit'] / gb['T']
        gb.loc[:, 'recall'] = gb['recall'].fillna(0)
        gb.loc[:, 'f-beta4'] = 17 * gb['precision'] * gb['recall'] / (16 * gb['precision'] + gb['recall'])
        gb.loc[:, 'f-beta4'] = gb['f-beta4'].fillna(0)

        gb = gb.sort_values('particle_type').reset_index(drop=False)
        gb.loc[:, 'weight'] = [1, 0, 2, 1, 2, 1]
        lb_score = (gb['f-beta4'] * gb['weight']).sum() / gb['weight'].sum()
        lb_score2 = (gb['precision'] * gb['weight']).sum() / gb['weight'].sum()
        return gb, lb_score,lb_score2 

    gb, lb_score,lb_score2  = compute_lb(location_df, f'./overlay/ExperimentRuns')
    print('lb_score:', lb_score,lb_score2 )
    return gb, lb_score

gb, lb_score = process_tomography_and_compute_lb(['TS_6_4'], model)
print(lb_score)

In [None]:
gb

0.7989678346892483 0.46827098951368556

*	particle_type	P	T	hit	miss	fp	precision	recall	f-beta4	weight
* 0	apo-ferritin	62	58	46	12	16	0.741935	0.793103	0.789899	1
* 1	beta-amylase	18	9	7	2	11	0.388889	0.777778	0.734568	0
* 2	beta-galactosidase	26	12	10	2	16	0.384615	0.833333	0.779817	2
* 3	ribosome	109	74	54	20	55	0.495413	0.729730	0.709977	1
* 4	thyroglobulin	86	30	26	4	60	0.302326	0.866667	0.780919	2
* 5	virus-like-particle	15	10	10	0	5	0.666667	1.000000	0.971429	1