# Imports

In [1]:
import os
import pickle
import numpy as np
from PIL import Image
from contextlib import redirect_stdout

from pycocotools.coco import COCO
from pycocotools import mask as maskUtils
from pycocotools.cocoeval import COCOeval

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import models, transforms
import torchmetrics

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

# Pre-processing

In [2]:
annotation_file = "./data/LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_train_val.json"
annotation_test_file = "./data/LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_test.json"

image_dir = "./data/LIVECell_dataset_2021/images/livecell_train_val_images"
image_test_dir = "./data/LIVECell_dataset_2021/images/livecell_test_images"

In [3]:
class CustomImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.tif')]
        self.transform = transform

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_filenames[idx])
        image = Image.open(img_name).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image

In [4]:
# Define a dataset and loader without transformations
dataset = CustomImageDataset(image_dir=image_dir, transform=transforms.ToTensor())
loader = DataLoader(dataset, batch_size=4, shuffle=True)

# Initialize running means and standard deviations
mean = 0.0
std = 0.0
num_samples = 0

for images in loader:
    batch_samples = images.size(0)
    images = images.view(batch_samples, images.size(1), -1)  # Flatten H and W
    mean += images.mean(2).sum(0)
    std += images.std(2).sum(0)
    num_samples += batch_samples

mean /= num_samples
std /= num_samples

print("Mean:", mean)
print("Std:", std)

Mean: tensor([0.5021, 0.5021, 0.5021])
Std: tensor([0.0422, 0.0422, 0.0422])


In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

In [6]:
models_path = "./models"
transform_path = os.path.join(models_path, "transform.pkl")

with open(transform_path, 'wb') as f:
    pickle.dump(transform, f)

# Data-loader

In [7]:
class LIVECellDataModule(pl.LightningDataModule):
    def __init__(self, annotation_file, annotation_test_file, image_test_dir, image_dir, batch_size=4, transform=None, val_split=0.2):
        super().__init__()
        self.annotation_file = annotation_file
        self.annotation_test_file = annotation_test_file
        self.image_dir = image_dir
        self.image_test_dir = image_test_dir
        
        self.batch_size = batch_size
        self.transform = transform
        self.val_split = val_split
            
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            with open(os.devnull, 'w') as fnull:
                with redirect_stdout(fnull):
                    coco = COCO(self.annotation_file)
            all_image_ids = coco.getImgIds()
    
            np.random.seed(123)
            np.random.shuffle(all_image_ids)
            num_val = int(len(all_image_ids) * self.val_split)
    
            train_image_ids = all_image_ids[num_val:]
            val_image_ids = all_image_ids[:num_val]
        
            self.train_dataset = LIVECellDataset(coco, self.image_dir, image_ids=train_image_ids, transform=self.transform)
            self.val_dataset = LIVECellDataset(coco, self.image_dir, image_ids=val_image_ids, transform=self.transform)
        elif stage == 'test':
            with open(os.devnull, 'w') as fnull:
                with redirect_stdout(fnull):
                    coco_test = COCO(self.annotation_test_file)
                all_image_ids = coco_test.getImgIds()
            self.test_dataset = LIVECellDataset(coco_test, self.image_test_dir, image_ids=all_image_ids, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.collate_fn)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, collate_fn=self.collate_fn)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, collate_fn=self.collate_fn)

    @staticmethod
    def collate_fn(batch):
        return tuple(zip(*batch))


class LIVECellDataset(Dataset):
    def __init__(self, coco, image_dir, image_ids=None, transform=None):
        self.coco = coco
        self.image_dir = image_dir
        self.transform = transform
        # Use provided image IDs or all if not provided
        self.image_ids = image_ids if image_ids is not None else self.coco.getImgIds()

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_info = self.coco.loadImgs(image_id)[0]
        image_path = os.path.join(self.image_dir, image_info['file_name'])
        image = Image.open(image_path).convert('RGB')
        
        annotation_ids = self.coco.getAnnIds(imgIds=image_id)
        annotations = self.coco.loadAnns(annotation_ids)
        
        mask = np.zeros((image_info['height'], image_info['width']), dtype=np.uint8)
        for ann in annotations:
            # Decode the binary mask from RLE
            rle_mask = self.coco.annToRLE(ann)
            binary_mask = maskUtils.decode(rle_mask)
            mask[binary_mask == 1] = ann['category_id']
        mask = torch.as_tensor(mask, dtype=torch.uint8)
        # One-hot encode the mask to have a probability distribution at every pixel
        one_hot_mask = torch.zeros((2, mask.size(0), mask.size(1)), dtype=torch.float32)        
        one_hot_mask[0, :, :] = (mask == 0).float()
        one_hot_mask[1, :, :] = (mask == 1).float()
        if self.transform is not None:
            image = self.transform(image)
        return image, one_hot_mask

## Example

In [8]:
data_module = LIVECellDataModule(
    annotation_file=annotation_file, annotation_test_file=annotation_test_file, 
    image_dir=image_dir, image_test_dir=image_test_dir, 
    batch_size=4, transform=transform
)
data_module.setup(stage='fit')
dataloader = data_module.train_dataloader()

for images, targets in dataloader:
    images = torch.stack(images)
    targets = torch.stack(targets)
    HEIGHT, WIDTH = images[0].shape[1:]
    print()
    print(f"Width = {WIDTH}, Height = {HEIGHT}")
    print("Image batch shape:", images.shape)
    print("Mask batch shape:", targets.shape)
    break


Width = 704, Height = 520
Image batch shape: torch.Size([4, 3, 520, 704])
Mask batch shape: torch.Size([4, 2, 520, 704])


# Model

In [9]:
from model import SegmentationModule

## Training

In [10]:
BATCH_SIZE = 4
N_CLASSES = 2  # cell vs. background

checkpoint_path = os.path.join(models_path, "checkpoints")
checkpoint_filename = "MobileNetV2"

In [11]:
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=3,
    min_delta=0.01, 
    verbose=True,
    mode='min'
)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    save_top_k=1,
    mode='min',
    dirpath=checkpoint_path,
    filename=checkpoint_filename
)

In [12]:
model = SegmentationModule(batch_size=BATCH_SIZE, num_classes=N_CLASSES, height=HEIGHT, width=WIDTH)
# Save the model object
model_path = os.path.join(models_path, checkpoint_filename) + ".pkl"
with open(model_path, 'wb') as f:
    pickle.dump(model, f)

trainer = pl.Trainer(
    max_epochs=20,
    devices=1,
    accelerator="cuda",
    callbacks=[early_stop_callback, checkpoint_callback],
    check_val_every_n_epoch=1
)

trainer.fit(model, data_module)

Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /home/pcss/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth
100%|█████████████████████████████████████████████████| 13.6M/13.6M [00:00<00:00, 19.1MB/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/pcss/Downloads/anadea_inst_segm/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/home/pcss/Downloads/anadea_inst_segm/.venv/lib/python3.11/site-packages/pytorch_lightnin

Sanity Checking: |                                                   | 0/? [00:00<?, ?it/s]

/home/pcss/Downloads/anadea_inst_segm/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/home/pcss/Downloads/anadea_inst_segm/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 0.323


Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Monitored metric val_loss did not improve in the last 3 records. Best score: 0.323. Signaling Trainer to stop.


# Evaluation

In [13]:
data_module.setup(stage='test')
dataloader = data_module.test_dataloader()

trainer.test(model=model, dataloaders=dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/pcss/Downloads/anadea_inst_segm/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing: |                                                           | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_loss_epoch        0.3248176574707031
         val_IoU            0.5862681269645691
         val_MaP                    0.0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss_epoch': 0.3248176574707031,
  'val_IoU': 0.5862681269645691,
  'val_MaP': 0.0}]