In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/post-images/Hanumannagar_Postflood_Orthomosaic.tif
/kaggle/input/pre-images/Hanumannagar_Preflood_Orthomosaic.tif


In [11]:
import torch 
import torch.nn as nn 
import torch.optim as optim 
import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt 
from torch.utils.data import Dataset, DataLoader 


In [2]:
!pip install rasterio

Collecting rasterio
  Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting affine (from rasterio)
  Downloading affine-2.4.0-py3-none-any.whl.metadata (4.0 kB)
Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.2 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m22.2/22.2 MB[0m [31m72.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDownloading affine-2.4.0-py3-none-any.whl (15 kB)
Installing collected packages: affine, rasterio
Successfully installed affine-2.4.0 rasterio-1.4.3


# Explore

In [None]:
preflood_tiles = []

with rasterio.open("/kaggle/input/pre-images/Hanumannagar_Preflood_Orthomosaic.tif") as src:
    for y in range(0, src.height, tile_size):
        for x in range(0, src.width, tile_size):
            window = Window(x, y, tile_size, tile_size)
            img = src.read(window=window)  
            preflood_tiles.append(img)



In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Function to normalize each tile safely
def normalize_tile(tile):
    tile = tile.astype(float)
    min_val, max_val = np.nanmin(tile), np.nanmax(tile)
    if max_val > min_val:
        tile = (tile - min_val) / (max_val - min_val)
    else:
        tile = np.zeros_like(tile)
    return tile

# Parameters
start_idx = 200       # first tile index
end_idx = 400       # last tile index
tiles_to_show = tiles[start_idx:end_idx]
n_tiles = len(tiles_to_show)

cols = 10             # number of columns in grid
rows = n_tiles // cols + (n_tiles % cols > 0)

fig, axes = plt.subplots(rows, cols, figsize=(20, 2*rows))
axes = axes.flatten()  # flatten to 1D for easy indexing

for i, tile in enumerate(tiles_to_show):
    rgb = tile[:3, :, :]                      # take R,G,B
    rgb = np.transpose(rgb, (1, 2, 0))       # (C,H,W) -> (H,W,C)
    rgb = normalize_tile(rgb)
    
    axes[i].imshow(rgb)
    axes[i].axis('off')
    axes[i].set_title(f"Tile {start_idx+i}", fontsize=8)

# Hide any empty subplots
for j in range(i+1, len(axes)):
    axes[j].axis('off')

plt.tight_layout()
plt.show()


# Setup

In [19]:
torch.manual_seed(42)

<torch._C.Generator at 0x7c6b92db9a10>

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
print("Device: ", device)

Device:  cuda


# pesudo mask generation (try)

In [None]:
import os
import numpy as np
import rasterio
from rasterio.windows import Window, from_bounds
from rasterio.coords import disjoint_bounds
import cv2
from PIL import Image

# ---- NDWI calculation ----
def compute_ndwi(tile):
    green = tile[..., 1].astype('float32')
    nir = tile[..., 3].astype('float32')
    ndwi = (green - nir) / (green + nir + 1e-8)
    return ndwi

# ---- if RGB only ----
def compute_rgb_diff(pre, post):
    gray_pre = cv2.cvtColor(pre[..., :3], cv2.COLOR_RGB2GRAY).astype('float32')
    gray_post = cv2.cvtColor(post[..., :3], cv2.COLOR_RGB2GRAY).astype('float32')
    diff = cv2.absdiff(gray_post, gray_pre)
    diff = cv2.GaussianBlur(diff, (5,5), 0)
    diff_norm = (diff - diff.min()) / (diff.max() - diff.min() + 1e-8)
    return diff_norm

# ---- threshold function ----
def binarize(image, method='otsu', thresh=None):
    img8 = np.uint8(image * 255)
    if method == 'otsu':
        _, mask = cv2.threshold(img8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    else:
        _, mask = cv2.threshold(img8, thresh, 255, cv2.THRESH_BINARY)
    return mask // 255

# ---- align and crop to overlap ----
def align_and_crop(pre_src, post_src):
    if pre_src.crs != post_src.crs:
        raise ValueError(" CRS mismatch! Please reproject before alignment.")

    if disjoint_bounds(pre_src.bounds, post_src.bounds):
        raise ValueError(" No overlapping region between pre and post rasters.")

    # Compute overlap bounds
    left = max(pre_src.bounds.left, post_src.bounds.left)
    right = min(pre_src.bounds.right, post_src.bounds.right)
    bottom = max(pre_src.bounds.bottom, post_src.bounds.bottom)
    top = min(pre_src.bounds.top, post_src.bounds.top)

    # Create windows for each
    win_pre = from_bounds(left, bottom, right, top, pre_src.transform)
    win_post = from_bounds(left, bottom, right, top, post_src.transform)

    # Read overlapping data
    pre_data = pre_src.read(window=win_pre)
    post_data = post_src.read(window=win_post)

    # Convert to HWC
    pre_data = np.transpose(pre_data, (1,2,0))
    post_data = np.transpose(post_data, (1,2,0))

    print(f"‚úÖ Cropped to overlap area: {pre_data.shape}")
    return pre_data, post_data

# ---- main function ----
def generate_pseudo_masks(pre_tif, post_tif, out_dir, tile_size=512, min_area_px=50):
    os.makedirs(out_dir, exist_ok=True)
    with rasterio.open(pre_tif) as pre_src, rasterio.open(post_tif) as post_src:
        pre, post = align_and_crop(pre_src, post_src)

        h, w = pre.shape[:2]
        print("Aligned image shape:", h, w)
        tid = 0

        for y in range(0, h, tile_size):
            for x in range(0, w, tile_size):
                wsize = min(tile_size, w - x)
                hsize = min(tile_size, h - y)
                pre_tile = pre[y:y+hsize, x:x+wsize]
                post_tile = post[y:y+hsize, x:x+wsize]

                # ----- Compute change mask -----
                if pre_tile.shape[-1] >= 4:  # NIR available
                    ndwi_pre = compute_ndwi(pre_tile)
                    ndwi_post = compute_ndwi(post_tile)
                    water_pre = binarize((ndwi_pre - ndwi_pre.min())/(ndwi_pre.max()-ndwi_pre.min()))
                    water_post = binarize((ndwi_post - ndwi_post.min())/(ndwi_post.max()-ndwi_post.min()))
                    mask = (water_post == 1) & (water_pre == 0)
                else:
                    diff = compute_rgb_diff(pre_tile, post_tile)
                    mask = binarize(diff, 'otsu')

                mask = mask.astype('uint8')
                mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3,3),np.uint8))
                if np.sum(mask) < min_area_px:
                    continue

                # ---- Save tile ----
                pre_rgb = pre_tile[..., :3]
                post_rgb = post_tile[..., :3]
                pre_png = os.path.join(out_dir, f"tile_{tid}_pre.png")
                post_png = os.path.join(out_dir, f"tile_{tid}_post.png")
                mask_png = os.path.join(out_dir, f"tile_{tid}_mask.png")

                def save_png(img, path):
                    img = np.clip(img, 0, np.percentile(img, 99))
                    img = ((img - img.min()) / (img.max() - img.min() + 1e-8) * 255).astype(np.uint8)
                    Image.fromarray(img).save(path)

                save_png(pre_rgb, pre_png)
                save_png(post_rgb, post_png)
                Image.fromarray(mask*255).save(mask_png)

                tid += 1

        print(f"‚úÖ Saved {tid} aligned tiles with pseudo masks in {out_dir}")

In [None]:
generate_pseudo_masks(
    "/kaggle/input/pre-images/Hanumannagar_Preflood_Orthomosaic.tif",
    "/kaggle/input/post-images/Hanumannagar_Postflood_Orthomosaic.tif",
    "/kaggle/working/masks"
)

In [None]:
import rasterio
with rasterio.open("/kaggle/input/pre-images/Hanumannagar_Preflood_Orthomosaic.tif") as pre_src, rasterio.open("/kaggle/input/post-images/Hanumannagar_Postflood_Orthomosaic.tif") as post_src:
    print(pre_src.width, pre_src.height)
    print(pre_src.crs)
    print(post_src.width, post_src.height)
    print(post_src.crs)

# Align and Crop 

In [3]:
import os
import rasterio
from rasterio.windows import from_bounds, Window
import math

def align_and_crop_to_overlap(pre_path, post_path, out_pre, out_post, block_size=1024):
    # Ensure output directories exist
    os.makedirs(os.path.dirname(out_pre), exist_ok=True)
    os.makedirs(os.path.dirname(out_post), exist_ok=True)

    with rasterio.open(pre_path) as pre, rasterio.open(post_path) as post:
        # Step 1: CRS check
        if pre.crs != post.crs:
            raise ValueError("‚ùå CRS mismatch! Please reproject first.")

        # Step 2: Compute overlap
        left = max(pre.bounds.left, post.bounds.left)
        right = min(pre.bounds.right, post.bounds.right)
        bottom = max(pre.bounds.bottom, post.bounds.bottom)
        top = min(pre.bounds.top, post.bounds.top)
        print(f"‚úÖ Overlap bounds: {left:.2f}, {bottom:.2f}, {right:.2f}, {top:.2f}")

        # Step 3: Create read windows
        win_pre = from_bounds(left, bottom, right, top, pre.transform)
        win_post = from_bounds(left, bottom, right, top, post.transform)

        # Step 4: Output transform and shape
        transform = pre.window_transform(win_pre)
        height = int(win_pre.height)
        width = int(win_pre.width)
        print(f"‚úÖ Overlap size: {width} x {height}")

        # Step 5: Update output profile
        profile = pre.profile
        profile.update({
            'height': height,
            'width': width,
            'transform': transform,
            'BIGTIFF': 'YES'  # ‚úÖ Enable BigTIFF output
        })

        # Step 6: Write tile-by-tile
        with rasterio.open(out_pre, 'w', **profile) as dst_pre, \
             rasterio.open(out_post, 'w', **profile) as dst_post:

            num_tiles_x = math.ceil(width / block_size)
            num_tiles_y = math.ceil(height / block_size)
            print(f"üß© Processing in {num_tiles_x} x {num_tiles_y} tiles...")

            for ty in range(num_tiles_y):
                for tx in range(num_tiles_x):
                    x_off = int(tx * block_size)
                    y_off = int(ty * block_size)
                    w = min(block_size, width - x_off)
                    h = min(block_size, height - y_off)

                    window_pre = Window(win_pre.col_off + x_off, win_pre.row_off + y_off, w, h)
                    window_post = Window(win_post.col_off + x_off, win_post.row_off + y_off, w, h)

                    pre_block = pre.read(window=window_pre)
                    post_block = post.read(window=window_post)

                    dst_pre.write(pre_block, window=Window(x_off, y_off, w, h))
                    dst_post.write(post_block, window=Window(x_off, y_off, w, h))

            print(f"‚úÖ Finished writing aligned cropped TIFFs:\n - {out_pre}\n - {out_post}")


In [4]:
align_and_crop_to_overlap(
    "/kaggle/input/pre-images/Hanumannagar_Preflood_Orthomosaic.tif",
    "/kaggle/input/post-images/Hanumannagar_Postflood_Orthomosaic.tif",
    "/kaggle/working/aligned/pre_aligned.tif",
    "/kaggle/working/aligned/post_aligned.tif",
    block_size=1024  # tune for memory vs speed
)


‚úÖ Overlap bounds: 489425.93, 2935667.56, 491944.07, 2939394.83
‚úÖ Overlap size: 58395 x 86434
üß© Processing in 58 x 85 tiles...
‚úÖ Finished writing aligned cropped TIFFs:
 - /kaggle/working/aligned/pre_aligned.tif
 - /kaggle/working/aligned/post_aligned.tif


# Dataset

In [5]:
import torch
from torch.utils.data import Dataset
import rasterio
import numpy as np
import cv2

class FloodDataset(Dataset):
    def __init__(self, pre_path, post_path, tile_size=512, augment=None, use_ndwi=True):
        self.pre_path = pre_path
        self.post_path = post_path
        self.tile_size = tile_size
        self.augment = augment
        self.use_ndwi = use_ndwi

        with rasterio.open(pre_path) as src:
            self.width = src.width
            self.height = src.height
            self.geo_transform = src.transform
            self.crs = src.crs

        self.tiles_x = (self.width + tile_size - 1) // tile_size
        self.tiles_y = (self.height + tile_size - 1) // tile_size
        self.total_tiles = self.tiles_x * self.tiles_y

    def __len__(self):
        return self.total_tiles

    def _compute_ndwi(self, tile):
        green = tile[..., 1].astype('float32')
        nir = tile[..., 3].astype('float32')
        ndwi = (green - nir) / (green + nir + 1e-8)
        return ndwi

    def _compute_rgb_diff(self, pre, post):
        gray_pre = cv2.cvtColor(pre[..., :3], cv2.COLOR_RGB2GRAY).astype('float32')
        gray_post = cv2.cvtColor(post[..., :3], cv2.COLOR_RGB2GRAY).astype('float32')
        diff = cv2.absdiff(gray_post, gray_pre)
        diff = cv2.GaussianBlur(diff, (5, 5), 0)
        diff_norm = (diff - diff.min()) / (diff.max() - diff.min() + 1e-8)
        return diff_norm

    def _binarize(self, image):
        img8 = np.uint8(image * 255)
        _, mask = cv2.threshold(img8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        return mask // 255

    def __getitem__(self, idx):
        x_idx = idx % self.tiles_x
        y_idx = idx // self.tiles_x
        x = x_idx * self.tile_size
        y = y_idx * self.tile_size
        window = rasterio.windows.Window(x, y, self.tile_size, self.tile_size)
    
        with rasterio.open(self.pre_path) as pre_src, rasterio.open(self.post_path) as post_src:
            pre_tile = np.transpose(pre_src.read(window=window), (1,2,0))
            post_tile = np.transpose(post_src.read(window=window), (1,2,0))
    
        # --- pad to tile_size if needed ---
        def pad_to_tile(tile):
            h, w, c = tile.shape
            pad_h = max(0, self.tile_size - h)
            pad_w = max(0, self.tile_size - w)
            if pad_h > 0 or pad_w > 0:
                tile = np.pad(tile, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
            return tile[:self.tile_size, :self.tile_size, :]
    
        pre_tile = pad_to_tile(pre_tile)
        post_tile = pad_to_tile(post_tile)
    
        # Normalize
        pre_tile = pre_tile.astype("float32")
        post_tile = post_tile.astype("float32")
        pre_tile /= pre_tile.max() if pre_tile.max() > 1 else 1
        post_tile /= post_tile.max() if post_tile.max() > 1 else 1
    
        # Compute flood mask
        if self.use_ndwi and pre_tile.shape[-1] >= 4:
            ndwi_pre = self._compute_ndwi(pre_tile)
            ndwi_post = self._compute_ndwi(post_tile)
            mask = (self._binarize(ndwi_post) == 1) & (self._binarize(ndwi_pre) == 0)
        else:
            diff = self._compute_rgb_diff(pre_tile, post_tile)
            mask = self._binarize(diff)
    
        mask = mask.astype("float32")[None, :, :]
    
        pre_tile = torch.from_numpy(np.transpose(pre_tile[..., :3], (2, 0, 1)))
        post_tile = torch.from_numpy(np.transpose(post_tile[..., :3], (2, 0, 1)))
        mask = torch.from_numpy(mask)
    
        if self.augment:
            pre_tile, post_tile, mask = self.augment(pre_tile, post_tile, mask)
    
        center_x = x + self.tile_size // 2
        center_y = y + self.tile_size // 2
    
        return pre_tile, post_tile, mask, (center_x, center_y), (x_idx, y_idx)


In [6]:
dataset = FloodDataset(
    "/kaggle/working/aligned/pre_aligned.tif",
    "/kaggle/working/aligned/post_aligned.tif",
    512
)

In [7]:
len(dataset)

19435

## Train and Val Split

In [8]:
from torch.utils.data import Subset

def split_flood_dataset(dataset, val_fraction=0.2):
    
    tiles_x, tiles_y = dataset.tiles_x, dataset.tiles_y
    
    # Number of rows for validation
    val_rows = int(tiles_y * val_fraction)
    
    # Assign tiles: last val_rows for validation, remaining for training
    train_indices = []
    val_indices = []
    
    for y in range(tiles_y):
        for x in range(tiles_x):
            idx = y * tiles_x + x
            if y >= tiles_y - val_rows:
                val_indices.append(idx)
            else:
                train_indices.append(idx)
    
    return train_indices, val_indices

train_idx, val_idx = split_flood_dataset(dataset, val_fraction=0.2)

# split
train_dataset = Subset(dataset, train_idx)
val_dataset = Subset(dataset, val_idx)


In [9]:
# size
print("size:\ntrain: ",len(train_dataset),", val:",len(val_dataset))

size:
train:  15640 , val: 3795


In [12]:
# dataloader 
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2, pin_memory=True)


# Siamese U-Net++ Model

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):

    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class SiameseUNetPP(nn.Module):
   
    def __init__(self, n_channels=3, n_classes=1, base_ch=64):
        super().__init__()
        self.base_ch = base_ch

        # ---------- Encoder (shared weights) ----------
        self.enc_conv0 = DoubleConv(n_channels, base_ch)
        self.enc_conv1 = DoubleConv(base_ch, base_ch * 2)
        self.enc_conv2 = DoubleConv(base_ch * 2, base_ch * 4)
        self.enc_conv3 = DoubleConv(base_ch * 4, base_ch * 8)
        self.enc_conv4 = DoubleConv(base_ch * 8, base_ch * 16)
        self.pool = nn.MaxPool2d(2)

        # ---------- Decoder (nested dense connections) ----------
        # channels[i] = base_ch * (2 ** i)
        ch = [base_ch, base_ch*2, base_ch*4, base_ch*8, base_ch*16]

        # For Siamese fusion, encoder outputs are doubled in channels
        # Define decoder convolutions following full U-Net++ dense pattern
        self.dconv3_0 = DoubleConv(ch[3]*2 + ch[4]*2, ch[3])               # X_3,0
        self.dconv2_0 = DoubleConv(ch[2]*2 + ch[3], ch[2])                 # X_2,0
        self.dconv2_1 = DoubleConv(ch[2] + ch[3], ch[2])                   # X_2,1
        self.dconv1_0 = DoubleConv(ch[1]*2 + ch[2], ch[1])                 # X_1,0
        self.dconv1_1 = DoubleConv(ch[1] + ch[2], ch[1])                   # X_1,1
        self.dconv1_2 = DoubleConv(ch[1] + ch[2], ch[1])                   # X_1,2
        self.dconv0_0 = DoubleConv(ch[0]*2 + ch[1], ch[0])                 # X_0,0
        self.dconv0_1 = DoubleConv(ch[0] + ch[1], ch[0])                   # X_0,1
        self.dconv0_2 = DoubleConv(ch[0] + ch[1], ch[0])                   # X_0,2
        self.dconv0_3 = DoubleConv(ch[0] + ch[1], ch[0])                   # X_0,3

        # ---------- Output layers (deep supervision) ----------
        self.out_convs = nn.ModuleList([
            nn.Conv2d(ch[0], n_classes, kernel_size=1),
            nn.Conv2d(ch[0], n_classes, kernel_size=1),
            nn.Conv2d(ch[0], n_classes, kernel_size=1),
            nn.Conv2d(ch[0], n_classes, kernel_size=1)
        ])

    # Utility upsampling
    def _upsample(self, x, ref):
        """Upsample x to match spatial size of ref using bilinear interpolation"""
        return F.interpolate(x, size=ref.shape[2:], mode='bilinear', align_corners=False)

    def forward(self, x1, x2, return_deep=False):
        # ---------- Encoder (Siamese shared weights) ----------
        e0_1 = self.enc_conv0(x1)
        e1_1 = self.enc_conv1(self.pool(e0_1))
        e2_1 = self.enc_conv2(self.pool(e1_1))
        e3_1 = self.enc_conv3(self.pool(e2_1))
        e4_1 = self.enc_conv4(self.pool(e3_1))

        e0_2 = self.enc_conv0(x2)
        e1_2 = self.enc_conv1(self.pool(e0_2))
        e2_2 = self.enc_conv2(self.pool(e1_2))
        e3_2 = self.enc_conv3(self.pool(e2_2))
        e4_2 = self.enc_conv4(self.pool(e3_2))

        # Siamese fusion (concat pre & post)
        e0 = torch.cat([e0_1, e0_2], dim=1)
        e1 = torch.cat([e1_1, e1_2], dim=1)
        e2 = torch.cat([e2_1, e2_2], dim=1)
        e3 = torch.cat([e3_1, e3_2], dim=1)
        e4 = torch.cat([e4_1, e4_2], dim=1)

        # ---------- Decoder (full dense path) ----------
        # Level 3
        X_3_0 = self.dconv3_0(torch.cat([e3, self._upsample(e4, e3)], dim=1))

        # Level 2
        X_2_0 = self.dconv2_0(torch.cat([e2, self._upsample(X_3_0, e2)], dim=1))
        X_2_1 = self.dconv2_1(torch.cat([X_2_0, self._upsample(X_3_0, X_2_0)], dim=1))

        # Level 1
        X_1_0 = self.dconv1_0(torch.cat([e1, self._upsample(X_2_0, e1)], dim=1))
        X_1_1 = self.dconv1_1(torch.cat([X_1_0, self._upsample(X_2_1, X_1_0)], dim=1))
        X_1_2 = self.dconv1_2(torch.cat([X_1_1, self._upsample(X_2_1, X_1_1)], dim=1))

        # Level 0
        X_0_0 = self.dconv0_0(torch.cat([e0, self._upsample(X_1_0, e0)], dim=1))
        X_0_1 = self.dconv0_1(torch.cat([X_0_0, self._upsample(X_1_1, X_0_0)], dim=1))
        X_0_2 = self.dconv0_2(torch.cat([X_0_1, self._upsample(X_1_2, X_0_1)], dim=1))
        X_0_3 = self.dconv0_3(torch.cat([X_0_2, self._upsample(X_1_2, X_0_2)], dim=1))

        # ---------- Deep supervision outputs ----------
        out0 = self.out_convs[0](X_0_0)
        out1 = self.out_convs[1](X_0_1)
        out2 = self.out_convs[2](X_0_2)
        out3 = self.out_convs[3](X_0_3)

        deep_outs = [out0, out1, out2, out3]

        if return_deep:
            return deep_outs
        else:
            # Average the deep supervision maps
            return torch.sigmoid(torch.mean(torch.stack(deep_outs), dim=0))


# # # ---------- Example usage ----------
# if __name__ == "__main__":
#     model = SiameseUNetPP(n_channels=3, n_classes=1, base_ch=64)
#     x1 = torch.randn(1, 3, 256, 256)
#     x2 = torch.randn(1, 3, 256, 256)
#     y = model(x1, x2)
#     print("Output shape:", y.shape)


In [None]:
# model visualization 
!pip install torchviz

In [None]:
from torchviz import make_dot
import torch

# Create two random sample inputs (batch, channel, height, width)
sample_input1 = torch.rand(1, 3, 512, 512)
sample_input2 = torch.rand(1, 3, 512, 512)

# Forward pass through the model with both inputs
output = model(sample_input1, sample_input2)

# Visualize model graph
dot = make_dot(output, params=dict(model.named_parameters()))
dot.format = 'png'
dot.render('model_architecture')

from IPython.display import Image
Image("model_architecture.png")


# Trainig 

In [14]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
import numpy as np

# ---------------------------
# 1Ô∏è‚É£ Loss Function
# ---------------------------
class DiceBCELoss(nn.Module):
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, preds, targets):
        preds = preds.contiguous()
        targets = targets.contiguous()

        bce_loss = self.bce(preds, targets)

        preds = torch.sigmoid(preds)
        intersection = (preds * targets).sum(dim=(2,3))
        dice = (2. * intersection + self.smooth) / (
            preds.sum(dim=(2,3)) + targets.sum(dim=(2,3)) + self.smooth
        )
        dice_loss = 1 - dice.mean()

        return bce_loss + dice_loss


# ---------------------------
# 2Ô∏è‚É£ Dice Metric
# ---------------------------
def dice_score(preds, targets, threshold=0.5):
    preds = torch.sigmoid(preds)
    preds = (preds > threshold).float()
    intersection = (preds * targets).sum()
    union = preds.sum() + targets.sum()
    dice = (2.0 * intersection) / (union + 1e-8)
    return dice.item()


# ---------------------------
# 3Ô∏è‚É£ Train & Validate Epoch
# ---------------------------
def train_one_epoch(model, dataloader, optimizer, loss_fn, device, scaler=None):
    model.train()
    epoch_loss = 0

    loop = tqdm(dataloader, desc="Training", leave=False)
    for pre_img, post_img, mask, _, _ in loop:  # <-- added mask + coords
        pre_img, post_img, mask = pre_img.to(device), post_img.to(device), mask.to(device)

        optimizer.zero_grad()

        with torch.amp.autocast('cuda', enabled=scaler is not None):
            outputs = model(pre_img, post_img)
            loss = loss_fn(outputs, mask)

        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        epoch_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    return epoch_loss / len(dataloader)


def validate_one_epoch(model, dataloader, loss_fn, device):
    model.eval()
    epoch_loss = 0
    epoch_dice = 0

    with torch.no_grad():
        loop = tqdm(dataloader, desc="Validating", leave=False)
        for pre_img, post_img, mask, _, _ in loop:
            pre_img, post_img, mask = pre_img.to(device), post_img.to(device), mask.to(device)

            outputs = model(pre_img, post_img)
            loss = loss_fn(outputs, mask)
            epoch_loss += loss.item()
            epoch_dice += dice_score(outputs, mask)

    return epoch_loss / len(dataloader), epoch_dice / len(dataloader)


# ---------------------------
# 4Ô∏è‚É£ Early Stopping
# ---------------------------
class EarlyStopping:
    def __init__(self, patience=10, delta=0.0, path="best_model.pth"):
        self.patience = patience
        self.delta = delta
        self.path = path
        self.best_loss = float("inf")
        self.counter = 0

    def step(self, val_loss, model):
        if val_loss < self.best_loss - self.delta:
            self.best_loss = val_loss
            self.counter = 0
            torch.save(model.state_dict(), self.path)
            print(f"‚úÖ Model improved ‚Äî saved to {self.path}")
        else:
            self.counter += 1
            print(f"‚ö†Ô∏è EarlyStopping counter: {self.counter}/{self.patience}")
        return self.counter >= self.patience


# ---------------------------
# 5Ô∏è‚É£ Training Loop
# ---------------------------
def train_model(model, train_loader, val_loader, device, epochs=50, lr=1e-4, ckpt_path="best_model.pth"):
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    loss_fn = DiceBCELoss()
    scaler = torch.amp.GradScaler('cuda', enabled=(device.type == "cuda"))
    early_stopper = EarlyStopping(patience=3, path=ckpt_path)

    for epoch in range(epochs):
        print(f"\nüåä Epoch [{epoch+1}/{epochs}]")
        train_loss = train_one_epoch(model, train_loader, optimizer, loss_fn, device, scaler)
        val_loss, val_dice = validate_one_epoch(model, val_loader, loss_fn, device)

        print(f"üìä Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Dice: {val_dice:.4f}")

        if early_stopper.step(val_loss, model):
            print("üõë Early stopping triggered.")
            break

    model.load_state_dict(torch.load(ckpt_path))
    return model


In [15]:
model = SiameseUNetPP(n_channels=3, n_classes=1, base_ch=32)
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
print(device)
train_model(model,train_loader, val_loader, device, 5)

cuda

üåä Epoch [1/5]


                                                                        

üìä Train Loss: 1.4839 | Val Loss: 1.5541 | Val Dice: 0.1495
‚úÖ Model improved ‚Äî saved to best_model.pth

üåä Epoch [2/5]


                                                                        

üìä Train Loss: 1.4422 | Val Loss: 1.5339 | Val Dice: 0.1495
‚úÖ Model improved ‚Äî saved to best_model.pth

üåä Epoch [3/5]


                                                                        

üìä Train Loss: 1.4375 | Val Loss: 1.5381 | Val Dice: 0.1495
‚ö†Ô∏è EarlyStopping counter: 1/3

üåä Epoch [4/5]


                                                                         

üìä Train Loss: 1.4347 | Val Loss: 1.5407 | Val Dice: 0.1524
‚ö†Ô∏è EarlyStopping counter: 2/3

üåä Epoch [5/5]


                                                                         

üìä Train Loss: 1.4324 | Val Loss: 1.5389 | Val Dice: 0.1498
‚ö†Ô∏è EarlyStopping counter: 3/3
üõë Early stopping triggered.


SiameseUNetPP(
  (enc_conv0): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (enc_conv1): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (enc_conv2): DoubleCon

# Inference

In [21]:
import os
import torch
import numpy as np
import rasterio
from rasterio.windows import Window
from skimage.measure import label, regionprops
import cv2
import pandas as pd
from tqdm import tqdm


# ---------------------------------------------------------
# LOAD MODEL
# ---------------------------------------------------------
def load_model(model_class, checkpoint, device):
    model = model_class(n_channels=3, n_classes=1, base_ch=32)
    model.load_state_dict(torch.load(checkpoint, map_location=device))
    model.to(device)
    model.eval()
    return model


# ---------------------------------------------------------
# MODEL PREDICT (single tile)
# ---------------------------------------------------------
def model_predict(model, pre_tile, post_tile, device):

    pre_tile = pre_tile[:, :, :3]
    post_tile = post_tile[:, :, :3]

    pre = torch.tensor(pre_tile / 255., dtype=torch.float32).permute(2,0,1).unsqueeze(0).to(device)
    post = torch.tensor(post_tile / 255., dtype=torch.float32).permute(2,0,1).unsqueeze(0).to(device)

    with torch.no_grad():
        out = model(pre, post)
        pred = torch.sigmoid(out).squeeze().cpu().numpy()

    del pre, post, out
    torch.cuda.empty_cache()
    return pred


# ---------------------------------------------------------
# EXTRACT REGION CENTROIDS + AREA
# ---------------------------------------------------------
def extract_regions(mask, full_transform, pixel_size):

    labelled = label(mask)
    props = regionprops(labelled)

    results = []

    for region in props:
        if region.area < 20:
            continue

        cy, cx = region.centroid
        lon, lat = rasterio.transform.xy(full_transform, cy, cx)
        area_m2 = region.area * (pixel_size ** 2)

        results.append({
            "centroid_lon": lon,
            "centroid_lat": lat,
            "area_m2": area_m2,
        })

    return results


# ---------------------------------------------------------
# SAFE IMAGE WRITER (JPG ONLY ‚Äî PREVENTS libpng ERRORS)
# ---------------------------------------------------------
def safe_write_image(path, image):

    # force JPG only
    path = path.replace(".png", ".jpg")

    # ensure valid uint8
    if image.dtype != np.uint8:
        image = np.nan_to_num(image)
        image = np.clip(image, 0, 255).astype(np.uint8)

    # skip too small tiles
    if image.shape[0] < 4 or image.shape[1] < 4:
        return False

    try:
        cv2.imwrite(path, image, [cv2.IMWRITE_JPEG_QUALITY, 92])
        return True
    except:
        return False



# ---------------------------------------------------------
# MAIN INFERENCE
# ---------------------------------------------------------
def run_inference(
    model,
    pre_path,
    post_path,
    output_dir,
    tile_size=512,
    threshold=0.5,
):

    os.makedirs(output_dir, exist_ok=True)
    cutout_dir = os.path.join(output_dir, "cutouts")
    os.makedirs(cutout_dir, exist_ok=True)

    device = next(model.parameters()).device
    pre = rasterio.open(pre_path)
    post = rasterio.open(post_path)

    assert pre.shape == post.shape, "ERROR: Pre/Post must be aligned!"

    H, W = pre.height, pre.width
    base_transform = pre.transform
    pixel_size = pre.res[0]

    rows = (H + tile_size - 1) // tile_size
    cols = (W + tile_size - 1) // tile_size

    results = []
    tile_id_counter = 0

    for ty in tqdm(range(rows), desc="Rows"):
        for tx in range(cols):

            x_off = tx * tile_size
            y_off = ty * tile_size

            w = min(tile_size, W - x_off)
            h = min(tile_size, H - y_off)

            window = Window(x_off, y_off, w, h)

            # Read only required area
            pre_arr = pre.read(window=window)
            post_arr = post.read(window=window)

            pre_tile = np.moveaxis(pre_arr, 0, 2).astype(np.uint8)
            post_tile = np.moveaxis(post_arr, 0, 2).astype(np.uint8)

            del pre_arr, post_arr

            # Predict mask
            pred = model_predict(model, pre_tile, post_tile, device)
            mask = (pred > threshold).astype(np.uint8)

            tile_transform = base_transform * rasterio.Affine.translation(x_off, y_off)
            regions = extract_regions(mask, tile_transform, pixel_size)

            # Skip saving tile images if NO regions detected
            if len(regions) == 0:
                continue

            tile_id = f"tile_{tile_id_counter:05d}"
            tile_id_counter += 1

            pre_file = os.path.join(cutout_dir, f"{tile_id}_pre.jpg")
            post_file = os.path.join(cutout_dir, f"{tile_id}_post.jpg")

            # Make BGR for OpenCV
            pre_img = cv2.cvtColor(pre_tile[:,:,:3], cv2.COLOR_RGB2BGR)
            post_img = cv2.cvtColor(post_tile[:,:,:3], cv2.COLOR_RGB2BGR)

            safe_write_image(pre_file, pre_img)
            safe_write_image(post_file, post_img)

            for region in regions:
                results.append({
                    "tile_id": tile_id,
                    "center_longitude": region["centroid_lon"],
                    "center_latitude": region["centroid_lat"],
                    "area_m2": region["area_m2"],
                    "area_lost_m2": region["area_m2"],
                    "pre_flood_land_image": pre_file,
                    "post_flood_land_image": post_file,
                })

            del pre_tile, post_tile, mask, pred, regions
            torch.cuda.empty_cache()

    df = pd.DataFrame(results)
    df.to_csv(os.path.join(output_dir, "affected.csv"), index=False)

    print("\n‚úÖ Inference Complete!")
    print("Output directory:", output_dir)


In [None]:
# from inference_flood_change import run_inference, load_model
# from siamese_unetpp import SiameseUNetPP  # your model file

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model = load_model(
#     SiameseUNetPP,
#     "best_siamese_unetpp.pth",
#     device
# )

model = model.eval()

run_inference(
    model=model,
    pre_path="/kaggle/working/aligned/pre_aligned.tif",
    post_path="/kaggle/working/aligned/post_aligned.tif",
    output_dir="submissions/Dharay/",
    tile_size=512,
    threshold=0.45
)
