In [None]:
import logging
import sys
from pathlib import Path

import h5py
import torch
import yaml

from dinozaur.training.trainer import Trainer

logger = logging.getLogger(__name__)
logger.addHandler(logging.StreamHandler(stream=sys.stdout))
logger.setLevel(logging.INFO)

In [None]:
# load config
with open("config/darcy/dinozaur.yml") as f:
    config = yaml.safe_load(f)

In [None]:
# set num_epoch=1 for demonstration purposes
config["training_params"]["num_epochs"] = 1

In [None]:
# train for 1 epoch
trainer = Trainer(**config)
trainer.train()

In [None]:
# save full model
Path("logs/example").mkdir(parents=True, exist_ok=True)
torch.save(trainer, "logs/example/model.pt")

In [None]:
# load full model
trainer = torch.load("logs/example/model.pt", weights_only=False)

In [None]:
# load test sample
sample = {}

with h5py.File("data/darcy/test/data_1024.h5", "r") as f:
    for k in f.keys():
        sample[k] = torch.tensor(f[k][:][None])

In [None]:
# run inference
prediction = trainer.predict(sample.copy()).detach().cpu().numpy()