<a href="https://colab.research.google.com/github/arjonnill07/AI-ML-experiment-Notebooks/blob/main/Change_Identification_of_spatial_images.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

# @title 2. Mount Google Drive & Set Dataset Path
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# -*- coding: utf-8 -*-
"""
LEVIR-CD Change Detection (Optimized for Colab Pro)

Trains a more robust Siamese UNet++ like model with a powerful encoder
at higher resolution for potentially better change detection accuracy.
Visualizes changes using bounding boxes.
"""

# @title 1. Setup: Install Libraries, Import Modules, Check GPU
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # Adapt cuXXX if needed
!pip install -q segmentation-models-pytorch albumentations opencv-python-headless matplotlib torchinfo

import os
import sys
import random
import time
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm # Progress bar

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from torchinfo import summary # For model inspection

# For reproducibility
seed = 42
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# torch.backends.cudnn.deterministic = True # Can slow down training slightly
# torch.backends.cudnn.benchmark = True # Set True if input sizes don't vary, usually faster

# Check GPU and Colab environment
print("--- System Information ---")
if torch.cuda.is_available():
    gpu_info = !nvidia-smi --query-gpu=gpu_name,memory.total --format=csv,noheader
    print(f"GPU Detected: {gpu_info[0]}")
    DEVICE = torch.device("cuda")
else:
    print("WARNING: No GPU detected. Training will be very slow. Ensure GPU is enabled in Runtime settings.")
    DEVICE = torch.device("cpu")

# Simple check for Colab Pro (might not be 100% reliable)
# Pro often gives V100, A100, or P100 GPUs with more RAM
if 'google.colab' in sys.modules:
     if torch.cuda.is_available() and torch.cuda.get_device_properties(0).total_memory / (1024**3) > 16: # Check if > 16GB GPU RAM
         print("High-RAM GPU detected, likely Colab Pro environment.")
     else:
         print("Standard Colab GPU or CPU detected. Performance might be limited.")
print(f"Using device: {DEVICE}")
print("------------------------")

In [None]:



# --- !!! IMPORTANT: SET YOUR DATASET PATH !!! ---
# Adjust this path to where you unzipped the LEVIR-CD dataset folder in your Google Drive
# It should contain subfolders like 'train', 'val', 'test'
# Inside each, there should be 'A', 'B', 'label' subfolders.
DRIVE_DATASET_PATH = "/content/drive/MyDrive/LEVIR-CD+" # <--- CHANGE THIS

if not os.path.exists(DRIVE_DATASET_PATH):
    print(f"ERROR: Dataset path not found: {DRIVE_DATASET_PATH}")
    print("Please ensure you have uploaded the dataset to Google Drive and updated the path.")
    sys.exit() # Stop execution if path is invalid
else:
    print(f"Dataset path confirmed: {DRIVE_DATASET_PATH}")
    # Optional: List contents to verify
    # !ls -l $DRIVE_DATASET_PATH

# @title 3. Configuration (Enhanced for Pro)

# --- Model & Training Hyperparameters ---
IMG_SIZE = 512 # Increased resolution for better detail
# Adjust Batch size based on your specific Colab Pro GPU (V100/A100 can likely handle more)
# Start with 16 for 512x512, decrease to 8 or 12 if you get OOM errors.
BATCH_SIZE = 16
LEARNING_RATE = 1e-4 # AdamW default LR is often a good starting point
NUM_EPOCHS = 50 # Increased epochs for potentially better convergence
# Choose a more powerful encoder
# Options: 'efficientnet-b4', 'efficientnet-b5', 'resnext50_32x4d', 'timm-regnety_032' etc.
ENCODER_NAME = 'efficientnet-b4'
ENCODER_WEIGHTS = 'imagenet'
# Use Unet++ which can sometimes capture finer details
MODEL_ARCHITECTURE = smp.UnetPlusPlus
MODEL_SAVE_PATH = f'/content/best_cd_model_{ENCODER_NAME}_{IMG_SIZE}.pth' # Include info in filename
PATIENCE = 7 # For ReduceLROnPlateau scheduler and early stopping

# Dataset paths (derived from DRIVE_DATASET_PATH)
TRAIN_IMG_T1_DIR = os.path.join(DRIVE_DATASET_PATH, 'train/A')
TRAIN_IMG_T2_DIR = os.path.join(DRIVE_DATASET_PATH, 'train/B')
TRAIN_MASK_DIR = os.path.join(DRIVE_DATASET_PATH, 'train/label')

VAL_IMG_T1_DIR = os.path.join(DRIVE_DATASET_PATH, 'val/A')
VAL_IMG_T2_DIR = os.path.join(DRIVE_DATASET_PATH, 'val/B')
VAL_MASK_DIR = os.path.join(DRIVE_DATASET_PATH, 'val/label')

# @title 4. Dataset and DataLoader (with More Augmentations)

class LevirCDDataset(Dataset):
    def __init__(self, image_dir_t1, image_dir_t2, mask_dir, transform=None):
        self.image_dir_t1 = image_dir_t1
        self.image_dir_t2 = image_dir_t2
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_files = sorted([f for f in os.listdir(image_dir_t1) if f.endswith('.png')]) # LEVIR-CD uses .png

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path_t1 = os.path.join(self.image_dir_t1, img_name)
        img_path_t2 = os.path.join(self.image_dir_t2, img_name)
        mask_path = os.path.join(self.mask_dir, img_name)

        # Load images
        image_t1 = cv2.imread(img_path_t1)
        image_t1 = cv2.cvtColor(image_t1, cv2.COLOR_BGR2RGB)
        image_t2 = cv2.imread(img_path_t2)
        image_t2 = cv2.cvtColor(image_t2, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        if mask is None:
             raise FileNotFoundError(f"Mask not found or invalid: {mask_path}")

        mask = mask / 255.0 # Normalize 0-1
        mask = mask.astype(np.float32)

        if self.transform:
            augmented = self.transform(image=image_t1, image1=image_t2, mask=mask)
            image_t1 = augmented['image']
            image_t2 = augmented['image1']
            mask = augmented['mask']

        mask = np.expand_dims(mask, axis=0) # Add channel dim: (1, H, W)

        return image_t1, image_t2, mask

# Define Transforms using Albumentations
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

# Enhanced Augmentations for Training
train_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE, interpolation=cv2.INTER_LINEAR), # Specify interpolation
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    # Add more complex geometric transforms
    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5,
                       border_mode=cv2.BORDER_CONSTANT, value=0),
    # Add color/brightness augmentations (applied independently to T1/T2 after geometric)
    # Use OneOf to apply only one type of color augmentation sometimes
    A.OneOf([
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
        A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5),
        A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
    ], p=0.7), # Apply one of these 70% of the time

    # Normalization should come after color augmentations but before ToTensor
    A.Normalize(mean=mean, std=std, max_pixel_value=255.0), # Normalize RGB images
    ToTensorV2() # Converts numpy HWC [0,1] to torch CHW tensor
], additional_targets={'image1': 'image'}) # Ensures geometric transforms are identical

val_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE, interpolation=cv2.INTER_LINEAR),
    A.Normalize(mean=mean, std=std, max_pixel_value=255.0),
    ToTensorV2()
], additional_targets={'image1': 'image'})

# Create Datasets
train_dataset = LevirCDDataset(TRAIN_IMG_T1_DIR, TRAIN_IMG_T2_DIR, TRAIN_MASK_DIR, transform=train_transform)
val_dataset = LevirCDDataset(VAL_IMG_T1_DIR, VAL_IMG_T2_DIR, VAL_MASK_DIR, transform=val_transform)

# Create DataLoaders
# num_workers=4 might be feasible on Pro, adjust if needed
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

print(f"--- Dataset Info ---")
print(f"Image Size: {IMG_SIZE}x{IMG_SIZE}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
print(f"--------------------")



In [None]:
# @title 5. Model Architecture (Siamese UNet++ Concatenation)

class SiamUnetPlusPlusConcatenate(nn.Module):
    """
    Siamese UNet++ where T1 and T2 are concatenated along the channel
    dimension and fed into a UNet++ whose first layer is modified.
    """
    def __init__(self, encoder_name='efficientnet-b4', encoder_weights='imagenet', classes=1, activation=None):
        super().__init__()
        self.model = MODEL_ARCHITECTURE( # Using the configured architecture (e.g., UnetPlusPlus)
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=3, # Start with 3, modify below
            classes=classes,
            activation=activation,
        )

        # --- Modify the first convolution layer to accept 6 channels ---
        # Layer name depends on the encoder in SMP. Inspect or check SMP docs.
        # For ResNets: self.model.encoder.conv1
        # For EfficientNets: self.model.encoder._conv_stem
        layer_to_modify = None
        original_weights = None
        try:
            if hasattr(self.model.encoder, '_conv_stem'): # EfficientNet style
                 layer_to_modify = self.model.encoder._conv_stem
                 print("Identified EfficientNet style stem layer.")
            elif hasattr(self.model.encoder, 'conv1'): # ResNet style
                 layer_to_modify = self.model.encoder.conv1
                 print("Identified ResNet style conv1 layer.")
            else:
                 raise AttributeError("Could not find standard first conv layer ('_conv_stem' or 'conv1') in the encoder.")

            original_weights = layer_to_modify.weight.clone()
            new_conv = nn.Conv2d(
                6, # New input channels
                layer_to_modify.out_channels,
                kernel_size=layer_to_modify.kernel_size,
                stride=layer_to_modify.stride,
                padding=layer_to_modify.padding,
                bias=layer_to_modify.bias is not None
            )

            # Initialize weights: copy first 3 channels, initialize next 3 (e.g., copy or Kaiming)
            new_conv.weight.data[:, :3, :, :] = original_weights
            # Simple initialization: Copy weights for the second set of channels
            new_conv.weight.data[:, 3:, :, :] = original_weights
            # # Alternative: Kaiming initialization for the second set
            # torch.nn.init.kaiming_normal_(new_conv.weight.data[:, 3:, :, :], mode='fan_in', nonlinearity='relu')

            # Replace the layer in the model
            if hasattr(self.model.encoder, '_conv_stem'):
                 self.model.encoder._conv_stem = new_conv
                 print(f"Replaced encoder._conv_stem with 6-channel input. New shape: {new_conv.weight.shape}")
            elif hasattr(self.model.encoder, 'conv1'):
                 self.model.encoder.conv1 = new_conv
                 print(f"Replaced encoder.conv1 with 6-channel input. New shape: {new_conv.weight.shape}")

        except AttributeError as e:
             print(f"Error modifying first conv layer: {e}")
             print("Model inspection might be needed.")
             # You might need to print(self.model.encoder) to find the exact layer name
             sys.exit()
        except Exception as e:
             print(f"An unexpected error occurred during layer modification: {e}")
             sys.exit()


    def forward(self, x1, x2):
        x = torch.cat([x1, x2], dim=1) # Shape: (Batch, 6, H, W)
        output = self.model(x)
        return output

# Instantiate the model
model = SiamUnetPlusPlusConcatenate(
    encoder_name=ENCODER_NAME,
    encoder_weights=ENCODER_WEIGHTS,
    classes=1,
    activation=None # Output logits for BCEWithLogitsLoss
).to(DEVICE)

# Inspect model structure and parameter count (optional)
print("\n--- Model Summary ---")
try:
    # Provide example input tensor shapes (Batch, Channels, Height, Width)
    summary(model, input_size=[(BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE), (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)], device=str(DEVICE))
except Exception as e:
    print(f"Could not generate model summary: {e}")
print("---------------------\n")

# @title 6. Loss Function, Optimizer, Metrics, Scheduler

# Loss: Combination of BCE and Dice
bce_loss = nn.BCEWithLogitsLoss()
dice_loss = smp.losses.DiceLoss(mode='binary', from_logits=True)

def combined_loss(logits, targets, bce_weight=0.6, dice_weight=0.4):
    bce = bce_loss(logits, targets)
    dice = dice_loss(logits, targets)
    return bce * bce_weight + dice * dice_weight

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5) # Added small weight decay

# Learning Rate Scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.2, patience=PATIENCE // 2, verbose=True) # Reduce LR based on validation IoU

# Metrics - IoU and F1 Score (common for segmentation)
def calculate_metrics(logits, targets, threshold=0.5, smooth=1e-6):
    with torch.no_grad():
        probs = torch.sigmoid(logits)
        preds = (probs > threshold).float()

        # Flatten spatial dimensions (Batch, 1, H, W) -> (Batch, H*W)
        preds_flat = preds.view(preds.size(0), -1)
        targets_flat = targets.view(targets.size(0), -1)

        # True Positives, False Positives, False Negatives
        tp = (preds_flat * targets_flat).sum(dim=1)
        fp = (preds_flat * (1 - targets_flat)).sum(dim=1)
        fn = ((1 - preds_flat) * targets_flat).sum(dim=1)

        # IoU (Jaccard)
        iou = (tp + smooth) / (tp + fp + fn + smooth)

        # F1 Score
        f1 = (2 * tp + smooth) / (2 * tp + fp + fn + smooth)

        return iou.mean(), f1.mean() # Average over batch




In [None]:
# @title 7. Training Loop with Metrics and Early Stopping

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    running_loss = 0.0
    pbar = tqdm(loader, desc=f"Training Epoch {epoch+1}/{NUM_EPOCHS}", leave=True, dynamic_ncols=True)
    for batch_idx, (img_t1, img_t2, mask) in enumerate(pbar):
        img_t1, img_t2, mask = img_t1.to(device), img_t2.to(device), mask.to(device)

        optimizer.zero_grad()
        outputs = model(img_t1, img_t2)
        loss = combined_loss(outputs, mask)
        loss.backward()
        # Optional: Gradient clipping can sometimes help stabilize training
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        running_loss += loss.item()
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    avg_loss = running_loss / len(loader)
    return avg_loss

def validate_one_epoch(model, loader, device):
    model.eval()
    running_loss = 0.0
    running_iou = 0.0
    running_f1 = 0.0
    pbar = tqdm(loader, desc=f"Validating Epoch {epoch+1}/{NUM_EPOCHS}", leave=True, dynamic_ncols=True)
    with torch.no_grad():
        for batch_idx, (img_t1, img_t2, mask) in enumerate(pbar):
            img_t1, img_t2, mask = img_t1.to(device), img_t2.to(device), mask.to(device)

            outputs = model(img_t1, img_t2)
            loss = combined_loss(outputs, mask)
            iou, f1 = calculate_metrics(outputs, mask)

            running_loss += loss.item()
            running_iou += iou.item()
            running_f1 += f1.item()
            pbar.set_postfix(loss=f"{loss.item():.4f}", iou=f"{iou.item():.4f}", f1=f"{f1.item():.4f}")

    avg_loss = running_loss / len(loader)
    avg_iou = running_iou / len(loader)
    avg_f1 = running_f1 / len(loader)
    return avg_loss, avg_iou, avg_f1

# --- Training Initialization ---
best_val_iou = 0.0
epochs_no_improve = 0 # For early stopping
train_losses, val_losses, val_ious, val_f1s = [], [], [], []
start_time = time.time()

print("🚀 Starting Training...")
for epoch in range(NUM_EPOCHS):
    epoch_start_time = time.time()
    print(f"\n--- Epoch {epoch+1}/{NUM_EPOCHS} ---")

    train_loss = train_one_epoch(model, train_loader, optimizer, DEVICE)
    val_loss, val_iou, val_f1 = validate_one_epoch(model, val_loader, DEVICE)

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_ious.append(val_iou)
    val_f1s.append(val_f1)

    epoch_duration = time.time() - epoch_start_time
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Epoch {epoch+1} Summary:")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss  : {val_loss:.4f}")
    print(f"  Val IoU   : {val_iou:.4f}")
    print(f"  Val F1    : {val_f1:.4f}")
    print(f"  LR        : {current_lr:.6f}")
    print(f"  Duration  : {epoch_duration:.2f}s")


    # Learning rate scheduling based on validation IoU (higher is better)
    scheduler.step(val_iou)

    # Save the model if validation IoU improves
    if val_iou > best_val_iou:
        print(f"✅ Validation IoU improved ({best_val_iou:.4f} --> {val_iou:.4f}). Saving model...")
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
        best_val_iou = val_iou
        epochs_no_improve = 0 # Reset counter
        print(f"   Model saved to {MODEL_SAVE_PATH}")
    else:
        epochs_no_improve += 1
        print(f"📉 Validation IoU did not improve for {epochs_no_improve} epoch(s). Best IoU: {best_val_iou:.4f}")

    # Early Stopping
    if epochs_no_improve >= PATIENCE:
        print(f"\n🚫 Early stopping triggered after {PATIENCE} epochs without improvement.")
        break

total_training_time = time.time() - start_time
print("\n🏁 Training Finished!")
print(f"Best Validation IoU: {best_val_iou:.4f}")
print(f"Total Training Time: {total_training_time // 60:.0f}m {total_training_time % 60:.0f}s")

# @title 8. Plot Training History
plt.figure(figsize=(18, 6))

plt.subplot(1, 3, 1)
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
plt.title('Loss History')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 2)
plt.plot(range(1, len(val_ious) + 1), val_ious, label='Validation IoU', color='green')
plt.title('Validation IoU History')
plt.xlabel('Epoch')
plt.ylabel('IoU')
plt.axhline(y=best_val_iou, color='r', linestyle='--', label=f'Best IoU: {best_val_iou:.4f}')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 3)
plt.plot(range(1, len(val_f1s) + 1), val_f1s, label='Validation F1-Score', color='orange')
plt.title('Validation F1-Score History')
plt.xlabel('Epoch')
plt.ylabel('F1 Score')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()




In [None]:
# @title 9. Bounding Box Generation Functions (Unchanged)

def mask_to_bboxes(mask_np, min_area_threshold=50): # Increased default threshold slightly for larger images
    """
    Converts a binary mask (numpy array 0 or 1) to bounding boxes.
    Args:
        mask_np: Binary mask (H, W) as a numpy array (dtype should be integer 0 or 1, or float 0.0 or 1.0).
        min_area_threshold: Minimum contour area to be considered a valid bounding box.
    Returns:
        List of bounding boxes [(x, y, w, h), ...].
    """
    if mask_np.ndim != 2:
        raise ValueError(f"Input mask must be 2D (H, W). Got shape {mask_np.shape}")

    if mask_np.dtype == np.float32 or mask_np.dtype == np.float64:
        mask_uint8 = (mask_np * 255).astype(np.uint8)
    elif mask_np.dtype == np.uint8:
         mask_uint8 = np.where(mask_np > 0, 255, 0).astype(np.uint8)
    else:
        mask_uint8 = (mask_np * 255).astype(np.uint8)

    contours, hierarchy = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    bboxes = []
    for contour in contours:
        area = cv2.contourArea(contour)
        if area > min_area_threshold:
            x, y, w, h = cv2.boundingRect(contour)
            bboxes.append((x, y, w, h))

    return bboxes

def draw_bboxes_on_image(image_np_rgb, bboxes, color=(255, 0, 0), thickness=2):
    """
    Draws bounding boxes on an image (expects RGB numpy array HWC).
    """
    output_image = image_np_rgb.copy()
    if output_image.ndim != 3 or output_image.shape[2] != 3:
         raise ValueError(f"Input image must be 3D RGB (H, W, 3). Got shape {output_image.shape}")

    # Ensure image is uint8 for drawing
    if output_image.dtype != np.uint8:
        print("Warning: Converting input image to uint8 for drawing.")
        output_image = output_image.astype(np.uint8)


    color_bgr = tuple(reversed(color)) # OpenCV uses BGR
    img_bgr = cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR)
    for (x, y, w, h) in bboxes:
        cv2.rectangle(img_bgr, (x, y), (x + w, y + h), color_bgr, thickness)
    output_image_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    return output_image_rgb


# @title 10. Inference and Visualization (using Best Model)

# Load the best saved model weights
print(f"Loading best model from: {MODEL_SAVE_PATH}")
# Re-initialize model architecture (must match the trained architecture)
inference_model = SiamUnetPlusPlusConcatenate(
    encoder_name=ENCODER_NAME,
    encoder_weights=None, # Weights are loaded below
    classes=1,
    activation=None
)

# Load the saved state dictionary
try:
    # Ensure loading to the correct device
    inference_model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
    inference_model.to(DEVICE) # Move model to device
    print("Model weights loaded successfully to", DEVICE)
except Exception as e:
    print(f"Error loading model weights: {e}")
    print(f"Ensure the path '{MODEL_SAVE_PATH}' is correct and the file exists.")
    sys.exit()

inference_model.eval() # Set model to evaluation mode

# --- Get a sample from the validation set for inference ---
# We need a dataset *without* augmentations (except resize) for clean visualization
vis_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE, interpolation=cv2.INTER_LINEAR)
])
vis_dataset = LevirCDDataset(VAL_IMG_T1_DIR, VAL_IMG_T2_DIR, VAL_MASK_DIR, transform=vis_transform)
# Load one sample (no batching needed for single inference)
vis_loader = DataLoader(vis_dataset, batch_size=1, shuffle=True)

try:
    img_t1_vis_np, img_t2_vis_np, mask_vis_np = next(iter(vis_loader))
except StopIteration:
     print("Validation loader is empty. Cannot perform inference visualization.")
     sys.exit()

# Remove batch dimension and ensure correct numpy format (H, W, C for images, H, W for mask)
img_t1_vis_np = img_t1_vis_np.squeeze(0).numpy().astype(np.uint8)
img_t2_vis_np = img_t2_vis_np.squeeze(0).numpy().astype(np.uint8)
mask_vis_np = mask_vis_np.squeeze(0).squeeze(0).numpy() # Remove batch and channel dims

# --- Preprocess the visualization sample *for the model* (normalization, tensor) ---
inference_preprocess = A.Compose([
    A.Normalize(mean=mean, std=std, max_pixel_value=255.0),
    ToTensorV2()
])
processed = inference_preprocess(image=img_t1_vis_np, image1=img_t2_vis_np)
img_t1_tensor = processed['image'].unsqueeze(0).to(DEVICE) # Add batch dim, send to device
img_t2_tensor = processed['image1'].unsqueeze(0).to(DEVICE) # Add batch dim, send to device

# --- Perform Inference ---
print("Running inference...")
with torch.no_grad():
    pred_logits = inference_model(img_t1_tensor, img_t2_tensor)
    pred_probs = torch.sigmoid(pred_logits)
    # Remove batch & channel dim, move to CPU, convert to numpy (H, W)
    pred_mask_np = pred_probs.squeeze().cpu().numpy()
print("Inference complete.")

# --- Post-process: Threshold and Find BBoxes ---
binary_mask_np = (pred_mask_np > 0.5).astype(np.float32)
bboxes = mask_to_bboxes(binary_mask_np, min_area_threshold=50) # Use threshold from function definition

# --- Draw BBoxes on the T2 image used for visualization ---
# This image is already IMG_SIZE x IMG_SIZE
img_t2_with_boxes = draw_bboxes_on_image(img_t2_vis_np, bboxes, color=(0, 255, 0), thickness=2) # Green boxes

# --- Display Results ---
print("Displaying results...")
plt.figure(figsize=(25, 10)) # Wider figure

plt.subplot(1, 5, 1)
plt.imshow(img_t1_vis_np)
plt.title(f'Image T1 ({IMG_SIZE}x{IMG_SIZE})')
plt.axis('off')

plt.subplot(1, 5, 2)
plt.imshow(img_t2_vis_np)
plt.title(f'Image T2 ({IMG_SIZE}x{IMG_SIZE})')
plt.axis('off')

plt.subplot(1, 5, 3)
plt.imshow(mask_vis_np, cmap='gray')
plt.title('Ground Truth Mask')
plt.axis('off')

plt.subplot(1, 5, 4)
plt.imshow(binary_mask_np, cmap='gray')
plt.title('Predicted Mask')
plt.axis('off')

plt.subplot(1, 5, 5)
plt.imshow(img_t2_with_boxes)
plt.title(f'T2 w/ BBoxes ({len(bboxes)} detected)')
plt.axis('off')

plt.tight_layout()
plt.show()



In [None]:
# @title 11. Download the Trained Model
from google.colab import files

if os.path.exists(MODEL_SAVE_PATH):
    print(f"\nAttempting to download the best model: {MODEL_SAVE_PATH}")
    files.download(MODEL_SAVE_PATH)
else:
    print(f"\nModel file not found at {MODEL_SAVE_PATH}. Skipping download.")
    print("This might happen if training stopped early or failed.")