In [1]:
import torch
from neuralop.models.fno import FNO
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, TensorDataset
import json
import os

# Load the model architecture
model = FNO(n_modes=(16, 16), hidden_channels=64, in_channels=2, out_channels=2)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Load optimizer and scheduler
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=50, gamma=0.1)

# Load a specific epoch checkpoint (e.g., epoch 10)
checkpoint_path = './checkpoints/checkpoint_epoch_10.pt'
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

# Load data
data = torch.load('./Data.pt')
test_input = data["test_in"]
test_output = data["test_sol"]

test_dataset = TensorDataset(test_input, test_output)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

# define and instantiate: loss and model
class L2Loss(object):
    # loss returns the sum over all the samples in the current batch
    def __init__(self,):
        super(L2Loss, self).__init__()
    
    def __call__(self, x, y):
        num_examples = x.size()[0]
        diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), 2, 1)
        y_norms = torch.norm(y.reshape(num_examples,-1), 2, 1)
        return torch.sum(diff_norms / y_norms)


# Evaluate the model
model.eval()
test_loss = 0.0
criterion = L2Loss()  # Assuming you've defined L2Loss similarly

with torch.no_grad():
    for inputs, targets in test_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        test_loss += loss.item()

avg_test_loss = test_loss / len(test_loader)
print(len(test_loader))
print(f"Average Test Loss: {avg_test_loss:.4f}")


Average Test Loss: 28.8515
