In [None]:
import os
import re
import glob
import numpy as np
import matplotlib.pyplot as plt
import random
import cv2
import rasterio

def lee_filter(img, size=7):
    img_mean = cv2.blur(img, (size, size))
    img_sqr_mean = cv2.blur(img**2, (size, size))
    img_variance = img_sqr_mean - img_mean**2
    overall_variance = np.var(img)
    img_weights = img_variance / (img_variance + overall_variance + 1e-8)
    img_filtered = img_mean + img_weights * (img - img_mean)
    return img_filtered

def extract_key(filename):
    name = os.path.basename(filename).lower()
    if any(city in name for city in ['gorakhpur', 'chennai', 'brahmaputra']):
        if 'mask' in name:
            match = re.search(r'sen_\w+_mask_(\d+)', name)
            return f'divit_{match.group(1)}' if match else None
        else:
            match = re.search(r'sen_\w+_(\d{8})_(\d+)', name)
            return f'divit_{match.group(2)}' if match else None
    if re.match(r'sen_gj_\d{4}_', name):
        match = re.search(r'sen_(\w{2})_(\d{4})', name)
        if match:
            state, tile = match.groups()
            return f"{state}_{tile}"
    match_vijay2 = re.search(r'sen_(\w{2})_(\d{6})_(\d{4})', name)
    if match_vijay2:
        state, _, tile = match_vijay2.groups()
        return f"{state}_{tile}"
    match_vijay = re.search(r'sen_(\w{2})_(\d{4})_\d{4}', name)
    if match_vijay:
        state, tile = match_vijay.groups()
        return f"{state}_{tile}"
    match_general = re.search(r'sen_(\w{2})_\d{8}_(\d{4})', name)
    if match_general:
        state, tile = match_general.groups()
        return f"{state}_{tile}"
    return None

def load_dataset(base_dir, target_size=(256, 256)):
    before_folder = os.path.join(base_dir, 'before_geotiff')
    after_folder = os.path.join(base_dir, 'after_geotiff')
    mask_folder = os.path.join(base_dir, 'masked_geotiff')
    before_images = sorted(glob.glob(os.path.join(before_folder, '*.tif')))
    after_images = sorted(glob.glob(os.path.join(after_folder, '*.tif')))
    mask_images = sorted(glob.glob(os.path.join(mask_folder, '*.tif')))
    before_dict = {extract_key(f): f for f in before_images if extract_key(f)}
    after_dict = {extract_key(f): f for f in after_images if extract_key(f)}
    mask_dict = {extract_key(f): f for f in mask_images if extract_key(f)}
    paired_data = []
    for key in before_dict:
        if key in after_dict and key in mask_dict:
            paired_data.append((before_dict[key], after_dict[key], mask_dict[key]))
    print(f"✅ Total Paired Samples: {len(paired_data)}")
    return paired_data

def preprocess_sar_image_tiff(img_path, target_size=(256, 256), apply_lee=True):
    with rasterio.open(img_path) as src:
        img = src.read(1).astype(np.float32)
    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
    if apply_lee:
        img = lee_filter(img)
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)
    img = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)
    img = np.stack([img]*3, axis=-1)
    return img

def preprocess_and_save_npz(
    paired_data, output_path, target_size=(256, 256),
    apply_lee=True, mask_threshold=50, dilate_mask=True
):
    before_list, after_list, mask_list = [], [], []
    all_zero_count = 0
    for idx, (before_path, after_path, mask_path) in enumerate(paired_data):
        before_img = preprocess_sar_image_tiff(before_path, target_size, apply_lee)
        after_img = preprocess_sar_image_tiff(after_path, target_size, apply_lee)
        with rasterio.open(mask_path) as src:
            mask_arr = src.read(1)
        mask_arr = cv2.resize(mask_arr, target_size, interpolation=cv2.INTER_NEAREST)
        bin_mask = (mask_arr > mask_threshold).astype(np.uint8)
        if dilate_mask:
            bin_mask = cv2.dilate(bin_mask, np.ones((3,3), np.uint8), iterations=1)
        if np.all(bin_mask == 0):
            all_zero_count += 1
            print(f"Warning: All-zero mask for {mask_path}")
        before_list.append(before_img)
        after_list.append(after_img)
        mask_list.append(bin_mask)
    before_np = np.stack(before_list)
    after_np = np.stack(after_list)
    mask_np = np.stack(mask_list)
    np.savez_compressed(output_path, before=before_np, after=after_np, mask=mask_np)
    print(f"Saved .npz file to {output_path}")
    print(f"Total all-zero masks: {all_zero_count} out of {len(mask_list)}")

base_dir = "/kaggle/working"
paired_data = load_dataset(base_dir)
preprocess_and_save_npz(
    paired_data,
    "flood_data.npz",
    target_size=(256, 256),
    apply_lee=True,
    mask_threshold=50,      # Keep threshold at 50 as you requested
    dilate_mask=True        # Apply dilation to preserve thin regions
)

def visualize_masks(npz_path, num_samples=5):
    data = np.load(npz_path)
    masks = data['mask']
    total = len(masks)
    all_zero_count = 0
    for i in range(num_samples):
        idx = random.randint(0, total-1)
        mask = masks[idx]
        percent_fg = 100 * (mask > 0).sum() / mask.size
        if np.all(mask == 0):
            all_zero_count += 1
        print(f"Sample {idx}: Foreground pixels = {percent_fg:.2f}%")
        plt.imshow(mask, cmap='gray')
        plt.title(f"Mask {idx} - Foreground: {percent_fg:.2f}%")
        plt.show()
    print(f"Total all-zero masks in visualization: {all_zero_count} out of {num_samples}")

visualize_masks("flood_data.npz", num_samples=5)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.cuda.amp import GradScaler, autocast
import torchvision.transforms.functional as TF
import os
import random

data = np.load("flood_data.npz")
before_dataset = data['before']
after_dataset = data['after']
mask_dataset = data['mask']

def to_tensor_channel_first(img):
    if img.ndim == 3:
        img = np.transpose(img, (2, 0, 1))
    return img.astype(np.float32)

before_dataset = np.array([to_tensor_channel_first(img) for img in before_dataset])
after_dataset = np.array([to_tensor_channel_first(img) for img in after_dataset])
mask_dataset = np.array([img.astype(np.float32) for img in mask_dataset])

combined_dataset = np.concatenate([before_dataset, after_dataset], axis=1)

class FloodDataset(Dataset):
    def __init__(self, X, y, augment=False, rotation_degrees=10):
        self.X = X
        self.y = y
        self.augment = augment
        self.rotation_degrees = rotation_degrees

    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        img = torch.tensor(self.X[idx], dtype=torch.float32)
        mask = torch.tensor(self.y[idx], dtype=torch.float32).unsqueeze(0)
        if self.augment and random.random() < 0.8:
            angle = random.uniform(-self.rotation_degrees, self.rotation_degrees)
            img = TF.rotate(img, angle, interpolation=TF.InterpolationMode.BILINEAR)
            mask = TF.rotate(mask, angle, interpolation=TF.InterpolationMode.NEAREST)
            if random.random() > 0.5:
                img = TF.hflip(img)
                mask = TF.hflip(mask)
            if random.random() > 0.5:
                img = TF.vflip(img)
                mask = TF.vflip(mask)
        return img, mask

dataset = FloodDataset(combined_dataset, mask_dataset, augment=True, rotation_degrees=10)
total = len(dataset)
train_size = int(0.7 * total)
val_size = int(0.15 * total)
test_size = total - train_size - val_size

generator = torch.Generator().manual_seed(42)
train_ds, val_ds, test_ds = random_split(dataset, [train_size, val_size, test_size], generator=generator)
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=8, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=8, pin_memory=True)

# --- Focal Tversky Loss ---
class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.7, beta=0.3, gamma=0.75, smooth=1e-6):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.smooth = smooth
    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        tp = (pred * target).sum()
        fp = ((1 - target) * pred).sum()
        fn = (target * (1 - pred)).sum()
        tversky = (tp + self.smooth) / (tp + self.alpha * fp + self.beta * fn + self.smooth)
        return torch.pow((1 - tversky), self.gamma)

loss_fn = FocalTverskyLoss(alpha=0.7, beta=0.3, gamma=0.75)

def dice_score(pred, target, smooth=1e-6):
    pred = torch.sigmoid(pred)
    pred = (pred > 0.5).float()
    target = target.float()
    intersection = (pred * target).sum(dim=(1,2,3))
    union = pred.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3))
    dice = (2 * intersection + smooth) / (union + smooth)
    return dice.mean().item()

# --- UNet++ Model (unchanged, but see below for deeper model option) ---
class UNetPlusPlus(nn.Module):
    def __init__(self, in_channels=6, out_channels=1, deep_supervision=True):
        super(UNetPlusPlus, self).__init__()
        self.deep_supervision = deep_supervision
        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True)
            )
        self.pool = nn.MaxPool2d(2, 2)
        self.conv0_0 = conv_block(in_channels, 64)
        self.conv1_0 = conv_block(64, 128)
        self.conv2_0 = conv_block(128, 256)
        self.conv3_0 = conv_block(256, 512)
        self.conv4_0 = conv_block(512, 1024)
        self.up1_0 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.up2_0 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.up3_0 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.up4_0 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv0_1 = conv_block(64+64, 64)
        self.conv1_1 = conv_block(128+128, 128)
        self.conv2_1 = conv_block(256+256, 256)
        self.conv3_1 = conv_block(512+512, 512)
        self.up1_1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.up2_1 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.up3_1 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv0_2 = conv_block(64*3, 64)
        self.conv1_2 = conv_block(128*3, 128)
        self.conv2_2 = conv_block(256*3, 256)
        self.up1_2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.up2_2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv0_3 = conv_block(64*4, 64)
        self.conv1_3 = conv_block(128*4, 128)
        self.up1_3 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv0_4 = conv_block(64*5, 64)
        self.final = nn.Conv2d(64, out_channels, kernel_size=1)
    def forward(self, x):
        x0_0 = self.conv0_0(x)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up1_0(x1_0)], 1))
        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up2_0(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up1_1(x1_1)], 1))
        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up3_0(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up2_1(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up1_2(x1_2)], 1))
        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up4_0(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up3_1(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up2_2(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up1_3(x1_3)], 1))
        output = self.final(x0_4)
        return output

# --- For a deeper model with ResNet encoder, use segmentation_models_pytorch ---
# import segmentation_models_pytorch as smp
# model = smp.UnetPlusPlus(encoder_name="resnet34", in_channels=6, classes=1, encoder_weights=None).to(device)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetPlusPlus(in_channels=6, out_channels=1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scaler = GradScaler()
model_path = "flood_segmentation_unetplusplus_train.pth"

def save_model(model, optimizer, path):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, path)
    print(f"Model saved to {path}")

best_dice = 0.500
num_epochs = 40  # Train for more epochs as recommended

def train_epoch(model, loader, optimizer, loss_fn, scaler):
    model.train()
    running_loss = 0.0
    running_dice = 0.0
    for images, masks in loader:
        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        with autocast():
            preds = model(images)
            loss = loss_fn(preds, masks)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item()
        running_dice += dice_score(preds, masks)
    return running_loss / len(loader), running_dice / len(loader)

def validate(model, loader, loss_fn):
    model.eval()
    running_loss = 0.0
    running_dice = 0.0
    with torch.no_grad():
        for images, masks in loader:
            images, masks = images.to(device), masks.to(device)
            preds = model(images)
            loss = loss_fn(preds, masks)
            running_loss += loss.item()
            running_dice += dice_score(preds, masks)
    return running_loss / len(loader), running_dice / len(loader)

print("Starting training from scratch...")
train_losses, val_losses = [], []
train_dices, val_dices = [], []
for epoch in range(num_epochs):
    train_loss, train_dice = train_epoch(model, train_loader, optimizer, loss_fn, scaler)
    val_loss, val_dice = validate(model, val_loader, loss_fn)
    train_losses.append(train_loss)
    train_dices.append(train_dice)
    val_losses.append(val_loss)
    val_dices.append(val_dice)
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f} | Train Dice: {train_dice:.4f}")
    print(f"Val Loss: {val_loss:.4f} | Val Dice: {val_dice:.4f}")
    if val_dice > best_dice:
        best_dice = val_dice
        save_model(model, optimizer, model_path)
        print(f"New best model saved with Dice: {best_dice:.4f}")
    print("-" * 50)
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.legend()
plt.title('Loss over epochs')
plt.subplot(1, 2, 2)
plt.plot(train_dices, label='Train Dice')
plt.plot(val_dices, label='Val Dice')
plt.legend()
plt.title('Dice Score over epochs')
plt.show()

test_loss, test_dice = validate(model, test_loader, loss_fn)
print(f"\nFinal Test Results:")
print(f"Test Loss: {test_loss:.4f} | Test Dice: {test_dice:.4f}")

def visualize_predictions(model, loader, num_samples=5):
    model.eval()
    with torch.no_grad():
        for i, (images, masks) in enumerate(loader):
            if i >= num_samples:
                break
            images, masks = images.to(device), masks.to(device)
            preds = torch.sigmoid(model(images))
            preds_bin = (preds > 0.5).float()
            img_np = images[0, :3].cpu().numpy().transpose(1, 2, 0)
            after_np = images[0, 3:6].cpu().numpy().transpose(1, 2, 0)
            mask_np = masks[0, 0].cpu().numpy()
            pred_np = preds_bin[0, 0].cpu().numpy()
            fg_pct = 100 * pred_np.sum() / pred_np.size
            print(f"Predicted mask {i}: {fg_pct:.2f}% foreground")
            plt.figure(figsize=(18, 5))
            plt.subplot(1, 4, 1)
            plt.imshow(img_np)
            plt.title('Before RGB')
            plt.subplot(1, 4, 2)
            plt.imshow(after_np)
            plt.title('After RGB')
            plt.subplot(1, 4, 3)
            plt.imshow(mask_np, cmap='gray')
            plt.title('Ground Truth')
            plt.subplot(1, 4, 4)
            plt.imshow(pred_np, cmap='gray')
            plt.title('Prediction')
            plt.show()

visualize_predictions(model, test_loader, num_samples=5)

# ---- Tiny Set Overfitting (disable augmentation for tiny set) ----
tiny_ds = torch.utils.data.Subset(FloodDataset(combined_dataset, mask_dataset, augment=False), range(10))
tiny_loader = DataLoader(tiny_ds, batch_size=2, shuffle=True)
tiny_model = UNetPlusPlus(in_channels=6, out_channels=1).to(device)
tiny_optimizer = torch.optim.Adam(tiny_model.parameters(), lr=1e-3)
tiny_scaler = GradScaler()
tiny_loss_fn = FocalTverskyLoss(alpha=0.7, beta=0.3, gamma=0.75)
for epoch in range(20):
    train_loss, train_dice = train_epoch(tiny_model, tiny_loader, tiny_optimizer, tiny_loss_fn, tiny_scaler)
    print(f"Tiny set Epoch {epoch+1}: Loss {train_loss:.4f}, Dice {train_dice:.4f}")
