Test and compare a model that only saw wind set (WS) 2 and 3 during training. Test only on WS3.
Each gas source position was simulated with one of the three wind sets, i.e., the model is only trained with 2/3 of the gas source positions.

In [1]:
# ~~~~~~~~~~~~~~~~
# PyTorch Model

import torch
import yaml
from models.decoder import architectures

with open("models/decoder/decoder_params.yaml") as file:
    params = yaml.safe_load(file)

decoder = architectures.LightningNet(params["inner_dims"], params["seq_len"], params["learning_rate"])
decoder.load_state_dict(torch.load("models/decoder/decoder.pth"))
decoder.eval();

decoder_light = architectures.LightningNet(params["inner_dims"], params["seq_len"], params["learning_rate"])
decoder_light.load_state_dict(torch.load("models/decoder/decoder_lightTraining.pth"))
decoder_light.eval();

In [2]:
from torch.utils import data
from data.gdm_dataset import GasDataSet

dataset = GasDataSet("data/30x25_reducedPositions/test_onlyWS1.pt")

In [3]:
import gdm_metrics

rmse = {"decoder": 0,
        "decoder_light": 0}

kld = {"decoder": 0,
       "decoder_light": 0}

In [5]:
from tqdm import tqdm

list_of_models = [
                  "decoder",
                  "decoder_light",
                 ]

dataloader = iter(data.DataLoader(dataset, shuffle=False, drop_last=True))

for X, y in tqdm(dataloader):
    with torch.no_grad(): 
        if "decoder" in list_of_models:
            y_decoder = decoder(X.squeeze(1))
            rmse["decoder"] += gdm_metrics.rmse(y_decoder, y)
            kld["decoder"] += gdm_metrics.kld(y_decoder, y)
        
        if "decoder_light" in list_of_models:
            y_decoder_light = decoder_light(X.squeeze(1))
            rmse["decoder_light"] += gdm_metrics.rmse(y_decoder_light, y)
            kld["decoder_light"] += gdm_metrics.kld(y_decoder_light, y)

100%|████████████████████████████████████| 10800/10800 [01:42<00:00, 104.89it/s]


In [6]:
# Print results
print(f"########\n# RMSE #\n########")
for elem in rmse:
    print(f"{elem}: \t {rmse[elem]/len(dataloader)}")
    
print(f"########\n# KL D #\n########")
for elem in kld:
    print(f"{elem}: \t {kld[elem]/len(dataloader)}")

########
# RMSE #
########
decoder: 	 0.06879420578479767
decoder_light: 	 0.08158489316701889
########
# KL D #
########
decoder: 	 0.003121771616861224
decoder_light: 	 0.004383096471428871
