The following notebook is based on: https://github.com/pashtari/factorizer-isles22/blob/master/get_started.ipynb

In [5]:
# !pip install -r ../requirements.txt

In [6]:
# pip install lightning

In [7]:
# pip install git+https://github.com/pashtari/factorizer.git@0.0.1

In [8]:
import torch
from torch import nn
import factorizer as ft

import os

import numpy as np
import torch
from torch import nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import pytorch_lightning as pl
from monai import transforms
from monai.data import Dataset, DataLoader
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.inferers import SlidingWindowInferer
import SimpleITK as sitk
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap

In [9]:
print(12*12*12)

1728


In [10]:
import einops

## Check the model

In [11]:
# swin_factorizer = ft.factorizer(
#     in_channels=4,
#     out_channels=3,
#     spatial_size=(8, 8, 8),
#     encoder_depth=(1, 1, 1),
#     encoder_width=(2, 2, 2),
#     strides=(1, 2, 2),
#     decoder_depth=(1, 1, 1, 1),
#     norm=ft.LayerNorm,
#     reshape=(ft.SWMatricize, {'head_dim': 2, 'patch_size': 2}),
#     act=nn.ReLU,
#     factorize=ft.NMF,
#     rank=1,
#     num_iters=5,
#     init="uniform",
#     solver="hals",
#     mlp_ratio=2,
#     dropout=0.1
# )

# x = torch.rand((1, 4, 8, 8, 8))

# y = swin_factorizer(x)
# print("Output shape: ", y.shape)

In [12]:
# print(swin_factorizer)

## Check image shape and visualize

In [13]:
import os
print(os.path.exists("/Data/"))
print(os.path.exists("/Data/data/isles/sub-strokecase0001/ses-0001/anat/sub-strokecase0001_ses-0001_flair_registered.nii.gz"))
print(os.path.exists("Data/data/isles/sub-strokecase0100/ses-0001/dwi/sub-strokecase0100_ses-0001_dwi.nii.gz"))

True
True
False


In [14]:
# set data path and
dataset_dir = "/Data/data/isles"

# set patient ID and images path
id_ = "sub-strokecase0001"
dwi_path = f"{dataset_dir}/{id_}/ses-0001/dwi/{id_}_ses-0001_dwi.nii.gz"
adc_path = f"{dataset_dir}/{id_}/ses-0001/dwi/{id_}_ses-0001_adc.nii.gz"

msk_path = f"{dataset_dir}/derivatives/{id_}/ses-0001/{id_}_ses-0001_msk.nii.gz"

# make data dictionary
data = {
    "image": [dwi_path, adc_path],
    "mask": msk_path,
}

load_image = transforms.LoadImaged(
    ["image", "mask"],
    ensure_channel_first=True,
    allow_missing_keys=True,
)

# load image data
data = load_image(data)
print(f"image shape: {data['image'].shape}")
print(f"mask shape: {data['mask'].shape}")


dwi_image = data["image"][0]
adc_image = data["image"][1]
msk_image = data["mask"][0]

# pick a slice with the largest lesion volume for visualization
slc = msk_image.sum((0, 1)).argmax()

fig, ax = plt.subplots(1, 3, dpi=200)
# visulize DWI image
ax[0].imshow(dwi_image[:, :, slc], cmap="gray", origin="lower")
ax[0].set_title("DWI")
ax[0].set_axis_off()

# visulize ADC image
ax[1].imshow(adc_image[:, :, slc], cmap="gray", origin="lower")
ax[1].set_title("ADC")
ax[1].set_axis_off()

# visulize mask
ax[2].imshow(dwi_image[:, :, slc], "gray", origin="lower")
masked = np.ma.masked_where(msk_image[:, :, slc] == 0, dwi_image[:, :, slc])
ax[2].imshow(masked, ListedColormap(["red"]), alpha=0.9, origin="lower")
ax[2].set_title("Ground Truth")
ax[2].set_axis_off()

RuntimeError: Unknown type: itkMatrixF44

## Setup transforms for training and visualization

In [15]:
def get_train_transform():
    train_transform = [
        ft.ReadImaged(["image", "label"], ensure_channel_first=True),
        transforms.SqueezeDimd("image", dim=1),
        transforms.CropForegroundd(["image", "label"], source_key="image"),
        transforms.NormalizeIntensityd("image", nonzero=True, channel_wise=True),
        transforms.Spacingd(
            ["image", "label"],
            pixdim=(2.0, 2.0, 2.0),
            mode=("bilinear", "bilinear"),
        ),
        transforms.RandSpatialCropd(
            ["image", "label"], roi_size=(64, 64, 64), random_size=False
        ),
        transforms.RandAffined(
            ["image", "label"],
            prob=0.15,
            spatial_size=(64, 64, 64),
            rotate_range=[30 * np.pi / 180] * 3,
            scale_range=[0.3] * 3,
            mode=("bilinear", "bilinear"),
            as_tensor_output=False,
        ),
        transforms.RandFlipd(["image", "label"], prob=0.5, spatial_axis=0),
        transforms.RandFlipd(["image", "label"], prob=0.5, spatial_axis=1),
        transforms.RandFlipd(["image", "label"], prob=0.5, spatial_axis=2),
        transforms.RandGaussianNoised("image", prob=0.15, std=0.1),
        transforms.RandGaussianSmoothd(
            "image",
            prob=0.15,
            sigma_x=(0.5, 1.5),
            sigma_y=(0.5, 1.5),
            sigma_z=(0.5, 1.5),
        ),
        transforms.RandScaleIntensityd("image", prob=0.15, factors=0.3),
        transforms.RandShiftIntensityd("image", prob=0.15, offsets=0.1),
        transforms.RandAdjustContrastd("image", prob=0.15, gamma=(0.7, 1.5)),
        transforms.AsDiscreted("label", threshold=0.5),
        transforms.ToTensord(["image", "label"]),
    ]
    train_transform = transforms.Compose(train_transform)
    return train_transform


def get_val_transform():
    val_transform = [
        ft.ReadImaged(
            ["image", "label"], ensure_channel_first=True, allow_missing_keys=True
        ),
        transforms.SqueezeDimd("image", dim=1),
        transforms.NormalizeIntensityd("image", nonzero=True, channel_wise=True),
        transforms.ToTensord(["image", "label"], allow_missing_keys=True),
    ]
    val_transform = transforms.Compose(val_transform)
    return val_transform

## Registry & Read config function

In [16]:
import yaml
from torch import nn, optim
import pytorch_lightning as pl
import monai
import factorizer as ft

def lambda_constructor(loader, node):
    lambda_expr = "lambda " + loader.construct_scalar(node)
    return eval(lambda_expr)


def get_constructor(obj):
    """Get constructor for an object."""

    def constructor(loader, node):
        if isinstance(node, yaml.nodes.ScalarNode):
            if node.value:
                out = obj(loader.construct_scalar(node))
            else:
                out = obj
        elif isinstance(node, yaml.nodes.SequenceNode):
            out = obj(*loader.construct_sequence(node, deep=True))
        elif isinstance(node, yaml.nodes.MappingNode):
            out = obj(**loader.construct_mapping(node, deep=True))

        return out

    return constructor


def add_attributes(obj, prefix=""):
    for attr_name in dir(obj):
        if not attr_name.startswith("_"):
            Loader.add_constructor(
                f"!{prefix}{attr_name}",
                get_constructor(getattr(obj, attr_name)),
            )


Loader = yaml.SafeLoader


# general
Loader.add_constructor("!eval", get_constructor(eval))
Loader.add_constructor("!lambda", lambda_constructor)


# pytorch
add_attributes(nn, "nn.")
add_attributes(optim, "optim.")


# pytorch lightning
add_attributes(pl.callbacks, "pl.")
add_attributes(pl.loggers, "pl.")


# monai
add_attributes(monai.losses, "monai.")
add_attributes(monai.networks.nets, "monai.")


# factorizer
add_attributes(ft, "ft.")


def read_config(path, loader=Loader):
    with open(path, "rb") as file:
        config = yaml.load(file, loader)

    return config

## Quick checks

In [17]:
# ft.ISLESDataModule()

In [18]:
monai.__version__

'0.9.1'

In [19]:
import numpy as np
np.__version__

'1.26.4'

In [20]:
### Lightning module

## Main

In [21]:
# from argparse import ArgumentParser, Namespace

# get config
# parser = ArgumentParser(description="""Train the model.""", add_help=False)
# parser.add_argument("--config", type=str, required=True)
# args = parser.parse_args()
path_config = "/users/eleves-a/2022/oussama.zouhry/factorizer-project/image-segmentation-factorizer/factorizer/configs/isles2022-dwi&adc/config_isles2022-dwi&adc_fold0_swin-factorizer.yaml"
config = read_config(path_config)

# data
dm = config["data"]

# init model
task_cls, task_params = config["task"]
if "checkpoint_path" in task_params:
    model = task_cls.load_from_checkpoint(strict=False, **task_params)
else:
    model = task_cls(**task_params)

# init trainer
trainer = Trainer(**config["training"])

# fit model
trainer.fit(model, dm)

Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.


NameError: name 'Trainer' is not defined