In [None]:
# import torch
# import torchvision as tv


# import matplotlib.pyplot as plt
from src.BBBC039.datamodule import BBBC039DataModule

In [None]:
img_transform = tv.transforms.Compose(
    [
        tv.transforms.ToTensor(),
        tv.transforms.Normalize(mean=(262.3408031194739), std=(220.18462229587527)),
    ]
)
train_dataset = BBBC039Dataset("data", subset="training", img_transform=img_transform)
val_dataset = BBBC039Dataset("data", subset="validation", img_transform=img_transform)
test_dataset = BBBC039Dataset("data", subset="test", img_transform=img_transform)

In [None]:
criterion = None
optimiser = None

In [None]:
import argparse

#

- Running a model with wandb logging
- Crossing over to lighting
- Setting up test.py
- setting up visualisations
- writing up report sections


In [None]:
import torch

model = torch.hub.load(
    "mateuszbuda/brain-segmentation-pytorch",
    "unet",
    in_channels=1,
    out_channels=3,
    init_features=32,
    pretrained=True,
)

In [None]:
model.summary()

In [None]:
import numpy as np
import torchvision as tv

for i in range(5):
    image, mask = dataset[i]
    _, (ax1, ax2) = plt.subplots(ncols=2)
    ax1.imshow(np.array(tv.transforms.ToPI`LImage()(image)), cmap="gray")
    ax2.imshow(np.array(tv.transforms.ToPILImage()(mask)), cmap="gray")
    plt.show()

In [None]:
annot = x[:, :, 0]
annot = skimage.morphology.label(annot)

# filter small objects, e.g. micronulcei
annot = skimage.morphology.remove_small_objects(annot, min_size=25)

# find boundaries
boundaries = skimage.segmentation.find_boundaries(annot)

for k in range(2, 2, 2):
    boundaries = skimage.morphology.binary_dilation(boundaries)

# BINARY LABEL

# prepare buffer for binary label
label_binary = np.zeros((annot.shape + (3,)))

# write binary label
label_binary[(annot == 0) & (boundaries == 0), 0] = 1
label_binary[(annot != 0) & (boundaries == 0), 1] = 1
label_binary[boundaries == 1, 2] = 1

In [None]:
label_binary.astype(np.uint8).shape

In [None]:
# import torchvision as tv
# import torch
# plt.imshow(tv.transforms.ToPILImage()
# (
#     tv.transforms.AutoAugment()(torch.Tensor((label_binary*).astype(np.uint8).reshape(3, 520, 696)).type(torch.uint8))
#     )
# )
# plt.show()

In [None]:
# plt.imshow(boundaries)
# plt.show()
# plt.imshow(label_binary)
# plt.show()
# np.unique(label_binary[None,None,:], axis=2)
label_binary.shape

In [None]:
plt.imshow(x[:, :, 0])
plt.show()
plt.imshow(skimage.morphology.label(x[:, :, 0]))
plt.show()

In [None]:
plt.imshow(x[:, :, 0] >= 1)

In [None]:
#!/usr/bin/env python3
# -*- coding:utf-8 -*-


import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import skimage

from enum import StrEnum, auto, unique
from pathlib import Path
from typing import Callable, TypeAlias, Final, final

import math
import requests
import zipfile
from pathlib import Path
from tqdm import tqdm

import numpy as np
import numpy.typing as npt
import torch
import torchvision as tv
from PIL import Image
from torch.utils.data import Dataset


_IDArray: TypeAlias = npt.NDArray[np.unicode_]
_ImageArray: TypeAlias = npt.NDArray[np.float_]


class BBBC039Dataset(Dataset):
    """
    BBBC039 dataset.

    Segmentation mask dataset for U2OS cell nuclei obtained through fluorescence
    microscopy. Contains approximately 23,000 individually manually annotated nuclei.
    TIF image and PNG mask files are 520x696 pixels.

    Sourced from: https://bbbc.broadinstitute.org/BBBC039
        Caicedo et al. 2018, available from the Broad Bioimage Benchmark Collection
        [Ljosa et al., Nature Methods, 2012]

    The dataset is made available across 3 archives consisting of image, mask and meta
    directories which are all required when loading the dataset. Each archive should be
    extracted and placed in a root directory. The minimal directory layout requirement
    is described as follows:

        ```
        <root directory>
        ├── images
        │   ├── ...
        ├── masks
        │   ├── ...
        └── metadata
            ├── test.txt
            ├── training.txt
            └── validation.txt
        ```
    """

    # Dataset source url prefix
    SOURCE_URL: Final[str] = "https://data.broadinstitute.org/bbbc/BBBC039/"
    NUCLEI_MINIMUM_SIZE: Final[int] = 25
    NUCLEI_BORDER_WIDTH: Final[int] = 2

    # Required file extensions
    @final
    @unique
    class FileExtension(StrEnum):
        Images = ".tif"
        Masks = ".png"

    # Required subdirectories
    @final
    @unique
    class Subdirectory(StrEnum):
        Images = auto()
        Masks = auto()
        Metadata = auto()

    # Required subsets
    @final
    @unique
    class Subset(StrEnum):
        Training = auto()
        Validation = auto()
        Test = auto()

    def __init__(
        self,
        root: str,
        subset: str = Subset.Training.value,
        img_transform: Callable = None,
        msk_transform: Callable = None,
        download_dataset: bool = False,
        force: bool = False,
    ):
        super().__init__()

        # Resolve dataset root directory
        self.root: Path = Path(root).resolve(strict=True)

        # Download dataset if requested
        if download_dataset:
            self._download_dataset(root, force=force)

        # Resolve required dataset directories
        imgs_dir: Path = (self.root / self.Subdirectory.Images.value).resolve(
            strict=True
        )
        msks_dir: Path = (self.root / self.Subdirectory.Masks.value).resolve(
            strict=True
        )
        metadata_dir: Path = (self.root / self.Subdirectory.Metadata.value).resolve(
            strict=True
        )

        # Extract subset ids
        subsets: list[str] = [s.value for s in self.Subset]
        subset_id_file: Path = metadata_dir / f"{subset}.txt"
        match subset:
            case None:  # Entire dataset
                self.imgs_ids = [imgs_dir / i for i in imgs_dir if i.is_file()]
                self.msks_ids = [msks_dir / i for i in msks_dir if i.is_file()]
            case subset if subset in subsets:  # Dataset subset
                self.imgs_ids, self.msks_ids = self._read_subset(subset_id_file)
                self.imgs_ids = [imgs_dir / i for i in self.imgs_ids]
                self.msks_ids = [msks_dir / i for i in self.msks_ids]
            case _:
                raise ValueError(f"dataset subset '{subset}' is invalid -> {subsets}")

        # Define additional transformations
        self.img_transform = img_transform
        self.msk_transform = msk_transform

        # Preprocess the dataset masks
        self._preprocess_masks()

    def _download_dataset(self, directory: Path, force: bool = False):
        """
        Download the BBBC039 dataset.

        Args:
            directory (Path): Directory for the location of the downloaded dataset.
            force (bool, optional): Overwrite any existing files. Defaults to False.
        """
        # Prepare each required dataset archive
        archives: list[str] = [f"{i.value}.zip" for i in self.Subdirectory]
        for archive in archives:
            # Download archive
            self._download_file(self.SOURCE_URL + archive, directory / archive, force)
            # Extract archive
            if not force:
                with zipfile.ZipFile(directory / archive, "r") as archive:
                    archive.extractall(directory)

    def _download_file(self, url: str, filepath: Path, force: bool = False):
        """
        Download utility function.

        Args:
            url (str): URL for the download file.
            filepath (Path): Filepath to write the output file.
            force (bool, optional): Overwrite any existing file. Defaults to False.
        """
        # Relative path used for concise stdout
        relpath: Path = Path(filepath).relative_to(Path().absolute())
        # Skip download if possible
        if relpath.exists() and not force:
            print("{:>20}: already downloaded (not forced)".format(str(relpath)))
            return
        # Stream response for the given chunk size
        chunk_size: int = 1024
        response: requests.models.Response = requests.get(url, stream=True)
        # Write response stream to file
        with open(relpath, "wb") as write_file:
            for chunk in tqdm(
                response.iter_content(chunk_size=chunk_size),
                "{:>20}".format(str(relpath)),
                unit="B",
                total=math.ceil(int(response.headers["Content-Length"]) / chunk_size),
                unit_scale=True,
                unit_divisor=chunk_size,
            ):
                write_file.write(chunk)

    def _read_subset(self, filepath: Path) -> tuple[_IDArray, _IDArray]:
        """
        Extract a specific BBBC039 dataset subset of image and mask ids.

        Args:
            filepath (Path): Path to metadata id split file.

        Returns:
            tuple[_IDArray, _IDArray]: Image and mask id arrays.
        """
        # File ids are mask filepaths, and image ids replace the file extension
        with open(filepath, "r") as split:
            masks_ids: list[str] = split.read().splitlines()
            image_ids: list[str] = [
                mask_id.replace(self.FileExtension.Masks, self.FileExtension.Images)
                for mask_id in masks_ids
            ]
        # Return the ids as numpy arrays for faster runtime
        return np.array(image_ids), np.array(masks_ids)

    def _preprocess_masks(self):
        """
        Preprocess masks to have the correct segmentation classes. Defined classes are
        specified for the background, nuclei and boundaries. These new images are saved
        to the 'masks_preprocessed' directory and the msks_ids are updated.

        Adapted from (CarpenterLab, 2018): https://github.com/carpenterlab/unet4nuclei
        """

        # Define the new mask directory
        msks_dir_processed: Path = self.root / "masks_preprocessed"
        msks_dir_processed.mkdir(exist_ok=True)

        # Process each file with a progress bar
        for mask_filepath in tqdm(
            self.msks_ids,
            total=len(self.msks_ids),
            desc="Preprocessing masks",
            unit="img",
        ):
            # Define the masks filepath and skip if it exists
            new_mask_filepath: Path = msks_dir_processed / mask_filepath.name
            if new_mask_filepath.exists():
                continue

            # Read the mask file and cut off extra channels and alpha
            mask: _ImageArray = np.array(Image.open(mask_filepath))
            mask = mask[:, :, 0]
            # Label all individual nuclei and remove smaller ones
            mask = skimage.morphology.label(mask)
            mask = skimage.morphology.remove_small_objects(
                mask, min_size=self.NUCLEI_MINIMUM_SIZE
            )
            # Extract nuclei boundaries
            mask_boundaries: _ImageArray = skimage.segmentation.find_boundaries(mask)
            for _ in range(2, self.NUCLEI_BORDER_WIDTH, 2):
                mask_boundaries = skimage.morphology.binary_dilation(mask_boundaries)
            # Create the mask channels (background, nuclei, boundary)
            mask_label: _ImageArray = np.zeros(mask.shape + (3,))
            mask_label[(mask == 0) & (mask_boundaries == 0), 0] = 1
            mask_label[(mask != 0) & (mask_boundaries == 0), 1] = 1
            mask_label[mask_boundaries == 1, 2] = 1

            # Save the new image
            Image.fromarray(mask_label.astype(np.uint8)).save(new_mask_filepath)

        # Update the masks lookup dir
        self.msks_ids = [msks_dir_processed / path.name for path in self.msks_ids]

    def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Read into memory and transform an image and mask pair.

        Args:
            index (int): Index of the image/mask pair to load and transform.

        Returns:
            tuple[torch.Tensor, torch.Tensor]: Loaded and transformed image and mask.
        """

        # Load image and mask, convert to float for performing any further transforms
        img: _ImageArray = np.array(Image.open(self.imgs_ids[index]), dtype=np.float_)
        msk: _ImageArray = np.array(Image.open(self.msks_ids[index]), dtype=np.float_)

        # Apply transformations and coerce image and mask into tensors
        img = self.img_transform(img) if self.img_transform else img
        img = img if isinstance(img, torch.Tensor) else tv.transforms.ToTensor()(img)
        msk = self.msk_transform(msk) if self.msk_transform else msk
        msk = msk if isinstance(msk, torch.Tensor) else tv.transforms.ToTensor()(msk)
        return img, msk

    def __len__(self):
        # Dunder required for torch dataloaders
        return len(self.imgs_ids)

    def extract_to_numpy(self) -> tuple[_ImageArray, _ImageArray]:
        """
        Convert the image and masks of the dataset to a numpy arrays with any applied
        transformations. This allows for easy loading of the dataset into a simple numpy
        format for any testing or other exploratory analyses.

        Returns:
            tuple[ImageArray, ImageArray]: Transformed image and mask numpy arrays.
        """
        loaded_imgs = []
        loaded_msks = []
        for img, msk in self:
            # Load and transform the image and masks to numpy arrays
            loaded_imgs.append(np.array(tv.transforms.ToPILImage()(img)))
            loaded_msks.append(np.array(tv.transforms.ToPILImage()(msk)))
        # Convert the final lists to arrays
        return np.array(loaded_imgs), np.array(loaded_msks)

In [None]:
BBBC039Dataset("data")

In [None]:
plt.imshow(i * 255)