In [1]:
%pip install more-itertools

You should consider upgrading via the '/home/tom.osika/ultrasound_tbi/med_env/bin/python -m pip install --upgrade pip' command.[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
import tbitk.ai.deep_learning as dl
import tbitk.ai.dl_cli as dl_cli
import more_itertools
import pickle
import torch
import torchvision
import matplotlib.pyplot as plt
import gc

from pathlib import Path
from monai.transforms import (
    RandRotated,
    RandScaleIntensityd,
    ThresholdIntensityd,
    RandAffined,
    LoadImaged,
    EnsureTyped,
    AddChanneld,
    Resized,
    ScaleIntensityd,
    RandFlipd,
    Compose,
    MapTransform
)
from tbitk.ai.transforms import eval_transforms, eval_transforms
from tbitk.ai.constants import DEFAULT_BEST_MODEL_NAME

Distutils was imported before Setuptools. This usage is discouraged and may exhibit undesirable behaviors or errors. Please use Setuptools' objects directly or at least import Setuptools first.


In [3]:
ALL_DATA_DIR = Path("../../data/")

file_patterns = []
file_patterns.append(str((ALL_DATA_DIR / "HRPO-E01240.1a/preprocessed/ONUS-00[12]HV/butterfly-iq/**/*.mha").resolve()))
file_patterns.append(str((ALL_DATA_DIR / "training_head_phantom-20220121/preprocessed/butterfly-iq/[AB]/**/*.mha").resolve()))
file_patterns.append(str((ALL_DATA_DIR / "training_head_phantom-20220121/preprocessed/butterfly-iq/E/e-[12].mha").resolve()))

NEED_TO_EXTRACT = 0

In [4]:
if NEED_TO_EXTRACT:
    # Extract the data
    extract_cmd = ["extract"]

    extract_cmd.append("--root_dir")
    extract_cmd.append(str(Path("data/").resolve()))

    extract_cmd.append("--file_patterns")
    for fp in file_patterns:
        extract_cmd.append(fp)

    extract_cmd.append("--print_found_files")
    
    dl_cli.main(extract_cmd)

In [5]:
class PrintTransform(MapTransform):
    def __init__(self):
        pass
    
    def __call__(self, data):
        print(data)
        return data

NETWORK_INPUT_SHAPE = (256, 256)
def get_transforms(l):
    transforms = [
        LoadImaged(keys=["x", "y"], image_only=True),
        EnsureTyped(keys=["x", "y"]),
        AddChanneld(keys=["x", "y"]),
        Resized(keys=["x", "y"], spatial_size=NETWORK_INPUT_SHAPE, mode="nearest"),
    ]
    if "gain" in l:
        transforms.append(RandScaleIntensityd(keys=["x"], prob=0.5, factors=(0, 0.75)))
        transforms.append(ThresholdIntensityd(keys=["x"], threshold=1, above=False, cval=1))
    if "randflip" in l:
        transforms.append(RandFlipd(keys=["x", "y"], prob=0.5, spatial_axis=1))
    if "randtranslate" in l:
        transforms.append(RandAffined(keys=["x", "y"], prob=0.5, translate_range=(0, 50), padding_mode="zeros"))
    if "randrotate" in l:
        transforms.append(RandRotated(keys=["x", "y"], prob=0.5, range_x=0.35, padding_mode="zeros"))
        
    transforms.extend([
        EnsureTyped(keys=["x"], data_type="numpy"),
        ScaleIntensityd(keys=["x"]),
        EnsureTyped(keys=["x", "y"])
    ])

    return Compose(transforms)


augmentation_names = ["gain", "randflip", "randtranslate", "randrotate"]
combos = list(more_itertools.powerset(augmentation_names))

In [6]:
model_name_to_dice = {}
LOG_TO_FILE = True
for combo in combos:
    transforms = get_transforms(combo)
    
    model = dl.get_model()
    
    train_dir = Path("data/train").resolve()
    val_dir = Path("data/val").resolve()
    train_loader = dl.get_data_loader([str(train_dir)], transforms)
    val_loader = dl.get_data_loader([str(val_dir)], dl.eval_transforms)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logfile = train_dir / "log.txt"
    
    if LOG_TO_FILE:
        with open(str(logfile), "a+") as f:
            f.write(f"###### {combo} ######\n")
    else:
        print(combo)

    # TODO: below
#     for _ in range(1):
#         batch = next(iter(train_loader))
#     plt.imshow(torchvision.utils.make_grid(batch["x"]).permute(1, 2, 0)); plt.show()
#     plt.imshow(torchvision.utils.make_grid(batch["y"]).permute(1, 2, 0)); plt.show()
    combo_tb_experiment_subdir = "".join(combo)
    if len(combo_tb_experiment_subdir) == 0:
        combo_tb_experiment_subdir = "no_aug"
    combo_tb_experiment_subdir = Path("runs") / combo_tb_experiment_subdir
    
    model_name = "".join(combo)
    if len(model_name) == 0:
        model_name = "no_aug"
    
    model_name += ".pt"
        
    dl.train_model(
        model,
        train_loader,
        val_loader,
        device,
        model_name,
        Path("models/").resolve(),
        monitor_with_tb=True,
        tb_logdir=combo_tb_experiment_subdir,
        status_logfile=logfile,
        num_epochs=40
    )
    
    best_model = dl.load_model(Path("models").resolve() / model_name)
    test_dir = Path("data/test").resolve()
    test_loader = dl.get_data_loader(
            [str(test_dir)], eval_transforms
    )
    test_dice = dl.test_model(best_model, test_loader, device)
    model_name_to_dice[model_name] = test_dice
    s = f"###### {model_name}: {test_dice} ######"
    if LOG_TO_FILE:
        with open(str(logfile), "a+") as f:
            f.write(s + "\n")
    print(s)
    
    gc.collect()
    gc.collect()
    
    
with open("results.p", "wb") as f:
    pickle.dump(model_name_to_dice, f)

0.9586665034294128
###### no_aug.pt: 0.9586665034294128 + ######
0.956766664981842
###### gain.pt: 0.956766664981842 + ######
0.9400314092636108
###### randflip.pt: 0.9400314092636108 + ######
0.9489437341690063
###### randtranslate.pt: 0.9489437341690063 + ######
0.9391183257102966
###### randrotate.pt: 0.9391183257102966 + ######
0.9442896246910095
###### gainrandflip.pt: 0.9442896246910095 + ######
0.9435879588127136
###### gainrandtranslate.pt: 0.9435879588127136 + ######
0.9424556493759155
###### gainrandrotate.pt: 0.9424556493759155 + ######
0.9363917708396912
###### randfliprandtranslate.pt: 0.9363917708396912 + ######
0.931233823299408
###### randfliprandrotate.pt: 0.931233823299408 + ######
0.9301144480705261
###### randtranslaterandrotate.pt: 0.9301144480705261 + ######
0.932685375213623
###### gainrandfliprandtranslate.pt: 0.932685375213623 + ######
0.9157265424728394
###### gainrandfliprandrotate.pt: 0.9157265424728394 + ######
0.9287465214729309
###### gainrandtranslateran