In [None]:
import numpy as np
import os
import torch
import scipy.io
import random
from src.fusion_model import Fusion_NSPDE
from src.utilities import LpLoss
from evaluation.loss import ACFLoss 

# Setup Device & Seed
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on {DEVICE}")

# Load Inference Data

In [None]:
print("Loading data...")
data_init = scipy.io.loadmat(r"data/initial.mat")
u0_ood = torch.tensor(data_init['sol'], dtype=torch.float32).to(DEVICE)

data_noise = scipy.io.loadmat(r"data/noise.mat")
xi_ood = torch.tensor(data_noise['W'], dtype=torch.float32).to(DEVICE)

# Initialize and Load Model


In [None]:
CHECKPOINT_PATH = r'src/best_fusion_model_trained.pth'

HIDDEN_CHANNELS = 32
MODES_X = 16
MODES_Y = 16

# Internal Evaluation ( Reproducibility Test)

In [None]:
print("===(INTERNAL EVAL) ===")
data_training = scipy.io.loadmat(r"data/public_data.mat")
W_raw = torch.tensor(data_training['W']).float()
Sol_raw = torch.tensor(data_training['sol']).float()

torch.manual_seed(42)
N_TOTAL = 1500
indices = torch.randperm(N_TOTAL)
idx_test = indices[1400:]

W_test = W_raw[idx_test].to(DEVICE)     # [100, 32, 32, 201]
Sol_test = Sol_raw[idx_test].to(DEVICE) # [100, 32, 32, 201]

# 2. MODEL for Internal (T=201)
T_points_internal = torch.linspace(0, 0.020, 201).to(DEVICE)
X_points = torch.linspace(0, 1, 32).to(DEVICE)
Y_points = torch.linspace(0, 1, 32).to(DEVICE)

print("Model A (Internal Mode)...")
model_internal = Fusion_NSPDE(
    dim=2, in_channels=1, noise_channels=1,
    hidden_channels=HIDDEN_CHANNELS, # 32
    n_iter=1, modes1=MODES_X, modes2=MODES_Y, # 16, 16
    solver='diffeq',
    T_points=T_points_internal, #  T=201
    X_points=X_points, Y_points=Y_points,
    device=DEVICE
).to(DEVICE)

# Load weights
model_internal.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE))
model_internal.eval()


print(" Running 100 samples test...")
BATCH_SIZE = 5
pred_list = []

with torch.no_grad():
    for i in range(0, len(W_test), BATCH_SIZE):
        u0_b = Sol_test[i:i+BATCH_SIZE, ..., 0].unsqueeze(1) 
        xi_b = W_test[i:i+BATCH_SIZE]
        
   
        pred = model_internal(u0_b, xi_b)
        pred_list.append(pred)

pred_sol = torch.cat(pred_list, dim=0) # [100, 32, 32, 201]
truth_sol = Sol_test

# 4. METRICS
myloss = LpLoss(size_average=False)
ntest = truth_sol.shape[0]

# Relative L2
with torch.no_grad():
    loss = myloss(pred_sol[..., 1:].reshape(ntest, -1), truth_sol[..., 1:].reshape(ntest, -1))
    Rel_L2 = loss.item() / ntest
print(f'Relative L2 Loss: {Rel_L2:.6f}')

# ACF Score
transform = lambda d: d.permute(0, 3, 1, 2).reshape(d.shape[0], d.shape[3], -1)
pred_sliced = pred_sol[..., 1:]
sol_sliced = truth_sol[..., 1:]

with torch.no_grad():
    max_lag = min(64, sol_sliced.shape[3] - 1)
    acf_calculator = ACFLoss(
        x_real=sol_sliced.cpu(), 
        transform=transform, stationary=True, max_lag=max_lag,
        name = "acf_loss"
    )
    acf_score = acf_calculator(pred_sliced.cpu()).item()
    print(f'ACF Score: {acf_score:.6f}')
print("=" * 40)

# Run Inference

In [None]:
print("\n=== SUBMISSION (OOD DATA) ===")

data_init = scipy.io.loadmat(r"data/initial.mat")
u0_ood = torch.tensor(data_init['sol'], dtype=torch.float32).to(DEVICE) # [500, 32, 32]

data_noise = scipy.io.loadmat(r"data/noise.mat")
xi_ood = torch.tensor(data_noise['W'], dtype=torch.float32).to(DEVICE) # [500, 32, 32, 251]

print(f"Submission data: {u0_ood.shape[0]} samples, Time steps: {xi_ood.shape[-1]}")


T_points_submission = torch.linspace(0, 0.025, 251).to(DEVICE) 

print(" Model B (Submission Mode)...")
model_submission = Fusion_NSPDE(
    dim=2, in_channels=1, noise_channels=1,
    hidden_channels=HIDDEN_CHANNELS,
    n_iter=1, modes1=MODES_X, modes2=MODES_Y,
    solver='diffeq',
    T_points=T_points_submission, # T=251
    X_points=X_points, Y_points=Y_points,
    device=DEVICE
).to(DEVICE)

# Load weights 
model_submission.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE))
model_submission.eval()

# 3. Predict and save file
BATCH_SIZE = 5
pred_list_ood = []
num_samples = u0_ood.shape[0]
num_batches = int(np.ceil(num_samples / BATCH_SIZE))

print(f"Running submission predictions ({num_samples} samples)â€¦")

with torch.no_grad():
    for i in range(num_batches):
        start = i * BATCH_SIZE
        end = min((i + 1) * BATCH_SIZE, num_samples)
        
 
        u0_batch = u0_ood[start:end]
        xi_batch = xi_ood[start:end]
        
        
        u_pred = model_submission(u0_batch, xi_batch)
        pred_list_ood.append(u_pred.cpu().numpy())
        
        # Logging
        if (i+1) % 10 == 0: 
            print(f"   Processed batch {i+1}/{num_batches}")
        
        del u0_batch, xi_batch, u_pred
        torch.cuda.empty_cache()


pred_sol_final = np.concatenate(pred_list_ood, axis=0)
print(f"Finished. Shape: {pred_sol_final.shape}")

# Save Output

In [None]:
if pred_sol_final.shape == (500, 32, 32, 251):
    scipy.io.savemat('pred.mat', {'sol': pred_sol_final})
    print("File 'pred.mat' has been saved successfully!")
else:
    print(f"Warning: The shape is incorrect. Please check and correct it.")