# Generating Synthetic Brain Tumor MRI Images using Stable Diffusion

This notebook showcases the process of fine-tuning a Stable Diffusion model to generate synthetic Magnetic Resonance Imaging (MRI) images of brain tumors.

# Setup

This section handles the initial setup, installing required modules, imports, and drive upload to setup the dataset.

In [None]:
!pip install --upgrade diffusers transformers accelerate torch torchvision datasets

In [None]:
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import Dataset as TorchDataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm.auto import tqdm
import os
import random
from diffusers import StableDiffusionImg2ImgPipeline, DDPMScheduler
from datasets import load_dataset, Dataset, DatasetDict
from google.colab import drive, files

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Data Loading and Preparation

This section defines a custom PyTorch Dataset class to load the brain tumor MRI images and their corresponding text prompts. It also includes the data transformations to be applied to the images and initializes the training dataset.

In [None]:
DATASET_ROOT = "/content/drive/MyDrive/MIT URTC 2025/Training_Dataset"

OUTPUT_DIR = "/content/stable-diffusion-brain-tumor"

In [None]:
CLASS_LABELS = {
    "meningioma": "Meningioma tumor MRI",
    "glioma": "Glioma tumor MRI",
    "pituitary_tumor": "Pituitary tumor MRI"
}

In [None]:
IMAGE_SIZE = 256

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
class MRIDataset(TorchDataset):
    def __init__(self, root_dir, transform=None):
        self.samples = []
        self.transform = transform

        for class_name in os.listdir(root_dir):
            class_dir = os.path.join(root_dir, class_name)
            if not os.path.isdir(class_dir):
                continue
            for fname in os.listdir(class_dir):
                img_path = os.path.join(class_dir, fname)
                text_prompt = CLASS_LABELS[class_name]
                self.samples.append((img_path, text_prompt))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, text_prompt = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return {"image": image, "prompt": text_prompt}


transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),
])

In [None]:
train_dataset = MRIDataset(os.path.join(DATASET_ROOT, "train"), transform=transform)

print(f"Train samples: {len(train_dataset)}")

# Model Definition and Training

This section covers loading the pre-trained Stable Diffusion model pipeline, configuring it for training (enabling gradient checkpointing), defining the optimizer (AdamW) and the loss function (MSE), setting up the DDPMScheduler for noise scheduling, and implementing the training loop to fine-tune the UNet component of the pipeline on the custom dataset.

In [None]:
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    "stabilityai/sd-turbo"
).to(DEVICE)

In [None]:
pipe.unet.enable_gradient_checkpointing()

pipe.unet.train()

optimizer = AdamW(pipe.unet.parameters(), lr=1e-5)

loss_fn = torch.nn.MSELoss()

EPOCHS = x #update epoch for your usage

scaler = torch.amp.GradScaler("cuda")

from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)

pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)


for epoch in range(EPOCHS):
    epoch_loss = 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for batch in pbar:
        optimizer.zero_grad()

        imgs = batch["image"].to(DEVICE)
        prompts = batch["prompt"]

        with torch.amp.autocast("cuda"):
            latents = pipe.vae.encode(imgs).latent_dist.sample()
            latents = latents * 0.18215

        # Sample noise to add to the latents
        noise = torch.randn_like(latents)

        # Sample a random timestep for each image
        bsz = latents.shape[0]
        timesteps = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (bsz,), device=DEVICE).long()

        # Add noise to the latents according to the noise magnitude at each timestep
        noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)

        text_embeddings = pipe.text_encoder(
            pipe.tokenizer(prompts, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids.to(DEVICE)
        )[0]

        with torch.amp.autocast("cuda"):
            model_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
            # Use the target noise (noise) as the target for the model prediction
            loss = loss_fn(model_pred.float(), noise.float())


        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()
        pbar.set_postfix({"loss": loss.item()})

    avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch {epoch+1} Average Loss: {avg_loss:.6f}")

In [None]:
pipe.save_pretrained(OUTPUT_DIR)
print(f"Model saved to {OUTPUT_DIR}")

# Image Generation

This section defines a function to generate synthetic images using the fine-tuned Stable Diffusion model. It takes an input image path and a text prompt, and uses the `StableDiffusionImg2ImgPipeline` to generate a new image based on the input and prompt. An example usage is also provided to demonstrate the image generation process. Additionally a class by class generation example is also provided.

In [None]:
def generate_image(input_image_path, text_prompt, strength=0.75, guidance_scale=7.5):
    init_image = Image.open(input_image_path).convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE))
    output = pipe(
        prompt=text_prompt,
        image=init_image,
        strength=strength,
        guidance_scale=guidance_scale,
        num_inference_steps=50,
    )
    return output.images[0]

# Example usage
sample_image = generate_image(
    input_image_path=os.path.join(DATASET_ROOT, "test", "pituitary_tumor", os.listdir(os.path.join(DATASET_ROOT, "test", "pituitary_tumor"))[0]),
    text_prompt=CLASS_LABELS["pituitary_tumor"]
)
sample_image.save("/content/generated_sample_pituitary_tumor.png")
print("Generated image saved.")

In [None]:
def generate_images_for_class(
    class_name,
    num_images=5000,
    strength=0.75,
    guidance_scale=7.5,
    save_dir="/content/generated_samples"
):
    class_dir = os.path.join(save_dir, class_name)
    os.makedirs(class_dir, exist_ok=True)

    input_dir = os.path.join(DATASET_ROOT, "train", class_name)
    input_files = os.listdir(input_dir)

    for i in tqdm(range(num_images), desc=f"Generating {class_name}"):
        input_image_path = os.path.join(input_dir, random.choice(input_files))

        img = generate_image(
            input_image_path=input_image_path,
            text_prompt=CLASS_LABELS[class_name],
            strength=strength,
            guidance_scale=guidance_scale
        )

        out_path = os.path.join(class_dir, f"{class_name}_{i:03d}.png")
        img.save(out_path)

    print(f"Saved {num_images} images for {class_name} to {class_dir}")

In [None]:
for cls in CLASS_LABELS.keys():
    generate_images_for_class(
        class_name=cls,
        num_images=5000,
        save_dir="/content/generated_samples"
    )

In [None]:
!zip -r /content/stable-diffusion-brain-tumor.zip /content/stable-diffusion-brain-tumor
!zip -r /content/generated_samples.zip /content/generated_samples

files.download('/content/generated_samples.zip')