## Setup Environment

In [1]:
!python -c "import monai" || pip install -q "monai-weekly[tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

## Setup Imports

In [None]:
import os
import csv
import sys
import time
import random
import shutil
import tempfile
import numpy as np
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast

from monai.transforms import (
    Compose,
    Lambdad,
    Resized,
    Randomizable,
    EnsureChannelFirstd,
    ScaleIntensityRanged
)
from monai.config import print_config
from monai.utils import first, set_determinism
from monai.data import Dataset, CacheDataset, DataLoader

from PIL import Image
from tqdm import tqdm

from generative.inferers import DiffusionInferer
from generative.networks.nets import DiffusionModelUNet
from generative.networks.schedulers import DDPMScheduler

print_config()

## Setup Data Directory

In [None]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

## Set Deterministic Training for Reproducibility

In [4]:
set_determinism(42)

## CamcanDataset

In [5]:
class CamcanDataset(Randomizable, CacheDataset):
    def __init__(
        self,
        root_dir,
        csv_file,
        section,
        transform=None,
        seed=0,
        val_frac=0.2,
        test_frac=0.2,
        cache_num=sys.maxsize,
        cache_rate=1.0,
        num_workers=0,
        progress: bool = True,
        condition_prob = 0,
    ) -> None:
        if not os.path.isdir(root_dir):
            raise ValueError("Root directory root_dir must be a directory.")
        self.root_dir = root_dir
        self.csv_file = csv_file
        self.section = section
        self.val_frac = val_frac
        self.test_frac = test_frac
        self.condition_prob = condition_prob
        self.set_random_state(seed=seed)

        data = self._generate_data_list()

        CacheDataset.__init__(
            self,
            data=data,
            transform=transform,
            cache_num=cache_num,
            cache_rate=cache_rate,
            num_workers=num_workers,
            progress=progress,
        )

    def randomize(self, data: np.ndarray) -> None:
        self.R.shuffle(data)

    def _generate_data_list(self):
        datalist = []
        with open(self.csv_file, mode='r') as file:
            reader = csv.DictReader(file)
            for row in reader:
                image_path = os.path.join(self.root_dir, f"sub-{row['Subject']}_defaced_T1.nii.gz")
                if not os.path.exists(image_path):
                    continue
                img = nib.load(image_path)
                img_data = img.get_fdata()
                for slice_idx in range(img_data.shape[2]//2 - 20, img_data.shape[2]//2 + 20):  # Assuming axial slices
                    slice_data = img_data[:,:,slice_idx]
                    condition = np.array([
                        [int(row['Age'])],
                        [int(row['Sex'])],
                        [slice_idx]
                    ]).reshape((1,3)).astype('float32')
                    datalist.append({
                        "image": slice_data,
                        "condition": condition
                    })
        
        length = len(datalist)
        indices = np.arange(length)
        self.randomize(indices)

        # train, validation, test split
        test_length = int(length * self.test_frac)
        val_length = int(length * self.val_frac)
        if self.section == "test":
            section_indices = indices[:test_length]
        elif self.section == "validation":
            section_indices = indices[test_length : test_length + val_length]
        elif self.section == "training":
            section_indices = indices[test_length + val_length :]
        else:
            raise ValueError(
                f'Unsupported section: {self.section}, available options are ["training", "validation", "test"].'
            )
        return [datalist[i] for i in section_indices]

    def __getitem__(self, index):
        sample = self.data[index]

        if random.random() < self.condition_prob:
            sample["condition"] = np.array([[-1, -1, -1]])
        if self.transform:
            sample = self.transform(sample)
        return sample


## Setup CamcanDataset and Training and Validation DataLoader

In [None]:
# Usage example
data_dir = "./dataset_camcan_sy"
csv_file = "./phenotype.csv"

train_transforms = Compose(
    [
        EnsureChannelFirstd(keys=["image"], channel_dim='no_channel'),
        ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0),
        Lambdad(keys=["condition"], func=lambda x: torch.tensor(x, dtype=torch.float32)),
        Resized(keys=["image"], spatial_size=(96,128)),
    ])


# Training DataLoader
train_ds = CamcanDataset(root_dir=data_dir, csv_file=csv_file, transform=train_transforms, section="training", condition_prob=0.2)
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=8, persistent_workers=True)


In [None]:
val_transforms = Compose(
    [
        EnsureChannelFirstd(keys=["image"], channel_dim='no_channel'),
        ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0),
        Lambdad(keys=["condition"], func=lambda x: torch.tensor(x, dtype=torch.float32)),
        Resized(keys=["image"], spatial_size=(96,128)),
    ]
)

# Validation DataLoader
val_ds = CamcanDataset(root_dir=data_dir, csv_file=csv_file, transform=val_transforms, section="validation")
val_loader = DataLoader(val_ds, batch_size=8, shuffle=False, num_workers=8, persistent_workers=True)

## Define Network, Scheduler, Optimizer and Inferer

In [8]:
device = torch.device("cuda")

model = DiffusionModelUNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    num_channels=(256, 256, 512),
    attention_levels=(False, False, True),
    num_res_blocks=2,
    num_head_channels=(0, 0, 512),
    with_conditioning=True,
    cross_attention_dim=3,
)
model.to(device)

scheduler = DDPMScheduler(num_train_timesteps=1000)

optimizer = torch.optim.Adam(params=model.parameters(), lr=2.5e-5)

inferer = DiffusionInferer(scheduler)

## Pretrained Model

In [None]:
# Path to the pretrained model
pretrained_model_path = 'pretrained_model_275.pth'

# Load the state dictionary
state_dict = torch.load(pretrained_model_path, map_location=device)

# Load the state dictionary into the model
model.load_state_dict(state_dict)

# If you have a checkpoint with more information (like optimizer state), use this:
checkpoint = torch.load('pretrained_model_checkpoint_275.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
print(start_epoch)

## Guidance Scale = 7.0

In [None]:
model.eval()
guidance_scale = 7.0

generated_images_path = './generated_images/275_7.0'
os.makedirs(generated_images_path, exist_ok=True)

max_samples = 1800
sample_count = 0

for batch in val_loader:
    if sample_count >= max_samples:
        break

    real_images = batch['image'].to(device)
    conditions = batch['condition'].to(device)
    
    for i in range(real_images.size(0)):
        if sample_count >= max_samples:
            break

        age, sex, slice_number = conditions[i][0, 0], conditions[i][0, 1], conditions[i][0, 2]
        unconditioned = torch.tensor([[-1, -1, -1]], dtype=torch.float32)  # Shape: (1, 3)
        conditioned = torch.tensor([[age, sex, slice_number]], dtype=torch.float32)  # Shape: (1, 3)
        conditioning = torch.stack([unconditioned, conditioned], dim=0).to(device)
            
        noise = torch.randn((1, 1, 96, 128)).to(device)
        scheduler.set_timesteps(num_inference_steps=1000)
        progress_bar = tqdm(scheduler.timesteps, leave=False)
            
        for t in progress_bar:
            with autocast(enabled=True):
                with torch.no_grad():
                    noise_input = torch.cat([noise] * 2)
                    model_output = model(noise_input, timesteps=torch.Tensor((t,)).to(noise.device), context=conditioning)
                    noise_pred_uncond, noise_pred_text = model_output.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            noise, _ = scheduler.step(noise_pred, t, noise)
            
        # Save the generated image
        noise_image = noise[0, 0].cpu().numpy()
        g_file_name = f"generated_{sample_count}_{int(slice_number):03d}.png"
        generated_file = os.path.join(generated_images_path, g_file_name)
        # Save the image using plt.imsave
        plt.imsave(generated_file, noise_image, vmin=0, vmax=1, cmap='gray')
        
        sample_count+=1

## Guidance Scale = 5.0

In [None]:
model.eval()
guidance_scale = 5.0

generated_images_path = './generated_images/275_5.0' 
os.makedirs(generated_images_path, exist_ok=True)

max_samples = 1800
sample_count = 0

for batch in val_loader:
    if sample_count >= max_samples:
        break

    real_images = batch['image'].to(device)
    conditions = batch['condition'].to(device)
    
    for i in range(real_images.size(0)):
        if sample_count >= max_samples:
            break

        age, sex, slice_number = conditions[i][0, 0], conditions[i][0, 1], conditions[i][0, 2]
        unconditioned = torch.tensor([[-1, -1, -1]], dtype=torch.float32)  # Shape: (1, 3)
        conditioned = torch.tensor([[age, sex, slice_number]], dtype=torch.float32)  # Shape: (1, 3)
        conditioning = torch.stack([unconditioned, conditioned], dim=0).to(device)
            
        noise = torch.randn((1, 1, 96, 128)).to(device)
        scheduler.set_timesteps(num_inference_steps=1000)
        progress_bar = tqdm(scheduler.timesteps, leave=False)
            
        for t in progress_bar:
            with autocast(enabled=True):
                with torch.no_grad():
                    noise_input = torch.cat([noise] * 2)
                    model_output = model(noise_input, timesteps=torch.Tensor((t,)).to(noise.device), context=conditioning)
                    noise_pred_uncond, noise_pred_text = model_output.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            noise, _ = scheduler.step(noise_pred, t, noise)
            
        # Save the generated image
        noise_image = noise[0, 0].cpu().numpy()
        g_file_name = f"generated_{sample_count}_{int(slice_number):03d}.png"
        generated_file = os.path.join(generated_images_path, g_file_name)
        # Save the image using plt.imsave
        plt.imsave(generated_file, noise_image, vmin=0, vmax=1, cmap='gray')
        
        sample_count+=1

## Guidance Scale = 1.0

In [None]:
guidance_scale = 1.0

generated_images_path = './generated_images/275_1.0'
os.makedirs(generated_images_path, exist_ok=True)

max_samples = 1800
sample_count = 0

for batch in val_loader:
    if sample_count >= max_samples:
        break

    real_images = batch['image'].to(device)
    conditions = batch['condition'].to(device)
    
    for i in range(real_images.size(0)):
        if sample_count >= max_samples:
            break

        age, sex, slice_number = conditions[i][0, 0], conditions[i][0, 1], conditions[i][0, 2]
        unconditioned = torch.tensor([[-1, -1, -1]], dtype=torch.float32)  # Shape: (1, 3)
        conditioned = torch.tensor([[age, sex, slice_number]], dtype=torch.float32)  # Shape: (1, 3)
        conditioning = torch.stack([unconditioned, conditioned], dim=0).to(device)
            
        noise = torch.randn((1, 1, 96, 128)).to(device)
        scheduler.set_timesteps(num_inference_steps=1000)
        progress_bar = tqdm(scheduler.timesteps, leave=False)
            
        for t in progress_bar:
            with autocast(enabled=True):
                with torch.no_grad():
                    noise_input = torch.cat([noise] * 2)
                    model_output = model(noise_input, timesteps=torch.Tensor((t,)).to(noise.device), context=conditioning)
                    noise_pred_uncond, noise_pred_text = model_output.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            noise, _ = scheduler.step(noise_pred, t, noise)
            
        # Save the generated image
        noise_image = noise[0, 0].cpu().numpy()
        g_file_name = f"generated_{sample_count}_{int(slice_number):03d}.png"
        generated_file = os.path.join(generated_images_path, g_file_name)
        # Save the image using plt.imsave
        plt.imsave(generated_file, noise_image, vmin=0, vmax=1, cmap='gray')
        
        sample_count+=1

## Guidance Scale = 0.5

In [None]:
guidance_scale = 0.5

generated_images_path = './generated_images/275_0.5'
os.makedirs(generated_images_path, exist_ok=True)

max_samples = 1800
sample_count = 0

for batch in val_loader:
    if sample_count >= max_samples:
        break

    real_images = batch['image'].to(device)
    conditions = batch['condition'].to(device)
    
    for i in range(real_images.size(0)):
        if sample_count >= max_samples:
            break

        age, sex, slice_number = conditions[i][0, 0], conditions[i][0, 1], conditions[i][0, 2]
        unconditioned = torch.tensor([[-1, -1, -1]], dtype=torch.float32)  # Shape: (1, 3)
        conditioned = torch.tensor([[age, sex, slice_number]], dtype=torch.float32)  # Shape: (1, 3)
        conditioning = torch.stack([unconditioned, conditioned], dim=0).to(device)
            
        noise = torch.randn((1, 1, 96, 128)).to(device)
        scheduler.set_timesteps(num_inference_steps=1000)
        progress_bar = tqdm(scheduler.timesteps, leave=False)
            
        for t in progress_bar:
            with autocast(enabled=True):
                with torch.no_grad():
                    noise_input = torch.cat([noise] * 2)
                    model_output = model(noise_input, timesteps=torch.Tensor((t,)).to(noise.device), context=conditioning)
                    noise_pred_uncond, noise_pred_text = model_output.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            noise, _ = scheduler.step(noise_pred, t, noise)
            
        # Save the generated image
        noise_image = noise[0, 0].cpu().numpy()
        g_file_name = f"generated_{sample_count}_{int(slice_number):03d}.png"
        generated_file = os.path.join(generated_images_path, g_file_name)
        # Save the image using plt.imsave
        plt.imsave(generated_file, noise_image, vmin=0, vmax=1, cmap='gray')
        
        # Save the corresponding real image
        real_image = real_images[i, 0].cpu().numpy()
        r_file_name = f"real_{sample_count}_{int(slice_number):03d}.png"
        real_file = os.path.join(real_images_path, r_file_name)
        # Save the image using plt.imsave
        plt.imsave(real_file, real_image, vmin=0, vmax=1, cmap='gray')
        
        sample_count+=1

 ## Guidance Scale = 0.0

In [None]:
guidance_scale = 0.0

generated_images_path = './generated_images/275_0.0'
os.makedirs(generated_images_path, exist_ok=True)

max_samples = 1800
sample_count = 0

for batch in val_loader:
    if sample_count >= max_samples:
        break

    real_images = batch['image'].to(device)
    conditions = batch['condition'].to(device)
    
    for i in range(real_images.size(0)):
        if sample_count >= max_samples:
            break

        age, sex, slice_number = conditions[i][0, 0], conditions[i][0, 1], conditions[i][0, 2]
        unconditioned = torch.tensor([[-1, -1, -1]], dtype=torch.float32)  # Shape: (1, 3)
        conditioned = torch.tensor([[age, sex, slice_number]], dtype=torch.float32)  # Shape: (1, 3)
        conditioning = torch.stack([unconditioned, conditioned], dim=0).to(device)
            
        noise = torch.randn((1, 1, 96, 128)).to(device)
        scheduler.set_timesteps(num_inference_steps=1000)
        progress_bar = tqdm(scheduler.timesteps, leave=False)
            
        for t in progress_bar:
            with autocast(enabled=True):
                with torch.no_grad():
                    noise_input = torch.cat([noise] * 2)
                    model_output = model(noise_input, timesteps=torch.Tensor((t,)).to(noise.device), context=conditioning)
                    noise_pred_uncond, noise_pred_text = model_output.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            noise, _ = scheduler.step(noise_pred, t, noise)
            
        # Save the generated image
        noise_image = noise[0, 0].cpu().numpy()
        g_file_name = f"generated_{sample_count}_{int(slice_number):03d}.png"
        generated_file = os.path.join(generated_images_path, g_file_name)
        # Save the image using plt.imsave
        plt.imsave(generated_file, noise_image, vmin=0, vmax=1, cmap='gray')
        
        # Save the corresponding real image
        real_image = real_images[i, 0].cpu().numpy()
        r_file_name = f"real_{sample_count}_{int(slice_number):03d}.png"
        real_file = os.path.join(real_images_path, r_file_name)
        # Save the image using plt.imsave
        plt.imsave(real_file, real_image, vmin=0, vmax=1, cmap='gray')
        
        sample_count+=1

## Clean up Data Directory

In [12]:
if directory is None:
    shutil.rmtree(root_dir)