In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from data_loader import load_dataset
from model_trainer import train

PATH_TRAIN = "../Data/Processed/Train.nc"
PATH_TEST = "../Data/Processed/Test.nc"
PATH_VAL = "../Data/Processed/Val.nc"

PATH_WEIGHTS_BEST = "../Models/PINN-Best.pth"
PATH_WEIGHTS_CURRENT = "../Models/PINN-Current.pth"

In [None]:
class DataDrivenModule(nn.Module):
    def __init__(self, 
                 in_channels=768, out_channels=768, kernel_size=5, padding=1, 
                 embed_dim=768, num_heads=4, 
                 dropout_p=0.1, 
                 mlp_hidden_dim=768,  
                 transconv_out_channels=15):
        super(DataDrivenModule, self).__init__()
        
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding)
        self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)
        self.dropout = nn.Dropout(p=dropout_p)
        self.norm = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.ReLU(),
            nn.Linear(mlp_hidden_dim, mlp_hidden_dim),
            nn.ReLU(),
            nn.Linear(mlp_hidden_dim, out_channels),
            nn.ReLU()
        )
        self.transconv = nn.ConvTranspose2d(in_channels=out_channels, out_channels=transconv_out_channels, kernel_size=kernel_size, padding=padding)

    def forward(self, x):

        x = self.conv(x)

        batch_size, channels, height, width = x.shape
        x = x.view(batch_size, channels, -1).permute(2, 0, 1)  # Flatten and reshape for attention

        x, _ = self.attn(x, x, x)  # Apply attention
        x = self.norm(x)  # Apply normalization
        x = self.dropout(x)  # Apply dropout
        x = self.mlp(x)  # Apply MLP

        x = x.permute(1, 2, 0).view(batch_size, channels, height, width)  # Reshape back

        x = self.transconv(x)  # Transposed Conv2D

        return x  # Returns u', v', SSH

In [None]:
class PhysicsInformedModule(nn.Module):
    def __init__(self, g=9.81, f=1e-4):
        super(PhysicsInformedModule, self).__init__()
        self.g = g
        self.f = f

    def forward(self, ssh):
        # Compute gradients ∂SSH/∂x and ∂SSH/∂y using finite differences
        dudx = torch.diff(ssh, dim=-1, append=ssh[:, :, :, -1:])
        dvdy = torch.diff(ssh, dim=-2, append=ssh[:, :, -1:, :])
        
        # Compute geostrophic velocity components
        u_g = self.g / self.f * dvdy
        v_g = -self.g / self.f * dudx
        return u_g, v_g
    
class SumModule(nn.Module):
    def forward(self, u_g, v_g, u_prime, v_prime):
        u = u_g + u_prime
        v = v_g + v_prime
        return u, v

In [None]:
class PICPModel(nn.Module):
    def __init__(self):
        super(PICPModel, self).__init__()
        self.data_module = DataDrivenModule()
        self.physics_module = PhysicsInformedModule()
        self.sum_module = SumModule()

    def forward(self, x):
        # Step 1: Data-driven module
        output = self.data_module(x)  # Produces u', v', SSH
        u_prime, v_prime, ssh = torch.chunk(output, chunks=3, dim=1)

        # Step 2: Physics-informed module
        u_g, v_g = self.physics_module(ssh)

        # Step 3: Sum module
        u, v = self.sum_module(u_g, v_g, u_prime, v_prime)

        return torch.stack((u, v, ssh), dim=1)

In [None]:
class WeightedLoss(nn.Module):
    def __init__(self, rho=2.0):
        super(WeightedLoss, self).__init__()
        self.rho = rho
        self.mse = nn.MSELoss(reduction='none')  # Compute loss for each point

    def forward(self, predictions, targets):
        # Compute SSC magnitude
        ssc = torch.sqrt(targets[:, 0, :, :]**2 + targets[:, 1, :, :]**2)  # sqrt(u^2 + v^2)

        # Compute 85th percentile threshold (U_max15%)
        U0 = torch.quantile(ssc, 0.85)

        # Create weight matrix (rho for high SSC, 1 otherwise)
        weight = torch.ones_like(ssc)
        weight[ssc > U0] = self.rho

        # Compute weighted MSE loss
        loss = self.mse(predictions, targets)
        loss = weight.unsqueeze(1) * loss  # Apply weights
        return loss.mean()

In [None]:
batch_size = 32
epochs = 10
learning_rate = 0.001

kernel_size = 5
num_heads = 4
rho = 5.0

model = PICPModel(kernel_size=kernel_size, num_heads=num_heads)
loss_function = WeightedLoss(rho=rho)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

train_ds = load_dataset(PATH_TRAIN, batch_size=batch_size)
val_ds = load_dataset(PATH_VAL, batch_size=batch_size)

In [None]:
train(model = model,
    loss_function = loss_function, 
    optimizer = optimizer,
    train = train_ds, 
    val = val_ds,
    epochs = epochs, 
    path_weights_best = PATH_WEIGHTS_BEST, 
    path_weights_last = PATH_WEIGHTS_CURRENT)