In [None]:
# Hybrid Flood Model Training Notebook
# Includes: Synthetic data generation, patching, training loop with MSE loss, checkpointing

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
from tqdm import tqdm

# -------------------------
# Model Components (from previous response)
# -------------------------
class GlobalEncoder(nn.Module):
    def __init__(self, in_channels=4, out_channels=64):
        super(GlobalEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(64, out_channels, 3, stride=2, padding=1), nn.ReLU()
        )

    def forward(self, x):
        return self.encoder(x)

class LocalUNet(nn.Module):
    def __init__(self, in_channels=4, global_channels=64):
        super(LocalUNet, self).__init__()
        input_channels = in_channels + global_channels
        self.down1 = nn.Sequential(
            nn.Conv2d(input_channels, 64, 3, padding=1), nn.ReLU()
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU()
        )
        self.pool = nn.MaxPool2d(2)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.final_conv = nn.Conv2d(64, 1, 1)

    def forward(self, x_local, global_context):
        global_upsampled = nn.functional.interpolate(global_context, size=x_local.shape[2:], mode='bilinear')
        x = torch.cat([x_local, global_upsampled], dim=1)
        x1 = self.down1(x)
        x2 = self.pool(x1)
        x3 = self.down2(x2)
        x4 = self.up1(x3)
        x5 = torch.relu(x1 + x4)
        return self.final_conv(x5)

class HybridFloodModel(nn.Module):
    def __init__(self):
        super(HybridFloodModel, self).__init__()
        self.global_encoder = GlobalEncoder(in_channels=4)
        self.local_unet = LocalUNet(in_channels=4, global_channels=64)

    def forward(self, global_input, local_patch):
        global_context = self.global_encoder(global_input)
        return self.local_unet(local_patch, global_context)

# -------------------------
# Synthetic Dataset
# -------------------------
class SyntheticFloodDataset(Dataset):
    def __init__(self, n_samples=200):
        self.n_samples = n_samples

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        global_input = torch.rand(4, 64, 64)  # coarse input
        local_patch = torch.rand(4, 256, 256) # hi-res patch
        target = torch.rand(1, 256, 256)      # flood depth
        return global_input, local_patch, target

# -------------------------
# Training Setup
# -------------------------
def save_checkpoint(model, optimizer, epoch, loss, path="checkpoint.pt"):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
    }, path)

def load_checkpoint(model, optimizer, path="checkpoint.pt"):
    if os.path.exists(path):
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Loaded checkpoint from epoch {checkpoint['epoch']}.")
        return start_epoch
    else:
        print("No checkpoint found, starting from scratch.")
        return 0

# -------------------------
# Training Loop
# -------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HybridFloodModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

dataset = SyntheticFloodDataset(n_samples=200)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

num_epochs = 50
start_epoch = load_checkpoint(model, optimizer)

for epoch in range(start_epoch, num_epochs):
    model.train()
    epoch_loss = 0
    for global_input, local_patch, target in tqdm(dataloader):
        global_input = global_input.to(device)
        local_patch = local_patch.to(device)
        target = target.to(device)

        optimizer.zero_grad()
        output = model(global_input, local_patch)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(dataloader):.4f}")
    save_checkpoint(model, optimizer, epoch, loss.item())
    torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pt")
