In [1]:
import random
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
# from tqdm import tqdm
import medpy.metric as metric

import torch.nn as nn
import torch.nn.functional as F
import torch
import wandb
from tqdm import tqdm

import monai
from monai.data import DataLoader, Dataset
from omegaconf import OmegaConf
from monai.transforms.utils import allow_missing_keys_mode
from monai.transforms import BatchInverseTransform
from monai.networks.nets import DynUNet

In [2]:
from tl_2d3d.data.make_dataset import make_dataloaders
from tl_2d3d.utils import hd95, set_seed

In [10]:
device = "cpu"

In [24]:
def test(model: torch.nn.Module, dataloader, inferer, config) -> None:
    dice_score = 0.0
    hd95_score = 0.0

    for batch in tqdm(dataloader):
        x = batch['image'].to(device)
        y = batch['label'].to(device)

        with torch.no_grad():
            y_pred = inferer(inputs=x, network=model)

        y = y.argmax(dim=1).squeeze(-1)           # 1, 256, 256
        y_pred = y_pred.argmax(dim=1).squeeze(-1) # 1, 256, 256

        dice_score += metric.dc(y_pred, y)
        hd95_score += hd95(y, y_pred, config)
    
    print(f"avg dice: {dice_score / len(dataloader)}, \navg hd95: {hd95_score / len(dataloader)}")

In [26]:
config = OmegaConf.load('2d-to-3d-transfer-learning/tl_2d3d/conf/config.yaml')
inferer = monai.inferers.SliceInferer(roi_size=[-1, -1], spatial_dim=2, sw_batch_size=1)
set_seed(config.hyperparameters.seed)

_, _, test_dataloader = make_dataloaders(config, use_dataset_a=True) # contains 20 volumes

model = torch.load("/work3/s204163/3dimaging_finalproject/weights/baseline2d_101/baseline2d_final.pt", map_location="cpu")
model.eval()

test(model, test_dataloader, inferer, config)

100%|█████████████████████████████████████████████████████████████████████| 20/20 [00:13<00:00,  1.50it/s]

avg dice: 0.703886128635445, 
avg hd95: 20.641459036681475





In [27]:
config = OmegaConf.load('2d-to-3d-transfer-learning/tl_2d3d/conf/config.yaml')
inferer = monai.inferers.SliceInferer(roi_size=[-1, -1], spatial_dim=2, sw_batch_size=1)
set_seed(config.hyperparameters.seed)

_, _, test_dataloader = make_dataloaders(config, use_dataset_a=True) # contains 20 volumes

model = torch.load("/work3/s204163/3dimaging_finalproject/weights/finetune2d_101/finetune2d_final.pt", map_location="cpu")
model.eval()

test(model, test_dataloader, inferer, config)

100%|█████████████████████████████████████████████████████████████████████| 20/20 [00:13<00:00,  1.50it/s]

avg dice: 0.7755802148196972, 
avg hd95: 21.117527552336703





# 3D

In [None]:
config = OmegaConf.load('2d-to-3d-transfer-learning/tl_2d3d/conf/config.yaml')
config.data.image_dims = [256, 256, 32]
config.data.crop_size  = [-1, -1, -1]
config.model.num_dimensions = 3
inferer = monai.inferers.SlidingWindowInferer(roi_size=(256, 256, -1), sw_batch_size=1)

_, _, test_dataloader = make_dataloaders(config, use_dataset_a=True) # contains 20 volumes

model = torch.load("/work3/s204163/3dimaging_finalproject/weights/baseline3d_101/baseline3d_final.pt", map_location="cpu")
model.eval()
test(model, test_dataloader, inferer, config)

  0%|                                                                              | 0/20 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████| 20/20 [04:03<00:00, 12.20s/it]

avg dice: 0.9277031313190689, 
avg hd95: 9.565172027284092



