In [1]:
import torch

from MAG_Cascade import grow_forest
from MAGIC import MAGIC_framework, out_caller
from Networks.nnUNet_wSD_decomposed_splitOutput import DynEncoder_wSD, DynDecoder_wSD
import os, glob

In [2]:
# model_hyperparams = dict(
#     scale_factors=[1, 2, 4, 8, 16],
#     spatial_dims=3,
#     filters = [32, 64, 128, 256, 512, 512],
#     strides = [(1, 1, 1)] + [(2, 2, 2) for _ in range(5)],
#     kernel_size = [(3, 3, 3) for _ in range(6)],
# ) 

# src = 'SavedModels/MAGIC_Published'
# forest = grow_forest(
#     roi_size=[96, 96, 96],
#     num_cascade_layers=1,
#     modality_names=['VR', 'simCT', 'CCTA'],
#     agnostic_name='Mask',
#     num_classes_set=[1 + 9 + 8 + 2 + 4],
#     encoder_module=(DynEncoder_wSD, model_hyperparams | {'attention_block': False}),
#     decoder_module=(DynDecoder_wSD, model_hyperparams | {'split_level': 4, 'out_groups': [2, 10, 9, 3]}),
#     in_channels_key='in_channels',
#     out_channels_key='out_channels',
#     detach_branch_inputs=True,
# )
# weight_paths = sorted(glob.glob(os.path.join(src, "Modules", "*.pth")))
# for path in weight_paths:
#     name = os.path.split(path)[-1].split('_WEIGHTS')[0]
#     forest.twigs[name].load_state_dict(torch.load(path))

In [3]:
# Setting up hyperparameteres for the MAGIC framwork
modality_names: list[str] = ['VR', 'simCT', 'CCTA'] # i.e. ['VR', 'CCTA', 'simCT]
agnostic_name: str = 'Mask'
num_classes_set: list[int] = 1 + 9 + 8 + 2 + 4 # Total number of classes involved, i.e. WH, Chambers/great vessels, Coronary arteries/valves, nodes, AND backgrounds
split_level = 4 # How deep in the network to split the model for each output group
# -> Configured to be a 6 layer nnU-Net base, i=4 puts the split right after the bottle neck
out_groups = [2, 10, 9, 3] # Number classes + bkg for each group to guide the output layers of the decoders
# 1 WH, 9 Chambers + Great Vessels, 8 coronary arteries + valves, 2 nodes, 4 "groups"

# Needed parameters for the backbone blocks
model_hyperparams = dict(
    scale_factors=[1, 2, 4, 8, 16],
    spatial_dims=3,
    filters = [32, 64, 128, 256, 512, 512],
    in_channels = 1,
    out_channels = num_classes_set,
    strides = [(1, 1, 1)] + [(2, 2, 2) for _ in range(5)],
    kernel_size = [(3, 3, 3) for _ in range(6)],
) 

# Initialize the MAGIC 
magic = MAGIC_framework(
    roi_size=[96, 96, 96],
    modality_names=modality_names,
    agnostic_name='Mask',
    encoder_module=(DynEncoder_wSD, model_hyperparams | {'attention_block': False}),
    decoder_module=(DynDecoder_wSD, model_hyperparams | {'split_level': split_level, 'out_groups': out_groups}),
    out_caller=out_caller,
    optimizer_class=None,
)

In [5]:
# Setting up hyperparameteres for the MAGIC framwork
modality_names: list[str] = ['VR', 'simCT', 'CCTA'] # i.e. ['VR', 'CCTA', 'simCT]
agnostic_name: str = 'Mask'
num_classes_set: list[int] = 2 # Total number of classes involved, i.e. WH, Chambers/great vessels, Coronary arteries/valves, nodes, AND backgrounds
split_level = 4 # How deep in the network to split the model for each output group
# -> Configured to be a 6 layer nnU-Net base, i=4 puts the split right after the bottle neck
out_groups = [2] # Number classes + bkg for each group to guide the output layers of the decoders
# 1 WH, 9 Chambers + Great Vessels, 8 coronary arteries + valves, 2 nodes, 4 "groups"

# Needed parameters for the backbone blocks
model_hyperparams = dict(
    scale_factors=[1, 2, 4, 8, 16],
    spatial_dims=3,
    filters = [32, 64, 128, 256, 512, 512],
    in_channels = 1,
    out_channels = num_classes_set,
    strides = [(1, 1, 1)] + [(2, 2, 2) for _ in range(5)],
    kernel_size = [(3, 3, 3) for _ in range(6)],
) 

# Initialize the MAGIC 
magic = MAGIC_framework(
    roi_size=[96, 96, 96],
    modality_names=modality_names,
    agnostic_name='Mask',
    encoder_module=(DynEncoder_wSD, model_hyperparams | {'attention_block': False}),
    decoder_module=(DynDecoder_wSD, model_hyperparams | {'split_level': split_level, 'out_groups': out_groups}),
    out_caller=out_caller,
    optimizer_class=None,
)

In [6]:
# src = '/mnt/data/Summerfield/UseMAGIC/SavedModels/MAGIC_Published'
src = "/mnt/data/Summerfield/UseMAGIC/SavedModels/Localizer"
weight_paths = sorted(glob.glob(os.path.join(src, "Modules", "*.pth")))
for path in weight_paths:
    name = os.path.split(path)[-1].split('_WEIGHTS')[0]
    if 'Encoder' in name: name = name.split('_')[0]
    else: name = name.split('_')[0]
    magic.magic_modules[name].load_state_dict(torch.load(path))

In [7]:
magic.save('MAGIC_WH_Localizer')