# Overview
**BRAIN SEGMENTATION INFERENCE**
This Jupyter notebook is designed to run inference on a brain MRI scan using a pre-trained segmentation model. It downloads a model from the cloud (requiring AWS access keys to be set up properly) and then runs inference on a single scan. The output is a nifti file with the segmentation labels. You may wish to then load it in a program such as Slicer to view the results.

# Setup
First, set up a python environment with the necessary packages.

For instance, in VSCode or Cursor, press Ctrl+Shift+P and type "Python: Create Environment" and follow the prompts.

Or, on the command line, run:
```bash
python -m venv brain_segmentation_inference_env
```
--or--
```bash
python3 -m venv brain_segmentation_inference_env
```
and then either 
```bash
brain_segmentation_inference_env\Scripts\activate
```
on Windows, or
```bash
source brain_segmentation_inference_env/bin/activate
```
on macOS and Linux.

Then, install the necessary packages:
```bash
pip install -r requirements.txt
```

After all that, this cell should run without error and import all the necessary packages.


## Logging

In [None]:
import logging

# Create a named logger
logger = logging.getLogger('brain_segmentation')
logger.setLevel(logging.DEBUG)

# Create a console handler and set its level
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)

# Create a formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# Add the formatter to the console handler
console_handler.setFormatter(formatter)

# Add the console handler to the logger
logger.addHandler(console_handler)

# Turn off logs from all other loggers
for name in logging.root.manager.loggerDict:
    if name != 'brain_segmentation':
        logging.getLogger(name).setLevel(logging.CRITICAL)

logger.info('🚀 Setting up logging...')
logger.debug(f'🔧 Current logging level: {logger.getEffectiveLevel()}')
logger.info('✅ Logging setup complete.')


## Imports

In [None]:
logger.info('📦 Importing outside packages...')
import os
import torch
import json, yaml
import tempfile
import boto3
import pprint
from tqdm import tqdm
from smart_open import open
from monai.transforms import AsDiscrete, Activations
from monai.utils.enums import MetricReduction
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
import nibabel as nib
import numpy as np
logger.info('✅ Outside packages imported.')

logger.info('📦 Importing local packages...')
from core_common import get_loader_val, datafold_read
import custom_model
logger.info('✅ Local packages imported.')

logger.info('📦 Setting up device...')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
logger.info('✅ Device set up.')
logger.debug(f'🖥️ {device = }')

# Model
This part loads the model from the cloud.

## Download
This part downloads the model from the cloud.

You'll need to have an AWS profile named 'theta-model-downloader'. You can create this profile by running this command in your terminal:
```bash
aws configure --profile theta-model-downloader
```
and entering your AWS credentials.

If you do not have the AWS command line tools installed, you can install them by following the instructions [here](https://docs.aws.amazon.com/cli/latest/userguide/install-cliv2.html).

In [None]:
logger.info('📂 Loading configuration file...')
config_pth = 'config/config.yaml'

logger.info('🔓 Opening configuration file...')
with open(config_pth, 'r') as file:
    logger.info('📖 Reading YAML content...')
    inference_cfg = yaml.safe_load(file)
logger.debug(f'🔧 {inference_cfg = }')
logger.info('✅ Configuration loaded.')

logger.info('🔧 Defining utility function...')
def check_optional_key(x: dict, key_name, true_val):
    logger.debug(f'🔍 Checking for key: {key_name} with value: {true_val}')
    result = (key_name in x.keys()) and (x[key_name] == true_val)
    logger.debug(f'🔍 Result of check: {result}')
    return result
logger.info('✅ Utility function defined.')

logger.info('☁️ Setting up AWS S3 connection...')
bucket = inference_cfg['model']['bucket']
key = inference_cfg['model']['key']
logger.info('🔐 Creating AWS session...')
session = boto3.Session(
    profile_name='theta-model-downloader',
    region_name=inference_cfg['model']['region'])
logger.info('🔗 Creating S3 client...')
s3_client = session.client('s3')
logger.debug(f'🪣 {bucket = }')
logger.debug(f'🔑 {key = }')
logger.info('✅ AWS S3 connection set up.')

logger.info('📦 Fetching model metadata...')
logger.info('🔍 Retrieving object metadata from S3...')
metadata = s3_client.head_object(Bucket=bucket, Key=key)
file_size = metadata['ContentLength']
logger.info('🔧 Parsing model configuration from metadata...')
model_cfg = yaml.safe_load(metadata['Metadata']['cfg'])
logger.debug(f'📏 {file_size = }')
logger.debug(f'🔧 {model_cfg = }')
logger.info('✅ Model metadata fetched.')

logger.info('⬇️ Downloading model checkpoint...')
logger.info('📁 Creating temporary file...')
with tempfile.NamedTemporaryFile(mode='wb', suffix='.ckpt', delete=False) as temp_file:
    logger.info('🔄 Setting up progress bar...')
    with tqdm(total=file_size, unit='B', unit_scale=True, desc='Downloading checkpoint...') as progress_bar:
        def update_progress(chunk):
            progress_bar.update(chunk)
        logger.info('🔄 Starting file download...')
        s3_client.download_fileobj(Bucket=bucket, Key=key, Fileobj=temp_file, Callback=update_progress)

    logger.info('💾 Flushing temporary file...')
    temp_file.flush()
    logger.info('🔄 Loading checkpoint into memory...')
    checkpoint = torch.load(temp_file.name, map_location=device)
    logger.info('🗑️ Closing and removing temporary file...')
    temp_file.close()
    os.unlink(temp_file.name) 
logger.info('✅ Model checkpoint downloaded and loaded.')

logger.info('🔄 Processing state dictionary...')
state_dict = checkpoint['state_dict']
logger.info('🔄 Removing "model." prefix from state dict keys...')
state_dict = {key.replace('model.', ''): value for key, value in state_dict.items()}
logger.debug(f'🔑 State dict keys: {state_dict.keys()}')
logger.info('✅ State dictionary processed.')

logger.info('🔄 Printing model configuration...')
pprint.pprint(model_cfg)
logger.info('✅ Model configuration printed.')

## Load
Once the model has been downloaded, this section loads the model into memory.

In [None]:
def listify_3d(x: dict):
    logger.info('🧊 Listifying 3D dimensions...')
    dimensions = [x['h'], x['w'], x['d']]
    logger.debug(f'📏 Listified dimensions: {dimensions = }')
    return dimensions

logger.info('🎛️ Fetching hyperparameters...')
hparams = model_cfg['hyperparameter']
logger.debug(f'🎛️ Hyperparameters: {hparams = }')

logger.info('🔗 Fetching label union...')
union = model_cfg['data']['label_union']
logger.debug(f'🔗 Label union: {union = }')

logger.info('📐 Calculating ROI size...')
roi_size = listify_3d(hparams['roi'])
logger.debug(f'📐 ROI size: {roi_size = }')

logger.info('🏗️ Creating model instance...')
model = custom_model.CustomSwinUNETR(
    in_channels       = 4, # one per MRI modality: T1, T2, T1-Contrast, FLAIR
    img_size          = roi_size,
    out_channels      = 4 if union else 3, # one per label: tumor core, whole tumor, enhancing tumor
    feature_size      = hparams['feature_size'],
    use_checkpoint    = True,
    depths            = hparams['depths'],
    num_heads         = hparams['num_heads'],
    norm_name         = hparams['norm_name'],
    normalize         = hparams['normalize'],
    downsample        = hparams['downsample'],
    use_v2            = hparams['use_v2'],
    mlp_ratio         = hparams['mlp_ratio'],
    qkv_bias          = hparams['qkv_bias'],
    patch_size        = hparams['patch_size'],
    window_size       = hparams['window_size'],
)
logger.info('✅ Model instance created.')

logger.info('🔑 Fetching first model state key...')
first_model_state_key = next(iter(model.state_dict().keys()))
logger.debug(f'✅ First model state key: {first_model_state_key = }')

logger.info('🔑 Fetching first state dict key...')
first_state_dict_key = next(iter(state_dict.keys()))
logger.debug(f'✅ First state dict key: {first_state_dict_key = }')

logger.info('💾 Loading state dictionary into model...')
model.load_state_dict(state_dict)
logger.info('✅ State dictionary loaded.')

logger.info('🖥️ Moving model to device...')
model.to(device)
logger.info('✅ Model moved to device.')

logger.info('🧠 Setting model to evaluation mode...')
model.eval()
logger.info('✅ Model set to evaluation mode.')

# Data
This section loads the data defined in your `config/config.yaml` file. Please see the example config file for more details. It requires you to have a nifti file for each MRI modality, typically T1, T2, T1-Contrast, and FLAIR, and the location to store the output nifti segmentation file. There are also instructions for loading multiple scans from a json file.


In [None]:
logger.info('🔢 Setting up fold and configuration...')
fold = 1  # Let the validation fold be 1 - same convention as during training.
logger.debug(f'📊 {fold = }')

logger.info('🔍 Checking optional configuration keys...')
do_ground_truth: bool = check_optional_key(inference_cfg['input'], 'compute_dice', True)
logger.debug(f'🎯 {do_ground_truth = }')
use_scan_list: bool = check_optional_key(inference_cfg['input'], 'mode', 'multi')
logger.debug(f'📋 {use_scan_list = }')

data_dir = inference_cfg['input']['data_dir']
logger.debug(f'📂 {data_dir = }')

logger.info('📁 Loading scan data...')
if use_scan_list:
    logger.info('📚 Using multiple scans from JSON...')
    json_path = inference_cfg['input']['multi_scan']['scan_list']
    _, validation_files = datafold_read(datalist=json_path, basedir=data_dir, fold=fold)
    with open(json_path) as f:
        test_instance = json.load(f)['training'][0]['image'][2]  # To get image size, grab the T1 scan.
else:
    logger.info('🖼️ Using single scan...')
    test_instance = inference_cfg['input']['single_scan']['t1']  # To get image size, grab the T1 scan.
    json_data = {
        'training': [
            {
                'fold': fold,
                'image': [
                    inference_cfg['input']['single_scan']['flair'],
                    inference_cfg['input']['single_scan']['t1c'],
                    inference_cfg['input']['single_scan']['t1'],
                    inference_cfg['input']['single_scan']['t2']
                ]
            }
        ]
    }
    if do_ground_truth:
        json_data['training'][0]['label'] = inference_cfg['input']['single_scan']['ground_truth']

    logger.info('📝 Creating temporary JSON file...')
    with tempfile.NamedTemporaryFile(mode='w+', suffix='.json', delete=False) as temp_file:  # Use delete=False to fix a permissions error on Windows.
        json.dump(json_data, temp_file, indent=4)
        temp_file.flush()
        json_path = temp_file.name
        _, validation_files = datafold_read(datalist=json_path, basedir=data_dir, fold=fold)
        temp_file.close()
        os.unlink(temp_file.name)
    logger.info('✅ Temporary JSON file created and processed.')

logger.info('📏 Loading image size...')
resize_shape = list(nib.load(os.path.join(data_dir, test_instance)).shape)
logger.debug(f'📐 Scan size: {resize_shape = }')

logger.info('🔄 Setting up validation data loader...')
val_loader = get_loader_val(
    batch_size=1,
    files=validation_files,
    val_resize=resize_shape,
    union=union,
    workers=1,
    cache_dir='',
    dataset_type='Dataset',
    add_label=do_ground_truth,
)
logger.info('✅ Validation data loader setup complete.')

# Inference
This section takes the loaded model and runs inference on loaded scans, and saves the output to the location specified in your `config/config.yaml` file.

In [None]:
with torch.no_grad():
    logger.info('🧮 Setting up accuracy function...')
    acc_func = DiceMetric(
        include_background=True,
        reduction=MetricReduction.MEAN_BATCH,
    )
    
    i = 0
    dice_scores = []
    
    logger.debug(f'🔢 {i = }')
    logger.debug(f'📊 {dice_scores = }')

    for val_data in val_loader:
        logger.info('📥 Loading data...')
        val_images = val_data['image']
        logger.debug(f'🖼️ {val_images.shape = }')

        logger.info('📜 Loading affine matrix...')
        if use_scan_list:
            with open(json_path) as f:
                d = json.load(f)
            d['training'] = [entry for entry in d['training'] if entry['fold'] == fold]
            affine_path = os.path.join(data_dir, d['training'][i]['image'][2])
        else:
            affine_path = os.path.join(data_dir, inference_cfg['input']['single_scan']['t1'])
        affine = nib.load(affine_path).affine
        logger.debug(f'📐 {affine.shape = }')

        logger.info('🔍 Performing sliding window inference...')
        val_outputs = sliding_window_inference(
            val_images.to(device),
            roi_size=roi_size,
            sw_batch_size=4,
            predictor=model,
        )
        logger.debug(f'🔢 {val_outputs.shape = }')

        logger.info('🔄 Applying post-processing...')
        post_sigmoid = Activations(sigmoid=True)
        post_pred = AsDiscrete(argmax=False, threshold=0.5)
        val_outputs_convert = [post_pred(post_sigmoid(val_pred_tensor)) for val_pred_tensor in val_outputs]
        logger.debug(f'🔢 {len(val_outputs_convert) = }')

        if do_ground_truth:
            logger.info('📊 Calculating DICE scores...')
            ground_truth = val_data['label']
            acc_func.reset()
            acc_func(y_pred=val_outputs_convert, y=ground_truth.to(device))
            acc = acc_func.aggregate().cpu().numpy()
            num_zeroes = [acc[0], acc[1], acc[2]].count(0.0)
            mean = (acc[0] + acc[1] + acc[2]) / (3 - num_zeroes) if num_zeroes < 3 else 0  # Ignore cases with zero DICE
            
            logger.info(f'📊 DICE (tc): {acc[0]}')
            logger.info(f'📊 DICE (wt): {acc[1]}')
            logger.info(f'📊 DICE (et): {acc[2]}')
            logger.info(f'📊 DICE (mean): {mean}')
            dice_scores.append([acc[0], acc[1], acc[2], mean])

        logger.info('🔄 Processing output...')
        val_outputs = val_outputs.clone().cpu().numpy().squeeze()
        logger.debug(f'🔢 {val_outputs.shape = }')
        segmentation_mask = [post_pred(post_sigmoid(val_pred_tensor)) for val_pred_tensor in val_outputs]
        segmentation_mask = segmentation_mask[:-1]  # Discard the union
        segmentation_mask = np.sum(segmentation_mask, axis=0)
        for k, v in {1: 5, 3: 4, 2: 1, 5: 2}.items():  # Correctly assign labels. Use 5 as a temporary for swapping.
            segmentation_mask[segmentation_mask == k] = v

        logger.info('🏷️ Identifying unique labels and their counts...')
        unique_labels, label_counts = np.unique(segmentation_mask, return_counts=True)
        logger.debug(f'🔢 {unique_labels = }')
        logger.debug(f'🔢 {label_counts = }')
        for label, count in zip(unique_labels, label_counts):
            logger.info(f'🏷️ Label {int(label)}: {count} voxels')
        total_voxels = np.prod(segmentation_mask.shape)
        logger.info(f'📊 Total voxels: {total_voxels}')
        for label, count in zip(unique_labels, label_counts):
            percentage = (count / total_voxels) * 100
            logger.info(f'📊 Label {int(label)}: {percentage:.2f}% of total volume')

        logger.info('🧠 Creating NIfTI image...')
        nii = nib.Nifti1Image(segmentation_mask.astype(np.uint8), affine=affine)

        if use_scan_list:  # If inferring on multiple scans, add the T1 file path to the name so they can be easily distinguished.
            dirs = affine_path.replace('/', '.')
            dir, fname = os.path.split(inference_cfg['output']['file_path'])
            save_pth = os.path.join(dir, f'{dirs}_{fname}')
        else:
            save_pth = inference_cfg['output']['file_path']
            
        logger.info('📁 Creating output directory...')
        output_directory = os.path.dirname(save_pth)
        os.makedirs(output_directory, exist_ok=True)
        logger.info(f'💾 Saving as {save_pth}...')
        nib.save(nii, save_pth)
        i += 1

if do_ground_truth:
    logger.info('📊 Calculating mean DICE scores...')
    mask = np.ma.masked_equal(dice_scores, 0)  # Ignore cases with zero DICE
    mean = mask.mean(axis=0).filled(np.nan)

    logger.info('📊 Mean DICE scores:')
    logger.info(f'📊 DICE (tc): {mean[0]}')
    logger.info(f'📊 DICE (wt): {mean[1]}')
    logger.info(f'📊 DICE (et): {mean[2]}')
    logger.info(f'📊 DICE (mean): {mean[3]}')

logger.info('✅ Inference complete.')