## Inspect Trained Model


In [29]:
import os

import torch
import yaml
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import Dataset

import data
from networks import *
from utils import mse2psnr

### Select model to inspect


In [30]:
# exp_name = "continual_base"
exp_name = "non_continual"
ids = [dir for dir in os.listdir(f"exp/{exp_name}") if not dir.startswith(".")]
id = ids[0]
path = f"exp/{exp_name}/{id}"

config_path = f"{path}/.hydra/config.yaml"
ckpt_path = f"{path}/ckpt/final.pt"

In [31]:
# Load config
with open(config_path, "r") as stream:
    config = yaml.safe_load(stream)
    cfg = DictConfig(config)

# Load model
model = instantiate(cfg.network)
ckpt = torch.load(ckpt_path)
model_state_dict = ckpt["model_state_dict"]
model.load_state_dict(ckpt["model_state_dict"])
model.eval()

# Load source image
dataset = instantiate(cfg.data)

### Sanity checks


In [32]:
# Sanity check model performance (on the full grid)
model_output, _ = model(dataset.full_coords)
model_output = model_output.cpu().detach()
mse = ((model_output - dataset.full_pixels) ** 2).mean().item()
psnr = mse2psnr(mse)
print(f"mse={mse}")
print(f"psnr={psnr}")

mse=0.0005577552365139127
psnr=32.535563436287575


### Weight Sparsity

In [33]:
total_params = 0
sparse_params = 0
for param in model.parameters():
    total_params += torch.numel(param)
    sparse_params += torch.isclose(param, torch.zeros_like(param)).sum()

print(f"sparsity={sparse_params / total_params}")

sparsity=5.027272891311441e-06
