In [None]:
import random
import numpy as np
import torch

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
import os
from dann_utils import get_vendor_info
from dann_utils import get_splits
from utils import generate_patient_info, crop_image
from utils import preprocess, preprocess_image, inSplit


vendor_info = get_vendor_info("./data/MnM2/dataset_information.csv")
vendor_info

In [None]:
if not os.path.isdir("dann_preprocessed"):
    os.makedirs("dann_preprocessed")

splits = get_splits(vendor_info, os.path.join("dann_preprocessed", "splits.pkl"))

In [None]:
patient_info = generate_patient_info(vendor_info, os.path.join("dann_preprocessed", "patient_info.pkl"))

In [None]:
spacings = [
    patient_info["{:03d}_{}".format(id, axis)]["spacing"] for axis in ["SA", "LA"] for id in (
        splits["train"] + splits["train"] + splits["val"]
    )
]
spacing_target = np.percentile(np.vstack(spacings), 50, 0)

In [None]:
if not os.path.isdir("dann_preprocessed/training/"): os.makedirs("dann_preprocessed/training/")
if not os.path.isdir("dann_preprocessed/validation/"): os.makedirs("dann_preprocessed/validation/")
if not os.path.isdir("dann_preprocessed/soft_validation/"): os.makedirs("dann_preprocessed/soft_validation/")
if not os.path.isdir("dann_preprocessed/testing/"): os.makedirs("dann_preprocessed/testing/")

preprocess(
    {k:v for k,v in patient_info.items() if inSplit(k, splits["train"])},
    spacing_target, "dann_preprocessed/training/"
)
preprocess(
    {k:v for k,v in patient_info.items() if inSplit(k, splits["val"])},
    spacing_target, "dann_preprocessed/validation/"
)
preprocess(
    {k:v for k,v in patient_info.items() if inSplit(k, splits["val"])},
    spacing_target, "dann_preprocessed/soft_validation/", soft_preprocessing=True
)
preprocess(
    {k:v for k,v in patient_info.items() if inSplit(k, splits["test"])},
    spacing_target, "dann_preprocessed/testing/", soft_preprocessing=True
)


In [None]:
import torch.nn as nn
import os
import torch

from baseline_1 import Baseline_1
from unet_model import Baseline_DANN
from utils import AttrDict
from utils import GDLoss, CELoss
from utils import device
from utils import Validator, Checkpointer
from utils import dann_training
from dann_loader import DANNDataLoader, DANNAllPatients
from utils import BATCH_SIZE, EPOCHS, CKPT
from utils import transform_augmentation_downsample, transform
from utils import plot_history


# In[ ]:


Model = Baseline_DANN

model = nn.ModuleDict([
    [axis, Model(
        AttrDict(**{
            "lr": 0.01,
            "functions": [GDLoss, CELoss]
        })
    )] for axis in ["SA", "LA"]
]).to(device)

In [None]:
ckpts = None
if ckpts is not None:
    for axis, ckpt in ckpts.items():
        _, start = os.path.split(ckpt)
        start = int(start.replace(".pth", ""))
        ckpt = torch.load(ckpt)
        model[axis].load_state_dict(ckpt["M_dann"])
        model[axis].optimizer.load_state_dict(ckpt["M_dann_optim"])
else:
    start = 1

In [None]:
validators = {
    "SA": Validator(5),
    "LA": Validator(5)
}

for axis in ["SA", "LA"]:
    dann_training(
        model[axis],
        range(start, EPOCHS),
        torch.utils.data.DataLoader(
            DANNAllPatients(
                os.path.join("dann_preprocessed/training/", axis),
                transform=transform_augmentation_downsample
            ),
            batch_size=BATCH_SIZE, shuffle=False, num_workers=0
        ),
        DANNDataLoader(
            os.path.join("dann_preprocessed/validation/", axis),
            batch_size=BATCH_SIZE, transform=transform
        ),
        validators[axis],
        Checkpointer(os.path.join(CKPT, "M_dann", axis))
    )

    plot_history(validators[axis].get_history("val"), "M-dann")

In [1]:
import torch.nn as nn
import os
import torch
import pickle
import numpy as np
import pandas as pd

from unet_model import Baseline_DANN
from utils import device
from dann_loader import DANNDataLoader
from utils import BATCH_SIZE, CKPT
from utils import transform
from dann_utils import infer_predictions
from dann_utils import get_splits
from utils import postprocess_predictions, display_results


# In[ ]:


Model = Baseline_DANN

model = nn.ModuleDict([
    [axis, Model()] for axis in ["SA", "LA"]
]).to(device)

for axis in ["SA", "LA"]:
    ckpt = os.path.join(CKPT, "M_dann", axis, "best_000.pth")
    temp = torch.load(ckpt)
    model[axis].load_state_dict(torch.load(ckpt)["M"])
    model[axis].to(device)
    model.eval()

    infer_predictions(
        os.path.join("dann_inference", axis),
        DANNDataLoader(
            f"dann_preprocessed/testing/{axis}",
            batch_size = BATCH_SIZE,
            transform = transform,
            transform_gt = False
        ),
        model[axis]
    )

  temp = torch.load(ckpt)
  model[axis].load_state_dict(torch.load(ckpt)["M"])


In [4]:
with open(os.path.join("dann_preprocessed", "splits.pkl"), "rb") as f:
    splits = pickle.load(f)

with open(os.path.join("dann_preprocessed", "patient_info.pkl"),'rb') as f:
    patient_info = pickle.load(f)

spacings = [
    patient_info["{:03d}_{}".format(id, axis)]["spacing"] for axis in ["SA", "LA"] for id in (
        splits["train"] + splits["train"] + splits["val"]
    )
]
spacing_target = np.percentile(np.vstack(spacings), 50, 0)

current_spacing = np.percentile(np.vstack(spacings), 50, 0)

In [5]:
results = {}
for axis in ["SA", "LA"]:
    results[axis] = postprocess_predictions(
        os.path.join("dann_inference", axis),
        patient_info,
        current_spacing,
        os.path.join("dann_postprocessed", axis),
    )

with open("dann_postprocessed/results.pkl", "wb") as f:
    pickle.dump(results,f)

display_results(results)

      RV_ED_DC   RV_ED_HD  RV_ES_DC   RV_ES_HD     RV_DC      RV_HD  LV_ED_DC  \
axis                                                                            
SA    0.798823  39.658234  0.723117  40.663955  0.760970  40.161095  0.899953   
LA    0.855576  16.365214  0.792488  19.870313  0.824032  18.117763  0.865388   

       LV_ED_HD  LV_ES_DC   LV_ES_HD     LV_DC      LV_HD  MYO_ED_DC  \
axis                                                                   
SA    11.399328  0.862108  11.450422  0.881030  11.424875   0.712162   
LA          inf  0.849597        inf  0.857493        inf   0.031715   

       MYO_ED_HD  MYO_ES_DC   MYO_ES_HD    MYO_DC      MYO_HD  
axis                                                           
SA     21.247111   0.758861   22.285677  0.735512   21.766394  
LA    164.651470   0.035318  171.330777  0.033516  167.991123  
