In [1]:
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import csv
import gc
import logging
import os
import json
import shutil
import sys
import time
from collections import OrderedDict
from datetime import datetime
import matplotlib.pyplot as plt
import monai.transforms as mt
import numpy as np
import torch
import torch.distributed as dist
import yaml
from monai.apps import get_logger
from monai.auto3dseg.utils import datafold_read
from monai.bundle import BundleWorkflow, ConfigParser
from monai.config import print_config
from monai.data import DataLoader, Dataset, decollate_batch
from monai.metrics import CumulativeAverage
from monai.utils import (
    BundleProperty,
    ImageMetaKey,
    convert_to_dst_type,
    ensure_tuple,
    look_up_option,
    optional_import,
    set_determinism,
)
from monai.inferers import SlidingWindowInfererAdapt
#from monai.networks.nets.cell_sam_wrapper
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import WeightedRandomSampler
from torch.utils.data.distributed import DistributedSampler
#from torch.utils.tensorboard import SummaryWriter

#if __package__ in (None, ""):
#    from components import LabelsToFlows, LoadTiffd, LogitsToLabels
#    from cell_sam_wrapper import CellSamWrapper
#else:
from components import LabelsToFlows, LoadTiffd, LogitsToLabels
from components import CellLoss, CellAcc

import importlib
import cell_sam_wrapper
importlib.reload(cell_sam_wrapper)
from cell_sam_wrapper import CellSamWrapper

The Cellpose dataset needs to be downloaded from the following link.
TODO: Write more text here 

The SAM weights need to be download as well, put the links etc

In [2]:
# Paths of training data, testing data and output log files
data_list_path = 'cellpose_toy_datalist.json'
data_root = os.path.normpath('/home/vnath/Downloads/cellpose_dataset/')
sam_weights_path = os.path.normpath('/home/vnath/Downloads/cellpose_dataset/sam_vit_b_01ec64.pth')

# Define the network, load SAM weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CellSamWrapper(checkpoint=sam_weights_path)
model.to(device)
print('SAM ViT-B weights loaded succesfully ...')



CellSamWrapper auto_resize_inputs True network_resize_roi [1024, 1024] checkpoint /home/vnath/Downloads/cellpose_dataset/sam_vit_b_01ec64.pth


  state_dict = torch.load(f)


SAM ViT-B weights loaded succesfully ...


In [3]:
# Create Required Data lists
# Append root path to training, validation and testing data list
with open(data_list_path, 'r') as f:
    data = json.load(f)

validation_fold = 0
training_list = []
validation_list = []
testing_list = []

# Process training data
for item in data.get("training", []):
    # Append the base path to image and label
    item["image"] = os.path.join(data_root, item["image"])
    item["label"] = os.path.join(data_root, item["label"])
    
    if item["fold"] == validation_fold:
        validation_list.append(item)
    else:
        training_list.append(item)

# Process testing data
for item in data.get("testing", []):
    # Append the base path to image and label
    item["image"] = os.path.join(data_root, item["image"])
    item["label"] = os.path.join(data_root, item["label"])
    testing_list.append(item)

print('Appended Data Root to Json file list ...')
print('Total Training Data: {}'.format(len(training_list)))
print('Total Validation Data: {}'.format(len(validation_list)))
print('Total Testing Data: {}'.format(len(testing_list)))

Appended Data Root to Json file list ...
Total Training Data: 40
Total Validation Data: 41
Total Testing Data: 68


In [4]:
# Training & Validation Transforms
roi_size = [256, 256]
train_transforms = mt.Compose([
    LoadTiffd(keys=["image", "label"]),
    mt.EnsureTyped(
        keys=["image", "label"], data_type="tensor", dtype=torch.float
    ),
    mt.ScaleIntensityd(keys="image", minv=0, maxv=1, channel_wise=True),
    mt.ScaleIntensityRangePercentilesd(
        keys="image",
        lower=1,
        upper=99,
        b_min=0.0,
        b_max=1.0,
        channel_wise=True,
        clip=True,
    ),
    mt.SpatialPadd(keys=["image", "label"], spatial_size=roi_size),
    mt.RandSpatialCropd(keys=["image", "label"], roi_size=roi_size),
    mt.RandAffined(
        keys=["image", "label"],
        prob=0.5,
        rotate_range=np.pi,
        scale_range=[-0.5, 0.5],
        mode=["bilinear", "nearest"],
        spatial_size=roi_size,
        cache_grid=True,
        padding_mode="border",
    ),
    mt.RandAxisFlipd(keys=["image", "label"], prob=0.5),
    mt.RandGaussianNoised(keys=["image"], prob=0.25, mean=0, std=0.1),
    mt.RandAdjustContrastd(keys=["image"], prob=0.25, gamma=(1, 2)),
    mt.RandGaussianSmoothd(keys=["image"], prob=0.25, sigma_x=(1, 2)),
    mt.RandHistogramShiftd(keys=["image"], prob=0.25, num_control_points=3),
    mt.RandGaussianSharpend(keys=["image"], prob=0.25),
    LabelsToFlows(keys="label", flow_key="flow")
])

val_transforms = mt.Compose([
    LoadTiffd(keys=["image", "label"], allow_missing_keys=True),
    mt.EnsureTyped(
                keys=["image", "label"],
                data_type="tensor",
                dtype=torch.float,
                allow_missing_keys=True,
            ),
    mt.ScaleIntensityRangePercentilesd(
                keys="image",
                lower=1,
                upper=99,
                b_min=0.0,
                b_max=1.0,
                channel_wise=True,
                clip=True,
            ),
    LabelsToFlows(keys="label", flow_key="flow", allow_missing_keys=True)
])

# Datasets & Dataloaders for training, validation and testing
train_dataset = Dataset(
                        data=training_list, 
                        transform=train_transforms
                    )
train_loader = DataLoader(
                        train_dataset,
                        batch_size = 1,
                        shuffle=True,
                        num_workers=2
                    )

val_dataset = Dataset(
                        data=validation_list, 
                        transform=val_transforms
                    )
val_loader = DataLoader(
                        val_dataset,
                        batch_size = 1,
                        shuffle=False,
                        num_workers=2
                    )

# Training loop with validation
loss_function = CellLoss()
acc_function = CellAcc

# Define the Sliding Window Inferer
sliding_inferrer = SlidingWindowInfererAdapt(
    roi_size=[256, 256],
    sw_batch_size=1,
    overlap=0.25,
    #mode="gaussian",
    cache_roi_weight_map=True,
    progress=False)

# TODO Just remove and hardset it to being True in the training loop
channels_last = True

# TODO This path need to be defined above
ckpt_path = os.path.join(data_root, 'sanity_model')
if os.path.exists(ckpt_path) == False:
    os.mkdir(ckpt_path)
num_epochs = 3

# TODO The validation is yet to be defined
num_epochs_per_validation = 1

# Optimizer 
optimizer = torch.optim.SGD(
                        params=model.parameters(),
                        momentum=0.9,
                        lr=0.01,
                        weight_decay=1e-5
                    )

best_ckpt_path = os.path.join(ckpt_path, "model.pt")
intermediate_ckpt_path = os.path.join(ckpt_path, "model_final.pt")

best_metric = -1
start_epoch = 0
best_metric_epoch = -1
epoch_loss_values = []
val_epoch_loss_value = []

for epoch in range(0, num_epochs):
    start_time = time.time()
    train_loss, train_acc = 0, 0

    model.train()
    memory_format = torch.channels_last if channels_last else torch.preserve_format
    run_loss = CumulativeAverage()
    avg_loss = avg_acc = 0

    for idx, batch_data in enumerate(train_loader):
        data = (
            batch_data["image"]
            .as_subclass(torch.Tensor)
            .to(memory_format=memory_format, device=device)
        )

        target = (
            batch_data["flow"]
            .as_subclass(torch.Tensor)
            .to(memory_format=memory_format, device=device)
        )

        optimizer.zero_grad(set_to_none=True)

        logits = model(data)

        # print('logits', logits.shape, logits.dtype)
        loss = loss_function(logits.float(), target)

        loss.backward()
        optimizer.step()
        
        batch_size = data.shape[0]
        run_loss.append(loss, count=batch_size)
        avg_loss = run_loss.aggregate() 

        print(f"Epoch {epoch}/{num_epochs} {idx}/{len(train_loader)} ")
        print(f"loss: {avg_loss:.4f} time {time.time() - start_time:.2f}s ")
    
    optimizer.zero_grad()

    epoch_loss_values.append(avg_loss)
    # Model Saving & Checkpointing
    if avg_loss < best_metric:
        best_metric = avg_loss
        state_dict = model.state_dict()
        torch.save({"state_dict": state_dict}, best_ckpt_path)


# Write the loss plot visualization here
%matplotlib inline
fig, axs = plt.subplots(1, 1, figsize=(8, 8))

# Plot 1: Loss
axs[0].plot(range(0, num_epochs), epoch_loss_values, marker='o')
axs[0].set_title('Training Loss')
axs[0].set_xlabel('Epoch')
axs[0].set_ylabel('Loss Value')
axs[0].grid(True)
axs[0].set_xticks(range(0, num_epochs))
#axs[0].set_yticks(sorted(set(loss_values)))

# Plot 2: Validation Loss
# TODO Update epoch_loss_values with val_loss_values variable, remove this eventually
#axs[1].plot(range(0, num_epochs), epoch_loss_values, marker='o', color='orange')
#axs[1].set_title('Validation Loss')
#axs[1].set_xlabel('Epoch')
#axs[1].set_ylabel('Val Loss Value')
#axs[1].grid(True)
#axs[1].set_xticks(epoch_values)
#axs[1].set_yticks(sorted(set(val_loss_values)))

# Adjust layout
plt.tight_layout()

# Show the plot
plt.show()

#print(run_loss)
#print(avg_loss)
# Testing

# Visualization of the dataset  



Epoch 0/3 0/40 
loss: 7.9356 time 1.53s 
Epoch 0/3 1/40 
loss: 6.5465 time 2.03s 
Epoch 0/3 2/40 
loss: 6.6404 time 2.54s 
Epoch 0/3 3/40 
loss: 5.9076 time 3.06s 
Epoch 0/3 4/40 
loss: 5.8856 time 3.57s 
Epoch 0/3 5/40 
loss: 5.8127 time 4.09s 
Epoch 0/3 6/40 
loss: 5.2565 time 4.61s 
Epoch 0/3 7/40 
loss: 5.2298 time 5.12s 
Epoch 0/3 8/40 
loss: 4.8683 time 5.64s 
Epoch 0/3 9/40 
loss: 4.5087 time 6.16s 
Epoch 0/3 10/40 
loss: 4.5209 time 6.66s 
Epoch 0/3 11/40 
loss: 4.3839 time 7.16s 
Epoch 0/3 12/40 
loss: 4.4858 time 7.66s 
Epoch 0/3 13/40 
loss: 4.4320 time 8.16s 
Epoch 0/3 14/40 
loss: 4.3583 time 8.67s 
Epoch 0/3 15/40 
loss: 4.2036 time 9.18s 
Epoch 0/3 16/40 
loss: 4.0403 time 9.70s 
Epoch 0/3 17/40 
loss: 4.1045 time 10.21s 
Epoch 0/3 18/40 
loss: 3.9796 time 10.73s 
Epoch 0/3 19/40 
loss: 3.8721 time 11.24s 
Epoch 0/3 20/40 
loss: 3.7623 time 11.75s 
Epoch 0/3 21/40 
loss: 3.7757 time 12.26s 
Epoch 0/3 22/40 
loss: 3.7683 time 12.77s 
Epoch 0/3 23/40 
loss: 3.7118 time 13.



2024-08-30 17:17:25,924 - INFO - CUDA out of memory. Tried to allocate 12.00 MiB. GPU 0 has a total capacity of 23.62 GiB of which 25.31 MiB is free. Including non-PyTorch memory, this process has 22.45 GiB memory in use. Of the allocated memory 22.03 GiB is allocated by PyTorch, and 220.34 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
2024-08-30 17:17:26,527 - INFO - CUDA out of memory. Tried to allocate 12.00 MiB. GPU 0 has a total capacity of 23.62 GiB of which 26.75 MiB is free. Including non-PyTorch memory, this process has 22.46 GiB memory in use. Of the allocated memory 22.03 GiB is allocated by PyTorch, and 228.26 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expan

OutOfMemoryError: CUDA out of memory. Tried to allocate 12.00 MiB. GPU 0 has a total capacity of 23.62 GiB of which 26.75 MiB is free. Including non-PyTorch memory, this process has 22.46 GiB memory in use. Of the allocated memory 22.04 GiB is allocated by PyTorch, and 224.94 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)