## Imports and Setup
Updated with NOISE_STD, VGG_WEIGHT, and library imports. Checks for GPU availability

In [None]:
import os
import json
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm.notebook import tqdm  # Use notebook version of tqdm
import matplotlib.pyplot as plt

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# --- Configuration ---
BATCH_SIZE = 4          # Keep small (4) to prevent OOM on Kaggle P100/T4
LEARNING_RATE_G = 1e-4
LEARNING_RATE_D = 1e-4
EPOCHS = 100            # INCREASED: Train longer for better convergence
PRETRAIN_EPOCHS = 20    # INCREASED: Longer warmup (MSE only) for stability

# Loss Weights
MSE_WEIGHT = 1.0        # Content Loss
PHYSICS_WEIGHT = 0.2    # INCREASED: Stronger physics constraint (was 0.1)
ADVERSARIAL_WEIGHT = 0.001 # GAN Loss (Keep small!)

OUTPUT_DIR = "/kaggle/working"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üöÄ Device: {DEVICE}")

# --- Dataset Discovery ---
search_path = "/kaggle/input/**/*.pt"
pt_files = glob.glob(search_path, recursive=True)

if not pt_files:
    print("‚ö†Ô∏è  No .pt file found automatically. Please check your path.")
    DATA_FILE = "dataset/native_64.pt" 
else:
    DATA_FILE = pt_files[0]
    print(f"üìÇ Dataset found at: {DATA_FILE}")

## Dataset Class 
Added ```noise_std``` parameter for Sim-to-Real gap augmentation.

In [None]:
class FluidLoader(Dataset):
    def __init__(self, pt_file):
        print(f"‚è≥ Loading data from {pt_file}...")
        try:
            data = torch.load(pt_file, map_location='cpu', mmap=True)
        except:
            print("‚ö†Ô∏è mmap failed. Loading entire dataset to RAM.")
            data = torch.load(pt_file, map_location='cpu')
            
        self.inputs = data['inputs']
        self.targets = data['targets']
        
        self.K_vel = float(data.get('K_vel', 1.0))
        self.K_pres = float(data.get('K_pres', 1.0))
        self.K_smoke = float(data.get('K_smoke', 1.0))
            
        print(f"‚úÖ Loaded {len(self.inputs)} samples.")
        print(f"   Normalization: v={self.K_vel}, p={self.K_pres}, s={self.K_smoke}")

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        lr = self.inputs[idx].float()
        hr = self.targets[idx].float()
        
        if lr.shape[-1] == 256:
            lr = F.interpolate(lr.unsqueeze(0), size=(64, 64), mode='bilinear', align_corners=False).squeeze(0)
            
        return lr, hr

## Model Architecture (ResUNet)

In [None]:
# --- Generator (SRResNet) ---
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        out = self.prelu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return out + residual

class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, scale_factor):
        super(UpsampleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * (scale_factor ** 2), kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(scale_factor)
        self.prelu = nn.PReLU()

    def forward(self, x):
        return self.prelu(self.pixel_shuffle(self.conv(x)))

class Generator(nn.Module):
    def __init__(self, in_channels=4, out_channels=4, hidden_channels=64, num_res_blocks=16):
        super(Generator, self).__init__()
        
        # Initial Feature Extraction
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=9, padding=4),
            nn.PReLU()
        )

        # Residual Trunk
        res_blocks = []
        for _ in range(num_res_blocks):
            res_blocks.append(ResidualBlock(hidden_channels))
        self.res_blocks = nn.Sequential(*res_blocks)
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(hidden_channels)
        )

        # Upsampling (two 2x blocks for 4x total)
        self.upsample = nn.Sequential(
            UpsampleBlock(hidden_channels, 2),
            UpsampleBlock(hidden_channels, 2)
        )

        # Final Reconstruction
        self.final_conv = nn.Conv2d(hidden_channels, out_channels, kernel_size=9, padding=4)

    def forward(self, x):
        out1 = self.conv1(x)
        out = self.res_blocks(out1)
        out = self.conv2(out)
        out = out + out1 # Global Skip Connection
        out = self.upsample(out)
        out = self.final_conv(out)
        return out

# --- Discriminator (VGG-Style) ---
class Discriminator(nn.Module):
    def __init__(self, in_channels=4, hidden_channels=64):
        super(Discriminator, self).__init__()
        
        def conv_block(in_c, out_c, stride=1, bn=True):
            layers = [nn.Conv2d(in_c, out_c, 3, stride, 1, bias=False)]
            if bn: layers.append(nn.BatchNorm2d(out_c))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.features = nn.Sequential(
            *conv_block(in_channels, hidden_channels, stride=1, bn=False),
            *conv_block(hidden_channels, hidden_channels, stride=2, bn=True),
            *conv_block(hidden_channels, hidden_channels*2, stride=1, bn=True),
            *conv_block(hidden_channels*2, hidden_channels*2, stride=2, bn=True),
            *conv_block(hidden_channels*2, hidden_channels*4, stride=1, bn=True),
            *conv_block(hidden_channels*4, hidden_channels*4, stride=2, bn=True),
            *conv_block(hidden_channels*4, hidden_channels*8, stride=1, bn=True),
            *conv_block(hidden_channels*8, hidden_channels*8, stride=2, bn=True),
        )

        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(hidden_channels*8, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1)
            # No Sigmoid here because we use BCEWithLogitsLoss
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)

## Loss Functions ( VGG, Physics Informed)

In [None]:
class PressurePoissonLoss(nn.Module):
    def __init__(self):
        super(PressurePoissonLoss, self).__init__()
        self.register_buffer('laplacian', torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], dtype=torch.float32))
        self.register_buffer('k_x', torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]], dtype=torch.float32) / 8.0)
        self.register_buffer('k_y', torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]], dtype=torch.float32) / 8.0)

    def get_grads(self, f):
        return F.conv2d(f, self.k_x, padding=1), F.conv2d(f, self.k_y, padding=1)

    def get_laplacian(self, f):
        return F.conv2d(f, self.laplacian, padding=1)

    def get_divergence(self, fx, fy):
        dfx_dx, _ = self.get_grads(fx)
        _, dfy_dy = self.get_grads(fy)
        return dfx_dx + dfy_dy

    def forward(self, pred, K_vel, K_pres, K_smoke):
        # Unpack and Un-normalize
        u = pred[:, 0:1, :, :] * K_vel
        v = pred[:, 1:2, :, :] * K_vel
        p = pred[:, 2:3, :, :] * K_pres
        s = pred[:, 3:4, :, :] * K_smoke

        lhs = self.get_laplacian(p)

        du_dx, du_dy = self.get_grads(u)
        dv_dx, dv_dy = self.get_grads(v)
        conv_u = u * du_dx + v * du_dy 
        conv_v = u * dv_dx + v * dv_dy 
        div_convection = self.get_divergence(conv_u, conv_v)

        force_y = s * 0.3
        div_force = self.get_divergence(torch.zeros_like(s), force_y)

        rhs = -div_convection + div_force
        return F.mse_loss(lhs, rhs)

## Training Loop

In [None]:
import gc
from torch.amp import autocast, GradScaler

def train_srgan():
    # --- 1. MEMORY CLEANUP ---
    # Clear old variables and GPU cache to prevent OOM
    gc.collect()
    torch.cuda.empty_cache()
    
    # --- 2. Dataset Setup ---
    dataset = FluidLoader(DATA_FILE)
    train_sz = int(0.9 * len(dataset))
    val_sz = len(dataset) - train_sz
    train_ds, val_ds = random_split(dataset, [train_sz, val_sz], generator=torch.Generator().manual_seed(42))
    
    # Use global BATCH_SIZE
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

    # --- 3. Models Setup ---
    netG = Generator(in_channels=4, out_channels=4).to(DEVICE)
    netD = Discriminator(in_channels=4).to(DEVICE)

    optimizerG = optim.Adam(netG.parameters(), lr=LEARNING_RATE_G)
    optimizerD = optim.Adam(netD.parameters(), lr=LEARNING_RATE_D)

    # --- 4. Losses & AMP Scaler ---
    mse_fn = nn.MSELoss()
    gan_fn = nn.BCEWithLogitsLoss()
    phys_fn = PressurePoissonLoss().to(DEVICE)
    scaler = GradScaler('cuda') # For Mixed Precision
    
    K_vel, K_pres, K_smoke = dataset.K_vel, dataset.K_pres, dataset.K_smoke

    print(f"\nüöÄ Starting SRGAN Training (AMP Enabled, BS={BATCH_SIZE})...")
    print(f"   Phase 1: Pre-training ({PRETRAIN_EPOCHS} epochs)")
    print(f"   Phase 2: GAN Training ({EPOCHS - PRETRAIN_EPOCHS} epochs)")

    # --- Training Loop ---
    for epoch in range(EPOCHS):
        netG.train()
        netD.train()
        
        is_pretraining = epoch < PRETRAIN_EPOCHS
        phase_name = "PRE-TRAIN" if is_pretraining else "GAN-TRAIN"
        
        loss_g_accum = 0.0
        loss_d_accum = 0.0
        
        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [{phase_name}]", leave=False)
        
        for lr_img, hr_img in loop:
            lr_img = lr_img.to(DEVICE)
            hr_img = hr_img.to(DEVICE)
            
            # --- 1. Train Generator ---
            optimizerG.zero_grad()
            
            with autocast('cuda'): # Mixed Precision Context
                sr_img = netG(lr_img)
                
                # Content Loss
                loss_content = MSE_WEIGHT * mse_fn(sr_img, hr_img)
                
                # Physics Loss
                loss_phys = PHYSICS_WEIGHT * phys_fn(sr_img, K_vel, K_pres, K_smoke)
                
                # Adversarial Loss (Phase 2)
                loss_adv = 0.0
                if not is_pretraining:
                    pred_fake = netD(sr_img)
                    loss_adv = ADVERSARIAL_WEIGHT * gan_fn(pred_fake, torch.ones_like(pred_fake))

                loss_G = loss_content + loss_phys + loss_adv
            
            # Scaled Backward Pass
            scaler.scale(loss_G).backward()
            scaler.step(optimizerG)
            scaler.update()
            
            loss_g_accum += loss_G.item()

            # --- 2. Train Discriminator (Phase 2 Only) ---
            loss_D = 0.0
            if not is_pretraining:
                optimizerD.zero_grad()
                
                with autocast('cuda'):
                    # Real Loss
                    pred_real = netD(hr_img)
                    loss_d_real = gan_fn(pred_real, torch.ones_like(pred_real))
                    
                    # Fake Loss (Detach to freeze G)
                    pred_fake = netD(sr_img.detach())
                    loss_d_fake = gan_fn(pred_fake, torch.zeros_like(pred_fake))
                    
                    loss_D = (loss_d_real + loss_d_fake) / 2
                
                scaler.scale(loss_D).backward()
                scaler.step(optimizerD)
                scaler.update()
                
                loss_d_accum += loss_D.item()

            loop.set_postfix(G=loss_G.item(), D=loss_D if not is_pretraining else 0.0)

        # --- Validation ---
        avg_g = loss_g_accum / len(train_loader)
        avg_d = loss_d_accum / len(train_loader)
        
        netG.eval()
        val_mse = 0.0
        with torch.no_grad():
            for lr, hr in val_loader:
                lr, hr = lr.to(DEVICE), hr.to(DEVICE)
                with autocast('cuda'):
                    val_mse += mse_fn(netG(lr), hr).item()
        val_mse /= len(val_loader)
        
        print(f"Epoch {epoch+1} | {phase_name} | G_Loss: {avg_g:.5f} | D_Loss: {avg_d:.5f} | Val MSE: {val_mse:.6f}")

        # Save Checkpoints
        os.makedirs("checkpoints", exist_ok=True)
        if (epoch+1) % 10 == 0 or epoch == EPOCHS-1:
            torch.save(netG.state_dict(), f"{OUTPUT_DIR}/SRGAN_Gen_epoch_{epoch+1}.pth")
            if not is_pretraining:
                torch.save(netD.state_dict(), f"{OUTPUT_DIR}/SRGAN_Disc_epoch_{epoch+1}.pth")
    
    print("\n‚úÖ Training Complete!")
    return netG

## Execute
Run this cell to start.

In [None]:
# Run the training pipeline
generator_model = train_srgan()