### Setup

In [1]:
from pathlib import Path
import sys
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
import torch
import factorizer as ft
from monai import transforms
from tqdm import tqdm
from monai.metrics import DiceMetric
from monai.transforms import AsDiscrete

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

  from .autonotebook import tqdm as notebook_tqdm


device(type='cuda', index=0)

### Registry definition

In [2]:
import yaml
from torch import nn, optim
import pytorch_lightning as pl
import monai
import factorizer as ft


def lambda_constructor(loader, node):
    lambda_expr = "lambda " + loader.construct_scalar(node)
    return eval(lambda_expr)


def get_constructor(obj):
    """Get constructor for an object."""

    def constructor(loader, node):
        if isinstance(node, yaml.nodes.ScalarNode):
            if node.value:
                out = obj(loader.construct_scalar(node))
            else:
                out = obj
        elif isinstance(node, yaml.nodes.SequenceNode):
            out = obj(*loader.construct_sequence(node, deep=True))
        elif isinstance(node, yaml.nodes.MappingNode):
            out = obj(**loader.construct_mapping(node, deep=True))

        return out

    return constructor


def add_attributes(obj, prefix=""):
    for attr_name in dir(obj):
        if not attr_name.startswith("_"):
            Loader.add_constructor(
                f"!{prefix}{attr_name}",
                get_constructor(getattr(obj, attr_name)),
            )


Loader = yaml.SafeLoader


# general
Loader.add_constructor("!eval", get_constructor(eval))
Loader.add_constructor("!lambda", lambda_constructor)


# pytorch
add_attributes(nn, "nn.")
add_attributes(optim, "optim.")


# pytorch lightning
add_attributes(pl.callbacks, "pl.")
add_attributes(pl.loggers, "pl.")


# monai
add_attributes(monai.losses, "monai.")
add_attributes(monai.networks.nets, "monai.")


# factorizer
add_attributes(ft, "ft.")


def read_config(path, loader=Loader):
    with open(path, "rb") as file:
        config = yaml.load(file, loader)

    return config

### Data module and inferer

In [3]:
config = read_config("../configs/isles2022-dwi&adc/config_isles2022-dwi&adc_fold0_swin-factorizer.yaml")
dm = config["data"]
dm.setup("fit")

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33moz2102[0m ([33mimage-segmentation-factorizer[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Loading dataset: 100%|██| 4/4 [00:00<00:00, 24.39it/s]
Loading dataset: 100%|██| 4/4 [00:00<00:00, 37.25it/s]


In [4]:
inferer = ft.ISLESInferer(
    spacing=[2.0, 2.0, 2.0],
    spatial_size=[64, 64, 64],
    overlap=0.5,
    post="class",
    mode = "constant"
)

In [5]:
dice_metric = DiceMetric(include_background=True, reduction="mean")
post_pred = AsDiscrete(threshold=0.5)
post_label = AsDiscrete(threshold=0.5)

### Inference given a checkpoint path

In [6]:
def dice_score_chkpt(checkpoint_path):
    model = ft.SemanticSegmentation.load_from_checkpoint(checkpoint_path, inferer=inferer).to(device)
    print(model.device)
    model.eval()
    dice_scores = []

    with torch.no_grad():
        for batch in tqdm(dm.val_dataloader()):
            batch['input'] = batch['input'].to(device)
            batch['target'] = batch['target'].to(device)
            pred = model.inferer.get_postprocessed(batch, model)
            
            dice_metric(
                y_pred=post_pred(pred['input']),
                y=post_label(batch['target']),
            )
            
    dice_scores = dice_metric.aggregate(reduction="mean_batch").cpu().numpy()
    dice_metric.reset()

    return dice_scores.mean()

In [7]:
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))
print(torch.cuda.get_device_capability(0))

1.12.1+cu116
11.6
True
NVIDIA RTX A4000
(8, 6)


In [8]:
checkpoint_root = Path("/Data/logs/isles2022-dwi&adc/fold0/image-segmentation-factorizer/51ij4wr5/checkpoints/epoch=999-step=100000.ckpt")
# print("Dataset Dice (mean):", dice_score_chkpt(checkpoint_root))
print("Nice, coherent with what we have in wandb.")

Nice, coherent with what we have in wandb.


### Use a new rank in inference

In [9]:
import os
os.listdir("/Data/logs/isles2022-dwi&adc/fold0/image-segmentation-factorizer/51ij4wr5/checkpoints/")

['epoch=999-step=100000.ckpt']

In [10]:
model = ft.SemanticSegmentation.load_from_checkpoint(checkpoint_root, inferer=inferer).to(device)
model.eval()
print()

Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.





In [11]:
def dice_score_chkpt(checkpoint_path, rank = 1):
    model = ft.SemanticSegmentation.load_from_checkpoint(checkpoint_path, inferer=inferer).to(device)
    # change the rank in the model
    for name, module in model.net.named_modules():
        if isinstance(module, ft.FactorizerSubblock):
            module.factorize.update_rank(rank = rank)

    model.eval()
    dice_scores = []

    with torch.no_grad():
        for batch in tqdm(dm.val_dataloader()):
            batch['input'] = batch['input'].to(device)
            batch['target'] = batch['target'].to(device)
            pred = model.inferer.get_postprocessed(batch, model)
            
            dice_metric(
                y_pred=post_pred(pred['input']),
                y=post_label(batch['target']),
            )
            
    dice_scores = dice_metric.aggregate(reduction="mean_batch").cpu().numpy()
    dice_metric.reset()

    return dice_scores.mean()

In [None]:
for rank in [2, 4, 8, 16, 32]:
    print("Dataset Dice (mean) for rank =", rank, "is:", dice_score_chkpt(checkpoint_root, rank = rank))

  8%|█▍                | 4/50 [00:21<04:06,  5.36s/it]

In [None]:
print(10)

### Trial & Error part

In [None]:
model.net
print("The attributes of model.net are: ", dir(model.net))
print("The named modules of model.net are: ", model.net.named_modules())

print("The encoder at model net is:", model.net.encoder) 

In [None]:
print("The first encoderblock of the encoder is:", model.net.encoder.blocks[0])

In [None]:
print("The first block of the first encoderblock of the encoder is:", model.net.encoder.blocks[0].blocks[0].blocks.nmf)

In [None]:
# Extracted this info from config
#   factorize: !ft.NMF
#   rank: 1
#   num_iters: 5
#   num_grad_steps: null
#   init: uniform
#   solver: hals
#   dropout: 0.1
ranks = []
for name, module in model.net.named_modules():
    if isinstance(module, ft.FactorizerSubblock):
        module.factorize.update_rank(rank = 10)
        print()

        ranks.append(module.factorize.rank)
        
print("Factorizer ranks:", ranks)