# Overview
**BRAIN SEGMENTATION INFERENCE**
This Jupyter notebook is designed to run inference on a brain MRI scan using a pre-trained segmentation model. It downloads a model from the cloud (requiring AWS access keys to be set up properly) and then runs inference on a single scan. The output is a nifti file with the segmentation labels. You may wish to then load it in a program such as Slicer to view the results.

# Setup
First, set up a python environment with the necessary packages.

For instance, in VSCode or Cursor, press Ctrl+Shift+P and type "Python: Create Environment" and follow the prompts.

Or, on the command line, run:
```bash
python -m venv brain_segmentation_inference_env
```
--or--
```bash
python3 -m venv brain_segmentation_inference_env
```
and then either 
```bash
brain_segmentation_inference_env\Scripts\activate
```
on Windows, or
```bash
source brain_segmentation_inference_env/bin/activate
```
on macOS and Linux.

Then, install the necessary packages:
```bash
pip install -r requirements.txt
```

After all that, this cell should run without error and import all the necessary packages.


In [1]:
pip install -r ../requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [2]:
#Check GPU availability
import torch
print(torch.__version__)
print('cuda devices (count) ....',torch.cuda.device_count() )

2.6.0+cu124
cuda devices (count) .... 1


In [3]:
!nvidia-smi

Tue Sep 23 03:10:49 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.51.03              Driver Version: 576.28         CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4070 ...    On  |   00000000:01:00.0  On |                  N/A |
| N/A   51C    P8              2W /  115W |     181MiB /   8188MiB |     17%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

## Logging

In [4]:
import logging

# Create a named logger
logger = logging.getLogger('brain_segmentation')
logger.setLevel(logging.DEBUG)

# Create a console handler and set its level
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)

# Create a formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# Add the formatter to the console handler
console_handler.setFormatter(formatter)

# Add the console handler to the logger
logger.addHandler(console_handler)

# Turn off logs from all other loggers
for name in logging.root.manager.loggerDict:
    if name != 'brain_segmentation':
        logging.getLogger(name).setLevel(logging.CRITICAL)

logger.info('🚀 Setting up logging...')
logger.debug(f'🔧 Current logging level: {logger.getEffectiveLevel()}')
logger.info('✅ Logging setup complete.')

2025-09-23 03:10:49,886 - brain_segmentation - INFO - 🚀 Setting up logging...
2025-09-23 03:10:49,887 - brain_segmentation - DEBUG - 🔧 Current logging level: 10
2025-09-23 03:10:49,888 - brain_segmentation - INFO - ✅ Logging setup complete.


## Imports

In [5]:
logger.info('📦 Importing outside packages...')
import os, sys
import torch
import glob
import json, yaml
import tempfile
import pandas as pd
import numpy as np
import boto3
import botocore
import pprint
from pathlib import Path
from tqdm import tqdm
from smart_open import open
from monai.transforms import AsDiscrete, Activations
from monai.utils.enums import MetricReduction
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
import nibabel as nib
import numpy as np
logger.info('✅ Outside packages imported.')

logger.info('📦 Importing local packages...')
from core_common import get_loader_val, datafold_read
import model
logger.info('✅ Local packages imported.')

logger.info('📦 Setting up device...')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
logger.info('✅ Device set up.')
logger.debug(f'🖥️ {device = }')

# make output directory
output_dir = "data/output_my"
os.makedirs(output_dir, exist_ok=True)
logger.info(f"📂 Output directory set to: {output_dir}")

2025-09-23 03:10:49,894 - brain_segmentation - INFO - 📦 Importing outside packages...
  from .autonotebook import tqdm as notebook_tqdm
2025-09-23 03:10:51,793 - brain_segmentation - INFO - ✅ Outside packages imported.
2025-09-23 03:10:51,793 - brain_segmentation - INFO - 📦 Importing local packages...
2025-09-23 03:10:51,802 - brain_segmentation - INFO - ✅ Local packages imported.
2025-09-23 03:10:51,803 - brain_segmentation - INFO - 📦 Setting up device...
2025-09-23 03:10:51,891 - brain_segmentation - INFO - ✅ Device set up.
2025-09-23 03:10:51,891 - brain_segmentation - DEBUG - 🖥️ device = device(type='cuda', index=0)
2025-09-23 03:10:51,894 - brain_segmentation - INFO - 📂 Output directory set to: data/output_my


# Model
This part loads the model from the cloud.

## Download
This part downloads the model from the cloud.

You'll need to have an AWS profile named 'theta-model-downloader'. You can create this profile by running this command in your terminal:
```bash
aws configure --profile theta-model-downloader
```
and entering your AWS credentials.

If you do not have the AWS command line tools installed, you can install them by following the instructions [here](https://docs.aws.amazon.com/cli/latest/userguide/install-cliv2.html).

In [6]:
!aws --version

aws-cli/2.30.1 Python/3.13.7 Linux/5.15.167.4-microsoft-standard-WSL2 exe/x86_64.ubuntu.24


## Writting the input data dict

In [7]:
# choose to run inference on time1 or time2
time_num = input("Please enter if you want to run inference on 'time1' or 'time2'")
print("User choose to run inference on " + time_num)
'''main_path="/app/train/tumor_seg/tumor_seg/data/Preprocessed"

cases_list=glob.glob(os.path.join(main_path, "*"))
cases_list=sorted(cases_list)

new_ids=[]
for paths_files in cases_list:
    case_id=Path(paths_files)
    case_id=case_id.parts[-3]
    print(case_id)
    dir_name=os.path.dirname(paths_files)
    print(dir_name)
    files = [dir_name+'/'+case_id+'_time1_flair.nii.gz', 
             dir_name+'/'+case_id + '_time1_t1ce.nii.gz',
             dir_name+'/'+case_id + '_time1_t1.nii.gz', 
             dir_name+'/'+case_id + '_time1_t2.nii.gz']
    
    all_exist = all([os.path.exists(file) for file in files])
    if all_exist:
        new_ids.append(case_id)

len(new_ids)'''
main_path = "/mnt/d/A1_RainSun_20240916/1-UWMadison/IDiA-Lab/Tasks/tumor_seg/tumor_seg/data/Preprocessed"
cases_list = sorted(glob.glob(os.path.join(main_path, "*")))

new_ids = []

for case_path in cases_list:
    case_id = Path(case_path).name  # ✅ correct way to extract the folder name (e.g., "100001")
    
    '''
    files = [
        os.path.join(case_path, f"{case_id}_time1_flair.nii.gz"),
        os.path.join(case_path, f"{case_id}_time1_t1ce.nii.gz"),
        os.path.join(case_path, f"{case_id}_time1_t1.nii.gz"),
        os.path.join(case_path, f"{case_id}_time1_t2.nii.gz"),
    ]
    '''
    files = [
        os.path.join(case_path, f"{case_id}_{time_num}_flair.nii.gz"),
        os.path.join(case_path, f"{case_id}_{time_num}_t1ce.nii.gz"),
        os.path.join(case_path, f"{case_id}_{time_num}_t1.nii.gz"),
        os.path.join(case_path, f"{case_id}_{time_num}_t2.nii.gz"),
    ]

    all_exist = all(os.path.exists(f) for f in files)
    if all_exist:
        new_ids.append(case_id)
    else:
        print(f"❌ Missing files for {case_id}")
        for f in files:
            if not os.path.exists(f):
                print(f"   - {f}")

print(f"\n✅ Total valid cases: {len(new_ids)}")
print("🧪 Sample valid case IDs:", new_ids[:5])

Please enter if you want to run inference on 'time1' or 'time2' time1


User choose to run inference on time1
❌ Missing files for UCSF_PostopGlioma_Table S1 R1 V5.0_UNBLINDED_FINAL.xlsx
   - /mnt/d/A1_RainSun_20240916/1-UWMadison/IDiA-Lab/Tasks/tumor_seg/tumor_seg/data/Preprocessed/UCSF_PostopGlioma_Table S1 R1 V5.0_UNBLINDED_FINAL.xlsx/UCSF_PostopGlioma_Table S1 R1 V5.0_UNBLINDED_FINAL.xlsx_time1_flair.nii.gz
   - /mnt/d/A1_RainSun_20240916/1-UWMadison/IDiA-Lab/Tasks/tumor_seg/tumor_seg/data/Preprocessed/UCSF_PostopGlioma_Table S1 R1 V5.0_UNBLINDED_FINAL.xlsx/UCSF_PostopGlioma_Table S1 R1 V5.0_UNBLINDED_FINAL.xlsx_time1_t1ce.nii.gz
   - /mnt/d/A1_RainSun_20240916/1-UWMadison/IDiA-Lab/Tasks/tumor_seg/tumor_seg/data/Preprocessed/UCSF_PostopGlioma_Table S1 R1 V5.0_UNBLINDED_FINAL.xlsx/UCSF_PostopGlioma_Table S1 R1 V5.0_UNBLINDED_FINAL.xlsx_time1_t1.nii.gz
   - /mnt/d/A1_RainSun_20240916/1-UWMadison/IDiA-Lab/Tasks/tumor_seg/tumor_seg/data/Preprocessed/UCSF_PostopGlioma_Table S1 R1 V5.0_UNBLINDED_FINAL.xlsx/UCSF_PostopGlioma_Table S1 R1 V5.0_UNBLINDED_FINAL.xl

In [8]:
'''
new_ids=["0171", "0178", "0179" ,"0190", "0196", "0198", "0199","0245",
         "0281","0296", "0325", "0338", "355"]
new_ids=["0190"]
'''

'\nnew_ids=["0171", "0178", "0179" ,"0190", "0196", "0198", "0199","0245",\n         "0281","0296", "0325", "0338", "355"]\nnew_ids=["0190"]\n'

In [9]:
data = {
    'training': []
}

for id_num in new_ids:
    new_entry = {
        'fold': 1,
        'label': f'{id_num}/{id_num}_{time_num}_seg.nii.gz',
        'image': [
            f'{id_num}/{id_num}_{time_num}_flair.nii.gz',
            f'{id_num}/{id_num}_{time_num}_t1ce.nii.gz',
            f'{id_num}/{id_num}_{time_num}_t1.nii.gz',
            f'{id_num}/{id_num}_{time_num}_t2.nii.gz'
        ]
    }
    # Append the new entry to the 'training' list
    data['training'].append(new_entry)
    
with open('./data/multi_example2.json', 'w') as json_file:
    json.dump(data, json_file, indent=4)

In [10]:
files_path="./data/multi_example2.json"
with open(files_path, 'r') as file:
    data = json.load(file)

# check first instance JSON data
data["training"][0:1]

[{'fold': 1,
  'label': '100001/100001_time1_seg.nii.gz',
  'image': ['100001/100001_time1_flair.nii.gz',
   '100001/100001_time1_t1ce.nii.gz',
   '100001/100001_time1_t1.nii.gz',
   '100001/100001_time1_t2.nii.gz']}]

In [11]:
logger.info('📂 Loading configuration file...')
config_pth = '../config/example_config.yaml'

logger.info('🔓 Opening configuration file...')
with open(config_pth, 'r') as file:
    logger.info('📖 Reading YAML content...')
    inference_cfg = yaml.safe_load(file)
logger.debug(f'🔧 {inference_cfg = }')
logger.info('✅ Configuration loaded.')

2025-09-23 03:11:00,336 - brain_segmentation - INFO - 📂 Loading configuration file...
2025-09-23 03:11:00,336 - brain_segmentation - INFO - 🔓 Opening configuration file...
2025-09-23 03:11:00,340 - brain_segmentation - INFO - 📖 Reading YAML content...
2025-09-23 03:11:00,342 - brain_segmentation - DEBUG - 🔧 inference_cfg = {'input': {'compute_dice': False, 'mode': 'multi', 'multi_scan': {'scan_list': 'data/multi_example2.json'}, 'single_scan': {'ground_truth': '100001/100001_time1_seg.nii.gz', 'flair': '100001/100001_time1_flair.nii.gz', 't1c': '100001/100001_time1_t1ce.nii.gz', 't1': '100001/100001_time1_t1.nii.gz', 't2': '100001/100001_time1_t2.nii.gz'}, 'data_dir': '/mnt/d/A1_RainSun_20240916/1-UWMadison/IDiA-Lab/Tasks/tumor_seg/tumor_seg/data/Preprocessed'}, 'output': {'file_path': '/mnt/d/A1_RainSun_20240916/1-UWMadison/IDiA-Lab/Tasks/tumor_seg/tumor_seg/data/out_seg.nii.gz'}, 'model': {'bucket': 'theta-trained-models', 'key': 'model.ckpt', 'region': 'us-east-1'}}
2025-09-23 03:11

# Download the model

In [12]:
# open the terminal
#!aws configure --profile theta-model-downloader
# enter your credential: (.txt)

In [13]:
import pprint

# ✅ 1. Show what's in the YAML config
pprint.pprint(inference_cfg)

# ✅ 2. Load local checkpoint
checkpoint_path = "../data/model_checkpoints/4-17-17/swinunetr-epoch=159.ckpt"
logger.info(f"📁 Loading checkpoint from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location=device)

# ✅ 3. Print top-level keys in checkpoint
print("📦 Checkpoint top-level keys:", checkpoint.keys())

# ✅ 4. Extract state_dict if checkpoint is from PyTorch Lightning
if 'state_dict' in checkpoint:
    state_dict = checkpoint['state_dict']
else:
    state_dict = checkpoint  # regular PyTorch

# ✅ 5. Remove 'model.' prefix if present
state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}

# ✅ 6. Inspect parameter shapes in the state_dict
print("\n🔍 Inspecting state_dict parameter shapes:")
for k, v in state_dict.items():
    print(f"{k:60s} {tuple(v.shape)}")

# ✅ 7. Define model hyperparameters manually (since config file doesn't include them)
hparams = {
    'roi': {'h': 224, 'w': 224, 'd': 96},
    'feature_size': 24,
    'drop_rate': 0.01,
    'attn_drop_rate': 0.01,
    'dropout_path_rate': 0.01,
    'depths': [2, 2, 2, 2],
    'num_heads': [3, 6, 6, 6],
    'norm_name': 'instance',
    'normalize': True,
    'downsample': 'merging',
    'use_v2': False,
    'mlp_ratio': 4,
    'qkv_bias': True,
    'patch_size': 2,
    'window_size': 3
}

# ✅ 8. Match output channels to checkpoint
out_channels = 4  # based on your checkpoint, not 3

# ✅ 9. Convert ROI dict to list
def listify_3d(x: dict):
    return [x['h'], x['w'], x['d']]

roi_size = listify_3d(hparams['roi'])
logger.info(f"📐 ROI size set to: {roi_size}")

# ✅ 10. Instantiate model
model_instance = model.CustomSwinUNETR(
    in_channels=4,
    img_size=roi_size,
    out_channels=out_channels,
    feature_size=hparams['feature_size'],
    use_checkpoint=True,
    depths=hparams['depths'],
    num_heads=hparams['num_heads'],
    norm_name=hparams['norm_name'],
    normalize=hparams['normalize'],
    downsample=hparams['downsample'],
    use_v2=hparams['use_v2'],
    mlp_ratio=hparams['mlp_ratio'],
    qkv_bias=hparams['qkv_bias'],
    patch_size=hparams['patch_size'],
    window_size=hparams['window_size'],
)

# ✅ 11. Load model weights with strict=False to allow shape mismatch fallback
logger.info("🧠 Loading model weights with strict=False (to ignore size mismatches)...")
model_instance.load_state_dict(state_dict, strict=False)
model_instance.to(device)
model_instance.eval()
logger.info("✅ Custom checkpoint loaded and model ready.")


2025-09-23 03:11:00,351 - brain_segmentation - INFO - 📁 Loading checkpoint from ../data/model_checkpoints/4-17-17/swinunetr-epoch=159.ckpt...


{'input': {'compute_dice': False,
           'data_dir': '/mnt/d/A1_RainSun_20240916/1-UWMadison/IDiA-Lab/Tasks/tumor_seg/tumor_seg/data/Preprocessed',
           'mode': 'multi',
           'multi_scan': {'scan_list': 'data/multi_example2.json'},
           'single_scan': {'flair': '100001/100001_time1_flair.nii.gz',
                           'ground_truth': '100001/100001_time1_seg.nii.gz',
                           't1': '100001/100001_time1_t1.nii.gz',
                           't1c': '100001/100001_time1_t1ce.nii.gz',
                           't2': '100001/100001_time1_t2.nii.gz'}},
 'model': {'bucket': 'theta-trained-models',
           'key': 'model.ckpt',
           'region': 'us-east-1'},
 'output': {'file_path': '/mnt/d/A1_RainSun_20240916/1-UWMadison/IDiA-Lab/Tasks/tumor_seg/tumor_seg/data/out_seg.nii.gz'}}


2025-09-23 03:11:01,221 - brain_segmentation - INFO - 📐 ROI size set to: [224, 224, 96]
2025-09-23 03:11:01,280 - brain_segmentation - INFO - 🧠 Loading model weights with strict=False (to ignore size mismatches)...
2025-09-23 03:11:01,311 - brain_segmentation - INFO - ✅ Custom checkpoint loaded and model ready.


📦 Checkpoint top-level keys: dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers'])

🔍 Inspecting state_dict parameter shapes:
swinViT.patch_embed.proj.weight                              (24, 4, 2, 2, 2)
swinViT.patch_embed.proj.bias                                (24,)
swinViT.layers1.0.blocks.0.norm1.weight                      (24,)
swinViT.layers1.0.blocks.0.norm1.bias                        (24,)
swinViT.layers1.0.blocks.0.attn.relative_position_bias_table (125, 3)
swinViT.layers1.0.blocks.0.attn.relative_position_index      (27, 27)
swinViT.layers1.0.blocks.0.attn.qkv.weight                   (72, 24)
swinViT.layers1.0.blocks.0.attn.qkv.bias                     (72,)
swinViT.layers1.0.blocks.0.attn.proj.weight                  (24, 24)
swinViT.layers1.0.blocks.0.attn.proj.bias                    (24,)
swinViT.layers1.0.blocks.0.norm2.weight                      (24,)
swinViT.layers1.0.blocks.0.nor

## Load
Once the model has been downloaded, this section loads the model into memory.

In [14]:
'''
def listify_3d(x: dict):
    logger.info('🧊 Listifying 3D dimensions...')
    dimensions = [x['h'], x['w'], x['d']]
    logger.debug(f'📏 Listified dimensions: {dimensions = }')
    return dimensions

logger.info('🎛️ Fetching hyperparameters...')
hparams = model_cfg['hyperparameter']
logger.debug(f'🎛️ Hyperparameters: {hparams = }')

logger.info('🔗 Fetching label union...')
union = model_cfg['data']['label_union']
logger.debug(f'🔗 Label union: {union = }')

logger.info('📐 Calculating ROI size...')
roi_size = listify_3d(hparams['roi'])
logger.debug(f'📐 ROI size: {roi_size = }')

logger.info('🏗️ Creating model instance...')
model = model.CustomSwinUNETR(
    in_channels       = 4, # one per MRI modality: T1, T2, T1-Contrast, FLAIR
    img_size          = roi_size,
    out_channels      = 4 if union else 3, # one per label: tumor core, whole tumor, enhancing tumor
    feature_size      = hparams['feature_size'],
    use_checkpoint    = True,
    depths            = hparams['depths'],
    num_heads         = hparams['num_heads'],
    norm_name         = hparams['norm_name'],
    normalize         = hparams['normalize'],
    downsample        = hparams['downsample'],
    use_v2            = hparams['use_v2'],
    mlp_ratio         = hparams['mlp_ratio'],
    qkv_bias          = hparams['qkv_bias'],
    patch_size        = hparams['patch_size'],
    window_size       = hparams['window_size'],
)
logger.info('✅ Model instance created.')

logger.info('🔑 Fetching first model state key...')
first_model_state_key = next(iter(model.state_dict().keys()))
logger.debug(f'✅ First model state key: {first_model_state_key = }')

logger.info('🔑 Fetching first state dict key...')
first_state_dict_key = next(iter(state_dict.keys()))
logger.debug(f'✅ First state dict key: {first_state_dict_key = }')

logger.info('💾 Loading state dictionary into model...')
model.load_state_dict(state_dict)
logger.info('✅ State dictionary loaded.')

logger.info('🖥️ Moving model to device...')
model.to(device)
logger.info('✅ Model moved to device.')

logger.info('🧠 Setting model to evaluation mode...')
model.eval()
logger.info('✅ Model set to evaluation mode.')
'''

"\ndef listify_3d(x: dict):\n    logger.info('🧊 Listifying 3D dimensions...')\n    dimensions = [x['h'], x['w'], x['d']]\n    logger.debug(f'📏 Listified dimensions: {dimensions = }')\n    return dimensions\n\nlogger.info('🎛️ Fetching hyperparameters...')\nhparams = model_cfg['hyperparameter']\nlogger.debug(f'🎛️ Hyperparameters: {hparams = }')\n\nlogger.info('🔗 Fetching label union...')\nunion = model_cfg['data']['label_union']\nlogger.debug(f'🔗 Label union: {union = }')\n\nlogger.info('📐 Calculating ROI size...')\nroi_size = listify_3d(hparams['roi'])\nlogger.debug(f'📐 ROI size: {roi_size = }')\n\nlogger.info('🏗️ Creating model instance...')\nmodel = model.CustomSwinUNETR(\n    in_channels       = 4, # one per MRI modality: T1, T2, T1-Contrast, FLAIR\n    img_size          = roi_size,\n    out_channels      = 4 if union else 3, # one per label: tumor core, whole tumor, enhancing tumor\n    feature_size      = hparams['feature_size'],\n    use_checkpoint    = True,\n    depths          

# Data
This section loads the data defined in your `config/config.yaml` file. Please see the example config file for more details. It requires you to have a nifti file for each MRI modality, typically T1, T2, T1-Contrast, and FLAIR, and the location to store the output nifti segmentation file. There are also instructions for loading multiple scans from a json file.


In [15]:
# 🔧 Utility to safely extract optional boolean keys from nested config
def check_optional_key(x: dict, key_name, true_val):
    logger.debug(f'🔍 Checking for key: {key_name} with value: {true_val}')
    result = (key_name in x.keys()) and (x[key_name] == true_val)
    logger.debug(f'🔍 Result of check: {result}')
    return result


logger.info('🔢 Setting up fold and configuration...')
fold = 1  # Let the validation fold be 1 - same convention as during training.
logger.debug(f'📊 {fold = }')

logger.info('🔍 Checking optional configuration keys...')
do_ground_truth: bool = check_optional_key(inference_cfg['input'], 'compute_dice', True)
logger.debug(f'🎯 {do_ground_truth = }')
use_scan_list: bool = check_optional_key(inference_cfg['input'], 'mode', 'multi')
logger.debug(f'📋 {use_scan_list = }')

data_dir = inference_cfg['input']['data_dir']
logger.debug(f'📂 {data_dir = }')

logger.info('📁 Loading scan data...')
if use_scan_list:
    logger.info('📚 Using multiple scans from JSON...')
    json_path = inference_cfg['input']['multi_scan']['scan_list']
    _, validation_files = datafold_read(datalist=json_path, basedir=data_dir, fold=fold)
    with open(json_path) as f:
        test_instance = json.load(f)['training'][0]['image'][2]  # To get image size, grab the T1 scan.
else:
    logger.info('🖼️ Using single scan...')
    test_instance = inference_cfg['input']['single_scan']['t1']  # To get image size, grab the T1 scan.
    json_data = {
        'training': [
            {
                'fold': fold,
                'image': [
                    inference_cfg['input']['single_scan']['flair'],
                    inference_cfg['input']['single_scan']['t1c'],
                    inference_cfg['input']['single_scan']['t1'],
                    inference_cfg['input']['single_scan']['t2']
                ]
            }
        ]
    }
    if do_ground_truth:
        json_data['training'][0]['label'] = inference_cfg['input']['single_scan']['ground_truth']

    logger.info('📝 Creating temporary JSON file...')
    with tempfile.NamedTemporaryFile(mode='w+', suffix='.json', delete=False) as temp_file:  # Use delete=False to fix a permissions error on Windows.
        json.dump(json_data, temp_file, indent=4)
        temp_file.flush()
        json_path = temp_file.name
        _, validation_files = datafold_read(datalist=json_path, basedir=data_dir, fold=fold)
        temp_file.close()
        os.unlink(temp_file.name)
    logger.info('✅ Temporary JSON file created and processed.')

logger.info('📏 Loading image size...')
resize_shape = list(nib.load(os.path.join(data_dir, test_instance)).shape)
logger.debug(f'📐 Scan size: {resize_shape = }')

logger.info('🔄 Setting up validation data loader...')
val_loader = get_loader_val(
    batch_size=1,
    files=validation_files,
    val_resize=resize_shape,
    # union=union,
    workers=1,
    cache_dir='',
    dataset_type='Dataset',
    add_label=do_ground_truth,
)
logger.info('✅ Validation data loader setup complete.')

2025-09-23 03:11:01,322 - brain_segmentation - INFO - 🔢 Setting up fold and configuration...
2025-09-23 03:11:01,323 - brain_segmentation - DEBUG - 📊 fold = 1
2025-09-23 03:11:01,323 - brain_segmentation - INFO - 🔍 Checking optional configuration keys...
2025-09-23 03:11:01,323 - brain_segmentation - DEBUG - 🔍 Checking for key: compute_dice with value: True
2025-09-23 03:11:01,324 - brain_segmentation - DEBUG - 🔍 Result of check: False
2025-09-23 03:11:01,324 - brain_segmentation - DEBUG - 🎯 do_ground_truth = False
2025-09-23 03:11:01,325 - brain_segmentation - DEBUG - 🔍 Checking for key: mode with value: multi
2025-09-23 03:11:01,325 - brain_segmentation - DEBUG - 🔍 Result of check: True
2025-09-23 03:11:01,325 - brain_segmentation - DEBUG - 📋 use_scan_list = True
2025-09-23 03:11:01,326 - brain_segmentation - DEBUG - 📂 data_dir = '/mnt/d/A1_RainSun_20240916/1-UWMadison/IDiA-Lab/Tasks/tumor_seg/tumor_seg/data/Preprocessed'
2025-09-23 03:11:01,326 - brain_segmentation - INFO - 📁 Loadin

Training set size: 0
Validation set size: 298


# Inference
This section takes the loaded model and runs inference on loaded scans, and saves the output to the location specified in your `config/config.yaml` file.

In [16]:
# Save dice score and output it
import csv 
import os

def save_dice_scores_to_csv(dice_scores, case_ids, output_csv_path):
    """
    Save per-case Dice scores to a CSV file.

    Args:
        dice_scores (List[List[float]]): A list of [DICE_tc, DICE_wt, DICE_et, DICE_mean] for each case.
        case_ids (List[str]): A list of case IDs corresponding to each Dice score row.
        output_csv_path (str): Path to the output CSV file.
    """
    os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)

    with open(output_csv_path, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['CaseID', 'DICE_tc', 'DICE_wt', 'DICE_et', 'DICE_mean'])
        for case_id, score in zip(case_ids, dice_scores):
            writer.writerow([case_id] + score)

    logger.info(f'📄 Dice scores saved to {output_csv_path}')

In [None]:
output_dir = "../outputs/output_0923_time1_03"
os.makedirs(output_dir, exist_ok=True)
logger.info(f"📂 Output directory set to: {output_dir}")

case_ids = []
dice_scores = []

# 🔁 If using scan list, load and filter it once outside the loop
if use_scan_list:
    with open(json_path) as f:
        d = json.load(f)
    scan_entries = [entry for entry in d['training'] if entry['fold'] == fold]

with torch.no_grad():
    logger.info('🧮 Setting up accuracy function...')
    acc_func = DiceMetric(include_background=True, reduction=MetricReduction.MEAN_BATCH)

    loop_data = zip(val_loader, scan_entries) if use_scan_list else enumerate(val_loader)

    for i, data in enumerate(loop_data):
        val_data = data[0]
        scan_entry = data[1] if use_scan_list else None

        logger.info(f'📥 Loading data for case {i}...')
        val_images = val_data['image']
        logger.debug(f'🖼️ val_images.shape = {val_images.shape}')

        # 🔍 Load affine matrix
        if use_scan_list:
            affine_path = os.path.join(data_dir, scan_entry['image'][2])
            case_id = os.path.basename(os.path.dirname(affine_path))
        else:
            affine_path = os.path.join(data_dir, inference_cfg['input']['single_scan']['t1'])
            case_id = 'single_scan'

        affine = nib.load(affine_path).affine
        logger.debug(f'📐 affine.shape = {affine.shape}')

        # 🔄 Perform inference
        val_outputs = sliding_window_inference(
            val_images.to(device),
            roi_size=roi_size,
            sw_batch_size=4,
            predictor=model_instance,
        )
        logger.debug(f'🔢 val_outputs.shape = {val_outputs.shape}')

        # 💡 Post-processing predictions
        val_outputs = val_outputs.cpu()[0]  # remove batch dim: [C, H, W, D]
        val_outputs_per_channel = torch.unbind(val_outputs, dim=0)
        post_sigmoid = Activations(sigmoid=True)
        post_pred = AsDiscrete(argmax=False, threshold=0.5)
        val_outputs_convert = [post_pred(post_sigmoid(c)) for c in val_outputs_per_channel]

        # 🧮 DICE calculation
        if do_ground_truth:
            logger.info('📊 Calculating DICE scores...')
            ground_truth = val_data['label']
            acc_func.reset()
            val_outputs_used = val_outputs_convert[:-1] if len(val_outputs_convert) > 3 else val_outputs_convert
            val_outputs_used = [v.to(device) for v in val_outputs_used]
            y_pred_tensor = torch.stack(val_outputs_used, dim=0).unsqueeze(0)  # shape [1, C, H, W, D]
            acc_func(y_pred=y_pred_tensor, y=ground_truth.to(device))
            acc = acc_func.aggregate().cpu().numpy()

            num_zeroes = list(acc[:3]).count(0.0)
            mean = (acc[0] + acc[1] + acc[2]) / (3 - num_zeroes) if num_zeroes < 3 else 0
            logger.info(f'📊 DICE (tc): {acc[0]}')
            logger.info(f'📊 DICE (wt): {acc[1]}')
            logger.info(f'📊 DICE (et): {acc[2]}')
            logger.info(f'📊 DICE (mean): {mean}')

            dice_scores.append([acc[0], acc[1], acc[2], mean])
            case_ids.append(case_id)

        # 🧠 Segmentation mask generation (channel-wise, no argmax)
        logger.info('🧠 Generating segmentation mask...')
        segmentation_mask = val_outputs_convert[:-1] if len(val_outputs_convert) > 3 else val_outputs_convert

        # You can update this map based on what you want each channel to represent
        channel_label_map = {0: 0, 1: 2, 2: 4, 3: 3, 4:1}  # e.g. 1 = NCR, 2 = SNFH, 3 = ET

        remapped = np.zeros_like(segmentation_mask[0].cpu().numpy(), dtype=np.uint8)
        for i, mask in enumerate(segmentation_mask):
            binary_mask = mask.cpu().numpy().astype(bool)
            if i in channel_label_map:
                remapped[binary_mask] = channel_label_map[i]

        segmentation_mask = remapped

        # 🧾 Log label distribution
        unique_labels, label_counts = np.unique(segmentation_mask, return_counts=True)
        for label, count in zip(unique_labels, label_counts):
            logger.info(f'🏷️ Label {int(label)}: {count} voxels')

        total_voxels = np.prod(segmentation_mask.shape)
        for label, count in zip(unique_labels, label_counts):
            percentage = (count / total_voxels) * 100
            logger.info(f'📊 Label {int(label)}: {percentage:.2f}% of total volume')

        # 💾 Save NIfTI
        save_pth = os.path.join(output_dir, f'{case_id}_seg.nii.gz')
        nib.save(nib.Nifti1Image(segmentation_mask.astype(np.uint8), affine=affine), save_pth)
        logger.info(f'💾 Saved segmentation to: {save_pth}')

# ✅ Final Dice summary
if do_ground_truth and dice_scores:
    logger.info('📊 Calculating mean DICE scores...')
    mask = np.ma.masked_equal(dice_scores, 0)
    mean = mask.mean(axis=0).filled(np.nan)
    logger.info(f'📊 DICE (tc): {mean[0]}')
    logger.info(f'📊 DICE (wt): {mean[1]}')
    logger.info(f'📊 DICE (et): {mean[2]}')
    logger.info(f'📊 DICE (mean): {mean[3]}')

    # 💾 Save per-case DICE to CSV
    logger.info("💾 Saving per-case DICE scores to CSV...")
    dice_df = pd.DataFrame(dice_scores, columns=["DICE_tc", "DICE_wt", "DICE_et", "DICE_mean"])
    dice_df.insert(0, "Case_ID", case_ids)
    dice_df.to_csv(os.path.join(output_dir, "dice_scores.csv"), index=False)
    logger.info(f"📄 Saved per-case DICE scores to: {os.path.join(output_dir, 'dice_scores.csv')}")

    # 💾 Save mean DICE score to CSV
    mean_df = pd.DataFrame([{
        "DICE_tc": mean[0],
        "DICE_wt": mean[1],
        "DICE_et": mean[2],
        "DICE_mean": mean[3],
    }])
    mean_df.to_csv(os.path.join(output_dir, "mean_dice_score.csv"), index=False)
    logger.info(f"📄 Saved mean DICE score to: {os.path.join(output_dir, 'mean_dice_score.csv')}")

logger.info('✅ Inference complete.')

2025-09-23 03:11:01,358 - brain_segmentation - INFO - 📂 Output directory set to: ../outputs/output_0923_time1_03
2025-09-23 03:11:01,367 - brain_segmentation - INFO - 🧮 Setting up accuracy function...
2025-09-23 03:11:02,663 - brain_segmentation - INFO - 📥 Loading data for case 0...
2025-09-23 03:11:02,665 - brain_segmentation - DEBUG - 🖼️ val_images.shape = torch.Size([1, 4, 240, 240, 155])
2025-09-23 03:11:02,673 - brain_segmentation - DEBUG - 📐 affine.shape = (4, 4)
2025-09-23 03:12:00,617 - brain_segmentation - DEBUG - 🔢 val_outputs.shape = torch.Size([1, 4, 240, 240, 155])
2025-09-23 03:12:00,720 - brain_segmentation - INFO - 🧠 Generating segmentation mask...
2025-09-23 03:12:00,782 - brain_segmentation - INFO - 🏷️ Label 0: 8877345 voxels
2025-09-23 03:12:00,782 - brain_segmentation - INFO - 🏷️ Label 2: 23514 voxels
2025-09-23 03:12:00,783 - brain_segmentation - INFO - 🏷️ Label 4: 27141 voxels
2025-09-23 03:12:00,783 - brain_segmentation - INFO - 📊 Label 0: 99.43% of total volume


In [None]:
# for i, mask in enumerate(segmentation_mask):
#     print(f"Element {i} type: {type(mask)}")

In [None]:
for i, mask in enumerate(segmentation_mask):
    assert mask.shape == torch.Size([240, 240, 155]), f"Shape mismatch at index {i}: {mask.shape}"

In [None]:
for i, mask in enumerate(segmentation_mask):
    assert mask.shape == torch.Size([240, 240, 155]), f"Shape mismatch at index {i}: {mask.shape}"

In [None]:
segmentation_mask_sum = torch.sum(torch.stack(segmentation_mask, dim=0), dim=0)