Claude:

I'll create a Jupyter notebook for Google Colab that simulates colon images with polyps and then builds a PyTorch model for polyp detection.


This notebook creates a complete pipeline for polyp detection using synthetic data and deep learning. Here's what it does:

1. **Synthetic Data Generation**
   - Creates realistic-looking colon images with pinkish tube-like structures
   - Adds concentric rings/folds to simulate colon anatomy
   - Generates polyps with varied sizes, shapes, and positions
   - Produces paired images and segmentation masks

2. **Dataset Preparation**
   - Splits data into training and validation sets
   - Implements data loading and preprocessing

3. **Model Architecture**
   - Implements a U-Net architecture for semantic segmentation
   - Uses skip connections to preserve spatial information

4. **Training Process**
   - Trains the model using Binary Cross Entropy loss
   - Tracks Dice coefficient for evaluation
   - Saves the best model based on validation performance

5. **Evaluation and Visualization**
   - Visualizes predictions alongside ground truth
   - Calculates performance metrics
   - Provides tools for interpreting model results

6. **Inference Function**
   - Includes a function for making predictions on new images
   - Creates visualization overlays to highlight detected polyps

To run this notebook in Google Colab, simply upload it and execute the cells sequentially. The model should be able to identify the synthetic polyps with good accuracy after training.

Would you like me to explain any specific part of the implementation in more detail?

In [1]:
# Colon Polyp Detection - Synthetic Data Generation and ML Model
# Google Colab Notebook

import os
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# Create directories for saving images
os.makedirs('synthetic_colon_data/images', exist_ok=True)
os.makedirs('synthetic_colon_data/masks', exist_ok=True)

# Part 1: Synthetic Colon Image Generation
# ---------------------------------------

def generate_colon_background(img_size=256, noise_factor=0.1):
    """Generate a synthetic colon background with a pinkish color and texture."""
    # Base pink color for colon tissue
    base_color = np.array([210, 140, 140])  # Pinkish color

    # Generate noise for texture
    noise = np.random.randn(img_size, img_size, 3) * noise_factor

    # Create base image with some variation
    img = np.ones((img_size, img_size, 3))

    # Add radial gradient for tube-like appearance
    y, x = np.ogrid[:img_size, :img_size]
    center = img_size // 2
    dist_from_center = np.sqrt((x - center)**2 + (y - center)**2)

    # Normalize to [0, 1]
    dist_from_center = dist_from_center / (np.sqrt(2) * center)

    # Apply radial gradient
    for c in range(3):
        img[:, :, c] = base_color[c] / 255.0 * (1 - 0.2 * dist_from_center**2)

    # Add noise for texture
    img += noise
    img = np.clip(img, 0, 1)

    return img

def add_colon_folds(img, num_folds=5, fold_width=10):
    """Add concentric rings/folds to the colon image."""
    img_size = img.shape[0]
    center = img_size // 2

    # Convert to PIL for drawing
    pil_img = Image.fromarray((img * 255).astype(np.uint8))
    draw = ImageDraw.Draw(pil_img)

    # Draw concentric circles for colon folds
    for i in range(1, num_folds + 1):
        radius = (img_size // (num_folds + 1)) * i

        # Add slight randomness to fold placement
        radius += random.randint(-5, 5)

        # Draw darker pink ellipses for folds
        ellipse_bounds = [
            center - radius,
            center - radius * 0.7,  # Slightly squished for perspective
            center + radius,
            center + radius * 0.7
        ]

        # Draw with semi-transparency for realistic look
        draw.ellipse(ellipse_bounds, outline=(190, 120, 120), width=fold_width)

    # Convert back to numpy
    img = np.array(pil_img) / 255.0
    return img

def generate_polyp(img_size=256, min_radius=10, max_radius=30):
    """Generate a polyp mask and the polyp appearance."""
    mask = np.zeros((img_size, img_size), dtype=np.uint8)

    # Random position for the polyp
    center_x = random.randint(img_size // 4, 3 * img_size // 4)
    center_y = random.randint(img_size // 4, 3 * img_size // 4)

    # Random radius for the polyp
    radius = random.randint(min_radius, max_radius)

    # Create a PIL image for drawing
    mask_img = Image.fromarray(mask)
    draw = ImageDraw.Draw(mask_img)

    # Draw an ellipse for the polyp (slightly irregular)
    squish_factor = random.uniform(0.7, 1.0)
    ellipse_bounds = [
        center_x - radius,
        center_y - radius * squish_factor,
        center_x + radius,
        center_y + radius * squish_factor
    ]
    draw.ellipse(ellipse_bounds, fill=1)

    # Convert back to numpy
    mask = np.array(mask_img)

    # For polyp appearance (reddish/darker color with texture)
    polyp_color = np.array([180, 100, 100]) / 255.0  # Darker red

    # Add slight texture variation
    polyp_texture = generate_polyp_texture(mask, polyp_color)

    return mask, polyp_texture

def generate_polyp_texture(mask, base_color):
    """Generate texture for the polyp."""
    img_size = mask.shape[0]
    texture = np.zeros((img_size, img_size, 3))

    # Where mask is 1, add the polyp color
    for c in range(3):
        # Add some variation to polyp color
        color_var = base_color[c] + np.random.randn(img_size, img_size) * 0.05
        texture[:, :, c] = mask * color_var

    return texture

def add_polyp_to_image(img, mask, polyp_texture):
    """Add a polyp to the colon image."""
    # Blend the polyp with the background image
    result = img.copy()

    # Where mask is 1, blend the polyp
    for c in range(3):
        # The polyp replaces the background where mask is 1
        result[:, :, c] = result[:, :, c] * (1 - mask) + polyp_texture[:, :, c]

    return result

def generate_dataset(num_images=500, img_size=256):
    """Generate a synthetic dataset of colon images with and without polyps."""
    images = []
    masks = []

    for i in tqdm(range(num_images), desc="Generating synthetic data"):
        # Generate background
        img = generate_colon_background(img_size)
        img = add_colon_folds(img)

        # With 70% probability, add a polyp
        has_polyp = random.random() < 0.7

        if has_polyp:
            # Generate between 1 and 3 polyps
            num_polyps = random.randint(1, 3)
            combined_mask = np.zeros((img_size, img_size))

            for _ in range(num_polyps):
                polyp_mask, polyp_texture = generate_polyp(img_size)
                img = add_polyp_to_image(img, polyp_mask, polyp_texture)
                combined_mask = np.maximum(combined_mask, polyp_mask)
        else:
            combined_mask = np.zeros((img_size, img_size))

        # Save the images and masks
        img_array = (img * 255).astype(np.uint8)
        mask_array = (combined_mask * 255).astype(np.uint8)

        # Save images to disk
        Image.fromarray(img_array).save(f'synthetic_colon_data/images/{i:04d}.png')
        Image.fromarray(mask_array).save(f'synthetic_colon_data/masks/{i:04d}.png')

        # Keep in memory for visualization
        images.append(img)
        masks.append(combined_mask)

    return images, masks

# Generate the dataset
images, masks = generate_dataset(num_images=100)  # Reduced for notebook demonstration

# Visualize some examples
fig, axes = plt.subplots(3, 2, figsize=(12, 12))
for i in range(3):
    idx = random.randint(0, len(images) - 1)
    axes[i, 0].imshow(images[idx])
    axes[i, 0].set_title(f"Image {idx}")
    axes[i, 0].axis('off')

    axes[i, 1].imshow(masks[idx], cmap='gray')
    axes[i, 1].set_title(f"Mask {idx}")
    axes[i, 1].axis('off')

plt.tight_layout()
plt.show()

# Part 2: PyTorch Dataset and DataLoader
# -------------------------------------

class ColonPolypsDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = sorted(os.listdir(img_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx])

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")  # Convert to grayscale

        if self.transform:
            image = self.transform(image)
            # Normalize mask to [0, 1]
            mask = np.array(mask) / 255.0
            mask = torch.from_numpy(mask).float().unsqueeze(0)  # Add channel dim

        return image, mask

# Data augmentation and preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create the dataset
full_dataset = ColonPolypsDataset(
    img_dir='synthetic_colon_data/images',
    mask_dir='synthetic_colon_data/masks',
    transform=transform
)

# Split into train and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

# Part 3: Model Definition (U-Net for segmentation)
# -----------------------------------------------

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1):
        super(UNet, self).__init__()
        self.encoder1 = DoubleConv(n_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.encoder2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.encoder3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.encoder4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)

        self.middle = DoubleConv(512, 1024)

        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.decoder4 = DoubleConv(1024, 512)
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = DoubleConv(512, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = DoubleConv(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = DoubleConv(128, 64)

        self.out = nn.Conv2d(64, n_classes, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Encoder
        enc1 = self.encoder1(x)
        x = self.pool1(enc1)

        enc2 = self.encoder2(x)
        x = self.pool2(enc2)

        enc3 = self.encoder3(x)
        x = self.pool3(enc3)

        enc4 = self.encoder4(x)
        x = self.pool4(enc4)

        # Middle
        x = self.middle(x)

        # Decoder
        x = self.upconv4(x)
        # Ensure spatial dimensions match for the skip connection
        diff_h = enc4.size()[2] - x.size()[2]
        diff_w = enc4.size()[3] - x.size()[3]
        x = nn.functional.pad(x, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])
        x = torch.cat([x, enc4], dim=1)
        x = self.decoder4(x)

        x = self.upconv3(x)
        # Ensure spatial dimensions match
        diff_h = enc3.size()[2] - x.size()[2]
        diff_w = enc3.size()[3] - x.size()[3]
        x = nn.functional.pad(x, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])
        x = torch.cat([x, enc3], dim=1)
        x = self.decoder3(x)

        x = self.upconv2(x)
        # Ensure spatial dimensions match
        diff_h = enc2.size()[2] - x.size()[2]
        diff_w = enc2.size()[3] - x.size()[3]
        x = nn.functional.pad(x, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])
        x = torch.cat([x, enc2], dim=1)
        x = self.decoder2(x)

        x = self.upconv1(x)
        # Ensure spatial dimensions match
        diff_h = enc1.size()[2] - x.size()[2]
        diff_w = enc1.size()[3] - x.size()[3]
        x = nn.functional.pad(x, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])
        x = torch.cat([x, enc1], dim=1)
        x = self.decoder1(x)

        x = self.out(x)
        return self.sigmoid(x)

# Initialize the model
model = UNet().to(device)

# Loss function and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Part 4: Training Loop
# -------------------

def dice_coefficient(y_pred, y_true, smooth=1e-6):
    """Calculate Dice coefficient for evaluation."""
    y_pred = y_pred.view(-1)
    y_true = y_true.view(-1)
    intersection = (y_pred * y_true).sum()
    return (2. * intersection + smooth) / (y_pred.sum() + y_true.sum() + smooth)

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    """Train the segmentation model."""
    best_val_dice = 0.0
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_dice': [],
        'val_dice': []
    }

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_dice = 0.0

        for inputs, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
            inputs = inputs.to(device)
            masks = masks.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, masks)

            # Calculate Dice coefficient
            dice = dice_coefficient(outputs, masks)
            train_dice += dice.item()

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        # Average training statistics
        train_loss /= len(train_loader)
        train_dice /= len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_dice = 0.0

        with torch.no_grad():
            for inputs, masks in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
                inputs = inputs.to(device)
                masks = masks.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, masks)

                # Calculate Dice coefficient
                dice = dice_coefficient(outputs, masks)
                val_dice += dice.item()

                val_loss += loss.item()

        # Average validation statistics
        val_loss /= len(val_loader)
        val_dice /= len(val_loader)

        # Save statistics
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_dice'].append(train_dice)
        history['val_dice'].append(val_dice)

        # Print statistics
        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"Train Loss: {train_loss:.4f} | Train Dice: {train_dice:.4f}")
        print(f"Val Loss: {val_loss:.4f} | Val Dice: {val_dice:.4f}")

        # Save best model
        if val_dice > best_val_dice:
            best_val_dice = val_dice
            torch.save(model.state_dict(), 'best_polyp_segmentation_model.pth')
            print(f"Saved new best model with Dice: {best_val_dice:.4f}")

    return history

# Train the model
history = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5)

# Plot training history
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.title('Loss Over Time')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history['train_dice'], label='Train Dice')
plt.plot(history['val_dice'], label='Validation Dice')
plt.title('Dice Coefficient Over Time')
plt.xlabel('Epochs')
plt.ylabel('Dice Coefficient')
plt.legend()

plt.tight_layout()
plt.show()

# Part 5: Model Evaluation and Visualization
# ----------------------------------------

def visualize_predictions(model, dataloader, num_samples=3):
    """Visualize some predictions from the model."""
    model.eval()

    # Get a batch from the dataloader
    dataiter = iter(dataloader)
    inputs, masks = next(dataiter)

    inputs = inputs.to(device)

    # Make predictions
    with torch.no_grad():
        predictions = model(inputs)

    # Convert to numpy arrays
    inputs = inputs.cpu().numpy()
    masks = masks.cpu().numpy()
    predictions = predictions.cpu().numpy()

    # Visualize a few examples
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))

    for i in range(num_samples):
        # Original image (denormalize)
        img = inputs[i].transpose(1, 2, 0)
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)

        # Ground truth mask
        mask = masks[i, 0]

        # Predicted mask
        pred = predictions[i, 0]

        # Display
        axes[i, 0].imshow(img)
        axes[i, 0].set_title('Original Image')
        axes[i, 0].axis('off')

        axes[i, 1].imshow(mask, cmap='gray')
        axes[i, 1].set_title('Ground Truth Mask')
        axes[i, 1].axis('off')

        axes[i, 2].imshow(pred, cmap='gray')
        axes[i, 2].set_title('Predicted Mask')
        axes[i, 2].axis('off')

    plt.tight_layout()
    plt.show()

# Load the best model
model.load_state_dict(torch.load('best_polyp_segmentation_model.pth'))

# Visualize some predictions
visualize_predictions(model, val_loader, num_samples=5)

# Part 6: Performance Evaluation
# ----------------------------

def evaluate_model(model, dataloader):
    """Evaluate the model on the given dataloader."""
    model.eval()
    dice_scores = []

    with torch.no_grad():
        for inputs, masks in tqdm(dataloader, desc="Evaluating model"):
            inputs = inputs.to(device)
            masks = masks.to(device)

            outputs = model(inputs)

            # Calculate Dice coefficient for each image
            for i in range(outputs.size(0)):
                dice = dice_coefficient(outputs[i], masks[i])
                dice_scores.append(dice.item())

    # Calculate statistics
    avg_dice = np.mean(dice_scores)
    std_dice = np.std(dice_scores)

    print(f"Evaluation Results:")
    print(f"Average Dice Coefficient: {avg_dice:.4f} ± {std_dice:.4f}")

    # Plot histogram of dice scores
    plt.figure(figsize=(10, 6))
    plt.hist(dice_scores, bins=20, alpha=0.7)
    plt.axvline(avg_dice, color='r', linestyle='dashed', linewidth=2)
    plt.text(avg_dice+0.02, plt.ylim()[1]*0.9, f'Mean: {avg_dice:.4f}', color='r')
    plt.title('Distribution of Dice Scores')
    plt.xlabel('Dice Coefficient')
    plt.ylabel('Count')
    plt.grid(alpha=0.3)
    plt.show()

    return dice_scores

# Evaluate the model
dice_scores = evaluate_model(model, val_loader)

# Part 7: Function for Making Predictions on New Images
# --------------------------------------------------

def predict_polyps(model, image_path, output_path=None):
    """Predict polyps on a new image."""
    # Load image
    image = Image.open(image_path).convert("RGB")
    original_size = image.size

    # Preprocess
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    input_tensor = transform(image).unsqueeze(0).to(device)

    # Make prediction
    model.eval()
    with torch.no_grad():
        prediction = model(input_tensor)
        prediction = prediction.cpu().squeeze().numpy()

    # Resize prediction back to original size
    prediction_img = Image.fromarray((prediction * 255).astype(np.uint8))
    prediction_img = prediction_img.resize(original_size, Image.NEAREST)

    # Convert original image to numpy for visualization
    image_np = np.array(image)

    # Create overlay
    prediction_np = np.array(prediction_img)

    # Create RGB mask with red color for visualization
    overlay = np.zeros_like(image_np)
    overlay[prediction_np > 127, 0] = 255  # Red channel

    # Blend with original image
    alpha = 0.5
    blended = (image_np * (1 - alpha) + overlay * alpha).astype(np.uint8)

    # Visualize
    plt.figure(figsize=(15, 5))

    plt.subplot(1, 3, 1)
    plt.imshow(image_np)
    plt.title('Original Image')
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(prediction_np, cmap='gray')
    plt.title('Predicted Mask')
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(blended)
    plt.title('Overlay')
    plt.axis('off')

    plt.tight_layout()
    plt.show()

    # Save results if requested
    if output_path:
        # Save the blended image
        Image.fromarray(blended).save(output_path)
        print(f"Saved prediction to {output_path}")

    return prediction_np

# Generate a new test image
test_img = generate_colon_background()
test_img = add_colon_folds(test_img)

# Add a polyp
mask, polyp_texture = generate_polyp()
test_img = add_polyp_to_image(test_img, mask, polyp_texture)

# Save the test image
test_img_path = 'test_colon_image.png'
Image.fromarray((test_img * 255).astype(np.uint8)).save(test_img_path)

# Predict on the test image
prediction = predict_polyps(model, test_img_path, 'test_prediction.png')

# Conclusion
print("Training and Evaluation Complete!")
print("The notebook demonstrated:")
print("1. Generation of synthetic colon images with polyps")
print("2. Building and training a U-Net model for polyp detection")
print("3. Evaluating model performance")
print("4. Making predictions on new images")
print("\nKey takeaways:")
print("- The model was trained on synthetic data, which is helpful for initial development")
print("- For real-world applications, this should be adapted to use real medical images")
print("- The U-Net architecture worked well for polyp segmentation")
print("- Data augmentation could further improve performance")

# Improvements for real-world application (for reference)
"""
If adapting this notebook for real medical images, consider:
1. Using transfer learning from a pretrained model
2. Implementing more extensive data augmentation
3. Adding more evaluation metrics (precision, recall, etc.)
4. Implementing post-processing for better segmentation
5. Using more sophisticated architectures (DeepLabV3+, etc.)
"""

Output hidden; open in https://colab.research.google.com to view.