# Imports and functions

In [None]:
import sys
sys.path.append("..")
import os
from monai.transforms import Compose, LoadImage, CropForeground, EnsureChannelFirst, ResizeWithPadOrCrop, ScaleIntensityRange
from guided_diffusion.c_unet import SuperResModel, UNetModel, EncoderUNetModel
import torch
import torch as th
from diffusers import DDPMScheduler, DPMSolverMultistepScheduler
from DWT_IDWT.DWT_IDWT_layer import IDWT_3D, DWT_3D
import nibabel as nib
import numpy as np
idwt = IDWT_3D("haar")
dwt = DWT_3D("haar")
from monai.data import load_decathlon_datalist, DataLoader, CacheDataset
from monai.transforms import (
    Compose, 
    LoadImaged,
    EnsureChannelFirstd, 
    EnsureTyped,
    Orientationd,
    ScaleIntensityRanged, 
    ResizeWithPadOrCropd,
    CopyItemsd
    )
from utils.data_loader_utils import ConvertToMultiChannel_BackandForeground_Contrastd
from tqdm import tqdm
import torch
from monai.transforms import Resize

In [None]:
def get_tensor(file_path, norm, clip):
    """
    Loads the nii.gz file, and normalises if necessary.
    Arguments:
        file_path (str): Path to the nii.gz file.
        norm (bool): True for clipping and normalisation.
    Return:
        Numpy array of nii.gz file.
    """
    transforms = [
        LoadImage(image_only=True),
        EnsureChannelFirst()
        ]
    if clip:
        transforms.append(
        ScaleIntensityRange(a_min=-200, a_max=200, b_min=-200, b_max=200, clip=True)
        )
    if norm:
        transforms.append(
        ScaleIntensityRange(a_min=-200, a_max=200, b_min=-1, b_max=1, clip=True)
        )
    apply_transforms = Compose(transforms)
    np_tensor = apply_transforms(file_path)[0].numpy()
    return np_tensor

def get_segmentation(file_path):
    """
    Load the segmentation, crops the foreground and reshape to 128x128x128 using padding.
    This ensures that the segmentation is in the middle of the volume
    Arguments:
        file_path (str): Path to the segmentation file.
    Return:
        Numpy array of the segmentation.
    """
    transforms = Compose([
        LoadImage(image_only=True),
        EnsureChannelFirst(),
        CropForeground(select_fn=lambda x: x > 0, margin=0),
        ResizeWithPadOrCrop(spatial_size=(128,128,128))
    ])
    segmentation = transforms(file_path)[0].numpy()
    return segmentation

def rescale_array(arr, minv, maxv): #monai function adapted
    """
    Rescale the values of numpy array `arr` to be from `minv` to `maxv`.
    """
    if isinstance(arr, np.ndarray):
        mina = np.min(arr)
        maxa = np.max(arr)
    elif isinstance(arr, th.Tensor):
        mina = th.min(arr)
        maxa = th.max(arr)
    if mina == maxa:
        return arr * minv
    # normalize the array first
    norm = (arr - mina) / (maxa - mina) 
    # rescale by minv and maxv, which is the normalized array by default 
    return (norm * (maxv - minv)) + minv  

from scipy.ndimage import center_of_mass
def get_crop_tensors(healthy_ct_scan_full_res, region_to_place_tumour_mask, segmentation, device):
    """
    Selects a random center and crops the volume with that center and shape 128x128x128.
    Arguments:
        healthy_ct_scan_full_res (numpy array): Healthy volume.
        region_to_place_tumour_mask (numpy array): Mask of the region to where the tumour can be placed.
        segmentation (numpy array): Tumour segmentation.
    """
    centroid = center_of_mass(segmentation)
    random_x, random_y, random_z = int(centroid[0]), int(centroid[1]), int(centroid[2])

    # Padding the volume so no region ouside of the volume is selected
    healthy_ct_scan_full_res = np.pad(healthy_ct_scan_full_res, pad_width=64, mode='constant', constant_values=-200) # -200 background
    region_to_place_tumour_mask = np.pad(region_to_place_tumour_mask, pad_width=64, mode='constant', constant_values=1) # 1 means that the tumour cannot be placed there
    
    # Select a random center
    voxel_indices = np.argwhere(region_to_place_tumour_mask == 2)

    # Crop the full resolution scan and mask
    healthy_ct_scan = healthy_ct_scan_full_res[
        random_x-64:random_x+64,
        random_y-64:random_y+64,
        random_z-64:random_z+64
        ]
    region_to_place_tumour_mask_crop = region_to_place_tumour_mask[ 
        random_x-64:random_x+64,
        random_y-64:random_y+64,
        random_z-64:random_z+64
        ] 
    # Ensure the segmentation remains within the anatomical boundaries defined by the region_to_place_tumour_mask_crop
    segmentation[region_to_place_tumour_mask_crop == 1] = 0

    # Keep the original intensities of the cropped region
    healthy_ct_scan_origin_intensities = np.copy(healthy_ct_scan)
    
    # Convert to torch and add two dimentions
    healthy_ct_scan = th.from_numpy(healthy_ct_scan)
    healthy_ct_scan = rescale_array(healthy_ct_scan, minv=-1, maxv=1)
    healthy_ct_scan = healthy_ct_scan.unsqueeze(0).unsqueeze(0).to(device)
    
    healthy_ct_scan_origin_intensities = th.from_numpy(healthy_ct_scan_origin_intensities)
    healthy_ct_scan_origin_intensities = healthy_ct_scan_origin_intensities.unsqueeze(0).unsqueeze(0).to(device)
    segmentation = th.from_numpy(segmentation)
    segmentation = segmentation.unsqueeze(0).unsqueeze(0).to(device)
    healthy_ct_scan_full_res = th.from_numpy(healthy_ct_scan_full_res)
    healthy_ct_scan_full_res = healthy_ct_scan_full_res.unsqueeze(0).unsqueeze(0).to(device)
    
    return healthy_ct_scan_full_res, healthy_ct_scan, healthy_ct_scan_origin_intensities, segmentation, random_voxel_indices

def get_affine_and_header(file_path):
  """
  Extracts the affine transformation matrix and header information from a NIfTI file.
  Args:
    filename (str): The path to the NIfTI file.
  Returns:
    tuple: A tuple containing the affine matrix and header information.
  """
  img = nib.load(file_path)
  affine = img.affine
  header = img.header
  return affine, header

In [None]:
def get_model(in_channels=11, out_channels=8, channel_mult=[1, 2, 2, 4, 4, 4], label_cond_in_channels=0, use_label_cond_conv=False, pretrained_weights_path=None):
    model = UNetModel(
        image_size=128,
        in_channels=in_channels,
        model_channels=64,
        out_channels=out_channels,
        num_res_blocks=2,
        attention_resolutions=tuple([]),
        dropout=0.0,
        channel_mult=channel_mult,
        num_classes=None,
        use_checkpoint=False,
        use_fp16=False,
        num_heads=1,
        num_head_channels=-1,
        num_heads_upsample=-1,
        use_scale_shift_norm=False,
        resblock_updown=True,
        use_new_attention_order=False,
        dims=3,
        num_groups=32,
        bottleneck_attention=False,
        additive_skips=True,
        resample_2d=False,
        label_cond_in_channels=label_cond_in_channels,
        use_label_cond_conv=use_label_cond_conv,
    )
    # Load the pre-trained weights
    state_dict = torch.load(pretrained_weights_path, map_location=torch.device('cuda:0'))  # Load to CPU, or adjust for GPU if needed

    # Load weights into the model
    model.load_state_dict(state_dict)

    return model

In [None]:
def get_scheduler(sch, num_inference_steps):
    if sch=="DPM++_2M":
        use_karras_sigmas = False
        algorithm_type = "dpmsolver++"
    elif sch=="DPM++_2M_Karras":
        use_karras_sigmas = True
        algorithm_type = "dpmsolver++"
    elif sch=="DPM++_2M_SDE":
        use_karras_sigmas = False
        algorithm_type = "sde-dpmsolver++"
    elif sch=="DPM++_2M_SDE_Karras":
        use_karras_sigmas = True
        algorithm_type = "sde-dpmsolver++"
        
    scheduler = DPMSolverMultistepScheduler(
            num_train_timesteps=1000, 
            variance_type="fixed_large", 
            prediction_type="sample", 
            use_karras_sigmas=use_karras_sigmas, 
            algorithm_type=algorithm_type
            #use_beta_sigmas=True # https://huggingface.co/papers/2407.12173
            )
    scheduler.set_timesteps(num_inference_steps=num_inference_steps)
    return scheduler

# Prediction for CT scans

## Whole CT scan generation

In [None]:
def get_loader(image_key, label_key, clip_min, clip_max, image_size, no_seg, full_background, data_split_json, base_dir):
    train_transforms = [
            LoadImaged(keys=[image_key, label_key], meta_key_postfix="meta_dict", image_only=False),
            EnsureChannelFirstd(keys=[image_key, label_key]),
            EnsureTyped(keys=[image_key, label_key], dtype=torch.float32),
            Orientationd(keys=[image_key, label_key], axcodes="RAS"),
            ScaleIntensityRanged(keys=[image_key], a_min=float(clip_min), a_max=float(clip_max), b_min=-1.0, b_max=1.0, clip=True),
            ResizeWithPadOrCropd(
                    keys=[image_key, label_key],
                    spatial_size=image_size,
                    mode="constant",
                    value=-1 # The value was -1 originally
                ),
            ConvertToMultiChannel_BackandForeground_Contrastd(
                    keys=[label_key], no_seg=no_seg, full_background=full_background
                    )
        ]
    train_transforms.append(EnsureTyped(keys=[image_key, label_key], dtype=torch.float32))
    train_transforms_final =  Compose(train_transforms)

    data_set = load_decathlon_datalist(
                data_split_json,
                is_segmentation=True,
                data_list_key="training",
                base_dir=base_dir,
            )

    print(f"Training cases: {len(data_set)}")

    print(data_set[-1:])
    print(f"TOTAL cases {len(data_set)}")
    # Creating traing dataset
    ds = CacheDataset( 
        data=data_set, 
        transform=train_transforms_final,
        cache_rate=0, 
        copy_cache=False,
        progress=True,
        num_workers=4,
    )
    dl = DataLoader(
                ds,
                batch_size=1,
                num_workers=4,
                pin_memory=torch.cuda.is_available(),
                shuffle=False, 
                #collate_fn=no_collation,
            )
    return dl, ds, data_set

---
### Tumour generation

#### hnn_CT_conv_before_concat__DA_tumorW_0_28_11_2024_11:19:14
* HU between -200 and 200. tumour weight 0. DA ROI. ROI and segmentation as condition, feeded first to a conv layer.

In [None]:
# Fixed for CT
image_key = "image"
label_key = "seg"
image_size = (256, 256, 256)
data_split_json = "../../HnN_cancer_data/HnN_cancer_data_1_1_1_256_256_256/data_split.json"
base_dir = "../../HnN_cancer_data/HnN_cancer_data_1_1_1_256_256_256"
full_background = False
no_seg = False

clip_min = -200
clip_max = 200
in_channels = 32
label_cond_in_channels = 3
use_label_cond_conv = True
pretrained_weights_path = '../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/runs/hnn_CT_conv_before_concat__DA_tumorW_0_28_11_2024_11:19:14/checkpoints/hnn_1300000.pt'  # Specify the correct path
    
model = get_model(in_channels=in_channels, 
                  label_cond_in_channels=label_cond_in_channels, 
                  use_label_cond_conv=use_label_cond_conv,
                  pretrained_weights_path=pretrained_weights_path)
model.eval()
model.cuda()

dl, ds, data_set = get_loader(image_key=image_key, 
                              label_key=label_key, 
                              clip_min=clip_min, 
                              clip_max=clip_max, 
                              image_size=image_size, 
                              no_seg=no_seg, 
                              full_background=full_background, 
                              data_split_json=data_split_json, 
                              base_dir=base_dir)
print("Loaded")

In [None]:
def run_inference(model, scheduler_list, n, num_inference_steps, clip_min, clip_max, out_path):
    model.cuda()
    for sch in scheduler_list:
        scheduler = get_scheduler(sch, num_inference_steps)

        for idx, batch  in enumerate(dl):
            noise_start = torch.randn(1, 8, 128, 128, 128)  
            # Prepare the noisy image
            final_scan = noise_start.clone().detach()
            final_scan = final_scan.cuda()
            input_model = final_scan.cuda()
            label_condition = batch["seg"].cuda()
            segmentation = label_condition[0][2]
            no_contrast_tensor = label_condition[0][0]
            contrast_tensor = label_condition[0][1]

            # Start the reverse process (denoising from noise)
            for timestep in tqdm(scheduler.timesteps, desc="Processing timesteps"):
                # Get the current timestep's noise
                t = torch.tensor([timestep] * final_scan.shape[0])
                t = t.cuda()
                # Perform one step of denoising
                with torch.no_grad():
                    model_kwargs = {}
                    noise_pred = model(input_model, timesteps=t, label_condition=label_condition, **model_kwargs)
                    # Update the noisy_latents (reverse the noise process)
                    final_scan = scheduler.step(model_output=noise_pred, timestep=timestep, sample=final_scan)
                    final_scan = final_scan['prev_sample']
                    input_model = final_scan
            B, C, D, H, W = final_scan.size()
            final_scan = idwt(final_scan[:, 0, :, :, :].view(B, 1, H, W, D) * 3.,
                        final_scan[:, 1, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 2, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 3, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 4, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 5, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 6, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 7, :, :, :].view(B, 1, H, W, D))
            # Assuming final_image is a PyTorch tensor
            # Convert the final_image tensor to a NumPy array if it's a tensor
            final_image_np = final_scan.squeeze().cpu().numpy()  # Remove the channel dim and move to CPU

            if th.sum(contrast_tensor) != 0:
                ending_name = "_Contrast"
                cube_coords = th.nonzero(contrast_tensor) 
                min_coords = cube_coords.min(dim=0)[0]  # Minimum x, y, z coordinates
                max_coords = cube_coords.max(dim=0)[0]  # Maximum x, y, z coordinates
            else:
                ending_name = "out_contrast"
                cube_coords = th.nonzero(no_contrast_tensor) 
                min_coords = cube_coords.min(dim=0)[0]  # Minimum x, y, z coordinates
                max_coords = cube_coords.max(dim=0)[0]  # Maximum x, y, z coordinates
            case_name = "generated"
            synth_ct_scan_output = os.path.join(out_path, f'CT_{ending_name}_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            sample_denorm = np.clip(final_image_np, a_min=-1, a_max=1) # remove very high and low values

            sample_denorm = rescale_array(
                            arr=sample_denorm, 
                            minv=int(clip_min), 
                            maxv=int(clip_max)
                            )
            # Cropping the output of the model considering the ROI
            x_min, y_min, z_min = min_coords
            x_max, y_max, z_max = max_coords
            sample_denorm_corrected = sample_denorm#[x_min:x_max, y_min:y_max, z_min:z_max]
            # Create a NIfTI image from the NumPy array
            nii_image = nib.Nifti1Image(sample_denorm_corrected, affine=np.eye(4))  # Identity affine for simplicity

            # Save the NIfTI image as a .nii.gz file
            nib.save(nii_image, synth_ct_scan_output)

            nii_image = nib.Nifti1Image(segmentation.cpu().numpy(), affine=np.eye(4))  # Identity affine for simplicity
            seg_ct_scan_output = os.path.join(out_path, f'label_{ending_name}_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            nib.save(nii_image, seg_ct_scan_output)
            if idx+1 == n:
                break
            
scheduler_list = ["DPM++_2M", "DPM++_2M_Karras", "DPM++_2M_SDE", "DPM++_2M_SDE_Karras"]
n = 1
num_inference_steps = 100
out_path = "../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/notebooks/trash/hnn_CT_conv_before_concat__DA_tumorW_0_28_11_2024_11:19:14"
os.makedirs(out_path, exist_ok=True)
run_inference(model=model,
              scheduler_list=scheduler_list,
               n=n, 
               num_inference_steps=100, 
               clip_min=clip_min,
               clip_max=clip_max, 
               out_path=out_path)            
        

#### hnn_CT_concat_cond__DA_tumorW_0_28_11_2024_11:36:09 
* HU between -200 and 200. tumour weight 0. DA ROI. downsampled ROI and segmentation as condition.

In [None]:
def run_inference(model, scheduler_list, n, num_inference_steps, clip_min, clip_max, out_path):
    model.cuda()
    for sch in scheduler_list:
        scheduler = get_scheduler(sch, num_inference_steps)

        for idx, batch  in enumerate(dl):
            noise_start = torch.randn(1, 8, 128, 128, 128)  
            # Prepare the noisy image
            final_scan = noise_start.clone().detach()
            final_scan = final_scan.cuda()

            label_condition = batch["seg"].cuda()
            segmentation = label_condition[0][2]
            no_contrast_tensor = label_condition[0][0]
            contrast_tensor = label_condition[0][1]


            # create input model
            resize = Resize((128, 128, 128), size_mode='all', mode="nearest", align_corners=None, anti_aliasing=False, anti_aliasing_sigma=None, dtype=torch.float32, lazy=False)
            label_cond_down = resize(label_condition[0]).unsqueeze(0)
            input_model = torch.cat((final_scan, label_cond_down), dim=1)
            input_model = input_model.cuda()

            # Start the reverse process (denoising from noise)
            for timestep in tqdm(scheduler.timesteps, desc="Processing timesteps"):
                # Get the current timestep's noise
                t = torch.tensor([timestep] * final_scan.shape[0])
                t = t.cuda()
                # Perform one step of denoising
                with torch.no_grad():
                    model_kwargs = {}
                    noise_pred = model(input_model, timesteps=t, label_condition=label_condition, **model_kwargs)
                    # Update the noisy_latents (reverse the noise process)
                    final_scan = scheduler.step(model_output=noise_pred, timestep=timestep, sample=final_scan)
                    final_scan = final_scan['prev_sample']
                    input_model = torch.cat((final_scan, label_cond_down), dim=1)
            B, C, D, H, W = final_scan.size()
            final_scan = idwt(final_scan[:, 0, :, :, :].view(B, 1, H, W, D) * 3.,
                        final_scan[:, 1, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 2, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 3, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 4, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 5, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 6, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 7, :, :, :].view(B, 1, H, W, D))
            # Assuming final_image is a PyTorch tensor
            # Convert the final_image tensor to a NumPy array if it's a tensor
            final_image_np = final_scan.squeeze().cpu().numpy()  # Remove the channel dim and move to CPU

            if th.sum(contrast_tensor) != 0:
                ending_name = "_Contrast"
                cube_coords = th.nonzero(contrast_tensor) 
                min_coords = cube_coords.min(dim=0)[0]  # Minimum x, y, z coordinates
                max_coords = cube_coords.max(dim=0)[0]  # Maximum x, y, z coordinates
            else:
                ending_name = "out_contrast"
                cube_coords = th.nonzero(no_contrast_tensor) 
                min_coords = cube_coords.min(dim=0)[0]  # Minimum x, y, z coordinates
                max_coords = cube_coords.max(dim=0)[0]  # Maximum x, y, z coordinates
            case_name = "generated"
            synth_ct_scan_output = os.path.join(out_path, f'CT_{ending_name}_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            sample_denorm = np.clip(final_image_np, a_min=-1, a_max=1) # remove very high and low values

            sample_denorm = rescale_array(
                            arr=sample_denorm, 
                            minv=int(clip_min), 
                            maxv=int(clip_max)
                            )
            # Cropping the output of the model considering the ROI
            x_min, y_min, z_min = min_coords
            x_max, y_max, z_max = max_coords
            sample_denorm_corrected = sample_denorm#[x_min:x_max, y_min:y_max, z_min:z_max]
            # Create a NIfTI image from the NumPy array
            nii_image = nib.Nifti1Image(sample_denorm_corrected, affine=np.eye(4))  # Identity affine for simplicity

            # Save the NIfTI image as a .nii.gz file
            nib.save(nii_image, synth_ct_scan_output)

            nii_image = nib.Nifti1Image(segmentation.cpu().numpy(), affine=np.eye(4))  # Identity affine for simplicity
            seg_ct_scan_output = os.path.join(out_path, f'label_{ending_name}_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            nib.save(nii_image, seg_ct_scan_output)
            if idx+1 == n:
                break
            
           
                

In [None]:
# Fixed for CT
image_key = "image"
label_key = "seg"
image_size = (256, 256, 256)
data_split_json = "../../HnN_cancer_data/HnN_cancer_data_1_1_1_256_256_256/data_split.json"
base_dir = "../../HnN_cancer_data/HnN_cancer_data_1_1_1_256_256_256"
full_background = False
no_seg = False

# To change
clip_min = -200
clip_max = 200
in_channels = 11
label_cond_in_channels = 0
use_label_cond_conv = False
pretrained_weights_path = '../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/runs/hnn_CT_concat_cond__DA_tumorW_0_28_11_2024_11:36:09/checkpoints/hnn_1355000.pt'  # Specify the correct path
    
model = get_model(in_channels=in_channels, 
                  label_cond_in_channels=label_cond_in_channels, 
                  use_label_cond_conv=use_label_cond_conv,
                  pretrained_weights_path=pretrained_weights_path)
model.eval()
model.cuda()

dl, ds, data_set = get_loader(image_key=image_key, 
                              label_key=label_key, 
                              clip_min=clip_min, 
                              clip_max=clip_max, 
                              image_size=image_size, 
                              no_seg=no_seg, 
                              full_background=full_background, 
                              data_split_json=data_split_json, 
                              base_dir=base_dir)
print("Loaded model and data loader")

# Control inference parameters
scheduler_list = ["DPM++_2M", "DPM++_2M_Karras", "DPM++_2M_SDE", "DPM++_2M_SDE_Karras"]
n = 1
num_inference_steps = 100
out_path = "../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/notebooks/trash/hnn_CT_concat_cond__DA_tumorW_0_28_11_2024_11:36:09"
os.makedirs(out_path, exist_ok=True)
run_inference(model=model,
              scheduler_list=scheduler_list,
               n=n, 
               num_inference_steps=num_inference_steps, 
               clip_min=clip_min,
               clip_max=clip_max, 
               out_path=out_path) 

#### hnn_CT_wavelet_cond__DA_tumorW_0_3_12_2024_15:36:29 
* HU between -200 and 200. tumour weight 0. DA ROI. wavelet tranformed ROI and segmentation as condition.

In [None]:
def run_inference(model, scheduler_list, n, num_inference_steps, clip_min, clip_max, out_path):
    model.cuda()
    for sch in scheduler_list:
        scheduler = get_scheduler(sch, num_inference_steps)

        for idx, batch  in enumerate(dl):
            noise_start = torch.randn(1, 8, 128, 128, 128)  
            # Prepare the noisy image
            final_scan = noise_start.clone().detach()
            final_scan = final_scan.cuda()

            label_condition = batch["seg"].cuda()
            segmentation = label_condition[0][2]
            no_contrast_tensor = label_condition[0][0]
            contrast_tensor = label_condition[0][1]

            LLL = None
            # create input model
            for condition in label_condition[0]:
                condition = condition.unsqueeze(0).unsqueeze(0)
                if LLL==None:
                    LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH = dwt(condition)
                    cond_dwt = th.cat([LLL / 3., LLH, LHL, LHH, HLL, HLH, HHL, HHH], dim=1)
                else:
                    LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH = dwt(condition)
                    cond_dwt = th.cat([cond_dwt, LLL / 3., LLH, LHL, LHH, HLL, HLH, HHL, HHH], dim=1)
            input_model = torch.cat((final_scan, cond_dwt), dim=1)
            input_model = input_model.cuda()

            # Start the reverse process (denoising from noise)
            for timestep in tqdm(scheduler.timesteps, desc="Processing timesteps"):
                # Get the current timestep's noise
                t = torch.tensor([timestep] * final_scan.shape[0])
                t = t.cuda()
                # Perform one step of denoising
                with torch.no_grad():
                    model_kwargs = {}
                    noise_pred = model(input_model, timesteps=t, label_condition=label_condition, **model_kwargs)
                    # Update the noisy_latents (reverse the noise process)
                    final_scan = scheduler.step(model_output=noise_pred, timestep=timestep, sample=final_scan)
                    final_scan = final_scan['prev_sample']
                    input_model = torch.cat((final_scan, cond_dwt), dim=1)
            B, C, D, H, W = final_scan.size()
            final_scan = idwt(final_scan[:, 0, :, :, :].view(B, 1, H, W, D) * 3.,
                        final_scan[:, 1, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 2, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 3, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 4, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 5, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 6, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 7, :, :, :].view(B, 1, H, W, D))
            # Assuming final_image is a PyTorch tensor
            # Convert the final_image tensor to a NumPy array if it's a tensor
            final_image_np = final_scan.squeeze().cpu().numpy()  # Remove the channel dim and move to CPU

            if th.sum(contrast_tensor) != 0:
                ending_name = "_Contrast"
                cube_coords = th.nonzero(contrast_tensor) 
                min_coords = cube_coords.min(dim=0)[0]  # Minimum x, y, z coordinates
                max_coords = cube_coords.max(dim=0)[0]  # Maximum x, y, z coordinates
            else:
                ending_name = "out_contrast"
                cube_coords = th.nonzero(no_contrast_tensor) 
                min_coords = cube_coords.min(dim=0)[0]  # Minimum x, y, z coordinates
                max_coords = cube_coords.max(dim=0)[0]  # Maximum x, y, z coordinates
            case_name = "generated"
            synth_ct_scan_output = os.path.join(out_path, f'CT_{ending_name}_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            sample_denorm = np.clip(final_image_np, a_min=-1, a_max=1) # remove very high and low values

            sample_denorm = rescale_array(
                            arr=sample_denorm, 
                            minv=int(clip_min), 
                            maxv=int(clip_max)
                            )
            # Cropping the output of the model considering the ROI
            x_min, y_min, z_min = min_coords
            x_max, y_max, z_max = max_coords
            sample_denorm_corrected = sample_denorm#[x_min:x_max, y_min:y_max, z_min:z_max]
            # Create a NIfTI image from the NumPy array
            nii_image = nib.Nifti1Image(sample_denorm_corrected, affine=np.eye(4))  # Identity affine for simplicity

            # Save the NIfTI image as a .nii.gz file
            nib.save(nii_image, synth_ct_scan_output)

            nii_image = nib.Nifti1Image(segmentation.cpu().numpy(), affine=np.eye(4))  # Identity affine for simplicity
            seg_ct_scan_output = os.path.join(out_path, f'label_{ending_name}_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            nib.save(nii_image, seg_ct_scan_output)
            if idx+1 == n:
                break
            
           
                

In [None]:
# Fixed for CT
image_key = "image"
label_key = "seg"
image_size = (256, 256, 256)
data_split_json = "../../HnN_cancer_data/HnN_cancer_data_1_1_1_256_256_256/data_split.json"
base_dir = "../../HnN_cancer_data/HnN_cancer_data_1_1_1_256_256_256"
full_background = False
no_seg = False

# To change
clip_min = -200
clip_max = 200
in_channels = 32
label_cond_in_channels = 0
use_label_cond_conv = False
pretrained_weights_path = '../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/runs/hnn_CT_wavelet_cond__DA_tumorW_0_3_12_2024_15:36:29/checkpoints/hnn_985000.pt'  # Specify the correct path
    
model = get_model(in_channels=in_channels, 
                  label_cond_in_channels=label_cond_in_channels, 
                  use_label_cond_conv=use_label_cond_conv,
                  pretrained_weights_path=pretrained_weights_path)
model.eval()
model.cuda()

dl, ds, data_set = get_loader(image_key=image_key, 
                              label_key=label_key, 
                              clip_min=clip_min, 
                              clip_max=clip_max, 
                              image_size=image_size, 
                              no_seg=no_seg, 
                              full_background=full_background, 
                              data_split_json=data_split_json, 
                              base_dir=base_dir)
print("Loaded model and data loader")

# Control inference parameters
scheduler_list = ["DPM++_2M", "DPM++_2M_Karras", "DPM++_2M_SDE", "DPM++_2M_SDE_Karras"]
n = 1
num_inference_steps = 100
out_path = "../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/notebooks/trash/hnn_CT_wavelet_cond__DA_tumorW_0_3_12_2024_15:36:29"
os.makedirs(out_path, exist_ok=True)
run_inference(model=model,
              scheduler_list=scheduler_list,
               n=n, 
               num_inference_steps=num_inference_steps, 
               clip_min=clip_min,
               clip_max=clip_max, 
               out_path=out_path) 

---
### Bone segmentation 

#### hnn_CT_concat_cond__DA_tumorW_0_28_11_2024_14:15:39
* HU between -200 and 200. DA ROI. downsampled ROI as condition. 

In [None]:
def run_inference(model, scheduler_list, n, num_inference_steps, clip_min, clip_max, out_path):
    model.cuda()
    for sch in scheduler_list:
        scheduler = get_scheduler(sch, num_inference_steps)

        for idx, batch  in enumerate(dl):
            noise_start = torch.randn(1, 8, 128, 128, 128)  
            # Prepare the noisy image
            final_scan = noise_start.clone().detach()
            final_scan = final_scan.cuda()

            label_condition = batch["seg"].cuda()
            segmentation = label_condition[0][2]
            no_contrast_tensor = label_condition[0][0]
            contrast_tensor = label_condition[0][1]

            label_condition =label_condition[:,0:2,:,:,:]
            print(f"label_condition: {label_condition.shape}")


            # create input model
            resize = Resize((128, 128, 128), size_mode='all', mode="nearest", align_corners=None, anti_aliasing=False, anti_aliasing_sigma=None, dtype=torch.float32, lazy=False)
            label_cond_down = resize(label_condition[0]).unsqueeze(0)
            input_model = torch.cat((final_scan, label_cond_down), dim=1)
            input_model = input_model.cuda()

            # Start the reverse process (denoising from noise)
            for timestep in tqdm(scheduler.timesteps, desc="Processing timesteps"):
                # Get the current timestep's noise
                t = torch.tensor([timestep] * final_scan.shape[0])
                t = t.cuda()
                # Perform one step of denoising
                with torch.no_grad():
                    model_kwargs = {}
                    noise_pred = model(input_model, timesteps=t, label_condition=label_condition, **model_kwargs)
                    # Update the noisy_latents (reverse the noise process)
                    final_scan = scheduler.step(model_output=noise_pred, timestep=timestep, sample=final_scan)
                    final_scan = final_scan['prev_sample']
                    input_model = torch.cat((final_scan, label_cond_down), dim=1)
            B, C, D, H, W = final_scan.size()
            final_scan = idwt(final_scan[:, 0, :, :, :].view(B, 1, H, W, D) * 3.,
                        final_scan[:, 1, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 2, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 3, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 4, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 5, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 6, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 7, :, :, :].view(B, 1, H, W, D))
            # Assuming final_image is a PyTorch tensor
            # Convert the final_image tensor to a NumPy array if it's a tensor
            final_image_np = final_scan.squeeze().cpu().numpy()  # Remove the channel dim and move to CPU

            if th.sum(contrast_tensor) != 0:
                ending_name = "_Contrast"
                cube_coords = th.nonzero(contrast_tensor) 
                min_coords = cube_coords.min(dim=0)[0]  # Minimum x, y, z coordinates
                max_coords = cube_coords.max(dim=0)[0]  # Maximum x, y, z coordinates
            else:
                ending_name = "out_contrast"
                cube_coords = th.nonzero(no_contrast_tensor) 
                min_coords = cube_coords.min(dim=0)[0]  # Minimum x, y, z coordinates
                max_coords = cube_coords.max(dim=0)[0]  # Maximum x, y, z coordinates
            case_name = "generated"
            synth_ct_scan_output = os.path.join(out_path, f'CT_{ending_name}_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            sample_denorm = np.clip(final_image_np, a_min=-1, a_max=1) # remove very high and low values

            sample_denorm = rescale_array(
                            arr=sample_denorm, 
                            minv=int(clip_min), 
                            maxv=int(clip_max)
                            )
            # Cropping the output of the model considering the ROI
            x_min, y_min, z_min = min_coords
            x_max, y_max, z_max = max_coords
            sample_denorm_corrected = sample_denorm#[x_min:x_max, y_min:y_max, z_min:z_max]
            # Create a NIfTI image from the NumPy array
            nii_image = nib.Nifti1Image(sample_denorm_corrected, affine=np.eye(4))  # Identity affine for simplicity

            # Save the NIfTI image as a .nii.gz file
            nib.save(nii_image, synth_ct_scan_output)

            nii_image = nib.Nifti1Image(segmentation.cpu().numpy(), affine=np.eye(4))  # Identity affine for simplicity
            seg_ct_scan_output = os.path.join(out_path, f'label_{ending_name}_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            nib.save(nii_image, seg_ct_scan_output)
            if idx+1 == n:
                break
            
           
                

In [None]:
# Fixed for CT
image_key = "image"
label_key = "seg"
image_size = (256, 256, 256)
data_split_json = "../../HnN_cancer_data/HnN_cancer_data_1_1_1_256_256_256/data_split.json"
base_dir = "../../HnN_cancer_data/HnN_cancer_data_1_1_1_256_256_256"
full_background = False
no_seg = False

# To change
clip_min = -200
clip_max = 200
in_channels = 10
label_cond_in_channels = 0
use_label_cond_conv = False
pretrained_weights_path = '../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/runs/hnn_CT_concat_cond__DA_tumorW_0_28_11_2024_14:15:39/checkpoints/hnn_1360000.pt'  # Specify the correct path
    
model = get_model(in_channels=in_channels, 
                  label_cond_in_channels=label_cond_in_channels, 
                  use_label_cond_conv=use_label_cond_conv,
                  pretrained_weights_path=pretrained_weights_path)
model.eval()
model.cuda()

dl, ds, data_set = get_loader(image_key=image_key, 
                              label_key=label_key, 
                              clip_min=clip_min, 
                              clip_max=clip_max, 
                              image_size=image_size, 
                              no_seg=no_seg, 
                              full_background=full_background, 
                              data_split_json=data_split_json, 
                              base_dir=base_dir)
print("Loaded model and data loader")

# Control inference parameters
scheduler_list = ["DPM++_2M", "DPM++_2M_Karras", "DPM++_2M_SDE", "DPM++_2M_SDE_Karras"]
n = 1
num_inference_steps = 100
out_path = "../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/notebooks/trash/hnn_CT_concat_cond__DA_tumorW_0_28_11_2024_14:15:39/"
os.makedirs(out_path, exist_ok=True)
run_inference(model=model,
              scheduler_list=scheduler_list,
               n=n, 
               num_inference_steps=num_inference_steps, 
               clip_min=clip_min,
               clip_max=clip_max, 
               out_path=out_path) 

#### hnn_CT_concat_cond__DA_tumorW_0_28_11_2024_14:18:20
* HU between -1000 and 1000. DA ROI. downsampled ROI as condition.

In [None]:
def run_inference(model, scheduler_list, n, num_inference_steps, clip_min, clip_max, out_path):
    model.cuda()
    for sch in scheduler_list:
        scheduler = get_scheduler(sch, num_inference_steps)

        for idx, batch  in enumerate(dl):
            noise_start = torch.randn(1, 8, 128, 128, 128)  
            # Prepare the noisy image
            final_scan = noise_start.clone().detach()
            final_scan = final_scan.cuda()

            label_condition = batch["seg"].cuda()
            segmentation = label_condition[0][2]
            no_contrast_tensor = label_condition[0][0]
            contrast_tensor = label_condition[0][1]

            label_condition =label_condition[:,0:2,:,:,:]
            print(f"label_condition: {label_condition.shape}")


            # create input model
            resize = Resize((128, 128, 128), size_mode='all', mode="nearest", align_corners=None, anti_aliasing=False, anti_aliasing_sigma=None, dtype=torch.float32, lazy=False)
            label_cond_down = resize(label_condition[0]).unsqueeze(0)
            input_model = torch.cat((final_scan, label_cond_down), dim=1)
            input_model = input_model.cuda()

            # Start the reverse process (denoising from noise)
            for timestep in tqdm(scheduler.timesteps, desc="Processing timesteps"):
                # Get the current timestep's noise
                t = torch.tensor([timestep] * final_scan.shape[0])
                t = t.cuda()
                # Perform one step of denoising
                with torch.no_grad():
                    model_kwargs = {}
                    noise_pred = model(input_model, timesteps=t, label_condition=label_condition, **model_kwargs)
                    # Update the noisy_latents (reverse the noise process)
                    final_scan = scheduler.step(model_output=noise_pred, timestep=timestep, sample=final_scan)
                    final_scan = final_scan['prev_sample']
                    input_model = torch.cat((final_scan, label_cond_down), dim=1)
            B, C, D, H, W = final_scan.size()
            final_scan = idwt(final_scan[:, 0, :, :, :].view(B, 1, H, W, D) * 3.,
                        final_scan[:, 1, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 2, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 3, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 4, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 5, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 6, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 7, :, :, :].view(B, 1, H, W, D))
            # Assuming final_image is a PyTorch tensor
            # Convert the final_image tensor to a NumPy array if it's a tensor
            final_image_np = final_scan.squeeze().cpu().numpy()  # Remove the channel dim and move to CPU

            if th.sum(contrast_tensor) != 0:
                ending_name = "_Contrast"
                cube_coords = th.nonzero(contrast_tensor) 
                min_coords = cube_coords.min(dim=0)[0]  # Minimum x, y, z coordinates
                max_coords = cube_coords.max(dim=0)[0]  # Maximum x, y, z coordinates
            else:
                ending_name = "out_contrast"
                cube_coords = th.nonzero(no_contrast_tensor) 
                min_coords = cube_coords.min(dim=0)[0]  # Minimum x, y, z coordinates
                max_coords = cube_coords.max(dim=0)[0]  # Maximum x, y, z coordinates
            case_name = "generated"
            synth_ct_scan_output = os.path.join(out_path, f'CT_{ending_name}_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            sample_denorm = np.clip(final_image_np, a_min=-1, a_max=1) # remove very high and low values

            sample_denorm = rescale_array(
                            arr=sample_denorm, 
                            minv=int(clip_min), 
                            maxv=int(clip_max)
                            )
            # Cropping the output of the model considering the ROI
            x_min, y_min, z_min = min_coords
            x_max, y_max, z_max = max_coords
            sample_denorm_corrected = sample_denorm#[x_min:x_max, y_min:y_max, z_min:z_max]
            # Create a NIfTI image from the NumPy array
            nii_image = nib.Nifti1Image(sample_denorm_corrected, affine=np.eye(4))  # Identity affine for simplicity

            # Save the NIfTI image as a .nii.gz file
            nib.save(nii_image, synth_ct_scan_output)

            nii_image = nib.Nifti1Image(segmentation.cpu().numpy(), affine=np.eye(4))  # Identity affine for simplicity
            seg_ct_scan_output = os.path.join(out_path, f'label_{ending_name}_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            nib.save(nii_image, seg_ct_scan_output)
            if idx+1 == n:
                break
            
           
                

In [None]:
# Fixed for CT
image_key = "image"
label_key = "seg"
image_size = (256, 256, 256)
data_split_json = "../../HnN_cancer_data/HnN_cancer_data_1_1_1_256_256_256/data_split.json"
base_dir = "../../HnN_cancer_data/HnN_cancer_data_1_1_1_256_256_256"
full_background = False
no_seg = False

# To change
clip_min = -1000
clip_max = 1000
in_channels = 10
label_cond_in_channels = 0
use_label_cond_conv = False
pretrained_weights_path = '../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/runs/hnn_CT_concat_cond__DA_tumorW_0_28_11_2024_14:18:20/checkpoints/hnn_1355000.pt'  # Specify the correct path
    
model = get_model(in_channels=in_channels, 
                  label_cond_in_channels=label_cond_in_channels, 
                  use_label_cond_conv=use_label_cond_conv,
                  pretrained_weights_path=pretrained_weights_path)
model.eval()
model.cuda()

dl, ds, data_set = get_loader(image_key=image_key, 
                              label_key=label_key, 
                              clip_min=clip_min, 
                              clip_max=clip_max, 
                              image_size=image_size, 
                              no_seg=no_seg, 
                              full_background=full_background, 
                              data_split_json=data_split_json, 
                              base_dir=base_dir)
print("Loaded model and data loader")

# Control inference parameters
scheduler_list = ["DPM++_2M", "DPM++_2M_Karras", "DPM++_2M_SDE", "DPM++_2M_SDE_Karras"]
n = 1
num_inference_steps = 100
out_path = "../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/notebooks/trash/hnn_CT_concat_cond__DA_tumorW_0_28_11_2024_14:18:20/"
os.makedirs(out_path, exist_ok=True)
run_inference(model=model,
              scheduler_list=scheduler_list,
               n=n, 
               num_inference_steps=num_inference_steps, 
               clip_min=clip_min,
               clip_max=clip_max, 
               out_path=out_path) 

#### hnn_CT_concat_cond_data_augment_27_11_2024 - Retrained model from -200 and 200
* HU between -1000 and 1000. DA ROI. downsampled ROI as condition. 

In [None]:
def run_inference(model, scheduler_list, n, num_inference_steps, clip_min, clip_max, out_path):
    model.cuda()
    for sch in scheduler_list:
        for idx, batch  in enumerate(dl):
            scheduler = get_scheduler(sch, num_inference_steps)
            noise_start = torch.randn(1, 8, 128, 128, 128)  
            # Prepare the noisy image
            final_scan = noise_start.clone().detach()
            final_scan = final_scan.cuda()

            label_condition = batch["seg"].cuda()
            segmentation = label_condition[0][2]
            no_contrast_tensor = label_condition[0][0]
            contrast_tensor = label_condition[0][1]

            label_condition =label_condition[:,0:2,:,:,:]
            print(f"label_condition: {label_condition.shape}")


            # create input model
            resize = Resize((128, 128, 128), size_mode='all', mode="nearest", align_corners=None, anti_aliasing=False, anti_aliasing_sigma=None, dtype=torch.float32, lazy=False)
            label_cond_down = resize(label_condition[0]).unsqueeze(0)
            input_model = torch.cat((final_scan, label_cond_down), dim=1)
            input_model = input_model.cuda()

            # Start the reverse process (denoising from noise)
            for timestep in tqdm(scheduler.timesteps, desc="Processing timesteps"):
                # Get the current timestep's noise
                t = torch.tensor([timestep] * final_scan.shape[0])
                t = t.cuda()
                # Perform one step of denoising
                with torch.no_grad():
                    model_kwargs = {}
                    noise_pred = model(input_model, timesteps=t, label_condition=label_condition, **model_kwargs)
                    # Update the noisy_latents (reverse the noise process)
                    final_scan = scheduler.step(model_output=noise_pred, timestep=timestep, sample=final_scan)
                    final_scan = final_scan['prev_sample']
                    input_model = torch.cat((final_scan, label_cond_down), dim=1)
            B, C, D, H, W = final_scan.size()
            final_scan = idwt(final_scan[:, 0, :, :, :].view(B, 1, H, W, D) * 3.,
                        final_scan[:, 1, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 2, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 3, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 4, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 5, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 6, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 7, :, :, :].view(B, 1, H, W, D))
            # Assuming final_image is a PyTorch tensor
            # Convert the final_image tensor to a NumPy array if it's a tensor
            final_image_np = final_scan.squeeze().cpu().numpy()  # Remove the channel dim and move to CPU

            if th.sum(contrast_tensor) != 0:
                ending_name = "_Contrast"
                cube_coords = th.nonzero(contrast_tensor) 
                min_coords = cube_coords.min(dim=0)[0]  # Minimum x, y, z coordinates
                max_coords = cube_coords.max(dim=0)[0]  # Maximum x, y, z coordinates
            else:
                ending_name = "out_contrast"
                cube_coords = th.nonzero(no_contrast_tensor) 
                min_coords = cube_coords.min(dim=0)[0]  # Minimum x, y, z coordinates
                max_coords = cube_coords.max(dim=0)[0]  # Maximum x, y, z coordinates
            case_name = "generated"
            synth_ct_scan_output = os.path.join(out_path, f'CT_{ending_name}_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            sample_denorm = np.clip(final_image_np, a_min=-1, a_max=1) # remove very high and low values

            sample_denorm = rescale_array(
                            arr=sample_denorm, 
                            minv=int(clip_min), 
                            maxv=int(clip_max)
                            )
            # Cropping the output of the model considering the ROI
            x_min, y_min, z_min = min_coords
            x_max, y_max, z_max = max_coords
            sample_denorm_corrected = sample_denorm#[x_min:x_max, y_min:y_max, z_min:z_max]
            # Create a NIfTI image from the NumPy array
            nii_image = nib.Nifti1Image(sample_denorm_corrected, affine=np.eye(4))  # Identity affine for simplicity

            # Save the NIfTI image as a .nii.gz file
            nib.save(nii_image, synth_ct_scan_output)

            nii_image = nib.Nifti1Image(segmentation.cpu().numpy(), affine=np.eye(4))  # Identity affine for simplicity
            seg_ct_scan_output = os.path.join(out_path, f'label_{ending_name}_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            nib.save(nii_image, seg_ct_scan_output)
            if idx+1 == n:
                break
            
           
                

In [None]:
# Fixed for CT
image_key = "image"
label_key = "seg"
image_size = (256, 256, 256)
data_split_json = "../../HnN_cancer_data/HnN_cancer_data_1_1_1_256_256_256/data_split.json"
base_dir = "../../HnN_cancer_data/HnN_cancer_data_1_1_1_256_256_256"
full_background = False
no_seg = False

# To change
clip_min = -1000
clip_max = 1000
in_channels = 10
label_cond_in_channels = 0
use_label_cond_conv = False
pretrained_weights_path = '../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/runs/hnn_CT_concat_cond__data_augment_27_11_2024_17:23:26/checkpoints/hnn_7380000.pt'  # Specify the correct path
    
model = get_model(in_channels=in_channels, 
                  label_cond_in_channels=label_cond_in_channels, 
                  use_label_cond_conv=use_label_cond_conv,
                  pretrained_weights_path=pretrained_weights_path)
model.eval()
model.cuda()

dl, ds, data_set = get_loader(image_key=image_key, 
                              label_key=label_key, 
                              clip_min=clip_min, 
                              clip_max=clip_max, 
                              image_size=image_size, 
                              no_seg=no_seg, 
                              full_background=full_background, 
                              data_split_json=data_split_json, 
                              base_dir=base_dir)
print("Loaded model and data loader")

# Control inference parameters
scheduler_list = ["DPM++_2M", "DPM++_2M_Karras", "DPM++_2M_SDE", "DPM++_2M_SDE_Karras"]
n = 5
num_inference_steps = 100
out_path = "../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/notebooks/trash/hnn_CT_concat_cond__data_augment_27_11_2024_17:23:26/"
os.makedirs(out_path, exist_ok=True)
run_inference(model=model,
              scheduler_list=scheduler_list,
               n=n, 
               num_inference_steps=num_inference_steps, 
               clip_min=clip_min,
               clip_max=clip_max, 
               out_path=out_path) 

---
## Cropped CT scans - Inpainting

In [None]:
from utils.convert_head_n_neck_cancer import ConvertHeadNNeckCancerd as LABEL_TRANSFORM
from utils.crop_scan_center_in_tumour import CropScanCenterInTumour
from monai.transforms import ToTensord
from monai.data import CSVDataset, CacheDataset, DataLoader, load_decathlon_datalist
from monai.data.utils import pad_list_data_collate

def get_loader(scan_name, col_names, col_types, use_dilation, clip_min, clip_max, CSV_PATH):
    train_transforms = [
                        LoadImaged(keys=[scan_name, 'label'], image_only=False),
                        EnsureChannelFirstd(keys=[scan_name, "label"]),
                        EnsureTyped(keys=[scan_name, "label"]),
                        CopyItemsd(keys=[scan_name], names=[f"{scan_name}_origin"]),
                        ScaleIntensityRanged(keys=[f"{scan_name}_origin"], a_min=int(clip_min), a_max=int(clip_max), b_min=int(clip_min), b_max=int(clip_max), clip=True),
                        LABEL_TRANSFORM(keys="label"),
                    ] 
    train_transforms.append(
        CropScanCenterInTumour(keys=scan_name, dilation=use_dilation, translate_range=None)
        )       
    train_transforms.append(
        ScaleIntensityRanged(keys=[scan_name], a_min=int(clip_min), a_max=int(clip_max), b_min=-1.0, b_max=1.0, clip=True)
    )
    train_transforms.append(ToTensord(keys=[scan_name, 'no_contrast_tensor', 'contrast_tensor', 'scan_volume_crop', 'scan_volume_crop_pad', 'label', 'label_crop_pad', 'center_x', 'center_y', 'center_z', 'x_extreme_min', 'x_extreme_max', 'y_extreme_min', 'y_extreme_max', 'z_extreme_min', 'z_extreme_max', 'x_size', 'y_size', 'z_size']))

    train_CSVdataset = CSVDataset(src=CSV_PATH, col_names=col_names, col_types=col_types) 
    train_CSVdataset = CacheDataset(train_CSVdataset, transform=train_transforms, cache_rate=0, num_workers=8, progress=True)  
    train_loader = DataLoader(train_CSVdataset, batch_size=1, num_workers=4, drop_last=True, shuffle=False, collate_fn=pad_list_data_collate)
    return train_loader, train_CSVdataset

#### hnn_tumour_inpainting_CT_default_tumour_inpainting__data_augment_20_11_2024_11:07:31
* HU between -200 and 200. tumour weight 10.

In [None]:
def run_inference(train_loader, model, scheduler_list, n, num_inference_steps, clip_min, clip_max, out_path):
    model.cuda()
    for sch in scheduler_list:
        scheduler = get_scheduler(sch, num_inference_steps)

        for idx, batch  in enumerate(train_loader):
            noise_start = torch.randn(1, 1, 128, 128, 128)  
            # Prepare the noisy image
            final_scan = noise_start.clone().detach()
            final_scan = final_scan.cuda()

            segmentation = batch["label_crop_pad"].cuda()
            no_contrast_tensor = batch["no_contrast_tensor"].cuda()
            contrast_tensor = batch["contrast_tensor"].cuda()
            label_condition = torch.cat((no_contrast_tensor, contrast_tensor, segmentation), dim=1)

            input_model = torch.cat((final_scan, label_condition), dim=1)
            input_model = input_model.cuda()

            # Start the reverse process (denoising from noise)
            for timestep in tqdm(scheduler.timesteps, desc="Processing timesteps"):
                # Get the current timestep's noise
                t = torch.tensor([timestep] * final_scan.shape[0])
                t = t.cuda()
                # Perform one step of denoising
                with torch.no_grad():
                    model_kwargs = {}
                    noise_pred = model(input_model, timesteps=t, label_condition=label_condition, **model_kwargs)
                    # Update the noisy_latents (reverse the noise process)
                    final_scan = scheduler.step(model_output=noise_pred, timestep=timestep, sample=final_scan)
                    final_scan = final_scan['prev_sample']
                    input_model = torch.cat((final_scan, label_condition), dim=1)

            # Assuming final_image is a PyTorch tensor
            # Convert the final_image tensor to a NumPy array if it's a tensor
            final_image_np = final_scan.squeeze().cpu().numpy()  # Remove the channel dim and move to CPU

            if th.sum(contrast_tensor) != 0:
                ending_name = "_Contrast"
            else:
                ending_name = "out_contrast"
                
            synth_ct_scan_output = os.path.join(out_path, f'CT_{ending_name}_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            sample_denorm = np.clip(final_image_np, a_min=-1, a_max=1) # remove very high and low values

            sample_denorm = rescale_array(
                            arr=sample_denorm, 
                            minv=int(clip_min), 
                            maxv=int(clip_max)
                            )

            sample_denorm_corrected = sample_denorm

            nii_image = nib.Nifti1Image(sample_denorm_corrected, affine=np.eye(4))  # Identity affine for simplicity
            nib.save(nii_image, synth_ct_scan_output)

            nii_image = nib.Nifti1Image(segmentation.cpu().numpy().astype(float)[0][0], affine=np.eye(4))  # Identity affine for simplicity
            seg_ct_scan_output = os.path.join(out_path, f'label_{ending_name}_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            nib.save(nii_image, seg_ct_scan_output)
            if idx+1 == n:
                break
            
           
                

In [None]:
clip_min = -200
clip_max = 200
in_channels = 4
out_channels = 1
label_cond_in_channels = 0
use_label_cond_conv = False
pretrained_weights_path = '../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/runs/hnn_tumour_inpainting_CT_default_tumour_inpainting__data_augment_20_11_2024_11:07:31/checkpoints/hnn_tumour_inpainting_2000000.pt'  # Specify the correct path
channel_mult=[1, 2, 2, 4, 4]

model = get_model(in_channels=in_channels, 
                  out_channels=out_channels,
                  channel_mult=channel_mult,
                  label_cond_in_channels=label_cond_in_channels, 
                  use_label_cond_conv=use_label_cond_conv,
                  pretrained_weights_path=pretrained_weights_path)
model.eval()
model.cuda()

scan_name = "scan_ct"
col_names = ['scan_ct', 'label', 'center_x', 'center_y', 'center_z', 'x_extreme_min', 'x_extreme_max', 'y_extreme_min', 'y_extreme_max', 'z_extreme_min', 'z_extreme_max', 'x_size', 'y_size', 'z_size', 'contrast']
col_types= {'center_x': {'type': int}, 'center_y': {'type': int}, 'center_z': {'type': int}, 'x_extreme_min': {'type': int}, 'x_extreme_max': {'type': int}, 'y_extreme_min': {'type': int}, 'y_extreme_max': {'type': int}, 'z_extreme_min': {'type': int}, 'z_extreme_max': {'type': int}, 'x_size': {'type': int}, 'y_size': {'type': int}, 'z_size': {'type': int}, 'contrast': {'type': int}}  
use_dilation = False
CSV_PATH = "../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/utils/hnn.csv"

train_loader, train_CSVdataset = get_loader(scan_name=scan_name,
           col_names=col_names,
           col_types=col_types, 
           use_dilation=use_dilation, 
           clip_min=clip_min, 
           clip_max=clip_max, 
           CSV_PATH=CSV_PATH)

print("Loaded model and data loader")

# Control inference parameters
scheduler_list = ["DPM++_2M", "DPM++_2M_Karras", "DPM++_2M_SDE", "DPM++_2M_SDE_Karras"]
n = 1
num_inference_steps = 100
out_path = "../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/notebooks/trash/hnn_tumour_inpainting_CT_default_tumour_inpainting__data_augment_20_11_2024_11:07:31"
os.makedirs(out_path, exist_ok=True)

run_inference(train_loader=train_loader,
              model=model,
              scheduler_list=scheduler_list,
               n=n, 
               num_inference_steps=num_inference_steps, 
               clip_min=clip_min,
               clip_max=clip_max, 
               out_path=out_path)
 

#### hnn_tumour_inpainting_CT_default_tumour_inpainting__DA_tumorW_10_28_11_2024_14:37:59
* HU between -1000 and 1000. tumour weight 10. 

In [None]:
def run_inference(train_loader, model, scheduler_list, n, num_inference_steps, clip_min, clip_max, out_path):
    model.cuda()
    for sch in scheduler_list:
        scheduler = get_scheduler(sch, num_inference_steps)

        for idx, batch  in enumerate(train_loader):
            noise_start = torch.randn(1, 1, 128, 128, 128)  
            # Prepare the noisy image
            final_scan = noise_start.clone().detach()
            final_scan = final_scan.cuda()

            segmentation = batch["label_crop_pad"].cuda()
            no_contrast_tensor = batch["no_contrast_tensor"].cuda()
            contrast_tensor = batch["contrast_tensor"].cuda()
            label_condition = torch.cat((no_contrast_tensor, contrast_tensor, segmentation), dim=1)

            input_model = torch.cat((final_scan, label_condition), dim=1)
            input_model = input_model.cuda()

            # Start the reverse process (denoising from noise)
            for timestep in tqdm(scheduler.timesteps, desc="Processing timesteps"):
                # Get the current timestep's noise
                t = torch.tensor([timestep] * final_scan.shape[0])
                t = t.cuda()
                # Perform one step of denoising
                with torch.no_grad():
                    model_kwargs = {}
                    noise_pred = model(input_model, timesteps=t, label_condition=label_condition, **model_kwargs)
                    # Update the noisy_latents (reverse the noise process)
                    final_scan = scheduler.step(model_output=noise_pred, timestep=timestep, sample=final_scan)
                    final_scan = final_scan['prev_sample']
                    input_model = torch.cat((final_scan, label_condition), dim=1)

            # Assuming final_image is a PyTorch tensor
            # Convert the final_image tensor to a NumPy array if it's a tensor
            final_image_np = final_scan.squeeze().cpu().numpy()  # Remove the channel dim and move to CPU

            if th.sum(contrast_tensor) != 0:
                ending_name = "_Contrast"
            else:
                ending_name = "out_contrast"
                
            synth_ct_scan_output = os.path.join(out_path, f'CT_{ending_name}_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            sample_denorm = np.clip(final_image_np, a_min=-1, a_max=1) # remove very high and low values

            sample_denorm = rescale_array(
                            arr=sample_denorm, 
                            minv=int(clip_min), 
                            maxv=int(clip_max)
                            )

            sample_denorm_corrected = sample_denorm

            nii_image = nib.Nifti1Image(sample_denorm_corrected, affine=np.eye(4))  # Identity affine for simplicity
            nib.save(nii_image, synth_ct_scan_output)

            nii_image = nib.Nifti1Image(segmentation.cpu().numpy().astype(float)[0][0], affine=np.eye(4))  # Identity affine for simplicity
            seg_ct_scan_output = os.path.join(out_path, f'label_{ending_name}_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            nib.save(nii_image, seg_ct_scan_output)
            if idx+1 == n:
                break
            
           
                

In [None]:

clip_min = -1000
clip_max = 1000
in_channels = 4
out_channels = 1
label_cond_in_channels = 0
use_label_cond_conv = False
pretrained_weights_path = '../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/runs/hnn_tumour_inpainting_CT_default_tumour_inpainting__DA_tumorW_10_28_11_2024_14:37:59/checkpoints/hnn_tumour_inpainting_1310000.pt'  # Specify the correct path
channel_mult=[1, 2, 2, 4, 4]

model = get_model(in_channels=in_channels, 
                  out_channels=out_channels,
                  channel_mult=channel_mult,
                  label_cond_in_channels=label_cond_in_channels, 
                  use_label_cond_conv=use_label_cond_conv,
                  pretrained_weights_path=pretrained_weights_path)
model.eval()
model.cuda()

scan_name = "scan_ct"
col_names = ['scan_ct', 'label', 'center_x', 'center_y', 'center_z', 'x_extreme_min', 'x_extreme_max', 'y_extreme_min', 'y_extreme_max', 'z_extreme_min', 'z_extreme_max', 'x_size', 'y_size', 'z_size', 'contrast']
col_types= {'center_x': {'type': int}, 'center_y': {'type': int}, 'center_z': {'type': int}, 'x_extreme_min': {'type': int}, 'x_extreme_max': {'type': int}, 'y_extreme_min': {'type': int}, 'y_extreme_max': {'type': int}, 'z_extreme_min': {'type': int}, 'z_extreme_max': {'type': int}, 'x_size': {'type': int}, 'y_size': {'type': int}, 'z_size': {'type': int}, 'contrast': {'type': int}}  
use_dilation = False
CSV_PATH = "../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/utils/hnn.csv"

train_loader, train_CSVdataset = get_loader(scan_name=scan_name,
           col_names=col_names,
           col_types=col_types, 
           use_dilation=use_dilation, 
           clip_min=clip_min, 
           clip_max=clip_max, 
           CSV_PATH=CSV_PATH)

print("Loaded model and data loader")

# Control inference parameters
scheduler_list = ["DPM++_2M", "DPM++_2M_Karras", "DPM++_2M_SDE", "DPM++_2M_SDE_Karras"]
n = 1
num_inference_steps = 100
out_path = "../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/notebooks/trash/hnn_tumour_inpainting_CT_default_tumour_inpainting__DA_tumorW_10_28_11_2024_14:37:59"
os.makedirs(out_path, exist_ok=True)

run_inference(train_loader=train_loader,
              model=model,
              scheduler_list=scheduler_list,
               n=n, 
               num_inference_steps=num_inference_steps, 
               clip_min=clip_min,
               clip_max=clip_max, 
               out_path=out_path)
 

#### Test blur mask

In [None]:
import torch.nn.functional as F

def blur_mask_3d(mask, label_crop_pad, blur_factor, blur_type):
    """
    Apply Gaussian blur to a 3D mask.
    
    Args:
        mask (torch.Tensor): The mask tensor of shape (1, 1, H, W, D).
        blur_factor (int): Kernel size for the Gaussian blur. Should be odd.

    Returns:
        torch.Tensor: Blurred mask.
    """
    # Ensure blur_factor is odd
    if blur_factor % 2 == 0:
        blur_factor += 1
    
    # Create a 3D Gaussian kernel
    sigma = blur_factor / 12.0  # Rule of thumb for Gaussian kernel
    x = torch.linspace(-3, 3, blur_factor)
    kernel_1d = torch.exp(-0.5 * x**2 / sigma**2)
    kernel_1d = kernel_1d / kernel_1d.sum()  # Normalize
    
    # Create 3D kernel from 1D kernels
    kernel_3d = kernel_1d[:, None, None] * kernel_1d[None, :, None] * kernel_1d[None, None, :]
    kernel_3d = kernel_3d.to(mask.device)
    kernel_3d = kernel_3d.unsqueeze(0).unsqueeze(0)  # Add batch and channel dims
    
    # Pad the mask for convolution
    padding = blur_factor // 2
    mask_padded = F.pad(mask, (padding, padding, padding, padding, padding, padding), mode="replicate")
    
    # Apply convolution
    blurred_mask = F.conv3d(mask_padded, kernel_3d, padding=0)
    if blur_type=="edge_blur":
        blurred_mask[label_crop_pad==1] = label_crop_pad[label_crop_pad==1]
    elif blur_type=="full_blur":
        blurred_mask=blurred_mask
    else:
        raise ValueError(f"blur_type must be edge_blur or full_blur not {blurred_mask}")
    return blurred_mask

In [None]:
from scipy.ndimage import center_of_mass
from scipy.ndimage import binary_dilation 
import torch
import numpy as np
from monai.transforms import Compose, LoadImage, CropForeground, EnsureChannelFirst, ResizeWithPadOrCrop, ScaleIntensityRange
import nibabel as nib
transforms = [
    LoadImage(image_only=True),
    EnsureChannelFirst()
    ]
for i in range(0,40,5):
    apply_transforms = Compose(transforms)
    segmentation = apply_transforms("../../HnN_cancer_data/HnN_cancer_data_1_1_1_256_256_256/seg/anderson_89f9e1d8eeae55b7ab93ac1fe0cf3801.nii.gz")[0].numpy()

    random_voxel_indices = np.array(center_of_mass(segmentation)).astype(int)
    random_x, random_y, random_z = int(random_voxel_indices[0]), int(random_voxel_indices[1]), int(random_voxel_indices[2])
    label_crop_pad = segmentation[
            random_x-64:random_x+64,
            random_y-64:random_y+64,
            random_z-64:random_z+64
            ]

    label_crop_pad_dillated = label_crop_pad
    structuring_element = np.ones((3, 3, 3), dtype=str)
    dilated_mask = binary_dilation(label_crop_pad_dillated, structure=structuring_element, iterations=5)
    dilated_mask = torch.from_numpy(dilated_mask).float()
    dilated_mask = dilated_mask.unsqueeze(dim=0).unsqueeze(dim=0)


    blurred_mask = blur_mask_3d(dilated_mask, label_crop_pad, blur_factor=i, blur_type="full_blur")
    blurred_mask.shape


    if i==0:
        img = nib.Nifti1Image(label_crop_pad, affine=np.eye(4)) 
        nib.save(img=img, filename=f"../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/trash/real_mask")
    else:
        img = nib.Nifti1Image(blurred_mask[0][0].numpy(), affine=np.eye(4)) 
        nib.save(img=img, filename=f"../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/trash/blurred_mask_{i}")

---
# Prediction for Brain tumour scans

In [None]:
from utils.data_loader_utils import ConvertToMultiChannelBasedOnBratsClasses2023d, QuantileAndScaleIntensityd
def get_brats_loader(in_keys, all_image_keys, label_key, base_dir, data_split_json, no_seg, image_size):
    train_transforms = Compose(
            [
                LoadImaged(keys=in_keys, meta_key_postfix="meta_dict", image_only=False),
                EnsureChannelFirstd(keys=in_keys),
                EnsureTyped(keys=in_keys, dtype=torch.float32),
                Orientationd(keys=in_keys, axcodes="RAS"),
                ResizeWithPadOrCropd(
                        keys=in_keys,
                        spatial_size=image_size,
                        mode="constant",
                        value=0
                    ),
                QuantileAndScaleIntensityd(keys=all_image_keys), # a_min=-1, a_max=1),
                ConvertToMultiChannelBasedOnBratsClasses2023d(
                    keys=[label_key], no_seg=no_seg,
                ),
                EnsureTyped(keys=in_keys, dtype=torch.float32)
            ]
        )

    data_set = load_decathlon_datalist(
                data_split_json,
                is_segmentation=True,
                data_list_key="training",
                base_dir=base_dir,
            )

    print(f"Training cases: {len(data_set)}")

    print(data_set[-1:])
    # Creating traing dataset
    ds = CacheDataset( 
        data=data_set,
        transform=train_transforms,
        cache_rate=0, 
        copy_cache=False,
        progress=True,
        num_workers=4,
    )

    # Creating data loader
    dl = DataLoader(
        ds,
        batch_size=1,
        num_workers=4,
        pin_memory=torch.cuda.is_available(),
        shuffle=False, 
        #collate_fn=no_collation,
        )
    return dl, ds, data_set

#### c_brats_t1c_conv_before_concat__tumorW_0_28_11_2024_13:02:05 
* tumour weight 0. Three channel segmentation as condition, feeded first to a conv layer.


In [None]:
def run_inference(model, scheduler_list, n, num_inference_steps, out_path):
    model.cuda()
    for sch in scheduler_list:
        scheduler = get_scheduler(sch, num_inference_steps)

        for idx, batch  in enumerate(dl):
            case_path = batch['t1c_meta_dict']['filename_or_obj'][0]
            print(f"case_path: {case_path}")
            
            noise_start = torch.randn(1, 8, 128, 128, 128)  
            # Prepare the noisy image
            final_scan = noise_start.clone().detach()
            final_scan = final_scan.cuda()
            input_model = final_scan

            label_condition = batch["seg"].cuda()
            tumour_core = label_condition[0][0]
            whole_tumour = label_condition[0][1]
            enhancing_tumour = label_condition[0][2]



            # Start the reverse process (denoising from noise)
            for timestep in tqdm(scheduler.timesteps, desc="Processing timesteps"):
                # Get the current timestep's noise
                t = torch.tensor([timestep] * final_scan.shape[0])
                t = t.cuda()
                # Perform one step of denoising
                with torch.no_grad():
                    model_kwargs = {}
                    noise_pred = model(input_model, timesteps=t, label_condition=label_condition, **model_kwargs)
                    # Update the noisy_latents (reverse the noise process)
                    final_scan = scheduler.step(model_output=noise_pred, timestep=timestep, sample=final_scan)
                    final_scan = final_scan['prev_sample']
                    input_model = final_scan
            B, C, D, H, W = final_scan.size()
            final_scan = idwt(final_scan[:, 0, :, :, :].view(B, 1, H, W, D) * 3.,
                        final_scan[:, 1, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 2, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 3, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 4, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 5, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 6, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 7, :, :, :].view(B, 1, H, W, D))
            # Assuming final_image is a PyTorch tensor
            # Convert the final_image tensor to a NumPy array if it's a tensor
            final_image_np = final_scan.squeeze().cpu().numpy()  # Remove the channel dim and move to CPU

            affine, header = get_affine_and_header(case_path)
            data = nib.load(case_path).get_fdata()
            clip_min = np.min(data)
            clip_max = np.max(data)

            synth_ct_scan_output = os.path.join(out_path, f'MRI_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            sample_denorm = np.clip(final_image_np, a_min=-1, a_max=1) # remove very high and low values

            sample_denorm = rescale_array(
                            arr=sample_denorm, 
                            minv=int(clip_min), 
                            maxv=int(clip_max)
                            )

            sample_denorm_corrected = sample_denorm
            nii_image = nib.Nifti1Image(sample_denorm_corrected, affine=np.eye(4))  # Identity affine for simplicity
            nib.save(nii_image, synth_ct_scan_output)

            segmentation = torch.zeros_like(tumour_core)
            segmentation[whole_tumour==1] = 2
            segmentation[tumour_core==1] = 1
            segmentation[enhancing_tumour==1] = 3

            nii_image = nib.Nifti1Image(segmentation.cpu().numpy(), affine=np.eye(4))  # Identity affine for simplicity
            seg_ct_scan_output = os.path.join(out_path, f'label_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            nib.save(nii_image, seg_ct_scan_output)
            if idx+1 == n:
                break
            
           
                

In [None]:
# Fixed for CT
image_size = (256, 256, 256)
full_background = False
no_seg = False

# To change
in_keys = ['t1c', 'seg']
all_image_keys = ['t1c']
label_key = 'seg'
base_dir = "/projects/brats2023_a_f/BRAINTUMOUR/data/brats2023/ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData"
data_split_json =  os.path.join('/'.join(base_dir.split("/")[0:-1]), "BraTS2023_GLI_data_split.json")

in_channels = 32
label_cond_in_channels = 3
use_label_cond_conv = True
pretrained_weights_path = '../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/runs/c_brats_t1c_conv_before_concat__tumorW_0_28_11_2024_16:46:58/checkpoints/c_brats_1295000.pt'  # Specify the correct path
    
model = get_model(in_channels=in_channels, 
                  label_cond_in_channels=label_cond_in_channels, 
                  use_label_cond_conv=use_label_cond_conv,
                  pretrained_weights_path=pretrained_weights_path)
model.eval()
model.cuda()

dl, ds, data_set = get_brats_loader(in_keys=in_keys, 
                                    all_image_keys=all_image_keys, 
                                    label_key=label_key, 
                                    base_dir=base_dir, 
                                    data_split_json=data_split_json, 
                                    no_seg=no_seg, 
                                    image_size=image_size)
print("Loaded model and data loader")

# Control inference parameters
scheduler_list = ["DPM++_2M", "DPM++_2M_Karras", "DPM++_2M_SDE", "DPM++_2M_SDE_Karras"]
n = 1
num_inference_steps = 100
out_path = "../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/notebooks/trash/c_brats_t1c_conv_before_concat__tumorW_0_28_11_2024_16:46:58/"
os.makedirs(out_path, exist_ok=True)

run_inference(model=model,
              scheduler_list=scheduler_list,
               n=n, 
               num_inference_steps=num_inference_steps, 
               out_path=out_path) 

#### c_brats_t1c_concat_cond__tumorW_0_28_11_2024_16:45:44
* tumour weight 0. downsampled three channel segmentation as condition.


In [None]:
def run_inference(model, scheduler_list, n, num_inference_steps, out_path):
    model.cuda()
    for sch in scheduler_list:
        scheduler = get_scheduler(sch, num_inference_steps)

        for idx, batch  in enumerate(dl):
            case_path = batch['t1c_meta_dict']['filename_or_obj'][0]
            print(f"case_path: {case_path}")
            
            noise_start = torch.randn(1, 8, 128, 128, 128)  
            # Prepare the noisy image
            final_scan = noise_start.clone().detach()
            final_scan = final_scan.cuda()
          

            label_condition = batch["seg"].cuda()
            tumour_core = label_condition[0][0]
            whole_tumour = label_condition[0][1]
            enhancing_tumour = label_condition[0][2]

            # create input model
            resize = Resize((128, 128, 128), size_mode='all', mode="nearest", align_corners=None, anti_aliasing=False, anti_aliasing_sigma=None, dtype=torch.float32, lazy=False)
            label_cond_down = resize(label_condition[0]).unsqueeze(0)
            input_model = torch.cat((final_scan, label_cond_down), dim=1)
            input_model = input_model.cuda()



            # Start the reverse process (denoising from noise)
            for timestep in tqdm(scheduler.timesteps, desc="Processing timesteps"):
                # Get the current timestep's noise
                t = torch.tensor([timestep] * final_scan.shape[0])
                t = t.cuda()
                # Perform one step of denoising
                with torch.no_grad():
                    model_kwargs = {}
                    noise_pred = model(input_model, timesteps=t, label_condition=label_condition, **model_kwargs)
                    # Update the noisy_latents (reverse the noise process)
                    final_scan = scheduler.step(model_output=noise_pred, timestep=timestep, sample=final_scan)
                    final_scan = final_scan['prev_sample']
                    input_model = torch.cat((final_scan, label_cond_down), dim=1)
            B, C, D, H, W = final_scan.size()
            final_scan = idwt(final_scan[:, 0, :, :, :].view(B, 1, H, W, D) * 3.,
                        final_scan[:, 1, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 2, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 3, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 4, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 5, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 6, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 7, :, :, :].view(B, 1, H, W, D))
            # Assuming final_image is a PyTorch tensor
            # Convert the final_image tensor to a NumPy array if it's a tensor
            final_image_np = final_scan.squeeze().cpu().numpy()  # Remove the channel dim and move to CPU

            affine, header = get_affine_and_header(case_path)
            data = nib.load(case_path).get_fdata()
            clip_min = np.min(data)
            clip_max = np.max(data)

            synth_ct_scan_output = os.path.join(out_path, f'MRI_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            sample_denorm = np.clip(final_image_np, a_min=-1, a_max=1) # remove very high and low values

            sample_denorm = rescale_array(
                            arr=sample_denorm, 
                            minv=int(clip_min), 
                            maxv=int(clip_max)
                            )

            sample_denorm_corrected = sample_denorm
            nii_image = nib.Nifti1Image(sample_denorm_corrected, affine=np.eye(4))  # Identity affine for simplicity
            nib.save(nii_image, synth_ct_scan_output)

            segmentation = torch.zeros_like(tumour_core)
            segmentation[whole_tumour==1] = 2
            segmentation[tumour_core==1] = 1
            segmentation[enhancing_tumour==1] = 3

            nii_image = nib.Nifti1Image(segmentation.cpu().numpy(), affine=np.eye(4))  # Identity affine for simplicity
            seg_ct_scan_output = os.path.join(out_path, f'label_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            nib.save(nii_image, seg_ct_scan_output)
            if idx+1 == n:
                break
            
           
                

In [None]:
# Fixed for CT
image_size = (256, 256, 256)
full_background = False
no_seg = False

# To change
in_keys = ['t1c', 'seg']
all_image_keys = ['t1c']
label_key = 'seg'
base_dir = "/projects/brats2023_a_f/BRAINTUMOUR/data/brats2023/ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData"
data_split_json =  os.path.join('/'.join(base_dir.split("/")[0:-1]), "BraTS2023_GLI_data_split.json")

in_channels = 11
label_cond_in_channels = 0
use_label_cond_conv = False
pretrained_weights_path = '../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/runs/c_brats_t1c_concat_cond__tumorW_0_9_12_2024_21:19:23/checkpoints/c_brats_625000.pt'  # Specify the correct path
    
model = get_model(in_channels=in_channels, 
                  label_cond_in_channels=label_cond_in_channels, 
                  use_label_cond_conv=use_label_cond_conv,
                  pretrained_weights_path=pretrained_weights_path)
model.eval()
model.cuda()

dl, ds, data_set = get_brats_loader(in_keys=in_keys, 
                                    all_image_keys=all_image_keys, 
                                    label_key=label_key, 
                                    base_dir=base_dir, 
                                    data_split_json=data_split_json, 
                                    no_seg=no_seg, 
                                    image_size=image_size)
print("Loaded model and data loader")

# Control inference parameters
scheduler_list = ["DPM++_2M", "DPM++_2M_Karras", "DPM++_2M_SDE", "DPM++_2M_SDE_Karras"]
n = 1
num_inference_steps = 100
out_path = "../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/notebooks/trash/c_brats_t1c_concat_cond__tumorW_0_9_12_2024_21:19:23/"
os.makedirs(out_path, exist_ok=True)

run_inference(model=model,
              scheduler_list=scheduler_list,
               n=n, 
               num_inference_steps=num_inference_steps, 
               out_path=out_path) 

#### c_brats_t1c_wavelet_cond__tumorW_0_3_12_2024_15:36:12
* tumour weight 0. wavelet transformed three channel segmentation as condition.


In [None]:
def run_inference(model, scheduler_list, n, num_inference_steps, out_path):
    model.cuda()
    for sch in scheduler_list:
        scheduler = get_scheduler(sch, num_inference_steps)

        for idx, batch  in enumerate(dl):
            case_path = batch['t1c_meta_dict']['filename_or_obj'][0]
            print(f"case_path: {case_path}")
            
            noise_start = torch.randn(1, 8, 128, 128, 128)  
            # Prepare the noisy image
            final_scan = noise_start.clone().detach()
            final_scan = final_scan.cuda()
          

            label_condition = batch["seg"].cuda()
            tumour_core = label_condition[0][0]
            whole_tumour = label_condition[0][1]
            enhancing_tumour = label_condition[0][2]

            # create input model
            LLL = None
            # create input model
            for condition in label_condition[0]:
                condition = condition.unsqueeze(0).unsqueeze(0)
                if LLL==None:
                    LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH = dwt(condition)
                    cond_dwt = th.cat([LLL / 3., LLH, LHL, LHH, HLL, HLH, HHL, HHH], dim=1)
                else:
                    LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH = dwt(condition)
                    cond_dwt = th.cat([cond_dwt, LLL / 3., LLH, LHL, LHH, HLL, HLH, HHL, HHH], dim=1)
            input_model = torch.cat((final_scan, cond_dwt), dim=1)
            input_model = input_model.cuda()


            # Start the reverse process (denoising from noise)
            for timestep in tqdm(scheduler.timesteps, desc="Processing timesteps"):
                # Get the current timestep's noise
                t = torch.tensor([timestep] * final_scan.shape[0])
                t = t.cuda()
                # Perform one step of denoising
                with torch.no_grad():
                    model_kwargs = {}
                    noise_pred = model(input_model, timesteps=t, label_condition=label_condition, **model_kwargs)
                    # Update the noisy_latents (reverse the noise process)
                    final_scan = scheduler.step(model_output=noise_pred, timestep=timestep, sample=final_scan)
                    final_scan = final_scan['prev_sample']
                    input_model = torch.cat((final_scan, cond_dwt), dim=1)
            B, C, D, H, W = final_scan.size()
            final_scan = idwt(final_scan[:, 0, :, :, :].view(B, 1, H, W, D) * 3.,
                        final_scan[:, 1, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 2, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 3, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 4, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 5, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 6, :, :, :].view(B, 1, H, W, D),
                        final_scan[:, 7, :, :, :].view(B, 1, H, W, D))
            # Assuming final_image is a PyTorch tensor
            # Convert the final_image tensor to a NumPy array if it's a tensor
            final_image_np = final_scan.squeeze().cpu().numpy()  # Remove the channel dim and move to CPU

            affine, header = get_affine_and_header(case_path)
            data = nib.load(case_path).get_fdata()
            clip_min = np.min(data)
            clip_max = np.max(data)

            synth_ct_scan_output = os.path.join(out_path, f'MRI_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            sample_denorm = np.clip(final_image_np, a_min=-1, a_max=1) # remove very high and low values

            sample_denorm = rescale_array(
                            arr=sample_denorm, 
                            minv=int(clip_min), 
                            maxv=int(clip_max)
                            )

            sample_denorm_corrected = sample_denorm
            nii_image = nib.Nifti1Image(sample_denorm_corrected, affine=np.eye(4))  # Identity affine for simplicity
            nib.save(nii_image, synth_ct_scan_output)

            segmentation = torch.zeros_like(tumour_core)
            segmentation[whole_tumour==1] = 2
            segmentation[tumour_core==1] = 1
            segmentation[enhancing_tumour==1] = 3

            nii_image = nib.Nifti1Image(segmentation.cpu().numpy(), affine=np.eye(4))  # Identity affine for simplicity
            seg_ct_scan_output = os.path.join(out_path, f'label_{clip_min}_{clip_max}_{idx}_{sch}.nii.gz')
            nib.save(nii_image, seg_ct_scan_output)
            if idx+1 == n:
                break
            
           
                

In [None]:
# Fixed for CT
image_size = (256, 256, 256)
full_background = False
no_seg = False

# To change
in_keys = ['t1c', 'seg']
all_image_keys = ['t1c']
label_key = 'seg'
base_dir = "/projects/brats2023_a_f/BRAINTUMOUR/data/brats2023/ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData"
data_split_json =  os.path.join('/'.join(base_dir.split("/")[0:-1]), "BraTS2023_GLI_data_split.json")

in_channels = 32
label_cond_in_channels = 0
use_label_cond_conv = False
pretrained_weights_path = '../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/runs/c_brats_t1c_wavelet_cond__tumorW_0_3_12_2024_15:36:12/checkpoints/c_brats_955000.pt'  # Specify the correct path
    
model = get_model(in_channels=in_channels, 
                  label_cond_in_channels=label_cond_in_channels, 
                  use_label_cond_conv=use_label_cond_conv,
                  pretrained_weights_path=pretrained_weights_path)
model.eval()
model.cuda()

dl, ds, data_set = get_brats_loader(in_keys=in_keys, 
                                    all_image_keys=all_image_keys, 
                                    label_key=label_key, 
                                    base_dir=base_dir, 
                                    data_split_json=data_split_json, 
                                    no_seg=no_seg, 
                                    image_size=image_size)
print("Loaded model and data loader")

# Control inference parameters
scheduler_list = ["DPM++_2M", "DPM++_2M_Karras", "DPM++_2M_SDE", "DPM++_2M_SDE_Karras"]
n = 1
num_inference_steps = 100
out_path = "../../aritifcial-head-and-neck-cts/WDM3D/wdm-3d/notebooks/trash/c_brats_t1c_wavelet_cond__tumorW_0_3_12_2024_15:36:12"
os.makedirs(out_path, exist_ok=True)

run_inference(model=model,
              scheduler_list=scheduler_list,
               n=n, 
               num_inference_steps=num_inference_steps, 
               out_path=out_path) 