## Imports and Setup


In [None]:
import os
import json
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

# Set seeds
torch.manual_seed(42)
np.random.seed(42)

# --- Fine-Tuning Configuration ---
BATCH_SIZE = 4          
LEARNING_RATE_G = 1e-5  # CRITICAL: 10x smaller than pre-training
LEARNING_RATE_D = 1e-5  # CRITICAL: 10x smaller
EPOCHS = 20             # Shorter training for adaptation
PRETRAIN_EPOCHS = 0     # START GAN IMMEDIATELY (Discriminator active from Epoch 1)

# Weights (High Physics Weight for Realism)
MSE_WEIGHT = 1.0        
PHYSICS_WEIGHT = 0.1    
ADVERSARIAL_WEIGHT = 0.001 

OUTPUT_DIR = "/kaggle/working"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üöÄ Device: {DEVICE}")

# --- Dataset: Native 64x64 (Noisy/Sim-to-Real) ---
# Update this path if your dataset name is different
DATA_FILE = "/kaggle/input/navier-stokes-dataset/native_64.pt" 
print(f"üìÇ Fine-Tuning Dataset: {DATA_FILE}")

# --- Checkpoints: Pre-trained Weights ---
# Assuming you uploaded your previous output as a dataset named 'checkpoints'
PRETRAINED_G = "/kaggle/input/checkpoints/SRGAN_Gen_epoch_100.pth"
PRETRAINED_D = "/kaggle/input/checkpoints/SRGAN_Disc_epoch_100.pth"

## Dataset Class 
Added ```noise_std``` parameter for Sim-to-Real gap augmentation.

In [None]:
class FluidLoader(Dataset):
    def __init__(self, pt_file):
        print(f"‚è≥ Loading data from {pt_file}...")
        try:
            data = torch.load(pt_file, map_location='cpu', mmap=True)
        except:
            print("‚ö†Ô∏è mmap failed. Loading entire dataset to RAM.")
            data = torch.load(pt_file, map_location='cpu')
            
        self.inputs = data['inputs']
        self.targets = data['targets']
        
        self.K_vel = float(data.get('K_vel', 1.0))
        self.K_pres = float(data.get('K_pres', 1.0))
        self.K_smoke = float(data.get('K_smoke', 1.0))
            
        print(f"‚úÖ Loaded {len(self.inputs)} samples.")
        print(f"   Normalization: v={self.K_vel}, p={self.K_pres}, s={self.K_smoke}")

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        lr = self.inputs[idx].float()
        hr = self.targets[idx].float()
        
        # --- CRITICAL FIX: Revert Upscaled Inputs ---
        # If the input is 256x256, downscale it to 64x64 so the Generator accepts it.
        if lr.shape[-1] == 256:
            lr = F.interpolate(lr.unsqueeze(0), size=(64, 64), mode='bilinear', align_corners=False).squeeze(0)
            
        return lr, hr

## Model Architecture (SRGAN)

In [None]:
# --- Generator (SRResNet) ---
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        out = self.prelu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return out + residual

class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, scale_factor):
        super(UpsampleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * (scale_factor ** 2), kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(scale_factor)
        self.prelu = nn.PReLU()

    def forward(self, x):
        return self.prelu(self.pixel_shuffle(self.conv(x)))

class Generator(nn.Module):
    def __init__(self, in_channels=4, out_channels=4, hidden_channels=64, num_res_blocks=16):
        super(Generator, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels, hidden_channels, kernel_size=9, padding=4), nn.PReLU())
        res_blocks = [ResidualBlock(hidden_channels) for _ in range(num_res_blocks)]
        self.res_blocks = nn.Sequential(*res_blocks)
        self.conv2 = nn.Sequential(nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(hidden_channels))
        self.upsample = nn.Sequential(UpsampleBlock(hidden_channels, 2), UpsampleBlock(hidden_channels, 2))
        self.final_conv = nn.Conv2d(hidden_channels, out_channels, kernel_size=9, padding=4)

    def forward(self, x):
        out1 = self.conv1(x)
        out = self.res_blocks(out1)
        out = self.conv2(out)
        out = out + out1 
        out = self.upsample(out)
        out = self.final_conv(out)
        return out

# --- Discriminator ---
class Discriminator(nn.Module):
    def __init__(self, in_channels=4, hidden_channels=64):
        super(Discriminator, self).__init__()
        def conv_block(in_c, out_c, stride=1, bn=True):
            layers = [nn.Conv2d(in_c, out_c, 3, stride, 1, bias=False)]
            if bn: layers.append(nn.BatchNorm2d(out_c))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.features = nn.Sequential(
            *conv_block(in_channels, hidden_channels, stride=1, bn=False),
            *conv_block(hidden_channels, hidden_channels, stride=2, bn=True),
            *conv_block(hidden_channels, hidden_channels*2, stride=1, bn=True),
            *conv_block(hidden_channels*2, hidden_channels*2, stride=2, bn=True),
            *conv_block(hidden_channels*2, hidden_channels*4, stride=1, bn=True),
            *conv_block(hidden_channels*4, hidden_channels*4, stride=2, bn=True),
            *conv_block(hidden_channels*4, hidden_channels*8, stride=1, bn=True),
            *conv_block(hidden_channels*8, hidden_channels*8, stride=2, bn=True),
        )
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), nn.Flatten(),
            nn.Linear(hidden_channels*8, 1024), nn.LeakyReLU(0.2),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)

## Loss Functions ( Poisson Pressure equation)

In [None]:
class PressurePoissonLoss(nn.Module):
    def __init__(self):
        super(PressurePoissonLoss, self).__init__()
        self.register_buffer('laplacian', torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], dtype=torch.float32))
        self.register_buffer('k_x', torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]], dtype=torch.float32) / 8.0)
        self.register_buffer('k_y', torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]], dtype=torch.float32) / 8.0)

    def get_grads(self, f): return F.conv2d(f, self.k_x, padding=1), F.conv2d(f, self.k_y, padding=1)
    def get_laplacian(self, f): return F.conv2d(f, self.laplacian, padding=1)
    def get_divergence(self, fx, fy): return self.get_grads(fx)[0] + self.get_grads(fy)[1]

    def forward(self, pred, K_vel, K_pres, K_smoke):
        u, v, p, s = pred[:,0:1]*K_vel, pred[:,1:2]*K_vel, pred[:,2:3]*K_pres, pred[:,3:4]*K_smoke
        lhs = self.get_laplacian(p)
        du_dx, du_dy = self.get_grads(u)
        dv_dx, dv_dy = self.get_grads(v)
        div_conv = self.get_divergence(u*du_dx + v*du_dy, u*dv_dx + v*dv_dy)
        div_force = self.get_divergence(torch.zeros_like(s), s*0.3)
        return F.mse_loss(lhs, -div_conv + div_force)

## Training Loop

In [None]:
import gc
from torch.amp import autocast, GradScaler

def train_finetune():
    gc.collect(); torch.cuda.empty_cache()
    
    # 1. Dataset
    dataset = FluidLoader(DATA_FILE)
    train_sz = int(0.9 * len(dataset))
    val_sz = len(dataset) - train_sz
    train_ds, val_ds = random_split(dataset, [train_sz, val_sz], generator=torch.Generator().manual_seed(42))
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

    # 2. Models
    netG = Generator(in_channels=4, out_channels=4).to(DEVICE)
    netD = Discriminator(in_channels=4).to(DEVICE)

    # --- LOAD PRE-TRAINED WEIGHTS ---
    print(f"\nüì• Loading Pre-trained Weights...")
    if os.path.exists(PRETRAINED_G):
        ckpt_g = torch.load(PRETRAINED_G, map_location=DEVICE)
        # Handle dict vs raw state_dict
        if 'model_state_dict' in ckpt_g: ckpt_g = ckpt_g['model_state_dict']
        netG.load_state_dict(ckpt_g)
        print("   ‚úÖ Generator Loaded")
    else:
        print(f"   ‚ö†Ô∏è Generator weights NOT FOUND at {PRETRAINED_G}!")

    if os.path.exists(PRETRAINED_D):
        ckpt_d = torch.load(PRETRAINED_D, map_location=DEVICE)
        if 'model_state_dict' in ckpt_d: ckpt_d = ckpt_d['model_state_dict']
        netD.load_state_dict(ckpt_d)
        print("   ‚úÖ Discriminator Loaded")
    else:
        print("   ‚ö†Ô∏è Discriminator weights not found. Re-initializing D.")

    # 3. Optimizers (Low LR)
    optimizerG = optim.Adam(netG.parameters(), lr=LEARNING_RATE_G)
    optimizerD = optim.Adam(netD.parameters(), lr=LEARNING_RATE_D)

    # 4. Losses
    mse_fn = nn.MSELoss()
    gan_fn = nn.BCEWithLogitsLoss()
    phys_fn = PressurePoissonLoss().to(DEVICE)
    scaler = GradScaler('cuda')
    K_vel, K_pres, K_smoke = dataset.K_vel, dataset.K_pres, dataset.K_smoke

    print(f"\nüöÄ Starting Fine-Tuning (Epochs={EPOCHS}, LR={LEARNING_RATE_G})...")

    for epoch in range(EPOCHS):
        netG.train(); netD.train()
        loss_g_accum = 0.0; loss_d_accum = 0.0
        
        loop = tqdm(train_loader, desc=f"Fine-Tune {epoch+1}/{EPOCHS}", leave=False)
        
        for lr_img, hr_img in loop:
            lr_img, hr_img = lr_img.to(DEVICE), hr_img.to(DEVICE)
            
            # Update G
            optimizerG.zero_grad()
            with autocast('cuda'):
                sr_img = netG(lr_img)
                loss_content = MSE_WEIGHT * mse_fn(sr_img, hr_img)
                loss_phys = PHYSICS_WEIGHT * phys_fn(sr_img, K_vel, K_pres, K_smoke)
                pred_fake = netD(sr_img)
                loss_adv = ADVERSARIAL_WEIGHT * gan_fn(pred_fake, torch.ones_like(pred_fake))
                loss_G = loss_content + loss_phys + loss_adv
            
            scaler.scale(loss_G).backward()
            scaler.step(optimizerG)
            loss_g_accum += loss_G.item()

            # Update D
            optimizerD.zero_grad()
            with autocast('cuda'):
                pred_real = netD(hr_img)
                loss_real = gan_fn(pred_real, torch.ones_like(pred_real))
                pred_fake = netD(sr_img.detach())
                loss_fake = gan_fn(pred_fake, torch.zeros_like(pred_fake))
                loss_D = (loss_real + loss_fake) / 2
            
            scaler.scale(loss_D).backward()
            scaler.step(optimizerD)
            scaler.update()
            loss_d_accum += loss_D.item()
            
            loop.set_postfix(G=loss_G.item(), D=loss_D.item())

        # Validation
        avg_g = loss_g_accum / len(train_loader)
        netG.eval()
        val_mse = 0.0
        with torch.no_grad():
            for lr, hr in val_loader:
                lr, hr = lr.to(DEVICE), hr.to(DEVICE)
                with autocast('cuda'):
                    val_mse += mse_fn(netG(lr), hr).item()
        
        print(f"Epoch {epoch+1} | G_Loss: {avg_g:.5f} | Val MSE: {val_mse/len(val_loader):.6f}")

        # Save Checkpoints
        if (epoch+1) % 5 == 0 or epoch == EPOCHS-1:
            os.makedirs("checkpoints_finetuned", exist_ok=True)
            torch.save(netG.state_dict(), f"checkpoints_finetuned/SRGAN_FT_Gen_epoch_{epoch+1}.pth")
            torch.save(netD.state_dict(), f"checkpoints_finetuned/SRGAN_FT_Disc_epoch_{epoch+1}.pth")

    print("‚úÖ Fine-Tuning Complete.")
    return netG

## Execute
Run this cell to start.

In [None]:
# Run the training pipeline
model = train_finetune()