In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from data_loader import load_dataset

PATH_TRAIN = "../Data/Processed/Train.nc"
PATH_TEST = "../Data/Processed/Test.nc"
PATH_VAL = "../Data/Processed/Val.nc"

PATH_WEIGHTS = "PINN-Best.pth"

In [None]:
class DataDrivenModule(nn.Module):
    def __init__(self):
        super(DataDrivenModule, self).__init__()
        self.conv = nn.Conv2d(in_channels=768, out_channels=768, kernel_size=3, padding=1)
        self.attn = nn.MultiheadAttention(embed_dim=768, num_heads=8)
        self.dropout = nn.Dropout(p=0.1)
        self.norm = nn.LayerNorm(768)
        self.mlp = nn.Sequential(
            nn.Linear(768, 768),
            nn.ReLU(),
            nn.Linear(768, 768),
            nn.ReLU(),
            nn.Linear(768, 768)
        )
        self.transconv = nn.ConvTranspose2d(in_channels=768, out_channels=45, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv(x)
        b, c, h, w = x.shape
        x = x.view(b, c, -1).permute(2, 0, 1)  # Flatten and reshape for attention
        x, _ = self.attn(x, x, x)  # Apply attention
        x = self.norm(x)  # Apply normalization
        x = self.dropout(x)  # Apply dropout
        x = F.relu(self.mlp(x))  # Apply MLP
        x = x.permute(1, 2, 0).view(b, c, h, w)  # Reshape back
        x = self.transconv(x)  # Transposed Conv2D
        return x  # Returns u', v', SSH

In [None]:
class PhysicsInformedModule(nn.Module):
    def __init__(self, g=9.81, f=1e-4):
        super(PhysicsInformedModule, self).__init__()
        self.g = g
        self.f = f

    def forward(self, ssh):
        # Compute gradients ∂SSH/∂x and ∂SSH/∂y using finite differences
        dudx = torch.diff(ssh, dim=-1, append=ssh[:, :, :, -1:])
        dvdy = torch.diff(ssh, dim=-2, append=ssh[:, :, -1:, :])
        
        # Compute geostrophic velocity components
        u_g = self.g / self.f * dvdy
        v_g = -self.g / self.f * dudx
        return u_g, v_g
    
class SumModule(nn.Module):
    def forward(self, u_g, v_g, u_prime, v_prime):
        u = u_g + u_prime
        v = v_g + v_prime
        return u, v

In [None]:
class PICPModel(nn.Module):
    def __init__(self):
        super(PICPModel, self).__init__()
        self.data_module = DataDrivenModule()
        self.physics_module = PhysicsInformedModule()
        self.sum_module = SumModule()

    def forward(self, x):
        # Step 1: Data-driven module
        output = self.data_module(x)  # Produces u', v', SSH
        u_prime, v_prime, ssh = torch.chunk(output, chunks=3, dim=1)

        # Step 2: Physics-informed module
        u_g, v_g = self.physics_module(ssh)

        # Step 3: Sum module
        u, v = self.sum_module(u_g, v_g, u_prime, v_prime)

        return u, v

In [None]:
class WeightedLoss(nn.Module):
    def __init__(self, rho=2.0):
        super(WeightedLoss, self).__init__()
        self.rho = rho
        self.mse = nn.MSELoss(reduction='none')  # Compute loss for each point

    def forward(self, predictions, targets):
        # Compute SSC magnitude
        ssc = torch.sqrt(targets[:, 0, :, :]**2 + targets[:, 1, :, :]**2)  # sqrt(u^2 + v^2)

        # Compute 85th percentile threshold (U_max15%)
        U0 = torch.quantile(ssc, 0.85)

        # Create weight matrix (rho for high SSC, 1 otherwise)
        weight = torch.ones_like(ssc)
        weight[ssc > U0] = self.rho

        # Compute weighted MSE loss
        loss = self.mse(predictions, targets)
        loss = weight.unsqueeze(1) * loss  # Apply weights
        return loss.mean()

In [None]:
def save_checkpoint(model, optimizer, epoch, best_val_loss, filename=PATH_WEIGHTS):
    checkpoint = {
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "best_val_loss": best_val_loss
    }
    torch.save(checkpoint, filename)
    print(f"Model checkpoint saved at epoch {epoch+1} with val_loss: {best_val_loss:.4f}")

In [None]:
def load_checkpoint(model, optimizer, filepath=PATH_WEIGHTS):
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    start_epoch = checkpoint["epoch"] + 1
    loss = checkpoint["loss"]

    print(f"Resuming training from epoch {start_epoch}")
    return start_epoch, loss

In [None]:
def validate(model, val_loader, loss_function, device="cuda"):
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            predictions = model(inputs)
            predictions = predictions.view(-1, 2, 30, 100)
            
            loss = loss_function(predictions, targets)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Validation Loss: {avg_val_loss:.4f}")
    return avg_val_loss

In [None]:

def train(model, train, val, loss_function, optimizer, num_epochs, start_epoch=0):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    best_val_loss = float("inf")

    for epoch in range(start_epoch, num_epochs):
        model.train()
        total_loss = 0.0
        
        for inputs, targets in train:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()

            predictions = model(inputs)
            predictions = predictions.view(-1, 2, 30, 100)

            loss = loss_function(predictions, targets)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
        val_loss = validate(model, val, loss_function, device)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint(model, optimizer, epoch, best_val_loss)

    print("Training complete")

In [None]:
train = load_dataset(PATH_TRAIN)
val = load_dataset(PATH_VAL)

model = PICPModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_function = WeightedLoss(rho=5.0)

In [None]:
num_epochs = 5

try:
    start_epoch, prev_loss = load_checkpoint(model, optimizer)
except FileNotFoundError:
    print("No checkpoint found, starting from scratch.")
    start_epoch = 0

train(model, train, val, loss_function, optimizer, num_epochs)