# Brain Tumor Segmentation with CMDiff (Conditional Diffusion Model)

Implementation of the paper: "Conditional diffusion model for high-accuracy brain tumor segmentation in MRI images"

**Key Features:**
- Conditional diffusion model with DDPM framework
- Channel attention mechanism at UNet bottleneck
- Fourier filtering preprocessing
- 4-class segmentation: Background, Necrotic/Core, Edema, Enhancing
- BraTS 2020 dataset

In [None]:
# Install required packages
!pip install kaggle nibabel scipy -q

In [None]:
# Setup Kaggle credentials
from google.colab import userdata
import os

kaggle_username = 'xiaoyanhe0713'
kaggle_key = 'cfeedab0085bb4e4a3468b7b2c11f006'

os.environ['KAGGLE_USERNAME'] = kaggle_username
os.environ['KAGGLE_KEY'] = kaggle_key

print("Kaggle credentials set successfully.")

Kaggle credentials set successfully.


In [None]:
# Download dataset
!kaggle datasets download awsaf49/brats20-dataset-training-validation

Dataset URL: https://www.kaggle.com/datasets/awsaf49/brats20-dataset-training-validation
License(s): CC0-1.0
brats20-dataset-training-validation.zip: Skipping, found more recently modified local copy (use --force to force download)


In [None]:
# Extract subset of data (first 60 cases for faster training)
!unzip -q brats20-dataset-training-validation.zip "BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_0[0-5][0-9]/*"

replace BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_flair.nii? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [None]:
# Import libraries
import numpy as np
import nibabel as nib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import glob
from pathlib import Path
import random
from scipy import ndimage
import csv
from collections import defaultdict
import math
import matplotlib.pyplot as plt
from tqdm import tqdm

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Define BraTS segment classes
SEGMENT_CLASSES = {
    0: 'NOT tumor',
    1: 'NECROTIC/CORE',
    2: 'EDEMA',
    3: 'ENHANCING'
}

print("Segment classes:")
for k, v in SEGMENT_CLASSES.items():
    print(f"  {k}: {v}")

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'\nUsing device: {device}')

Segment classes:
  0: NOT tumor
  1: NECROTIC/CORE
  2: EDEMA
  3: ENHANCING

Using device: cuda


In [None]:
TRAIN_DATASET_PATH = "/content/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/"

In [None]:
# Select Slices and Image Size
VOLUME_SLICES = 100
VOLUME_START_AT = 22 # first slice of volume that we will include
IMG_SIZE=128

## Dataset Class

In [None]:
def fourier_low_pass_filter(image_slice, radius_ratio=0.1):
    """
    Applies Fourier frequency-domain filtering to remove high-frequency noise.
    Implements Eq. 7 from the CMDiff paper.
    """
    # 1. Transform to frequency domain
    f = np.fft.fft2(image_slice)
    fshift = np.fft.fftshift(f)

    # 2. Create Circular Low Pass Filter Mask
    rows, cols = image_slice.shape
    crow, ccol = rows // 2, cols // 2
    mask = np.zeros((rows, cols), np.uint8)

    r = int(min(rows, cols) * radius_ratio)
    y, x = np.ogrid[:rows, :cols]
    mask_area = (x - ccol) ** 2 + (y - crow) ** 2 <= r*r
    mask[mask_area] = 1

    # 3. Apply mask and Inverse FFT
    fshift_filtered = fshift * mask
    f_ishift = np.fft.ifftshift(fshift_filtered)
    img_back = np.fft.ifft2(f_ishift)

    return np.abs(img_back)

class BrainDataset(Dataset):
    def __init__(self, list_IDs, root_dir, dim=(128, 128), use_fourier=True):
        self.dim = dim
        self.list_IDs = list_IDs
        self.root_dir = root_dir
        self.use_fourier = use_fourier
        # Paper uses 4 modalities: FLAIR, T1, T2, T1CE
        self.modalities = ['flair', 't1', 't2', 't1ce']

        self.samples = []
        for ID in list_IDs:
            # Select middle slices (e.g., 60-100) where tumors are most likely to appear
            # to speed up this demonstration.
            for slice_idx in range(60, 100):
                self.samples.append((ID, slice_idx))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        case_id, slice_idx = self.samples[idx]
        case_path = os.path.join(self.root_dir, case_id)

        channels = []
        for mod in self.modalities:
            # Construct filename (e.g., BraTS20_Training_001_flair.nii)
            file_path = os.path.join(case_path, f"{case_id}_{mod}.nii")

            # Load NIfTI
            img = nib.load(file_path).get_fdata()
            slice_img = cv2.resize(img[:, :, slice_idx], self.dim)

            # Apply Fourier Filtering if enabled
            if self.use_fourier:
                slice_img = fourier_low_pass_filter(slice_img)

            # Normalize (Z-score normalization per slice)
            if np.std(slice_img) > 0:
                slice_img = (slice_img - np.mean(slice_img)) / np.std(slice_img)

            channels.append(slice_img)

        # Stack channels -> (4, H, W)
        X = np.stack(channels, axis=0)

        # Load Segmentation Mask
        seg_path = os.path.join(case_path, f"{case_id}_seg.nii")
        seg = nib.load(seg_path).get_fdata()
        mask_slice = cv2.resize(seg[:, :, slice_idx], self.dim, interpolation=cv2.INTER_NEAREST)

        # BraTS Re-mapping: 4 -> 3 (enhancing tumor)
        mask_slice[mask_slice == 4] = 3

        return torch.FloatTensor(X), torch.LongTensor(mask_slice)

## Channel Attention Module (from paper)

In [None]:
class DiffusionUtils:
    def __init__(self, timesteps=1000):
        self.timesteps = timesteps

        # Cosine schedule (often preferred for medical images over linear)
        self.betas = self.cosine_beta_schedule(timesteps).to(device)
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)

    def cosine_beta_schedule(self, timesteps, s=0.008):
        steps = timesteps + 1
        x = torch.linspace(0, timesteps, steps)
        alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        return torch.clip(betas, 0.0001, 0.9999)

    def extract(self, a, t, x_shape):
        batch_size = t.shape[0]
        out = a.gather(-1, t)
        return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        sqrt_alphas_cumprod_t = self.extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = self.extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    @torch.no_grad()
    def p_sample(self, model, x, t, condition, t_index):
        # Equation 6: Predict mean
        betas_t = self.extract(self.betas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = self.extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape)
        sqrt_recip_alphas_t = self.extract(self.sqrt_recip_alphas, t, x.shape)

        # Model predicts noise, conditioned on MRI image
        model_output = model(x, t, condition)

        model_mean = sqrt_recip_alphas_t * (x - betas_t * model_output / sqrt_one_minus_alphas_cumprod_t)

        if t_index == 0:
            return model_mean
        else:
            posterior_variance_t = self.extract(self.posterior_variance, t, x.shape)
            noise = torch.randn_like(x)
            return model_mean + torch.sqrt(posterior_variance_t) * noise

In [None]:
class ChannelAttention(nn.Module):
    """
    Channel Attention Mechanism (Fig. 3 in paper).
    Uses Global Average Pooling and Max Pooling to re-weight channels.
    """
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return x * self.sigmoid(out)

class CMDiffUNet2D(nn.Module):
    def __init__(self, in_channels=4, out_channels=4, base=64):
        super().__init__()

        # Time embedding MLP
        self.time_mlp = nn.Sequential(
            nn.Linear(1, base),
            nn.GELU(),
            nn.Linear(base, base),
        )

        # --- Condition Encoder (MRI Image) ---
        self.cond_enc1 = nn.Sequential(nn.Conv2d(in_channels, base, 3, padding=1), nn.ReLU())
        self.cond_enc2 = nn.Sequential(nn.MaxPool2d(2), nn.Conv2d(base, base*2, 3, padding=1), nn.ReLU())
        self.cond_enc3 = nn.Sequential(nn.MaxPool2d(2), nn.Conv2d(base*2, base*4, 3, padding=1), nn.ReLU())

        # --- Noisy Mask Encoder ---
        self.mask_enc1 = nn.Sequential(nn.Conv2d(out_channels, base, 3, padding=1), nn.ReLU())
        self.mask_enc2 = nn.Sequential(nn.MaxPool2d(2), nn.Conv2d(base, base*2, 3, padding=1), nn.ReLU())
        self.mask_enc3 = nn.Sequential(nn.MaxPool2d(2), nn.Conv2d(base*2, base*4, 3, padding=1), nn.ReLU())

        # --- Bottleneck with Attention ---
        self.bottleneck = nn.Sequential(
            nn.Conv2d(base*4 * 2, base*8, 3, padding=1), # *2 for concat
            nn.ReLU(),
            ChannelAttention(base*8), # Paper's Attention Block
            nn.Conv2d(base*8, base*4, 3, padding=1),
            nn.ReLU()
        )

        # --- Decoder ---
        self.up1 = nn.ConvTranspose2d(base*4, base*2, 2, stride=2)
        self.dec1 = nn.Sequential(nn.Conv2d(base*2 * 3, base*2, 3, padding=1), nn.ReLU()) # *3 for skip (mask+cond) + up

        self.up2 = nn.ConvTranspose2d(base*2, base, 2, stride=2)
        self.dec2 = nn.Sequential(nn.Conv2d(base * 3, base, 3, padding=1), nn.ReLU())

        self.final = nn.Conv2d(base, out_channels, 1)

    def forward(self, x, t, condition):
        # 1. Time Embedding
        t_emb = self.time_mlp(t.float().view(-1, 1)) # (B, base)
        t_emb = t_emb[:, :, None, None].expand(-1, -1, x.shape[2], x.shape[3])

        # 2. Encode Condition
        c1 = self.cond_enc1(condition)
        c2 = self.cond_enc2(c1)
        c3 = self.cond_enc3(c2)

        # 3. Encode Noisy Mask (inject time into first layer)
        m1 = self.mask_enc1(x) + t_emb
        m2 = self.mask_enc2(m1)
        m3 = self.mask_enc3(m2)

        # 4. Bottleneck Fusion
        b = torch.cat([m3, c3], dim=1)
        b = self.bottleneck(b)

        # 5. Decoder
        d1 = self.up1(b)
        # Skip connections from both Mask encoder and Condition encoder
        d1 = torch.cat([d1, m2, c2], dim=1)
        d1 = self.dec1(d1)

        d2 = self.up2(d1)
        d2 = torch.cat([d2, m1, c1], dim=1)
        d2 = self.dec2(d2)

        return self.final(d2)

## Loss Functions and Metrics

In [None]:
def dice_coefficient(pred, target, num_classes=4, smooth=1e-6):
    """
    Calculate Dice coefficient for multi-class segmentation.
    """
    dice_scores = []

    for c in range(num_classes):
        pred_c = (pred == c).float()
        target_c = (target == c).float()

        intersection = (pred_c * target_c).sum()
        union = pred_c.sum() + target_c.sum()

        dice = (2.0 * intersection + smooth) / (union + smooth)
        dice_scores.append(dice.item())

    return np.mean(dice_scores[1:])  # Exclude background

def iou_score(pred, target, num_classes=4, smooth=1e-6):
    """
    Calculate IoU (Intersection over Union) for multi-class segmentation.
    """
    iou_scores = []

    for c in range(num_classes):
        pred_c = (pred == c).float()
        target_c = (target == c).float()

        intersection = (pred_c * target_c).sum()
        union = pred_c.sum() + target_c.sum() - intersection

        iou = (intersection + smooth) / (union + smooth)
        iou_scores.append(iou.item())

    return np.mean(iou_scores[1:])  # Exclude background

print("Metrics defined!")

Metrics defined!


In [None]:
import torch
import torch.nn as nn

def foreground_weighted_mse(
    pred_noise: torch.Tensor,
    target_noise: torch.Tensor,
    masks: torch.Tensor,
    fg_weight: float = 40.0,
    bg_weight: float = 1.0,
) -> torch.Tensor:
    """
    Weighted MSE loss in noise space.

    pred_noise:  (B, C, H, W) - model output
    target_noise:(B, C, H, W) - true noise used in q_sample
    masks:       (B, H, W)    - GT labels (0 = background, >0 = foreground)
    fg_weight:   multiplicative weight for foreground pixels
    bg_weight:   multiplicative weight for background pixels
    """

    # Per-pixel binary mask: 1 for foreground, 0 for background
    with torch.no_grad():
        fg_mask = (masks != 0).float()           # (B, H, W)
        fg_mask = fg_mask.unsqueeze(1)          # (B, 1, H, W)
        fg_mask = fg_mask.expand_as(pred_noise) # (B, C, H, W)

        # Combine foreground/background weights
        weights = fg_mask * fg_weight + (1.0 - fg_mask) * bg_weight  # (B, C, H, W)

    mse = (pred_noise - target_noise) ** 2

    # Normalize by sum of weights so the overall scale is stable
    weighted_loss = (weights * mse).sum() / weights.sum()

    return weighted_loss


## Training Functions

In [None]:
import cv2
# --- Configuration ---
# UPDATE THIS PATH to your actual data location
ROOT_DIR = "/content/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/"

# Generate dummy list of IDs for demonstration
# (In real usage, list actual directory names, e.g. "BraTS20_Training_001")
train_ids = [f"BraTS20_Training_{i:03d}" for i in range(1, 31)]
# Assuming you already have `train_ids` (a list of case IDs)
val_fraction = 0.2
num_val = int(len(train_ids) * val_fraction)
val_ids = train_ids[-num_val:]
train_ids = train_ids[:-num_val]

train_dataset = BrainDataset(train_ids, ROOT_DIR, use_fourier=True)
val_dataset   = BrainDataset(val_ids, ROOT_DIR, use_fourier=True)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_dataset,   batch_size=8, shuffle=False, num_workers=2)

diffuser = DiffusionUtils(timesteps=1000)

model = CMDiffUNet2D().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

In [None]:
import numpy as np
import torch.nn.functional as F

def validate(model, diffuser, val_loader, device):
    model.eval()
    criterion = nn.MSELoss()

    all_losses = []
    all_dices = []
    all_ious = []

    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)  # (B, H, W)

            # One-hot encode masks, scale to [-1, 1] (same as training)
            masks_onehot = F.one_hot(masks.long(), num_classes=4).permute(0, 3, 1, 2).float()
            masks_scaled = masks_onehot * 2 - 1  # (B, 4, H, W)

            # Random t and noise as in training
            t = torch.randint(0, diffuser.timesteps, (images.shape[0],), device=device).long()
            noise = torch.randn_like(masks_scaled)
            x_t = diffuser.q_sample(masks_scaled, t, noise)

            # Predict noise and compute validation loss
            pred_noise = model(x_t, t, condition=images)
            loss = criterion(pred_noise, noise)
            all_losses.append(loss.item())

            # --- Reconstruct x0 estimate and compute segmentation metrics ---
            sqrt_alpha_t = diffuser.extract(diffuser.sqrt_alphas_cumprod, t, x_t.shape)
            sqrt_one_minus_alpha_t = diffuser.extract(diffuser.sqrt_one_minus_alphas_cumprod, t, x_t.shape)

            # x_t = sqrt(alpha_t) * x0 + sqrt(1 - alpha_t) * noise
            # => x0 ≈ (x_t - sqrt(1 - alpha_t) * eps_theta) / sqrt(alpha_t)
            x0_pred = (x_t - sqrt_one_minus_alpha_t * pred_noise) / (sqrt_alpha_t + 1e-8)

            # Get discrete predicted labels by argmax over channels
            preds = torch.argmax(x0_pred, dim=1)  # (B, H, W)

            # Accumulate Dice & IoU (exclude background as your functions do)
            for i in range(preds.shape[0]):
                d = dice_coefficient(preds[i], masks[i])
                j = iou_score(preds[i], masks[i])
                all_dices.append(d)
                all_ious.append(j)

    avg_loss = float(np.mean(all_losses)) if all_losses else 0.0
    avg_dice = float(np.mean(all_dices)) if all_dices else 0.0
    avg_iou  = float(np.mean(all_ious))  if all_ious  else 0.0

    return avg_loss, avg_dice, avg_iou

In [None]:
epochs = 50

for epoch in range(epochs):
    model.train()
    train_losses = []

    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.to(device)  # (B, H, W)

        # One-hot encode masks and scale to [-1, 1]
        masks_onehot = F.one_hot(masks.long(), num_classes=4).permute(0, 3, 1, 2).float()
        masks_scaled = masks_onehot * 2 - 1

        # Sample t and noise
        t = torch.randint(0, diffuser.timesteps, (images.shape[0],), device=device).long()
        noise = torch.randn_like(masks_scaled)
        x_t = diffuser.q_sample(masks_scaled, t, noise)

        # Predict noise
        predicted_noise = model(x_t, t, condition=images)

        loss = foreground_weighted_mse(
            pred_noise=predicted_noise,
            target_noise=noise,
            masks=masks,          # (B, H, W) GT labels
            fg_weight=4.0,        # try 3–10, tune later
            bg_weight=1.0
        )

        train_losses.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    avg_train_loss = np.mean(train_losses)
    val_loss, val_dice, val_iou = validate(model, diffuser, val_loader, device)

    print(
        f"Epoch {epoch+1}/{epochs} | "
        f"Train Loss: {avg_train_loss:.4f} | "
        f"Val Loss: {val_loss:.4f} | "
        f"Val Dice: {val_dice:.4f} | "
        f"Val IoU: {val_iou:.4f}"
    )


Epoch 1/50 | Train Loss: 22.8748 | Val Loss: 1.0518 | Val Dice: 0.3306 | Val IoU: 0.3014
Epoch 2/50 | Train Loss: 1.0060 | Val Loss: 0.9910 | Val Dice: 0.2872 | Val IoU: 0.2568
Epoch 3/50 | Train Loss: 0.9902 | Val Loss: 0.9845 | Val Dice: 0.3280 | Val IoU: 0.2997
Epoch 4/50 | Train Loss: 0.9814 | Val Loss: 0.9767 | Val Dice: 0.3754 | Val IoU: 0.3489
Epoch 5/50 | Train Loss: 0.9725 | Val Loss: 0.9663 | Val Dice: 0.3235 | Val IoU: 0.2967
Epoch 6/50 | Train Loss: 1.0105 | Val Loss: 0.9621 | Val Dice: 0.3263 | Val IoU: 0.3015
Epoch 7/50 | Train Loss: 0.9572 | Val Loss: 0.9522 | Val Dice: 0.3578 | Val IoU: 0.3341
Epoch 8/50 | Train Loss: 0.9469 | Val Loss: 0.9433 | Val Dice: 0.3602 | Val IoU: 0.3348
Epoch 9/50 | Train Loss: 0.9319 | Val Loss: 0.9240 | Val Dice: 0.3553 | Val IoU: 0.3282
Epoch 10/50 | Train Loss: 0.9189 | Val Loss: 0.9067 | Val Dice: 0.3569 | Val IoU: 0.3303
Epoch 11/50 | Train Loss: 0.8985 | Val Loss: 0.8864 | Val Dice: 0.3612 | Val IoU: 0.3323
Epoch 12/50 | Train Loss: 1.0