In [None]:
# Change Detection Using U-Net with ResNet34 Backbone
# Author: Dhrumil Prajapati
# Description: Satellite image change detection using semantic segmentation with U-Net architecture

import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from glob import glob
from tqdm import tqdm
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
from sklearn.model_selection import train_test_split
from albumentations import (
    Compose, RandomRotate90, RandomCrop, HorizontalFlip, VerticalFlip, Resize,
    RandomBrightnessContrast, HueSaturationValue, Normalize, ShiftScaleRotate,
    RandomGamma, ElasticTransform, GaussNoise, OneOf
)

# Set seeds for reproducibility
def set_seed(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    print(f"Random seed set as {seed}")

set_seed(42)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Paths
data_dir = "/content/drive/MyDrive/dataset_onera"
model_save_path = "best_model.pth"

# Configuration
config = {
    "img_size": 256,
    "batch_size": 16,
    "epochs": 50,
    "learning_rate": 0.0001,
    "encoder": "resnet34",
    "encoder_weights": "imagenet",
    "classes": 1,
    "activation": None,  # We'll apply sigmoid manually for more flexibility
    "val_size": 0.1,
    "test_size": 0.1
}

# Data preparation
def read_img(path):
    """Read image from path and convert to RGB format"""
    img = cv2.imread(path, cv2.IMREAD_COLOR)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def read_mask(path):
    """Read mask from path and normalize"""
    mask = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    mask = mask / 255.0
    return mask

def get_data_paths():
    """Get all image and mask paths from the dataset"""
    img1_paths = sorted(glob(os.path.join(data_dir, "imgs/im1/*.png")))
    img2_paths = sorted(glob(os.path.join(data_dir, "imgs/im2/*.png")))
    mask_paths = sorted(glob(os.path.join(data_dir, "gt/*.png")))
    
    # Create DataFrame for easier data handling
    df = pd.DataFrame({
        "img1_path": img1_paths,
        "img2_path": img2_paths,
        "mask_path": mask_paths
    })
    
    return df

# Augmentations
def get_training_augmentation():
    """Get augmentation pipeline for training data"""
    return Compose([
        RandomRotate90(p=0.5),
        HorizontalFlip(p=0.5),
        VerticalFlip(p=0.5),
        ShiftScaleRotate(scale_limit=0.2, rotate_limit=45, p=0.5),
        RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
        HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
        OneOf([
            ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03, p=0.5),
            GaussNoise(var_limit=(10, 50), p=0.5),
            RandomGamma(gamma_limit=(80, 120), p=0.5)
        ], p=0.5),
        Resize(height=config["img_size"], width=config["img_size"], always_apply=True),
        Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), always_apply=True)
    ])

def get_validation_augmentation():
    """Get augmentation pipeline for validation data"""
    return Compose([
        Resize(height=config["img_size"], width=config["img_size"], always_apply=True),
        Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), always_apply=True)
    ])

# Dataset
class ChangeDetectionDataset(Dataset):
    def __init__(self, df, augmentation=None):
        self.df = df
        self.augmentation = augmentation
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img1_path = self.df.iloc[idx].img1_path
        img2_path = self.df.iloc[idx].img2_path
        mask_path = self.df.iloc[idx].mask_path
        
        img1 = read_img(img1_path)
        img2 = read_img(img2_path)
        mask = read_mask(mask_path)
        
        # Stack images to create a 6-channel input
        img = np.concatenate([img1, img2], axis=2)
        
        if self.augmentation:
            sample = self.augmentation(image=img, mask=mask)
            img, mask = sample['image'], sample['mask']
        
        # Reshape mask for model input
        mask = np.expand_dims(mask, axis=0)
        
        return torch.from_numpy(img.transpose(2, 0, 1)).float(), torch.from_numpy(mask).float()

# Model
def build_model():
    """Build U-Net model with ResNet34 backbone"""
    model = smp.Unet(
        encoder_name=config["encoder"],
        encoder_weights=config["encoder_weights"],
        in_channels=6,  # 3 channels from img1 + 3 channels from img2
        classes=config["classes"],
        activation=config["activation"]
    )
    return model.to(device)

# Combined Loss Function
class CombinedLoss(nn.Module):
    def __init__(self, weights=(0.5, 0.5)):
        super(CombinedLoss, self).__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice_loss = smp.losses.DiceLoss(mode='binary')
        self.weights = weights
        
    def forward(self, y_pred, y_true):
        bce_loss = self.bce(y_pred, y_true)
        dice_loss = self.dice_loss(torch.sigmoid(y_pred), y_true)
        return self.weights[0] * bce_loss + self.weights[1] * dice_loss

# Metrics
def iou_score(y_pred, y_true, threshold=0.5):
    """Calculate IoU score"""
    y_pred = (torch.sigmoid(y_pred) > threshold).float()
    intersection = (y_pred * y_true).sum()
    union = y_pred.sum() + y_true.sum() - intersection
    return (intersection + 1e-7) / (union + 1e-7)

def f1_score(y_pred, y_true, threshold=0.5):
    """Calculate F1 score"""
    y_pred = (torch.sigmoid(y_pred) > threshold).float()
    tp = (y_true * y_pred).sum()
    fp = ((1 - y_true) * y_pred).sum()
    fn = (y_true * (1 - y_pred)).sum()
    precision = tp / (tp + fp + 1e-7)
    recall = tp / (tp + fn + 1e-7)
    return 2 * precision * recall / (precision + recall + 1e-7)

# Training functions
def train_epoch(model, loader, optimizer, criterion, device):
    """Train for one epoch"""
    model.train()
    epoch_loss = 0
    epoch_iou = 0
    epoch_f1 = 0
    
    with tqdm(loader, desc="Training", leave=False) as pbar:
        for images, masks in pbar:
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Metrics
            iou = iou_score(outputs, masks)
            f1 = f1_score(outputs, masks)
            
            # Update progress
            epoch_loss += loss.item()
            epoch_iou += iou.item()
            epoch_f1 += f1.item()
            pbar.set_postfix(loss=loss.item(), iou=iou.item(), f1=f1.item())
    
    # Return average metrics
    return epoch_loss / len(loader), epoch_iou / len(loader), epoch_f1 / len(loader)

def validate_epoch(model, loader, criterion, device):
    """Validate for one epoch"""
    model.eval()
    epoch_loss = 0
    epoch_iou = 0
    epoch_f1 = 0
    
    with torch.no_grad():
        with tqdm(loader, desc="Validation", leave=False) as pbar:
            for images, masks in pbar:
                images = images.to(device)
                masks = masks.to(device)
                
                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, masks)
                
                # Metrics
                iou = iou_score(outputs, masks)
                f1 = f1_score(outputs, masks)
                
                # Update progress
                epoch_loss += loss.item()
                epoch_iou += iou.item()
                epoch_f1 += f1.item()
                pbar.set_postfix(loss=loss.item(), iou=iou.item(), f1=f1.item())
    
    # Return average metrics
    return epoch_loss / len(loader), epoch_iou / len(loader), epoch_f1 / len(loader)

# Visualization
def visualize_predictions(model, test_loader, device, num_samples=5):
    """Visualize model predictions on test data"""
    model.eval()
    
    images, masks, predictions = [], [], []
    
    with torch.no_grad():
        for i, (img, mask) in enumerate(test_loader):
            if i >= num_samples:
                break
                
            img = img.to(device)
            pred = model(img)
            pred = torch.sigmoid(pred).cpu().numpy()
            
            # Store original images, masks and predictions
            img1 = img[0, :3].cpu().numpy().transpose(1, 2, 0)
            img2 = img[0, 3:].cpu().numpy().transpose(1, 2, 0)
            mask = mask[0, 0].cpu().numpy()
            pred = pred[0, 0]
            
            # Denormalize images
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            img1 = std * img1 + mean
            img2 = std * img2 + mean
            img1 = np.clip(img1, 0, 1)
            img2 = np.clip(img2, 0, 1)
            
            images.append((img1, img2))
            masks.append(mask)
            predictions.append(pred)
    
    # Create visualization
    fig, axes = plt.subplots(num_samples, 4, figsize=(20, 4*num_samples))
    
    for i in range(num_samples):
        # Image 1
        axes[i, 0].imshow(images[i][0])
        axes[i, 0].set_title('Image 1')
        axes[i, 0].axis('off')
        
        # Image 2
        axes[i, 1].imshow(images[i][1])
        axes[i, 1].set_title('Image 2')
        axes[i, 1].axis('off')
        
        # Ground Truth
        axes[i, 2].imshow(masks[i], cmap='gray')
        axes[i, 2].set_title('Ground Truth')
        axes[i, 2].axis('off')
        
        # Prediction
        axes[i, 3].imshow((predictions[i] > 0.5).astype(np.uint8), cmap='gray')
        axes[i, 3].set_title('Prediction')
        axes[i, 3].axis('off')
    
    plt.tight_layout()
    return fig

# Main training loop
def train_model():
    # Prepare data
    print("Preparing data...")
    df = get_data_paths()
    print(f"Total samples: {len(df)}")
    
    # Split data
    train_df, temp_df = train_test_split(df, test_size=config["val_size"]+config["test_size"], random_state=42)
    val_df, test_df = train_test_split(temp_df, test_size=config["test_size"]/(config["val_size"]+config["test_size"]), random_state=42)
    
    print(f"Train samples: {len(train_df)}")
    print(f"Validation samples: {len(val_df)}")
    print(f"Test samples: {len(test_df)}")
    
    # Create datasets and dataloaders
    train_dataset = ChangeDetectionDataset(train_df, augmentation=get_training_augmentation())
    val_dataset = ChangeDetectionDataset(val_df, augmentation=get_validation_augmentation())
    test_dataset = ChangeDetectionDataset(test_df, augmentation=get_validation_augmentation())
    
    train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)
    
    # Build model
    print("Building model...")
    model = build_model()
    print(f"Model: U-Net with {config['encoder']} backbone")
    
    # Set up training
    criterion = CombinedLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=5, verbose=True
    )
    
    # Training loop
    print(f"Starting training for {config['epochs']} epochs...")
    history = {
        'train_loss': [], 'train_iou': [], 'train_f1': [],
        'val_loss': [], 'val_iou': [], 'val_f1': []
    }
    
    best_iou = 0
    
    for epoch in range(config["epochs"]):
        print(f"Epoch {epoch+1}/{config['epochs']}")
        
        # Train and validate
        train_loss, train_iou, train_f1 = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_iou, val_f1 = validate_epoch(model, val_loader, criterion, device)
        
        # Update learning rate
        scheduler.step(val_iou)
        
        # Save history
        history['train_loss'].append(train_loss)
        history['train_iou'].append(train_iou)
        history['train_f1'].append(train_f1)
        history['val_loss'].append(val_loss)
        history['val_iou'].append(val_iou)
        history['val_f1'].append(val_f1)
        
        # Print epoch results
        print(f"Train Loss: {train_loss:.4f}, IoU: {train_iou:.4f}, F1: {train_f1:.4f}")
        print(f"Val Loss: {val_loss:.4f}, IoU: {val_iou:.4f}, F1: {val_f1:.4f}")
        
        # Save best model
        if val_iou > best_iou:
            best_iou = val_iou
            torch.save(model.state_dict(), model_save_path)
            print(f"Saved best model with IoU: {best_iou:.4f}")
    
    # Load best model
    model.load_state_dict(torch.load(model_save_path))
    
    # Final evaluation
    print("Evaluating on test set...")
    test_loss, test_iou, test_f1 = validate_epoch(model, test_loader, criterion, device)
    print(f"Test Loss: {test_loss:.4f}, IoU: {test_iou:.4f}, F1: {test_f1:.4f}")
    
    # Visualize results
    fig = visualize_predictions(model, test_loader, device)
    
    # Plot training history
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history['train_iou'], label='Train IoU')
    plt.plot(history['val_iou'], label='Val IoU')
    plt.title('IoU Score')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.close()
    
    return model, history, test_iou, test_f1

# Run the training process
if __name__ == "__main__":
    model, history, test_iou, test_f1 = train_model()
    print("Training completed!")
    print(f"Final Test IoU: {test_iou:.4f}")
    print(f"Final Test F1 Score: {test_f1:.4f}")