<a href="https://colab.research.google.com/github/PrathamSetia/Enhancing-Medical-Image-Segmentation-with-GANs-Transfer-Learning-and-Data-Augmentation/blob/main/brain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install opendatasets
import opendatasets as od
od.download("https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation")
# It will ask for your username and key.
# Then you can skip the unzip step because this library unzips it for you!

Collecting opendatasets
  Downloading opendatasets-0.1.22-py3-none-any.whl.metadata (9.2 kB)
Downloading opendatasets-0.1.22-py3-none-any.whl (15 kB)
Installing collected packages: opendatasets
Successfully installed opendatasets-0.1.22
Please provide your Kaggle credentials to download this dataset. Learn more: http://bit.ly/kaggle-creds
Your Kaggle username: Prathamthegreat
Your Kaggle Key:


Abort: 

In [None]:
import glob
import pandas as pd

# Find all files containing '_mask' (these are the answers)
mask_files = glob.glob('/content/lgg-mri-segmentation/lgg-mri-segmentation/kaggle_3m/*/*_mask.tif')

data = []

for mask_path in mask_files:
    # The image path is the same as mask path, but without '_mask'
    image_path = mask_path.replace('_mask', '')

    # Store in a list
    data.append({"image_path": image_path, "mask_path": mask_path})

# Convert to a DataFrame (Table) for easier handling
df = pd.DataFrame(data)

print(f"Total MRI Scans found: {len(df)}")
print(df.head()) # Show the first 5 rows

In [None]:
import cv2
import matplotlib.pyplot as plt
import random

# Pick a random row from our dataset
random_index = random.randint(0, len(df))
row = df.iloc[random_index]

# Read the image and the mask
image = cv2.imread(row['image_path'])
mask = cv2.imread(row['mask_path'])

# Plot them
fig, ax = plt.subplots(1, 2, figsize=(12, 6))

# MRI Image
ax[0].imshow(image)
ax[0].set_title("Brain MRI Scan")
ax[0].axis("off")

# Segmentation Mask
ax[1].imshow(mask, cmap='gray')
ax[1].set_title("Tumor Mask (Ground Truth)")
ax[1].axis("off")

plt.show()

In [None]:
from sklearn.model_selection import train_test_split

# Split the dataframe: 80% Train, 20% Validation
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

print(f"Training Images: {len(train_df)}")
print(f"Validation Images: {len(val_df)}")

In [None]:
import torch
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np

class BrainTumorDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        # 1. Get the file paths
        row = self.dataframe.iloc[idx]
        image_path = row['image_path']
        mask_path = row['mask_path']

        # 2. Read the image and mask using OpenCV
        # Images are read in BGR format by default, convert to RGB
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Read mask in grayscale mode (0 for black, 255 for white)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        # 3. Apply transformations (Resize, Normalize, etc.)
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # 4. Preprocess mask: Normalize to 0 and 1, and add a channel dimension
        # Current shape is [256, 256], we need [1, 256, 256] for PyTorch
        mask = mask / 255.0
        mask = mask.unsqueeze(0)

        return image, mask.float()

# Define the "Recipe" (Transforms)
# We resize everything to 256x256 pixels
transform_recipe = A.Compose([
    A.Resize(256, 256),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

print("Dataset Class defined!")

In [None]:
from torch.utils.data import DataLoader

# Create the Dataset objects
train_dataset = BrainTumorDataset(train_df, transform=transform_recipe)
val_dataset = BrainTumorDataset(val_df, transform=transform_recipe)

# Create the DataLoaders (The "Waiters")
# batch_size=16 means we feed 16 images at a time to the GPU
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

# Sanity Check: Let's check the shape of one batch
images, masks = next(iter(train_loader))
print(f"Batch Image Shape: {images.shape}") # Should be [16, 3, 256, 256]
print(f"Batch Mask Shape: {masks.shape}")   # Should be [16, 1, 256, 256]

In [None]:
!pip install segmentation-models-pytorch

In [None]:
import segmentation_models_pytorch as smp

# 1. Define the Model
model = smp.Unet(
    encoder_name="resnet34",        # Use ResNet34 as the "backbone"
    encoder_weights="imagenet",     # Use pre-trained weights (Transfer Learning)
    in_channels=3,                  # Our images have 3 channels (RGB)
    classes=1                       # Output is 1 channel (Mask)
)

# 2. Move model to GPU (Essential for speed)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

print("Model loaded and moved to GPU!")

In [None]:
# Loss Function: DiceLoss is standard for segmentation
loss_fn = smp.losses.DiceLoss(mode='binary', from_logits=True)

# Optimizer: Adam
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

print("Training tools are ready.")

In [None]:
from tqdm import tqdm # This gives us a cool progress bar

def train_one_epoch(model, loader, optimizer, loss_fn, device):
    model.train() # Set model to training mode
    running_loss = 0.0

    # Progress bar wrapper
    loop = tqdm(loader)

    for images, masks in loop:
        # 1. Move data to GPU
        images = images.to(device)
        masks = masks.to(device)

        # 2. Forward Pass (Make a prediction)
        predictions = model(images)

        # 3. Calculate Loss (How bad was the prediction?)
        loss = loss_fn(predictions, masks)

        # 4. Backward Pass (Update weights)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update progress bar
        running_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    return running_loss / len(loader)

# --- RUN THE TRAINING ---
print("Starting Training...")
EPOCHS = 5

for epoch in range(EPOCHS):
    avg_loss = train_one_epoch(model, train_loader, optimizer, loss_fn, device)
    print(f"Epoch {epoch+1}/{EPOCHS} - Average Loss: {avg_loss:.4f}")

print("Training Complete!")

In [None]:
import matplotlib.pyplot as plt

# 1. Set model to evaluation mode (turns off training specifics)
model.eval()

# 2. Get a batch of validation data
images, masks = next(iter(val_loader))
images = images.to(device)

# 3. Predict! (No need to calculate gradients here)
with torch.no_grad():
    predictions = model(images)
    # Convert raw scores (logits) to probabilities (0 to 1)
    predictions = torch.sigmoid(predictions)
    # Convert probabilities to binary mask (0 or 1)
    predictions = (predictions > 0.5).float()

# 4. Visualize the first 3 results
# Helper function to convert Tensor to Image for plotting
def tensor_to_image(tensor):
    tensor = tensor.cpu().numpy() # Move to CPU
    tensor = tensor.transpose(1, 2, 0) # Move channels to end [H, W, C]
    return tensor

def tensor_to_mask(tensor):
    tensor = tensor.cpu().numpy()
    return tensor[0, :, :] # Just take the 2D slice

fig, axes = plt.subplots(3, 3, figsize=(12, 10))

for i in range(3):
    # Original Image
    image = tensor_to_image(images[i])
    # Because we normalized the image, we need to un-normalize for pretty display
    # (This is a rough un-normalization for visualization)
    image = (image * 0.229 + 0.485)

    axes[i, 0].imshow(image)
    axes[i, 0].set_title("MRI Scan")
    axes[i, 0].axis('off')

    # Ground Truth
    axes[i, 1].imshow(tensor_to_mask(masks[i]), cmap='gray')
    axes[i, 1].set_title("Doctor's Annotation")
    axes[i, 1].axis('off')

    # AI Prediction
    axes[i, 2].imshow(tensor_to_mask(predictions[i]), cmap='gray')
    axes[i, 2].set_title("AI Prediction")
    axes[i, 2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

# 1. Training Transforms (The "Hard Mode")
train_transform = A.Compose([
    A.Resize(256, 256),

    # Geometric Transforms (Positioning)
    A.HorizontalFlip(p=0.5),              # 50% chance to flip L/R
    A.ShiftScaleRotate(                   # Shift, Zoom, or Rotate
        shift_limit=0.0625,
        scale_limit=0.1,
        rotate_limit=15,                  # Rotate +/- 15 degrees
        p=0.5
    ),

    # Elastic Transforms (Simulating soft tissue)
    # This is crucial for medical data!
    A.ElasticTransform(
        alpha=1,
        sigma=50,
        alpha_affine=50,
        p=0.2
    ),

    # Pixel Transforms (Scanner differences)
    A.RandomBrightnessContrast(p=0.2),

    # Finalize
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

# 2. Validation Transforms (The "Clean Mode")
val_transform = A.Compose([
    A.Resize(256, 256),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

print("Augmentation Pipelines Created!")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Helper to visualize un-normalized images
def show_aug(dataset, idx=0):
    fig, ax = plt.subplots(1, 5, figsize=(20, 5))

    for i in range(5):
        # Get image/mask pair from dataset
        # Each time we call this, the random augmentation runs again!
        image, mask = dataset[idx]

        # Un-normalize for display
        image = image.cpu().numpy().transpose(1, 2, 0)
        image = (image * 0.229 + 0.485)

        ax[i].imshow(image)
        ax[i].imshow(mask[0], cmap='jet', alpha=0.3) # Overlay mask in color
        ax[i].set_title(f"Augmentation {i+1}")
        ax[i].axis('off')
    plt.show()

# Create a temporary dataset just to test the visuals
debug_dataset = BrainTumorDataset(train_df, transform=train_transform)

# Show the result
print("Visualizing Augmentations on a single patient...")
show_aug(debug_dataset, idx=10) # Change idx to see different patients

In [None]:
# 1. Update Datasets with new transforms
train_dataset = BrainTumorDataset(train_df, transform=train_transform)
val_dataset = BrainTumorDataset(val_df, transform=val_transform) # Note: Validation uses clean transform!

# 2. Update Loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

# 3. Re-initialize Model (Start fresh or keep fine-tuning?)
# Let's start fresh to see the pure effect of augmentation
model = smp.Unet(encoder_name="resnet34", encoder_weights="imagenet", in_channels=3, classes=1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# 4. Train again (Let's do 5 epochs again)
print("Starting Robust Training...")
for epoch in range(5):
    avg_loss = train_one_epoch(model, train_loader, optimizer, loss_fn, device)
    print(f"Epoch {epoch+1}/5 - Average Loss: {avg_loss:.4f}")

In [None]:
import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, z_dim=100, channels=3, feature_g=64):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # Input: N x z_dim x 1 x 1
            # Block 1: Create the "seed" of the image (4x4 size)
            self._block(z_dim, feature_g * 16, 4, 1, 0),  # Output: (feature_g*16) x 4 x 4

            # Block 2: Upsample to 8x8
            self._block(feature_g * 16, feature_g * 8, 4, 2, 1),

            # Block 3: Upsample to 16x16
            self._block(feature_g * 8, feature_g * 4, 4, 2, 1),

            # Block 4: Upsample to 32x32
            self._block(feature_g * 4, feature_g * 2, 4, 2, 1),

            # Block 5: Final upsample to 64x64
            nn.ConvTranspose2d(
                feature_g * 2, channels, kernel_size=4, stride=2, padding=1
            ),
            nn.Tanh() # Output: N x channels x 64 x 64 (Pixel values between -1 and 1)
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        return self.gen(x)

class Discriminator(nn.Module):
    def __init__(self, channels=3, feature_d=64):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # Input: N x channels x 64 x 64
            # Block 1: 64x64 -> 32x32
            nn.Conv2d(channels, feature_d, 4, 2, 1),
            nn.LeakyReLU(0.2),

            # Block 2: 32x32 -> 16x16
            self._block(feature_d, feature_d * 2, 4, 2, 1),

            # Block 3: 16x16 -> 8x8
            self._block(feature_d * 2, feature_d * 4, 4, 2, 1),

            # Block 4: 8x8 -> 4x4
            self._block(feature_d * 4, feature_d * 8, 4, 2, 1),

            # Output: Single probability (Real vs Fake)
            nn.Conv2d(feature_d * 8, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid()
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.disc(x)

In [None]:
# 1. Settings
Z_DIM = 100   # Size of the random noise vector
CHANNELS = 3  # RGB
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 2. Create the Networks
gen = Generator(Z_DIM, CHANNELS).to(device)
disc = Discriminator(CHANNELS).to(device)

# 3. Initialize weights (Important for GAN stability)
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

gen.apply(initialize_weights)
disc.apply(initialize_weights)

# 4. Test with dummy data
# Create 8 random noise vectors
noise = torch.randn(8, Z_DIM, 1, 1).to(device)
fake_images = gen(noise)

# Pass fake images to discriminator
predictions = disc(fake_images)

print(f"Generator Output Shape: {fake_images.shape}")
# Should be: [8, 3, 64, 64]

print(f"Discriminator Output Shape: {predictions.shape}")
# Should be: [8, 1, 1, 1] (A single probability score for each image)


In [None]:
# 1. Specific Transform for GAN (64x64)
gan_transform = A.Compose([
    A.Resize(64, 64),
    A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), # Scale to [-1, 1]
    ToTensorV2()
])

# 2. Create the GAN Dataset/Loader
gan_dataset = BrainTumorDataset(train_df, transform=gan_transform)
gan_loader = DataLoader(gan_dataset, batch_size=32, shuffle=True) # Larger batch size helps GANs

print("GAN Data Loader ready (64x64 images).")

In [None]:
import torchvision

# --- Hyperparameters ---
LR = 0.0002             # Learning Rate (Low is better for stability)
BETA1 = 0.5             # Momentum term for Adam
EPOCHS = 20             # How long to train (Increase to 100+ for good results)
real_label = 1.
fake_label = 0.

# --- Optimizers ---
opt_gen = torch.optim.Adam(gen.parameters(), lr=LR, betas=(BETA1, 0.999))
opt_disc = torch.optim.Adam(disc.parameters(), lr=LR, betas=(BETA1, 0.999))

# --- Loss Function ---
criterion = nn.BCELoss() # Binary Cross Entropy

# --- Visualization Fixed Noise ---
# We use this to track the SAME 16 random vectors over time
fixed_noise = torch.randn(16, Z_DIM, 1, 1).to(device)

print("Starting GAN Training...")

for epoch in range(EPOCHS):
    for batch_idx, (real, _) in enumerate(gan_loader):
        real = real.to(device)
        batch_size = real.shape[0]

        ### 1. TRAIN DISCRIMINATOR: max log(D(x)) + log(1 - D(G(z)))
        disc.zero_grad()

        # 1a. Train on Real Data
        label = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
        output = disc(real).reshape(-1)
        loss_real = criterion(output, label)
        loss_real.backward()

        # 1b. Train on Fake Data
        noise = torch.randn(batch_size, Z_DIM, 1, 1).to(device)
        fake = gen(noise)
        label.fill_(fake_label)
        # .detach() ensures we don't update Generator weights yet
        output = disc(fake.detach()).reshape(-1)
        loss_fake = criterion(output, label)
        loss_fake.backward()

        loss_disc = loss_real + loss_fake
        opt_disc.step()

        ### 2. TRAIN GENERATOR: max log(D(G(z)))
        gen.zero_grad()
        label.fill_(real_label) # The "Lie": We want Disc to think these are real
        # We re-run discriminator on the fake images (now with gradients flowing to Gen)
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, label)
        loss_gen.backward()
        opt_gen.step()

    # --- Visualization Block (End of Epoch) ---
    print(f"Epoch [{epoch+1}/{EPOCHS}] Loss D: {loss_disc.item():.4f}, Loss G: {loss_gen.item():.4f}")

    if (epoch + 1) % 5 == 0: # Every 5 epochs, show us the result
        with torch.no_grad():
            fake = gen(fixed_noise)
            # Make a grid of images
            img_grid = torchvision.utils.make_grid(fake, normalize=True)
            plt.figure(figsize=(8,8))
            plt.axis("off")
            plt.title(f"Generated Images at Epoch {epoch+1}")
            plt.imshow(img_grid.cpu().permute(1, 2, 0)) # Move channels to end
            plt.show()

In [None]:
import torch
import numpy as np
import cv2
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2

# --- CONFIGURATION ---
OUTPUT_DIR = "/content/synthetic_data"
DEBUG_DIR = "/content/synthetic_debug_rejects" # New folder for bad images
NUM_IMAGES_TO_GENERATE = 500
CONFIDENCE_THRESHOLD = 0.2    # LOWERED: Let's be less strict (was 0.5)
MIN_TUMOR_PIXELS = 10         # LOWERED: Accept smaller spots (was 50)
MAX_ATTEMPTS = 1000           # Reduced for quick debugging
Z_DIM = 100

# Create directories
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(DEBUG_DIR, exist_ok=True)

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Define the Normalization transform
unet_transform = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

def generate_data():
    # Set models to evaluation mode
    gen.eval()
    model.eval()

    generated_count = 0
    attempts = 0
    debug_saved = 0

    print(f"üöÄ Starting DEBUG Generation...")
    print(f"Target: {NUM_IMAGES_TO_GENERATE} images.")
    print(f"Thresholds: Confidence > {CONFIDENCE_THRESHOLD}, Pixels > {MIN_TUMOR_PIXELS}")

    with torch.no_grad():
        while generated_count < NUM_IMAGES_TO_GENERATE:
            attempts += 1
            if attempts > MAX_ATTEMPTS:
                print(f"\n‚ö†Ô∏è Max attempts reached. Stopping early.")
                break

            # 1. Generate a batch of raw images (64x64) from Noise
            noise = torch.randn(16, Z_DIM, 1, 1).to(device)
            fake_lowres = gen(noise)

            for i in range(fake_lowres.size(0)):
                if generated_count >= NUM_IMAGES_TO_GENERATE:
                    break

                # 2. Post-process Image
                img_t = fake_lowres[i].cpu()
                img_np = (img_t.permute(1, 2, 0).numpy() * 0.5) + 0.5
                img_np = np.clip(img_np, 0, 1)

                # 3. Upscale to 256x256
                img_highres = cv2.resize(img_np, (256, 256), interpolation=cv2.INTER_CUBIC)

                # 4. Prepare for U-Net
                aug = unet_transform(image=img_highres.astype(np.float32))
                input_tensor = aug['image'].unsqueeze(0).to(device)

                # 5. Generate Mask
                pred_mask = model(input_tensor)
                pred_prob = torch.sigmoid(pred_mask)
                mask_binary = (pred_prob > CONFIDENCE_THRESHOLD).float().cpu().numpy()[0, 0]

                # 6. Check Result
                tumor_size = np.sum(mask_binary)

                # Save integers for disk
                save_img = (img_highres * 255).astype(np.uint8)
                save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR)
                save_mask = (mask_binary * 255).astype(np.uint8)

                if tumor_size > MIN_TUMOR_PIXELS:
                    # --- SUCCESS ---
                    base_name = f"syn_{generated_count}"
                    cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}.png"), save_img)
                    cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}_mask.png"), save_mask)
                    generated_count += 1
                    if generated_count % 10 == 0:
                        print(f"Generated {generated_count}...")

                elif debug_saved < 20:
                    # --- FAILURE: SAVE FOR INSPECTION ---
                    # We save the rejected image so you can see what's wrong
                    base_name = f"REJECTED_{attempts}_{i}"
                    cv2.imwrite(os.path.join(DEBUG_DIR, f"{base_name}.png"), save_img)
                    cv2.imwrite(os.path.join(DEBUG_DIR, f"{base_name}_predicted_mask.png"), save_mask)
                    debug_saved += 1

    print(f"\n‚úÖ Process Complete!")
    print(f"Total Generated: {generated_count}")
    print(f"Check '{DEBUG_DIR}' to see why images were rejected.")

generate_data()

In [None]:
import glob
import matplotlib.pyplot as plt
import cv2
import os

def view_rejects():
    # Find images in the debug folder
    reject_dir = "/content/synthetic_debug_rejects"
    files = sorted(glob.glob(os.path.join(reject_dir, "*_predicted_mask.png")))

    if len(files) == 0:
        print("No rejected images found. Did the generation script run?")
        return

    print(f"Found {len(files)} rejected samples. Showing the first 5...")

    # Show top 5
    num_show = min(5, len(files))
    fig, axes = plt.subplots(num_show, 2, figsize=(10, 4 * num_show))

    if num_show == 1: axes = [axes] # Handle single row case

    for i in range(num_show):
        # The mask file is "REJECTED_X_Y_predicted_mask.png"
        mask_path = files[i]
        # The image file is "REJECTED_X_Y.png"
        img_path = mask_path.replace("_predicted_mask.png", ".png")

        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        # Plot Image
        ax_img = axes[i][0] if num_show > 1 else axes[0]
        ax_img.imshow(img)
        ax_img.set_title(f"GAN Output {i+1}")
        ax_img.axis('off')

        # Plot What U-Net Saw (The Mask)
        ax_mask = axes[i][1] if num_show > 1 else axes[1]
        ax_mask.imshow(mask, cmap='gray')
        ax_mask.set_title(f"U-Net Prediction (Empty?)")
        ax_mask.axis('off')

    plt.tight_layout()
    plt.show()

view_rejects()

In [None]:
# Resume training for 30 more epochs
print("Resuming GAN Training for 30 more epochs...")
resume_epochs = 30

for epoch in range(resume_epochs):
    for batch_idx, (real, _) in enumerate(gan_loader):
        real = real.to(device)
        batch_size = real.shape[0]

        # Train Discriminator
        disc.zero_grad()
        label = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
        output = disc(real).reshape(-1)
        loss_real = criterion(output, label)
        loss_real.backward()

        noise = torch.randn(batch_size, Z_DIM, 1, 1).to(device)
        fake = gen(noise)
        label.fill_(fake_label)
        output = disc(fake.detach()).reshape(-1)
        loss_fake = criterion(output, label)
        loss_fake.backward()
        opt_disc.step()

        # Train Generator
        gen.zero_grad()
        label.fill_(real_label)
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, label)
        loss_gen.backward()
        opt_gen.step()

    print(f"Extra Epoch {epoch+1}/{resume_epochs} Complete")

print("Extra Training Done. Try Generating again!")

In [None]:
import matplotlib.pyplot as plt
import torch
import torchvision
import numpy as np

def check_system_health():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.eval()
    gen.eval()

    print("--- DIAGNOSTIC REPORT ---")

    # TEST 1: IS THE U-NET ALIVE?
    # We take a real image from the validation loader
    try:
        real_images, real_masks = next(iter(val_loader))
        real_image = real_images[0].to(device).unsqueeze(0)

        with torch.no_grad():
            pred = model(real_image)
            pred_prob = torch.sigmoid(pred)
            pred_mask = (pred_prob > 0.5).float()

        tumor_pixels = pred_mask.sum().item()

        print(f"1. U-Net Health Check:")
        if tumor_pixels == 0:
            print("   ‚ùå FAILED: U-Net predicted EMPTY mask for a real patient.")
            print("   -> DIAGNOSIS: You likely restarted Colab and forgot to run the U-Net training loop (Step 10).")
        else:
            print(f"   ‚úÖ PASSED: U-Net found {tumor_pixels} tumor pixels in a real patient.")

    except NameError:
        print("   ‚ö†Ô∏è SKIPPED: 'val_loader' not found. Ensure you ran the data loading steps.")

    # TEST 2: IS THE GAN GENERATING SHAPES?
    print("\n2. GAN output (Raw):")
    with torch.no_grad():
        noise = torch.randn(16, 100, 1, 1).to(device)
        fake = gen(noise)

        # Visualize
        img_grid = torchvision.utils.make_grid(fake, normalize=True, nrow=4)
        plt.figure(figsize=(8,8))
        plt.axis("off")
        plt.title("Raw GAN Output (Do these look like brains?)")
        plt.imshow(img_grid.cpu().permute(1, 2, 0))
        plt.show()

check_system_health()

In [None]:
import torch
import numpy as np
import cv2
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2

# --- CONFIGURATION ---
OUTPUT_DIR = "/content/synthetic_data"
NUM_IMAGES_TO_GENERATE = 500
CONFIDENCE_THRESHOLD = 0.1    # Very permissive
MIN_TUMOR_PIXELS = 10
MAX_ATTEMPTS = 5000           # Increased attempts
Z_DIM = 100

os.makedirs(OUTPUT_DIR, exist_ok=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

unet_transform = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

def generate_data_verbose():
    gen.eval()
    model.eval()

    generated_count = 0
    attempts = 0

    print(f"üöÄ Starting VERBOSE Generation...")

    with torch.no_grad():
        while generated_count < NUM_IMAGES_TO_GENERATE:
            attempts += 1
            if attempts > MAX_ATTEMPTS:
                print(f"\n‚ö†Ô∏è Max attempts reached ({MAX_ATTEMPTS}). Stopping.")
                break

            # 1. Generate Batch
            noise = torch.randn(16, Z_DIM, 1, 1).to(device)
            fake_lowres = gen(noise)

            for i in range(fake_lowres.size(0)):
                if generated_count >= NUM_IMAGES_TO_GENERATE: break

                # 2. Process
                img_t = fake_lowres[i].cpu()
                img_np = (img_t.permute(1, 2, 0).numpy() * 0.5) + 0.5
                img_np = np.clip(img_np, 0, 1)
                img_highres = cv2.resize(img_np, (256, 256), interpolation=cv2.INTER_CUBIC)

                # Rainbow Fix
                img_uint8 = (img_highres * 255).astype(np.uint8)
                img_gray = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2GRAY)
                img_clean = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2RGB)
                img_clean_float = img_clean.astype(np.float32) / 255.0

                # 3. U-Net Check
                aug = unet_transform(image=img_clean_float)
                input_tensor = aug['image'].unsqueeze(0).to(device)
                pred_mask = model(input_tensor)
                pred_prob = torch.sigmoid(pred_mask)
                mask_binary = (pred_prob > CONFIDENCE_THRESHOLD).float().cpu().numpy()[0, 0]
                tumor_size = np.sum(mask_binary)

                # 4. Save or Reject
                if tumor_size > MIN_TUMOR_PIXELS:
                    base_name = f"syn_{generated_count}"
                    save_img = cv2.cvtColor(img_clean, cv2.COLOR_RGB2BGR)
                    save_mask = (mask_binary * 255).astype(np.uint8)
                    cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}.png"), save_img)
                    cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}_mask.png"), save_mask)

                    generated_count += 1
                    print(f"‚úÖ Saved Image {generated_count}/{NUM_IMAGES_TO_GENERATE}")

            # Heartbeat Message
            if attempts % 50 == 0:
                print(f"   ... Checked {attempts*16} candidate images so far ...")

    print(f"\nProcess Complete. Total Saved: {generated_count}")

generate_data_verbose()

In [None]:
import torch
import numpy as np
import cv2
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2

# --- CONFIGURATION ---
OUTPUT_DIR = "/content/synthetic_data"
# We will temporarily bypass the count to ensure we get debug info
NUM_IMAGES_TO_GENERATE = 500
CONFIDENCE_THRESHOLD = 0.1
MIN_TUMOR_PIXELS = 10
MAX_ATTEMPTS = 1000           # Lowered for debug run
Z_DIM = 100

os.makedirs(OUTPUT_DIR, exist_ok=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

unet_transform = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

def generate_data_verbose():
    gen.eval()
    model.eval()

    generated_count = 0
    attempts = 0

    print(f"üöÄ Starting DIAGNOSTIC Generation...")

    with torch.no_grad():
        while generated_count < NUM_IMAGES_TO_GENERATE:
            attempts += 1
            if attempts > MAX_ATTEMPTS:
                print(f"\n‚ö†Ô∏è Max attempts reached. Stopping.")
                break

            # 1. Generate Batch
            noise = torch.randn(16, Z_DIM, 1, 1).to(device)
            fake_lowres = gen(noise)

            for i in range(fake_lowres.size(0)):
                if generated_count >= NUM_IMAGES_TO_GENERATE: break

                # 2. Process
                img_t = fake_lowres[i].cpu()
                img_np = (img_t.permute(1, 2, 0).numpy() * 0.5) + 0.5
                img_np = np.clip(img_np, 0, 1)
                img_highres = cv2.resize(img_np, (256, 256), interpolation=cv2.INTER_CUBIC)

                # Rainbow Fix
                img_uint8 = (img_highres * 255).astype(np.uint8)
                img_gray = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2GRAY)
                img_clean = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2RGB)
                img_clean_float = img_clean.astype(np.float32) / 255.0

                # 3. U-Net Check
                aug = unet_transform(image=img_clean_float)
                input_tensor = aug['image'].unsqueeze(0).to(device)
                pred_mask = model(input_tensor)
                pred_prob = torch.sigmoid(pred_mask)

                # --- DEBUGGING STATS ---
                max_conf = pred_prob.max().item()

                # Force Save the first 10 images just to see them
                if attempts == 1 and i < 10:
                    base_name = f"FORCE_DEBUG_{i}"
                    save_img = cv2.cvtColor(img_clean, cv2.COLOR_RGB2BGR)
                    # Multiply mask by 255 to make it visible
                    # Use the raw probability map to see faint detections
                    save_mask = (pred_prob[0,0].cpu().numpy() * 255).astype(np.uint8)

                    cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}.png"), save_img)
                    cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}_raw_prob.png"), save_mask)
                    if i == 0: print(f"üì∏ Saved FORCE_DEBUG examples to {OUTPUT_DIR}")

                mask_binary = (pred_prob > CONFIDENCE_THRESHOLD).float().cpu().numpy()[0, 0]
                tumor_size = np.sum(mask_binary)

                # 4. Save or Reject
                if tumor_size > MIN_TUMOR_PIXELS:
                    base_name = f"syn_{generated_count}"
                    save_img = cv2.cvtColor(img_clean, cv2.COLOR_RGB2BGR)
                    save_mask = (mask_binary * 255).astype(np.uint8)
                    cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}.png"), save_img)
                    cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}_mask.png"), save_mask)
                    generated_count += 1
                    print(f"‚úÖ Saved Image {generated_count}/{NUM_IMAGES_TO_GENERATE}")

            # Heartbeat Message with DEBUG INFO
            if attempts % 50 == 0:
                print(f"   ... Checked {attempts*16} images. Last Max Conf: {max_conf:.5f} (Need > {CONFIDENCE_THRESHOLD})")

    print(f"\nProcess Complete. Total Saved: {generated_count}")

generate_data_verbose()

In [None]:
import glob
import matplotlib.pyplot as plt
import cv2
import os

def inspect_force_debug():
    # Look for the specific debug files we saved
    debug_dir = "/content/synthetic_data"
    files = sorted(glob.glob(os.path.join(debug_dir, "FORCE_DEBUG_*_raw_prob.png")))

    if len(files) == 0:
        print("‚ùå No FORCE_DEBUG files found.")
        return

    print(f"Found {len(files)} debug samples. Showing top 3...")

    for i in range(min(3, len(files))):
        mask_path = files[i]
        img_path = mask_path.replace("_raw_prob.png", ".png")

        # Load Image (Clean RGB)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # Load Raw Probability (Heatmap of where it thinks tumor MIGHT be)
        prob_map = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        fig, ax = plt.subplots(1, 3, figsize=(15, 5))

        # 1. The Input Image
        ax[0].imshow(img)
        ax[0].set_title("What U-Net Saw (Input)")
        ax[0].axis("off")

        # 2. The Raw Probability (Boosted contrast)
        # We multiply by 10 to make faint detections visible
        ax[1].imshow(prob_map, cmap='jet')
        ax[1].set_title("Raw Probability Heatmap")
        ax[1].axis("off")

        # 3. Histogram (Is the image too dark?)
        ax[2].hist(img.ravel(), bins=256, range=[0, 256])
        ax[2].set_title("Pixel Intensity Distribution")

        plt.show()

        # Check stats
        print(f"Sample {i}: Max Probability in heatmap = {prob_map.max()/255:.5f}")

inspect_force_debug()

In [None]:
import torch
import numpy as np
import cv2
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2

# --- CONFIGURATION ---
OUTPUT_DIR = "/content/synthetic_data"
DEBUG_DIR = "/content/synthetic_debug_view"
NUM_IMAGES_TO_GENERATE = 500
CONFIDENCE_THRESHOLD = 0.1    # Low U-Net threshold
MIN_TUMOR_PIXELS = 10
MAX_ATTEMPTS = 2000
Z_DIM = 100

os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(DEBUG_DIR, exist_ok=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Standard U-Net normalization
unet_transform = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

def force_dynamic_range(image):
    """Stretches histogram to full 0-255 range"""
    img_min = image.min()
    img_max = image.max()
    if img_max - img_min > 1e-5:
        stretched = (image - img_min) / (img_max - img_min)
        return (stretched * 255).astype(np.uint8)
    else:
        return (image * 255).astype(np.uint8)

def apply_clahe(image):
    """Boosts contrast locally"""
    clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    enhanced = clahe.apply(gray)
    return cv2.cvtColor(enhanced, cv2.COLOR_GRAY2RGB)

def heuristic_mask(image):
    """Fallback: Finds the brightest spot in the image (common for tumors)"""
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    # Blur to remove noise
    blurred = cv2.GaussianBlur(gray, (5, 5), 0)
    # Threshold the top 2% brightest pixels
    _, mask = cv2.threshold(blurred, 240, 255, cv2.THRESH_BINARY)
    return mask / 255.0 # Return as 0.0-1.0

def generate_data_final():
    gen.eval()
    model.eval()

    generated_count = 0
    attempts = 0

    print(f"üöÄ Starting HYBRID Generation (U-Net + Heuristic Fallback)...")

    with torch.no_grad():
        while generated_count < NUM_IMAGES_TO_GENERATE:
            attempts += 1
            if attempts > MAX_ATTEMPTS:
                print(f"\n‚ö†Ô∏è Max attempts reached. Stopping.")
                break

            # 1. Generate
            noise = torch.randn(16, Z_DIM, 1, 1).to(device)
            fake_lowres = gen(noise)

            for i in range(fake_lowres.size(0)):
                if generated_count >= NUM_IMAGES_TO_GENERATE: break

                # 2. Process & Contrast Boost
                img_t = fake_lowres[i].cpu()
                img_np = (img_t.permute(1, 2, 0).numpy() * 0.5) + 0.5
                img_np = np.clip(img_np, 0, 1)
                img_highres = cv2.resize(img_np, (256, 256), interpolation=cv2.INTER_CUBIC)

                # Apply Fixes
                img_stretched = force_dynamic_range(img_highres)
                img_final = apply_clahe(img_stretched)

                # Save Debug View (First 10 only)
                if generated_count < 10 and attempts == 1:
                    cv2.imwrite(os.path.join(DEBUG_DIR, f"DEBUG_{i}.png"), cv2.cvtColor(img_final, cv2.COLOR_RGB2BGR))

                # 3. Try U-Net Prediction
                img_float = img_final.astype(np.float32) / 255.0
                aug = unet_transform(image=img_float)
                input_tensor = aug['image'].unsqueeze(0).to(device)
                pred_mask = model(input_tensor)
                pred_prob = torch.sigmoid(pred_mask)

                mask_binary = (pred_prob > CONFIDENCE_THRESHOLD).float().cpu().numpy()[0, 0]

                # 4. Logic: U-Net vs Heuristic
                method = "UNet"
                if np.sum(mask_binary) < MIN_TUMOR_PIXELS:
                    # U-Net failed, try Heuristic
                    mask_binary = heuristic_mask(img_final)
                    method = "Heuristic"

                # 5. Save if EITHER method found something
                if np.sum(mask_binary) > MIN_TUMOR_PIXELS:
                    base_name = f"syn_{generated_count}"
                    save_img = cv2.cvtColor(img_final, cv2.COLOR_RGB2BGR)
                    save_mask = (mask_binary * 255).astype(np.uint8)

                    cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}.png"), save_img)
                    cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}_mask.png"), save_mask)

                    generated_count += 1
                    if generated_count % 50 == 0:
                        print(f"‚úÖ Saved {generated_count}/{NUM_IMAGES_TO_GENERATE} (Latest: {method})")

            if attempts % 50 == 0:
                if generated_count == 0:
                     print(f"   ... Checked {attempts*16} images...")

    print(f"\nProcess Complete. Total Saved: {generated_count}")
    print(f"Images located in: {OUTPUT_DIR}")

generate_data_final()

In [None]:
import pandas as pd
import glob
import torch
import torch.optim as optim
import segmentation_models_pytorch as smp
from torch.utils.data import DataLoader
from tqdm import tqdm  # For progress bar

# --- 1. GATHER SYNTHETIC DATA ---
# We look for files ending in _mask.png in the synthetic folder
syn_mask_paths = glob.glob("/content/synthetic_data/*_mask.png")
synthetic_data = []

for mask_path in syn_mask_paths:
    # reconstruct image path from mask path
    image_path = mask_path.replace('_mask.png', '.png')
    synthetic_data.append({"image_path": image_path, "mask_path": mask_path})

print(f"Found {len(synthetic_data)} synthetic examples.")

# --- 2. MERGE WITH REAL DATA ---
# Convert to DataFrame
syn_df = pd.DataFrame(synthetic_data)

# Combine with the original train_df (Make sure train_df exists from Step 4)
if len(synthetic_data) > 0:
    # We assume 'train_df' exists from your earlier steps.
    # If you lost it, you might need to re-run the data splitting step.
    combined_train_df = pd.concat([train_df, syn_df], ignore_index=True)
    print(f"Original Real Data: {len(train_df)}")
    print(f"Synthetic GAN Data: {len(syn_df)}")
    print(f"TOTAL HYBRID DATASET: {len(combined_train_df)}")
else:
    print("‚ö†Ô∏è No synthetic data found! Training on real data only.")
    combined_train_df = train_df

# --- 3. SETUP HYBRID TRAINING ---
# We use the 'train_transform' we defined earlier (Augmentation)
# This applies rotations/elastic transforms to BOTH real and synthetic data
hybrid_dataset = BrainTumorDataset(combined_train_df, transform=train_transform)
hybrid_loader = DataLoader(hybrid_dataset, batch_size=16, shuffle=True)

# Initialize a BRAND NEW model for the final test
# We want to see if training from scratch with more data helps
device = 'cuda' if torch.cuda.is_available() else 'cpu'
final_model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=3,
    classes=1
).to(device)

optimizer = torch.optim.Adam(final_model.parameters(), lr=0.0001)
loss_fn = smp.losses.DiceLoss(mode='binary', from_logits=True)

# --- 4. TRAINING LOOP FUNCTION ---
def train_one_epoch(model, loader, optimizer, loss_fn, device):
    model.train()
    running_loss = 0.0
    loop = tqdm(loader)

    for images, masks in loop:
        images = images.to(device)
        masks = masks.to(device)

        # Forward
        predictions = model(images)
        loss = loss_fn(predictions, masks)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    return running_loss / len(loader)

# --- 5. RUN TRAINING ---
EPOCHS = 5 # Increase to 20-50 for best results
print("\nüöÄ Starting Hybrid Training (Real + Synthetic)...")

for epoch in range(EPOCHS):
    avg_loss = train_one_epoch(final_model, hybrid_loader, optimizer, loss_fn, device)
    print(f"Epoch {epoch+1}/{EPOCHS} - Average Loss: {avg_loss:.4f}")

print("\n‚úÖ Final Model Trained! You can now run the visualization code.")

In [None]:
import matplotlib.pyplot as plt
import torch
import numpy as np

def visualize_predictions(model, loader, device):
    # 1. Set model to evaluation mode
    model.eval()

    # 2. Get a batch of validation data
    # We use 'next(iter())' to grab the first batch
    images, masks = next(iter(loader))
    images = images.to(device)

    # 3. Predict!
    with torch.no_grad():
        predictions = model(images)
        # Convert raw scores to probabilities (0 to 1)
        predictions = torch.sigmoid(predictions)
        # Threshold: If prob > 0.5, it's a tumor
        pred_masks = (predictions > 0.5).float()

    # 4. Helper to convert tensors for plotting
    def tensor_to_image(tensor):
        # Move to CPU and numpy
        img = tensor.cpu().numpy()
        # Change from [Channels, Height, Width] to [Height, Width, Channels]
        img = img.transpose(1, 2, 0)
        # Undo Normalization roughly for display
        img = (img * 0.229 + 0.485)
        # Clip to 0-1 range just in case
        return np.clip(img, 0, 1)

    def tensor_to_mask(tensor):
        return tensor.cpu().numpy()[0, :, :]

    # 5. Plot the first 3 patients
    fig, axes = plt.subplots(3, 3, figsize=(12, 12))

    print(f"--- Results on Unseen Validation Data ---")

    for i in range(3):
        # A. Original MRI
        ax = axes[i, 0]
        ax.imshow(tensor_to_image(images[i]))
        ax.set_title("MRI Scan (Input)")
        ax.axis('off')

        # B. Doctor's Label (Ground Truth)
        ax = axes[i, 1]
        ax.imshow(tensor_to_mask(masks[i]), cmap='gray')
        ax.set_title("Doctor's Annotation")
        ax.axis('off')

        # C. AI Prediction
        ax = axes[i, 2]
        ax.imshow(tensor_to_mask(pred_masks[i]), cmap='gray')
        ax.set_title("Hybrid Model Prediction")
        ax.axis('off')

        # Add a green box if prediction is correct, red if wrong (simple check)
        # (This is just a visual aid, not a strict metric)

    plt.tight_layout()
    plt.show()

# Run the function
visualize_predictions(final_model, val_loader, device)

In [None]:
import torch
import numpy as np

def calculate_dice_score(model, loader, device):
    model.eval()
    dice_scores = []

    print("--- Starting Evaluation on Validation Set ---")

    with torch.no_grad():
        for images, masks in loader:
            images = images.to(device)
            masks = masks.to(device)

            # 1. Predict
            logits = model(images)
            preds = (torch.sigmoid(logits) > 0.5).float()

            # 2. Calculate Dice for each image in the batch
            # Formula: 2*Intersection / (Area_Pred + Area_Truth)
            epsilon = 1e-7 # To prevent division by zero

            # Flatten to [Batch_Size, -1] to count pixels easily
            preds_flat = preds.view(preds.size(0), -1)
            masks_flat = masks.view(masks.size(0), -1)

            intersection = (preds_flat * masks_flat).sum(dim=1)
            union = preds_flat.sum(dim=1) + masks_flat.sum(dim=1)

            dice = (2. * intersection + epsilon) / (union + epsilon)

            # Store scores (move to CPU)
            dice_scores.extend(dice.cpu().numpy())

    # 3. Report
    final_score = np.mean(dice_scores)
    print(f"\n‚úÖ Final Average Dice Score: {final_score:.4f}")

    if final_score > 0.8:
        print("üåü Result: Excellent! The model is highly accurate.")
    elif final_score > 0.6:
        print("üëç Result: Good. The model detects tumors well but might miss edges.")
    else:
        print("‚ö†Ô∏è Result: Needs Improvement. Try training for more Epochs.")

# Run the calculation
# We use 'val_loader' because we want to test on REAL data, not synthetic
calculate_dice_score(final_model, val_loader, device)

In [None]:
import torch
import os
from google.colab import files

# 1. Save the model state dictionary
save_path = "brain_tumor_unet_hybrid.pth"
torch.save(final_model.state_dict(), save_path)

print(f"‚úÖ Model saved to {save_path}")
print(f"   File size: {os.path.getsize(save_path) / 1e6:.2f} MB")

# 2. Download it to your local computer
files.download(save_path)

In [None]:
import torch
import segmentation_models_pytorch as smp
import cv2
import numpy as np
import matplotlib.pyplot as plt

# 1. DEFINE THE ARCHITECTURE
# You must create an "empty brain" that matches exactly what you trained.
# If you trained a ResNet34 U-Net, you must create a ResNet34 U-Net here.
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights=None, # We don't need ImageNet weights, we are loading our own!
    in_channels=3,
    classes=1
).to(device)

# 2. LOAD THE WEIGHTS
# This is where we "open" the .pth file
save_path = "brain_tumor_unet_hybrid.pth"

# Check if file exists just to be safe
try:
    model.load_state_dict(torch.load(save_path, map_location=device))
    print("‚úÖ Model loaded successfully!")
except FileNotFoundError:
    print(f"‚ùå Error: Could not find {save_path}. Make sure it's in the same folder.")

# 3. USE IT FOR PREDICTION
model.eval() # Set to evaluation mode

# Let's pretend we have a new patient image
# (Here I'm just creating a dummy black image for demonstration)
dummy_image = np.zeros((256, 256, 3), dtype=np.uint8)

# Preprocess (Normalize & Convert to Tensor)
# Note: In a real app, you'd use the exact same Albumentations transform here
img_tensor = torch.from_numpy(dummy_image).permute(2, 0, 1).float().unsqueeze(0).to(device)
img_tensor = img_tensor / 255.0 # Simple normalization

with torch.no_grad():
    prediction = model(img_tensor)
    prob_map = torch.sigmoid(prediction)
    # If > 50% confidence, it's a tumor
    binary_mask = (prob_map > 0.5).float()

print("Prediction complete. The model is ready to use.")