In [1]:
!pip install torch torchvision torchaudio einops pillow tqdm accelerate numpy matplotlib



In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Hyperparameters
batch_size = 16
epochs = 200  # Increase epochs as needed for better results
image_size = 64  # Resize the images to 64x64 (you can increase to 128 or 256)
learning_rate = 1e-4

In [4]:
root_dir = "data/generation_data/Training/"


In [5]:
class TumorDataset(Dataset):
    def __init__(self, root_dir, target_class='notumor'):
        """
        Args:
            root_dir (str): Root directory with 'tumor' and 'notumor' folders.
            target_class (str): 'tumor' or 'no_tumor'
        """
        self.root_dir = os.path.join(root_dir, target_class)
        self.images = []
        self.labels = []
        self.transform = transforms.Compose([
            transforms.Resize((64, 64)),       # Make it 64x64
            transforms.ToTensor(),             # Convert to tensor
            transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1]
        ])
        
        # Load image paths
        for fname in os.listdir(self.root_dir):
            if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                self.images.append(os.path.join(self.root_dir, fname))
                self.labels.append(0 if target_class == 'notumor' else 1)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)     # Read as grayscale
        img = Image.fromarray(img)                           # Convert to PIL
        img = self.transform(img)                            # Apply transforms
        label = self.labels[idx]
        return img, label


# Load dataset
dataset = TumorDataset(root_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [None]:
# DDPM model (U-Net architecture for diffusion)
class UNetDDPM(nn.Module):
    def __init__(self):
        super(UNetDDPM, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.middle = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)
        )
    
    def forward(self, x):
        enc = self.encoder(x)
        middle = self.middle(enc)
        dec = self.decoder(middle)
        return dec

# Instantiate and move model to GPU/CPU
model = UNetDDPM().to(device)


In [None]:
# Loss function (MSE for DDPM)
loss_fn = nn.MSELoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


In [None]:
# Training loop
def train(model, dataloader, epochs):
    model.train()
    for epoch in range(epochs):
        loop = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}")
        for batch_idx, (imgs, _) in enumerate(loop):
            imgs = imgs.to(device)
            optimizer.zero_grad()

            # Forward pass
            output = model(imgs)
            loss = loss_fn(output, imgs)  # Compare output with original image (DDPM)
            loss.backward()
            optimizer.step()

            loop.set_postfix(loss=loss.item())

# Start training
train(model, dataloader, epochs)


In [None]:
# Saving the model
torch.save(model.state_dict(), "unet_ddpm.pth")


In [None]:
# Function to generate synthetic MRI images
def generate_images(model, num_images=5):
    model.eval()
    with torch.no_grad():
        z = torch.randn(num_images, 1, image_size, image_size).to(device)  # Random noise
        generated_images = model(z)  # Generate images
        generated_images = generated_images.cpu().numpy()

        for i, img in enumerate(generated_images):
            plt.subplot(1, num_images, i + 1)
            plt.imshow(img[0], cmap='gray')
            plt.axis('off')
        plt.show()

# Generate 5 synthetic images after training
generate_images(model, num_images=5)
