## Inspect Trained Model

In [33]:
import torch
from networks import *
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
import yaml
from torch.utils.data import Dataset
import data
from utils import mse2psnr

### Select model to inspect

In [8]:
exp_name = "continual_base"
id = "222356"
path = f"exp/{exp_name}/{id}"

config_path = f"{path}/.hydra/config.yaml"
ckpt_path = f"{path}/ckpt/ckpt_1000.pt"

In [28]:
# Load config
with open(config_path, "r") as stream:
    config = yaml.safe_load(stream)
    cfg = DictConfig(config)

# Load model
model = instantiate(cfg.network)
ckpt = torch.load(ckpt_path)
model_state_dict = ckpt["model_state_dict"]
model.load_state_dict(ckpt['model_state_dict'])
model.eval()

# Load source image
dataset = instantiate(cfg.data)

### Sanity checks

In [40]:
# Sanity check model performance
model_output, _ = model(dataset.full_coords)
model_output = model_output.cpu().detach()
mse = ((model_output - dataset.full_pixels) ** 2).mean().item()
psnr = mse2psnr(mse)