## Imports and Setup
Updated with NOISE_STD, VGG_WEIGHT, and library imports. Checks for GPU availability

In [None]:
import os
import json
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader, random_split
# FIXED: Updated AMP imports
from torch.amp import autocast, GradScaler 
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.auto import tqdm  # FIXED: Progress bar
import matplotlib.pyplot as plt

# Set seeds
torch.manual_seed(42)
np.random.seed(42)

# --- Configuration ---
BATCH_SIZE = 4
LEARNING_RATE = 1e-4
EPOCHS = 50
PHYSICS_WEIGHT = 0.3  # Increased for stronger divergence-free constraint
SPECTRAL_WEIGHT = 0.3  # Increased to compensate for removed vorticity
VORTICITY_WEIGHT = 0.2  # Vorticity matching weight
OUTPUT_DIR = "/kaggle/working"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üöÄ Device: {device}")

# Dataset Discovery
search_path = "/kaggle/input/**/*.pt"
pt_files = glob.glob(search_path, recursive=True)
if not pt_files:
    raise FileNotFoundError("‚ùå Dataset not found!")
DATA_FILE = pt_files[0]
DATA_DIR = os.path.dirname(DATA_FILE)
STATS_FILE = os.path.join(DATA_DIR, "normalization_stats.json")

## Dataset Class 
Added ```noise_std``` parameter for Sim-to-Real gap augmentation.

In [None]:
class FluidLoader(Dataset):
    def __init__(self, pt_file, stats_file, target_res=256):
        print(f"‚è≥ Loading data map...")
        try:
            data = torch.load(pt_file, map_location='cpu', mmap=True)
        except:
            data = torch.load(pt_file, map_location='cpu')
            
        self.inputs = data['inputs']
        self.targets = data['targets']
        self.target_res = target_res
        
        # FIXED: Check resolution once at startup
        # We check the first sample to determine if upscaling is needed
        sample_h = self.inputs.shape[-1]
        self.needs_upscale = (sample_h != target_res)
        
        # Load K
        if 'K' in data:
            self.K = float(data['K'])
        elif os.path.exists(stats_file):
            with open(stats_file, 'r') as f:
                stats = json.load(f)
            self.K = float(stats['scaling_factor'])
        else:
            self.K = 1.0
            
        print(f"‚úÖ Loaded {len(self.inputs)} samples. Upscaling: {self.needs_upscale}")

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        # FIXED: Removed .clone() (F.interpolate creates a new tensor anyway)
        lr = self.inputs[idx] 
        hr = self.targets[idx]
        
        # FIXED: Use pre-calculated flag instead of checking shape every time
        if self.needs_upscale:
            lr = F.interpolate(
                lr.unsqueeze(0), 
                size=(self.target_res, self.target_res), 
                mode='bilinear', 
                align_corners=False
            ).squeeze(0)
            
        return lr, hr

## Model Architecture (ResUNet)

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(residual)
        out = self.relu(out)
        return out

class ResUNet(nn.Module):
    def __init__(self, in_channels=6, out_channels=2, features=[64, 128, 256, 512]):
        super(ResUNet, self).__init__()
        self.encoder = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Encoder
        self.input_conv = nn.Sequential(
            nn.Conv2d(in_channels, features[0], kernel_size=3, padding=1),
            nn.BatchNorm2d(features[0]),
            nn.ReLU(inplace=True)
        )
        input_feat = features[0]
        for feature in features:
            self.encoder.append(ResidualBlock(input_feat, feature))
            input_feat = feature

        # Bottleneck
        self.bottleneck = ResidualBlock(features[-1], features[-1] * 2)

        # Decoder
        self.upconvs = nn.ModuleList()
        self.decoder = nn.ModuleList()
        features = features[::-1]
        for feature in features:
            self.upconvs.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2))
            self.decoder.append(ResidualBlock(feature * 2, feature))

        # Output
        self.final_conv = nn.Conv2d(features[-1], out_channels, kernel_size=1)
        
    def forward(self, x):
        skip_connections = []
        out = self.input_conv(x)
        
        for layer in self.encoder:
            out = layer(out)
            skip_connections.append(out)
            out = self.pool(out)
            
        out = self.bottleneck(out)
        skip_connections = skip_connections[::-1] 
        
        for idx in range(len(self.decoder)):
            out = self.upconvs[idx](out)
            skip = skip_connections[idx]
            if out.shape != skip.shape:
                out = F.interpolate(out, size=skip.shape[2:], mode='bilinear', align_corners=False)
            concat_skip = torch.cat((skip, out), dim=1)
            out = self.decoder[idx](concat_skip)
            
        out = self.final_conv(out)
        
        return out

## Loss Functions ( VGG, Physics Informed)

In [None]:
# Spectral and Vorticity losses replace VGG for velocity fields
class SpectralLoss(nn.Module):
    def __init__(self):
        super(SpectralLoss, self).__init__()

    def forward(self, pred, target):
        # Compute FFT magnitude for each channel and match spectra
        # Assumes inputs are (B, 2, H, W)
        pred_fft = torch.fft.rfft2(pred, norm='ortho')
        targ_fft = torch.fft.rfft2(target, norm='ortho')
        pred_mag = torch.abs(pred_fft)
        targ_mag = torch.abs(targ_fft)
        return F.l1_loss(pred_mag, targ_mag)

def compute_vorticity(field):
    # field: (B, 2, H, W) with channels (u, v)
    u = field[:, 0]
    v = field[:, 1]
    # Centered differences with padding
    u_pad = F.pad(u, (1,1,1,1), mode='replicate')
    v_pad = F.pad(v, (1,1,1,1), mode='replicate')
    du_dy = (u_pad[:, 2:, 1:-1] - u_pad[:, :-2, 1:-1]) * 0.5
    dv_dx = (v_pad[:, 1:-1, 2:] - v_pad[:, 1:-1, :-2]) * 0.5
    return dv_dx - du_dy

class VorticityLoss(nn.Module):
    def __init__(self):
        super(VorticityLoss, self).__init__()
        self.spectral = SpectralLoss()

    def forward(self, pred, target):
        w_pred = compute_vorticity(pred)
        w_targ = compute_vorticity(target)
        vort_l2 = F.mse_loss(w_pred, w_targ)
        # Optional: spectral match on vorticity magnitude
        w_pred_fft = torch.fft.rfft2(w_pred, norm='ortho')
        w_targ_fft = torch.fft.rfft2(w_targ, norm='ortho')
        spec_l1 = F.l1_loss(torch.abs(w_pred_fft), torch.abs(w_targ_fft))
        return vort_l2 + 0.5 * spec_l1

class DivergenceLoss(nn.Module):
    def __init__(self):
        super(DivergenceLoss, self).__init__()

    def forward(self, output, scaling_factor, mask=None):
        # output: (B, 2, H, W); mask: (B, 1, H, W) where 1 means active fluid
        u = output[:, 0] * scaling_factor
        v = output[:, 1] * scaling_factor
        # Centered differences with replicate padding to limit boundary artifacts
        u_pad = F.pad(u, (1,1,1,1), mode='replicate')
        v_pad = F.pad(v, (1,1,1,1), mode='replicate')
        du_dx = (u_pad[:, 1:-1, 2:] - u_pad[:, 1:-1, :-2]) * 0.5
        dv_dy = (v_pad[:, 2:, 1:-1] - v_pad[:, :-2, 1:-1]) * 0.5
        div = du_dx + dv_dy
        if mask is not None:
            div = div * mask.squeeze(1)
        return torch.mean(div**2)

In [None]:
def compute_metrics(pred, target, K=1.0):
    """
    Compute physics-aware metrics for evaluation.
    
    Args:
        pred: (B, 2, H, W) predicted velocity field
        target: (B, 2, H, W) ground truth velocity field
        K: scaling factor
    
    Returns:
        dict of metrics
    """
    metrics = {}
    
    # 1. MSE
    metrics['mse'] = F.mse_loss(pred, target).item()
    
    # 2. Divergence error
    u_pred = pred[:, 0] * K
    v_pred = pred[:, 1] * K
    u_targ = target[:, 0] * K
    v_targ = target[:, 1] * K
    
    # Centered differences
    u_pred_pad = F.pad(u_pred, (1,1,1,1), mode='replicate')
    v_pred_pad = F.pad(v_pred, (1,1,1,1), mode='replicate')
    u_targ_pad = F.pad(u_targ, (1,1,1,1), mode='replicate')
    v_targ_pad = F.pad(v_targ, (1,1,1,1), mode='replicate')
    
    div_pred = (u_pred_pad[:, 1:-1, 2:] - u_pred_pad[:, 1:-1, :-2]) * 0.5 + \
               (v_pred_pad[:, 2:, 1:-1] - v_pred_pad[:, :-2, 1:-1]) * 0.5
    div_targ = (u_targ_pad[:, 1:-1, 2:] - u_targ_pad[:, 1:-1, :-2]) * 0.5 + \
               (v_targ_pad[:, 2:, 1:-1] - v_targ_pad[:, :-2, 1:-1]) * 0.5
    
    metrics['div_l2_pred'] = torch.sqrt(torch.mean(div_pred**2)).item()
    metrics['div_l2_targ'] = torch.sqrt(torch.mean(div_targ**2)).item()
    metrics['div_max_pred'] = torch.max(torch.abs(div_pred)).item()
    
    # 3. Vorticity error
    w_pred = compute_vorticity(pred * K)
    w_targ = compute_vorticity(target * K)
    metrics['vort_mse'] = F.mse_loss(w_pred, w_targ).item()
    
    # 4. Energy spectrum correlation (simplified)
    pred_fft = torch.fft.rfft2(pred, norm='ortho')
    targ_fft = torch.fft.rfft2(target, norm='ortho')
    pred_energy = torch.abs(pred_fft).pow(2).mean().item()
    targ_energy = torch.abs(targ_fft).pow(2).mean().item()
    metrics['energy_pred'] = pred_energy
    metrics['energy_targ'] = targ_energy
    metrics['energy_ratio'] = pred_energy / (targ_energy + 1e-8)
    
    return metrics

def visualize_field_comparison(lr, pred, target, idx=0, save_path=None):
    """
    Visualize LR input, SR prediction, and HR ground truth with vorticity.
    
    Args:
        lr: (B, 6, H, W) LR input (first 2 channels are velocity)
        pred: (B, 2, H, W) predicted HR velocity
        target: (B, 2, H, W) ground truth HR velocity
        idx: batch index to visualize
        save_path: optional path to save figure
    """
    # Extract velocity magnitude
    lr_vel = lr[idx, :2].cpu()
    pred_vel = pred[idx].cpu()
    targ_vel = target[idx].cpu()
    
    lr_mag = torch.sqrt(lr_vel[0]**2 + lr_vel[1]**2).numpy()
    pred_mag = torch.sqrt(pred_vel[0]**2 + pred_vel[1]**2).numpy()
    targ_mag = torch.sqrt(targ_vel[0]**2 + targ_vel[1]**2).numpy()
    
    # Compute vorticity
    lr_vort = compute_vorticity(lr_vel.unsqueeze(0)).squeeze(0).numpy()
    pred_vort = compute_vorticity(pred_vel.unsqueeze(0)).squeeze(0).numpy()
    targ_vort = compute_vorticity(targ_vel.unsqueeze(0)).squeeze(0).numpy()
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Row 1: Velocity magnitude
    im0 = axes[0,0].imshow(lr_mag, cmap='viridis')
    axes[0,0].set_title('LR Input (Velocity Mag)')
    plt.colorbar(im0, ax=axes[0,0])
    
    im1 = axes[0,1].imshow(pred_mag, cmap='viridis')
    axes[0,1].set_title('SR Prediction (Velocity Mag)')
    plt.colorbar(im1, ax=axes[0,1])
    
    im2 = axes[0,2].imshow(targ_mag, cmap='viridis')
    axes[0,2].set_title('HR Target (Velocity Mag)')
    plt.colorbar(im2, ax=axes[0,2])
    
    # Row 2: Vorticity
    vmin = min(lr_vort.min(), pred_vort.min(), targ_vort.min())
    vmax = max(lr_vort.max(), pred_vort.max(), targ_vort.max())
    
    im3 = axes[1,0].imshow(lr_vort, cmap='RdBu_r', vmin=vmin, vmax=vmax)
    axes[1,0].set_title('LR Vorticity')
    plt.colorbar(im3, ax=axes[1,0])
    
    im4 = axes[1,1].imshow(pred_vort, cmap='RdBu_r', vmin=vmin, vmax=vmax)
    axes[1,1].set_title('SR Vorticity')
    plt.colorbar(im4, ax=axes[1,1])
    
    im5 = axes[1,2].imshow(targ_vort, cmap='RdBu_r', vmin=vmin, vmax=vmax)
    axes[1,2].set_title('HR Vorticity')
    plt.colorbar(im5, ax=axes[1,2])
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()
    
    return fig

## Evaluation Metrics
Compute physics-aware metrics: divergence, vorticity spectra, and energy spectrum.

## Training Loop

In [None]:
def train(use_physics=False):
    # 1. Data
    dataset = FluidLoader(DATA_FILE, STATS_FILE)
    train_sz = int(0.8 * len(dataset))
    val_sz = len(dataset) - train_sz
    train_ds, val_ds = random_split(dataset, [train_sz, val_sz], generator=torch.Generator().manual_seed(42))
    
    # Num_workers=2 is usually safe on Kaggle
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

    # 2. Setup
    model = ResUNet(in_channels=6, out_channels=2).to(device)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    # FIXED: Added Scheduler
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
    # FIXED: Updated Scaler init
    scaler = GradScaler('cuda') 
    
    mse_fn = nn.MSELoss()
    spec_fn = SpectralLoss().to(device)
    vort_fn = VorticityLoss().to(device)
    div_fn = DivergenceLoss().to(device)
    K = dataset.K
    
    mode_name = "PINN" if use_physics else "Baseline"
    print(f"\nüöÄ Starting {mode_name} Training...")
    
    hist = {'train': [], 'val': [], 'lr': []}
    best_val_loss = float('inf')
    patience = 7
    counter = 0

    for epoch in range(EPOCHS):
        model.train()
        train_loss = 0.0
        
        # FIXED: TQDM Progress Bar
        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)
        
        for x, y in loop:
            x, y = x.to(device), y.to(device)
            
            optimizer.zero_grad(set_to_none=True) # set_to_none is slightly faster
            
            with autocast('cuda'): 
                pred = model(x)
                
                # Losses
                loss = mse_fn(pred, y)
                loss += SPECTRAL_WEIGHT * spec_fn(pred, y)
                loss += VORTICITY_WEIGHT * vort_fn(pred, y)
                
                if use_physics:
                    # Physics: Divergence-free with optional mask (None for now)
                    loss += PHYSICS_WEIGHT * div_fn(pred, K, mask=None)
            
            scaler.scale(loss).backward()
            
            # FIXED: Gradient Clipping (Unscale before clipping)
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            scaler.step(optimizer)
            scaler.update()
            
            train_loss += loss.item()
            loop.set_postfix(loss=loss.item())
            
        # Validation
        model.eval()
        val_loss = 0.0
        val_metrics_accum = {'div_l2_pred': 0, 'vort_mse': 0, 'energy_ratio': 0}
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                with autocast('cuda'):
                    pred = model(x)
                    val_loss += mse_fn(pred, y).item()
                    # Compute physics metrics
                    batch_metrics = compute_metrics(pred, y, K)
                    for k in val_metrics_accum:
                        val_metrics_accum[k] += batch_metrics[k]
        
        avg_train = train_loss / len(train_loader)
        avg_val = val_loss / len(val_loader)
        for k in val_metrics_accum:
            val_metrics_accum[k] /= len(val_loader)
        
        # Step Scheduler
        scheduler.step(avg_val)
        current_lr = optimizer.param_groups[0]['lr']
        
        hist['train'].append(avg_train)
        hist['val'].append(avg_val)
        hist['lr'].append(current_lr)
        
        print(f"Epoch {epoch+1} | Train: {avg_train:.6f} | Val: {avg_val:.6f} | LR: {current_lr:.2e}")
        print(f"  Metrics: Div={val_metrics_accum['div_l2_pred']:.4f} | Vort={val_metrics_accum['vort_mse']:.4f} | Energy={val_metrics_accum['energy_ratio']:.3f}")
        
        # FIXED: Better Checkpointing
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_val,
        }
        
        if avg_val < best_val_loss:
            best_val_loss = avg_val
            counter = 0
            torch.save(checkpoint, f"{OUTPUT_DIR}/{mode_name}_best.pth")
            print("  --> New Best Model Saved!")
        else:
            counter += 1
            if counter >= patience:
                print("üõë Early Stopping triggered.")
                break
                
    # FIXED: Plotting improvements
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(hist['train'], label='Train')
    plt.plot(hist['val'], label='Val')
    plt.title(f'{mode_name} Loss')
    plt.yscale('log') # Log scale is often better for loss
    plt.grid(True, alpha=0.3)
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(hist['lr'], color='orange')
    plt.title('Learning Rate')
    plt.grid(True, alpha=0.3)
    
    plt.savefig(f"{OUTPUT_DIR}/{mode_name}_history.png")
    plt.show()
    
    # Reload best model for return
    best_checkpoint = torch.load(f"{OUTPUT_DIR}/{mode_name}_best.pth")
    model.load_state_dict(best_checkpoint['model_state_dict'])
    
    return model, dataset

## Execute
Run this cell to start.

In [None]:
# --- Run Baseline Training ---
# 1. Train the model (Physics is OFF for baseline)
#model, dataset = train(use_physics=False)

# --- (Optional) Run Physics-Informed Training ---
# Uncomment the line below to train with Physics Loss enabled
model_pinn, dataset = train(use_physics=True)

In [None]:
# Sanity check: Load a validation batch and compute metrics
def sanity_check(model, dataset, device, K):
    """
    Run inference on a single batch and visualize results with metrics.
    """
    model.eval()
    # Get a small subset
    val_loader = DataLoader(dataset, batch_size=4, shuffle=False)
    x, y = next(iter(val_loader))
    x, y = x.to(device), y.to(device)
    
    with torch.no_grad():
        pred = model(x)
    
    # Compute metrics
    metrics = compute_metrics(pred, y, K)
    
    print("=" * 60)
    print("SANITY CHECK METRICS")
    print("=" * 60)
    for k, v in metrics.items():
        print(f"{k:20s}: {v:.6f}")
    print("=" * 60)
    
    # Visualize first sample
    visualize_field_comparison(x, pred, y, idx=0, save_path=f"{OUTPUT_DIR}/sanity_check.png")
    
    return metrics

# Run sanity check after training
# Uncomment after training completes:
# sanity_metrics = sanity_check(model, dataset, device, dataset.K)

## Sanity Check & Visualization
Quick validation of a single batch to check improvements.