# Out-of-Distribution Detection for Model Refinement in Cardiac Image Segmentation

In [1]:
# !jupyter nbconvert --to script Model_Refinement.ipynb
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"Is CUDA available: {torch.cuda.is_available()}")
from utils import device
print(device)

PyTorch version: 2.4.0
Is CUDA available: True
cuda


In [6]:
!nvidia-smi -L

GPU 0: NVIDIA TITAN Xp (UUID: GPU-921f5bc6-fb1c-0b1a-5dcb-e1a61d694298)
GPU 1: NVIDIA TITAN Xp (UUID: GPU-14a6fb47-7208-937e-a68b-afae647baea4)
GPU 2: NVIDIA TITAN Xp (UUID: GPU-00871fef-e5bf-78f4-6dad-66205263da35)
GPU 3: NVIDIA TITAN Xp (UUID: GPU-607f1c50-94fe-5c77-2907-9c17eb0d1b42)
GPU 4: NVIDIA TITAN Xp (UUID: GPU-d82980ac-2a8f-8539-d663-89f819822a01)
GPU 5: NVIDIA TITAN Xp (UUID: GPU-15de566d-9114-12cd-2933-5bb551bd05b4)
GPU 6: NVIDIA TITAN Xp (UUID: GPU-e6702f9f-f275-8499-8700-e90423f98c65)


In [7]:
import random
import numpy as np
import torch

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Data preparation

In [8]:
import os
from utils import get_vendor_info
from utils import get_splits
from utils import generate_patient_info, crop_image
from utils import preprocess, preprocess_image, inSplit

The challenge cohort was composed of 360 patients with different rigth ventricle and left ventricle pathologies as well as healthy subjects. All subjects were scanned in four clinical centres in two different countries (Spain, Germany) using four different magnetic resonance scanner vendors (Siemens, General Electric and Philips ).

The training set contained 200 annotated images from four different centres. The CMR images have been segmented by experienced clinicians from the respective institutions, including contours for the left (LV) and right ventricle (RV) blood pools, as well as for the left ventricular myocardium (MYO). Labels are: 1 (LV), 2 (MYO) and 3 (RV) in both short-axis and long-axis views with a variety of difficult RV pathologies and remodelling as well as LV pathologies. This year we will focus on RV segmentation. Labels 1 and 2 will be provided but will not score in the final challenge results. 40 cases, 5 for each pathology, will be used to create a public leaderboard and will be added to the final testing set. Two pathologies (Tricuspidal Regurgitation and Congenital Arrhythmogenesis) will be not present in the training set but in the validation and testing sets to evaluate generalisation to unseen pathologies.

<table style="width:70%; margin: auto">
        <tbody><tr>
            <th>Pathology</th><th>Num. studies training</th> <th>Num. studies validation</th>
        </tr>
        <tr>
            <td style="text-align:left">Normal subjects</td><td>40</td><td>5</td>
        </tr>
        <tr>
            <td style="text-align:left">Dilated Left Ventricle</td><td>30</td><td>5</td>
        </tr>
        <tr>
            <td style="text-align:left">Hypertrophic Cardiomyopathy</td><td>30</td><td>5</td>
        </tr>
        <tr>
            <td style="text-align:left">Congenital Arrhythmogenesis</td><td>20</td><td>5</td>
        </tr>
        <tr>
            <td style="text-align:left">Tetralogy of Fallot</td><td>20</td><td>5</td>
        </tr>
        <tr>      
            <td style="text-align:left">Interatrial Comunication</td><td>20</td><td>5</td>
        </tr>
        <tr>       
            <td style="text-align:left">Dilated Right Ventricle</td><td>0</td><td>5</td>
        </tr>
        <tr>
            <td style="text-align:left">Tricuspidal Regurgitation</td><td>0</td><td>5</td>
        </tr>
</tbody></table>


In [9]:
vendor_info = get_vendor_info("./data/MnM2/dataset_information.csv")
vendor_info 

Unnamed: 0,SUBJECT_CODE,DISEASE,VENDOR,SCANNER,FIELD,PATH
0,001_SA,NOR,GE MEDICAL SYSTEMS,SIGNA EXCITE,1.5,./data/MnM2/dataset/001/001_SA_{}.nii.gz
1,001_LA,NOR,GE MEDICAL SYSTEMS,SIGNA EXCITE,1.5,./data/MnM2/dataset/001/001_LA_{}.nii.gz
2,002_SA,NOR,GE MEDICAL SYSTEMS,SIGNA EXCITE,1.5,./data/MnM2/dataset/002/002_SA_{}.nii.gz
3,002_LA,NOR,GE MEDICAL SYSTEMS,SIGNA EXCITE,1.5,./data/MnM2/dataset/002/002_LA_{}.nii.gz
4,003_SA,NOR,GE MEDICAL SYSTEMS,SIGNA EXCITE,1.5,./data/MnM2/dataset/003/003_SA_{}.nii.gz
...,...,...,...,...,...,...
795,358_LA,TRI,SIEMENS,Symphony,1.5,./data/MnM2/dataset/358/358_LA_{}.nii.gz
796,359_SA,TRI,SIEMENS,SymphonyTim,1.5,./data/MnM2/dataset/359/359_SA_{}.nii.gz
797,359_LA,TRI,SIEMENS,SymphonyTim,1.5,./data/MnM2/dataset/359/359_LA_{}.nii.gz
798,360_SA,TRI,SIEMENS,SymphonyTim,1.5,./data/MnM2/dataset/360/360_SA_{}.nii.gz


In [10]:
vendor_info[['VENDOR','SUBJECT_CODE']].groupby(['VENDOR']).count().rename(columns={'SUBJECT_CODE':'NUM_STUDIES'})

Unnamed: 0_level_0,NUM_STUDIES
VENDOR,Unnamed: 1_level_1
GE MEDICAL SYSTEMS,106
Philips Medical Systems,176
SIEMENS,438


In [11]:
if not os.path.isdir("preprocessed"):
    os.makedirs("preprocessed")
splits = get_splits(os.path.join("preprocessed", "splits.pkl"))

In [12]:
patient_info = generate_patient_info(vendor_info, os.path.join("preprocessed", "patient_info.pkl"))

In [29]:
spacings = [
    patient_info["{:03d}_{}".format(id, axis)]["spacing"] for axis in ["SA", "LA"] for id in (
        splits["train"]["lab"] + splits["train"]["ulab"] + splits["val"]
    )
]
spacing_target = np.percentile(np.vstack(spacings), 50, 0)

if not os.path.isdir("preprocessed/training/labelled"): os.makedirs("preprocessed/training/labelled")
if not os.path.isdir("preprocessed/training/unlabelled"): os.makedirs("preprocessed/training/unlabelled")
if not os.path.isdir("preprocessed/validation/"): os.makedirs("preprocessed/validation/")
if not os.path.isdir("preprocessed/soft_validation/"): os.makedirs("preprocessed/soft_validation/")
if not os.path.isdir("preprocessed/testing/"): os.makedirs("preprocessed/testing/")

preprocess(
    {k:v for k,v in patient_info.items() if inSplit(k, splits["train"]["lab"])},
    spacing_target, "preprocessed/training/labelled/"
)
preprocess(
    {k:v for k,v in patient_info.items() if inSplit(k, splits["train"]["ulab"])},
    spacing_target, "preprocessed/training/unlabelled/"
)
preprocess(
    {k:v for k,v in patient_info.items() if inSplit(k, splits["val"])},
    spacing_target, "preprocessed/validation/"
)
preprocess(
    {k:v for k,v in patient_info.items() if inSplit(k, splits["val"])},
    spacing_target, "preprocessed/soft_validation/", soft_preprocessing=True
)
preprocess(
    {k:v for k,v in patient_info.items() if inSplit(k, splits["test"])},
    spacing_target, "preprocessed/testing/", soft_preprocessing=True
)

## $\mathcal{M}$ - Supervised Training

In [28]:
import torch.nn as nn
import os
import torch

from baseline_1 import Baseline_1
from baseline_2 import Baseline_2
from utils import AttrDict
from utils import GDLoss, CELoss
from utils import device
from utils import Validator, Checkpointer
from utils import supervised_training
from utils import ACDCDataLoader, ACDCAllPatients
from utils import BATCH_SIZE, EPOCHS, CKPT
from utils import transform_augmentation_downsample, transform
from utils import plot_history

In [29]:
print(device)
Model = Baseline_2

model = nn.ModuleDict([
    [axis, Model(
        AttrDict(**{
            "lr": 0.01,
            "functions": [GDLoss, CELoss]
        })
    )] for axis in ["SA", "LA"]
]).to(device)

ckpts = None
if ckpts is not None:
    for axis, ckpt in ckpts.items():
        _, start = os.path.split(ckpt)
        start = int(start.replace(".pth", ""))
        ckpt = torch.load(ckpt)
        model[axis].load_state_dict(ckpt["M"])
        model[axis].optimizer.load_state_dict(ckpt["M_optim"])
else:
    start = 1
print(model)

cuda
ModuleDict(
  (SA): Baseline_2(
    (unet): Generic_UNet(
      (conv_blocks_localization): ModuleList(
        (0-1): 2 x Sequential(
          (0): StackedConvLayers(
            (blocks): Sequential(
              (0): ConvDropoutNormNonlin(
                (conv): Conv2d(960, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (instnorm): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
              )
            )
          )
          (1): StackedConvLayers(
            (blocks): Sequential(
              (0): ConvDropoutNormNonlin(
                (conv): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (instnorm): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
              )
            )
          )
        )
    

In [30]:
validators = {
    "SA": Validator(5),
    "LA": Validator(5)
}

for axis in ["SA", "LA"]:
    supervised_training(
        model[axis],
        range(start, EPOCHS),
        torch.utils.data.DataLoader(
            ACDCAllPatients(
                os.path.join("preprocessed/training/labelled/", axis),
                transform=transform_augmentation_downsample
            ),
            batch_size=BATCH_SIZE, shuffle=False, num_workers=0
        ),
        ACDCDataLoader(
            os.path.join("preprocessed/validation/", axis),
            batch_size=BATCH_SIZE, transform=transform
        ),
        validators[axis],
        Checkpointer(os.path.join(CKPT, "M", axis))
    )

    plot_history(validators[axis].get_history("val"))
torch.save(model.state_dict(), './model_saved.pth')

KeyboardInterrupt: 

## $\mathcal{M}$ - Testing

In [5]:
import torch.nn as nn
import os
import torch
import pickle
import numpy as np
import pandas as pd

from baseline_1 import Baseline_1
from unet_model import Baseline_2
from utils import device
from utils import ACDCDataLoader
from utils import BATCH_SIZE, CKPT
from utils import transform
from utils import infer_predictions
from utils import get_splits
from utils import postprocess_predictions, display_results

In [3]:
Model = Baseline_2

model = nn.ModuleDict([
    [axis, Model()] for axis in ["SA", "LA"]
]).to(device)

for axis in ["SA", "LA"]:
    ckpt = os.path.join(CKPT, "M", axis, "best_000.pth")
    model[axis].load_state_dict(torch.load(ckpt)["M"])
    model[axis].to(device)
    model.eval()

    infer_predictions(
        os.path.join("inference", axis),
        ACDCDataLoader(
            f"preprocessed/testing/{axis}",
            batch_size = BATCH_SIZE,
            transform = transform,
            transform_gt = False
        ),
        model[axis]
    )

  model[axis].load_state_dict(torch.load(ckpt)["M"])


In [7]:
print(device)
splits = get_splits(os.path.join(CKPT, "splits.pkl"))

with open(os.path.join("preprocessed", "patient_info.pkl"),'rb') as f:
    patient_info = pickle.load(f)

spacings = [
    patient_info["{:03d}_{}".format(id, axis)]["spacing"] for axis in ["SA", "LA"] for id in (
        splits["train"]["lab"] + splits["train"]["ulab"] + splits["val"]
    )
]

current_spacing = np.percentile(np.vstack(spacings), 50, 0)

cuda
[4.5250006 1.1986301 1.1986301]


In [5]:
results = {}
for axis in ["SA", "LA"]:
    results[axis] = postprocess_predictions(
        os.path.join("inference", axis),
        patient_info,
        current_spacing,
        os.path.join("postprocessed", axis),
    )

with open("postprocessed/results.pkl", "wb") as f:
    pickle.dump(results,f)

display_results(results)

      RV_ED_DC   RV_ED_HD  RV_ES_DC   RV_ES_HD     RV_DC      RV_HD  LV_ED_DC  \
axis                                                                            
SA    0.881428  12.326966  0.842821  12.412815  0.862125  12.369890  0.931895   
LA    0.915473   6.854922  0.895974   7.070609  0.905724   6.962766  0.955872   

      LV_ED_HD  LV_ES_DC  LV_ES_HD     LV_DC     LV_HD  MYO_ED_DC  MYO_ED_HD  \
axis                                                                           
SA    4.330233  0.888338  6.248102  0.910116  5.289167   0.797672   7.393551   
LA    4.416826  0.935375  4.129138  0.945623  4.272982   0.843709   6.268225   

      MYO_ES_DC  MYO_ES_HD    MYO_DC    MYO_HD  
axis                                            
SA     0.825067   7.765695  0.811369  7.579623  
LA     0.870662   5.209975  0.857186  5.739100  


Here we test our model on the testing set provided in the M&Ms-2 Challenge. For this reason, no GT is available, and validation metrics cannot be directly evaluated (this is why all values in the table above are NaN). The code in the cells below display the results reported in the <a href="https://www.ub.edu/mnms-2/#:~:text=the%20competition%20in-,Codalab,-to%20submit%20your"> Codalab platform</a>.

In [5]:
for axis in ["SA", "LA"]:
    for src in os.listdir(os.path.join("postprocessed", axis)):
        id = src.split("_")[0]
        if int(id) < 161:
            continue
        nib_image = nib.load(os.path.join("postprocessed", axis, src))
        image = np.around(nib_image.get_fdata()).astype(int)
        image = np.where(image==3, 1, 0)
        dst = os.path.join("submission", id, src.split(".nii.gz")[0] + "_pred.nii.gz")
        if not os.path.isdir(os.path.split(dst)[0]):
            os.makedirs(os.path.split(dst)[0])
        nib.save(nib.Nifti1Image(image, nib_image.affine, nib_image.header), dst)

In [7]:
!zip -rq submission.zip submission

/usr/bin/sh: 1: zip: not found


In [6]:
print("\033[1mBaseline\033[0m")
pd.DataFrame.from_dict({
    "axis": ["SA", "LA", "avg"],
    "RV_DC": [0.903681832739, 0.899725671050, 0.902692792316],
    "RV_HD": [13.350667610394, 7.404528745594, 11.864132894194],
}).set_index("axis")

[1mBaseline[0m


Unnamed: 0_level_0,RV_DC,RV_HD
axis,Unnamed: 1_level_1,Unnamed: 2_level_1
SA,0.903682,13.350668
LA,0.899726,7.404529
avg,0.902693,11.864133


## $\mathcal{R}$ - Training

In [1]:
import torch.nn as nn
import torch
import os

from reconstructor import Reconstructor
from utils import AttrDict
from utils import device
from utils import plot_history
from utils import ACDCAllPatients, ACDCDataLoader
from utils import transform
from utils import BATCH_SIZE, CKPT

In [2]:
ae = nn.ModuleDict([
    [axis, Reconstructor(
        AttrDict(**{
            "latent_size": 100,
            "lr": 2e-4,
            "last_layer": [4,2,1],
            "in_channels": 4,
            "weighted_epochs": 0
        })
    )] for axis in ["SA", "LA"]
]).to(device)

ckpts = None
if ckpts is not None:
    for axis, ckpt in ckpts.items():
        _, start = os.path.split(ckpt)
        start = int(start.replace(".pth", ""))
        ckpt = torch.load(ckpt)
        ae[axis].load_state_dict(ckpt["R"])
        ae[axis].optimizer.load_state_dict(ckpt["R_optim"])
else:
    start = 0
print(ae)

ModuleDict(
  (SA): Reconstructor(
    (encoder): Sequential(
      (0): Conv2d(4, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
      (3): Dropout(p=0.5, inplace=False)
      (4): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): LeakyReLU(negative_slope=0.2)
      (7): Dropout(p=0.5, inplace=False)
      (8): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (9): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): LeakyReLU(negative_slope=0.2)
      (11): Dropout(p=0.5, inplace=False)
      (12): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (14): Leaky

In [None]:
for axis in ["SA", "LA"]:
    plot_history(ae[axis].training_routine(
        range(start, 500),
        torch.utils.data.DataLoader(
            ACDCAllPatients(
                os.path.join("preprocessed/training/labelled/", axis),
                transform=transform
            ),
            batch_size=BATCH_SIZE, shuffle=False, num_workers=0
        ),
        ACDCDataLoader(
            os.path.join("preprocessed/validation/", axis),
            batch_size=BATCH_SIZE, transform=transform
        ),
        os.path.join(CKPT, "R", axis)
    ))

{'MSELoss': np.float64(0.01941307244822383), 'GDLoss': np.float64(-0.013654394899245302), 'Total': np.float64(0.005758677525445819), 'LV_dc': np.float64(0.0), 'LV_hd': np.float64(inf), 'MYO_dc': np.float64(0.0), 'MYO_hd': np.float64(inf), 'RV_dc': np.float64(0.0), 'RV_hd': np.float64(inf)}
[1mEpoch [0][0m
MSELos	GDLoss	Total	LV_dc	LV_hd	MYO_dc	MYO_hd	RV_dc	RV_hd	
0.0194	-0.013	0.0058	0.0000	inf	0.0000	inf	0.0000	inf	
{'MSELoss': np.float64(0.01941307244822383), 'GDLoss': np.float64(-0.013654394899245302), 'Total': np.float64(0.005758677525445819), 'LV_dc': np.float64(0.0), 'LV_hd': np.float64(inf), 'MYO_dc': np.float64(0.0), 'MYO_hd': np.float64(inf), 'RV_dc': np.float64(0.0), 'RV_hd': np.float64(inf)}
{'MSELoss': np.float64(0.013412956171669066), 'GDLoss': np.float64(-0.008331803836471181), 'Total': np.float64(0.005081152333877981), 'LV_dc': np.float64(0.0), 'LV_hd': np.float64(inf), 'MYO_dc': np.float64(0.0), 'MYO_hd': np.float64(inf), 'RV_dc': np.float64(0.0), 'RV_hd': np.float64(

## QC-based Candidate Selection

In [None]:
import torch.nn as nn
import os
import torch
import pickle
import numpy as np

from baseline_1 import Baseline_1
from baseline_2 import Baseline_2
from reconstructor import Reconstructor
from utils import device
from utils import AttrDict
from utils import Validator
from utils import ACDCDataLoader
from utils import BATCH_SIZE, CKPT
from utils import transform
from utils import GDLoss, CELoss, GDLoss_RV, CELoss_RV
from utils import infer_predictions
from utils import get_splits
from utils import postprocess_predictions
from utils import display_results

In [None]:
Model = Baseline_2

model = nn.ModuleDict([
    [axis, Model(
        AttrDict(**{
            "lr": 0.01,

            "functions": [GDLoss, CELoss],
            "functions_RV": [GDLoss_RV, CELoss_RV]
        })
    )] for axis in ["SA", "LA"]
]).to(device)

ae = nn.ModuleDict([
    [axis, Reconstructor(
        AttrDict(**{
            "latent_size": 100,
            "lr": 2e-4,
            "last_layer": [4,2,1],
            "in_channels": 4,
            "weighted_epochs": 0
        })
    )] for axis in ["SA", "LA"]
]).to(device)

validators = {
    "SA": Validator(5),
    "LA": Validator(5)
}

for axis in ["SA", "LA"]:
    ckpt = os.path.join(CKPT, "R", axis)
    ckpt = os.path.join(ckpt, sorted([file for file in os.listdir(ckpt) if "_best" in file])[-1])
    ckpt = torch.load(ckpt)
    ae[axis].load_state_dict(ckpt["R"])
    ae.eval()
    
    ckpt = os.path.join(CKPT, "M_refinement")
    if not os.path.isdir(ckpt):
        ckpt = os.path.join(CKPT, "M")
    for file in os.listdir(os.path.join(ckpt, axis)):
        if "best_" not in file or not file.endswith(".pth"):
            continue
        model[axis].load_state_dict(torch.load(os.path.join(ckpt, axis, file))["M"])
        model.eval()
        with torch.no_grad():
            validators[axis].domain_evaluation(
                "test",
                model[axis],
                ACDCDataLoader(
                    f"preprocessed/testing/{axis}",
                    batch_size=BATCH_SIZE,
                    transform=transform,
                    transform_gt=False
                ),    
                reconstructor=ae[axis]
            )


In [None]:
for axis in ["SA", "LA"]:
    infer_predictions(
        os.path.join("inference", axis),
        ACDCDataLoader(
            f"preprocessed/testing/{axis}",
            batch_size=BATCH_SIZE,
            transform=transform,
            transform_gt=False
        ),
        validator=validators[axis]
    )

In [None]:
splits = get_splits(os.path.join(CKPT, "splits.pkl"))

with open(os.path.join("preprocessed", "patient_info.pkl"),'rb') as f:
    patient_info = pickle.load(f)

spacings = [
    patient_info["{:03d}_{}".format(id, axis)]["spacing"] for axis in ["SA", "LA"] for id in (
        splits["train"]["lab"] + splits["train"]["ulab"] + splits["val"]
    )
]

current_spacing = np.percentile(np.vstack(spacings), 50, 0)

In [None]:
results = {}
for axis in ["SA", "LA"]:
    results[axis] = postprocess_predictions(
        os.path.join("inference", axis),
        patient_info,
        current_spacing,
        os.path.join("postprocessed", axis),
    )

with open("postprocessed/results.pkl", "wb") as f:
    pickle.dump(results,f)

display_results(results)

In [None]:
for axis in ["SA", "LA"]:
    for src in os.listdir(os.path.join("postprocessed", axis)):
        id = src.split("_")[0]
        if int(id) < 161:
            continue
        nib_image = nib.load(os.path.join("postprocessed", axis, src))
        image = np.around(nib_image.get_fdata()).astype(int)
        image = np.where(image==3, 1, 0)
        dst = os.path.join("submission", id, src.split(".nii.gz")[0] + "_pred.nii.gz")
        if not os.path.isdir(os.path.split(dst)[0]):
            os.makedirs(os.path.split(dst)[0])
        nib.save(nib.Nifti1Image(image, nib_image.affine, nib_image.header), dst)

In [None]:
!zip -rq submission.zip submission

We display below the results reported in the <a href="https://www.ub.edu/mnms-2/#:~:text=the%20competition%20in-,Codalab,-to%20submit%20your"> Codalab platform</a> after submitting the .zip file generated above.

In [None]:
print("\033[1mQC-based Candidate Selection\033[0m")
pd.DataFrame.from_dict({
    "axis": ["SA", "LA", "avg"],
    "RV_DC": [0.898024197115, 0.904328292327, 0.899600220918],
    "RV_HD": [12.194214048970, 7.522092252961, 11.026183599968],
}).set_index("axis")

## Semi-Supervised Refinement

In [None]:
import torch.nn as nn
import os
import torch
import pickle

from baseline_1 import Baseline_1
from baseline_2 import Baseline_2
from reconstructor import Reconstructor
from utils import AttrDict
from utils import GDLoss, CELoss, GDLoss_RV, CELoss_RV
from utils import device
from utils import Validator, Checkpointer
from utils import semisupervised_refinement
from utils import ACDCDataLoader, ACDCAllPatients
from utils import BATCH_SIZE, EPOCHS, CKPT
from utils import transform_augmentation_downsample, transform
from utils import plot_history

In [None]:
Model = Baseline_2

model = nn.ModuleDict([
    [axis, Model(
        AttrDict(**{
            "lr": 0.01,
            "functions": [GDLoss, CELoss],
            "functions_RV": [GDLoss_RV, CELoss_RV]
        })
    )] for axis in ["SA", "LA"]
]).to(device)

ae = nn.ModuleDict([
    [axis, Reconstructor(
        AttrDict(**{
            "latent_size": 100,
            "lr": 2e-4,
            "last_layer": [4,2,1],
            "in_channels": 4,
            "weighted_epochs": 0
        })
    )] for axis in ["SA", "LA"]
]).to(device)

validators = {
    "SA": Validator(5),
    "LA": Validator(5)
}

for axis in ["SA", "LA"]:
    ckpt = os.path.join(CKPT, "R", axis)
    ckpt = os.path.join(ckpt, sorted([file for file in os.listdir(ckpt) if "_best" in file])[-1])
    ckpt = torch.load(ckpt)
    ae[axis].load_state_dict(ckpt["R"])
    ae.eval()

    ckpt = os.path.join(CKPT, "M", axis, "200.pth")
    _, start = os.path.split(ckpt)
    start = int(start.replace(".pth", ""))
    ckpt = torch.load(ckpt)
    model[axis].load_state_dict(ckpt["M"])
    model[axis].optimizer.load_state_dict(ckpt["M_optim"])

    ckpt = os.path.join(CKPT, "M", axis, "200_val.pkl")
    with open(ckpt, "rb") as f:
        validators[axis] = pickle.load(f)

print(model)

In [None]:
for axis in ["SA", "LA"]:
    semisupervised_refinement(
        model[axis],
        ae[axis],
        range(start, EPOCHS),
        torch.utils.data.DataLoader(
            ACDCAllPatients(
                os.path.join("preprocessed/training/labelled/", axis),
                transform=transform_augmentation_downsample
            ),
            batch_size=BATCH_SIZE, shuffle=False, num_workers=0
        ),
        ACDCDataLoader(
            os.path.join("preprocessed/validation/", axis),
            batch_size=BATCH_SIZE, transform=transform
        ),
        ACDCDataLoader(
            os.path.join("preprocessed/training/unlabelled/", axis),
            batch_size = BATCH_SIZE, transform=transform
        ),
        validators[axis],
        Checkpointer(os.path.join(CKPT, "M_refinement", axis))
    )

    plot_history(validators[axis].get_history("val"))

After semi-supervised refinement, go back to QC-based Candidate Selection to validate your model. We report below the results from the <a href="https://www.ub.edu/mnms-2/#:~:text=the%20competition%20in-,Codalab,-to%20submit%20your"> Codalab platform</a>.

In [None]:
print("\033[1mSemi-Supervised Refinement\033[0m")
pd.DataFrame.from_dict({
    "axis": ["SA", "LA", "avg"],
    "RV_DC": [0.900794533587, 0.899223080762, 0.900401670381],
    "RV_HD": [12.268410479696, 7.016155479374, 10.955346729616],
}).set_index("axis")