# Evaluating models
Here you can find some code that may be useful in evaluating trained models.

In [1]:
import torch
import numpy as np
import json

import torch.nn as nn

# FrEIA imports
import FrEIA.framework as Ff
import FrEIA.modules as Fm

## Load data

Replace the location of the weights file and initialization parameters below.

In [2]:
d = '../saved_weights/' + 'LTN_airfoil-6-circles'
with open(d + '/init_parameters.json') as f:
    init_params = json.loads(f.read())

Initialize with the right data or there may be problems with dimensionality.

In [3]:
base_data = torch.load('../datasets/airfoil_base.pt')
cover_data = torch.load('../datasets/airfoil_bundle.pt')
init_params['base_data'] = base_data
init_params['cover_data'] = cover_data

## Initialize model

In [4]:
device = 'cuda:7'

dev = torch.device('cuda:7')
init_params['device'] = device

In [5]:
from bundlenet import BundleNet as LocalTrivNet
model = LocalTrivNet(**init_params)

model.model.load_state_dict(torch.load(d + '/final-weights.pt', map_location=dev))

model.centers = torch.load(d+'/centers.pt')
model.base_nbhds = torch.load(d+'/base_nbhds.pt')
model.cover_nbhds = torch.load(d+'/cover_nbhds.pt')

_ = model.model.to(model.device)

In [12]:
from bundlenet import CGAN_triv
model = CGAN_triv(**init_params)

model.G.load_state_dict(torch.load(d + '/final-G-weights.pt'))
model.D.load_state_dict(torch.load(d + '/final-D-weights.pt'))

_ = model.G.to(model.device)
_ = model.D.to(model.device)

In [13]:
model.centers = torch.load(d+'/centers.pt')
model.base_nbhds = torch.load(d+'/base_nbhds.pt')
model.cover_nbhds = torch.load(d+'/cover_nbhds.pt')

In [4]:
from bundlenet import CGAN
model = CGAN(**init_params)

model.G.load_state_dict(torch.load(d + '/final-G-weights.pt'))
model.D.load_state_dict(torch.load(d + '/final-D-weights.pt'))

_ = model.G.to(model.device)
_ = model.D.to(model.device)

In [4]:
from bundlenet import WGAN_div
model=WGAN_div(**init_params)

model.G.load_state_dict(torch.load(d + '/final-G-weights.pt'))
model.D.load_state_dict(torch.load(d + '/final-D-weights.pt'))

_ = model.G.to(model.device)
_ = model.D.to(model.device)

## Evaluate
Here we run the evaluation and pretty-print the results (which include bootstrapped CI's)

In [7]:
from bundlenet.evaluation import evaluate_airfoil_minimal as evaluate
results = evaluate(model)

Global Losses: 0 1 2 3 4 5 6 7 8 9 
Fiber Losses 0 1 2 3 4 5 6 7 8 9 
 {'global': {'Wass1': (1.1240829, BootstrapResult(confidence_interval=ConfidenceInterval(low=1.1174857145718278, high=1.12821770308617), standard_error=0.0026329735))}, 'fiber': {'Wass1': (3.0715075, BootstrapResult(confidence_interval=ConfidenceInterval(low=3.003512397194155, high=3.1434332595634107), standard_error=0.036108673))}}


In [8]:
for name in ['global', 'fiber']:
    for key in results[name]:
        t = results[name][key]
        val = t[0]
        upper = t[1].confidence_interval.high
        lower = t[1].confidence_interval.low
        if key == 'MMD':
            val *= 1000
            upper *= 1000
            lower *= 1000
        print(f'{name} {key}: mean {val:5.3f}\pm{max(upper-val, val-lower):5.3f}')

global Wass1: mean 1.124\pm0.007
fiber Wass1: mean 3.072\pm0.072
