In [1]:
from pprint import pprint

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import AdamW
import torch.nn.functional as F
import torch.utils.data as data


from src.params import ModelName, DatasetName, DeviceType, get_new_seed
from src.models import LogisticRegression
from src.datasets import load_embeddings
from src.datasets import DatasetSampler
from src.engine import DatasetTrainer, DatasetInference
from src.metrics import cross_entropy, reverse_loss, logistic_bregman
from src.measurements import EstimateKLMisfit
from src.np_utils import index_random_split

# Setup

## Models

Weak Supervisor = AlexNet

Strong Model 1 = Dino ResNet50
Strong Model 2 = Dino Vit-B/8

## Datasets

1. Cifar10
2. ImageNet (download in progress)
3. Cifar100?

For now, not caring about whether validation set is same across models

In [2]:
weak_model_id = ModelName.ALEXNET
strong_model_id = ModelName.RESNET50_DINO

dataset_id = DatasetName.CIFAR10

In [3]:
# Load the data
weak_embeddings = load_embeddings(model_name=weak_model_id, dataset_name=dataset_id)

In [4]:
# Some sanity checks
assert weak_embeddings.x_train.shape[0] == weak_embeddings.y_train.shape[0]
assert weak_embeddings.x_train.shape[1] == weak_embeddings.x_test.shape[1]

# Print some stats on the data
print("Weak model embeddings:")
print("Train data shape:", weak_embeddings.x_train.shape)
print("Test data shape:", weak_embeddings.x_test.shape)
print("Number of classes:", weak_embeddings.num_classes)
# Dtypes
print("Train data dtype:", weak_embeddings.x_train.dtype)
print("Test data dtype:", weak_embeddings.x_test.dtype)
print("Train Labels dtype:", weak_embeddings.y_train.dtype)
print("Test Labels dtype:", weak_embeddings.y_test.dtype)


Weak model embeddings:
Train data shape: (50000, 9216)
Test data shape: (10000, 9216)
Number of classes: 10
Train data dtype: float32
Test data dtype: float32
Train Labels dtype: int64
Test Labels dtype: int64


In [5]:
weak_split = 0.3
validation_split = 0.2
validation_splits = [validation_split * weak_split, (1 - validation_split) * weak_split, validation_split * (1 - weak_split), (1 - validation_split) * (1 - weak_split)]    
weak_val_inds, weak_train_inds, st_from_wk_val_inds, st_from_wk_train_inds = index_random_split(weak_embeddings.x_train.shape[0], split_sizes=validation_splits, random_state=np.random.default_rng(seed=get_new_seed()))

In [6]:
weak_train_dataset = data.TensorDataset(
    torch.tensor(weak_embeddings.x_train[weak_train_inds]),
    F.one_hot(
        torch.tensor(weak_embeddings.y_train[weak_train_inds]), num_classes=weak_embeddings.num_classes
    ).to(torch.float32),
)
weak_val_dataset = data.TensorDataset(
    torch.tensor(weak_embeddings.x_train[weak_val_inds]),
    F.one_hot(
        torch.tensor(weak_embeddings.y_train[weak_val_inds]), num_classes=weak_embeddings.num_classes
    ).to(torch.float32),
)

In [8]:
import math

# Some sanity checks
# assert len(weak_train_dataset) + len(weak_val_dataset) + len(st_gt_val_inds) + len(st_from_wk_val_inds)== weak_embeddings.x_train.shape[0]
assert len(weak_val_dataset) == int(validation_splits[0] * weak_embeddings.x_train.shape[0])
assert len(weak_train_dataset) == int(validation_splits[1] * weak_embeddings.x_train.shape[0])

In [9]:
# Instantiate the weak_model 
num_inputs = weak_embeddings.x_train.shape[1]
weak_model = LogisticRegression(num_inputs, weak_embeddings.num_classes)


In [10]:
# Train the weak model
weight_decay = 0.0
lr = 1e-3
num_epochs = 20
batch_size=256
validation_split=0.2

n_iter = num_epochs * len(range(0, len(weak_train_dataset), batch_size))

optimizer = optim.Adam(weak_model.parameters(), weight_decay=weight_decay, lr=lr)
schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_iter)

weak_trainer = DatasetTrainer(
    model=weak_model,
    optimizer=optimizer,
    loss_fn=cross_entropy,
    dataset=weak_train_dataset,
    val_dataset=weak_val_dataset,
    scheduler=schedule,
)
weak_trainer.train(num_epochs=num_epochs, batch_size=batch_size,)

Using model device for training: cpu
Using model device for training: cpu
Epoch 1/20


100%|██████████| 47/47 [00:00<00:00, 164.00it/s, loss=0.665]


Validating model...


100%|██████████| 12/12 [00:00<00:00, 312.55it/s]


Validation loss: 0.6511116971572241
Epoch 2/20


100%|██████████| 47/47 [00:00<00:00, 177.84it/s, loss=0.448]


Validating model...


100%|██████████| 12/12 [00:00<00:00, 347.88it/s]


Validation loss: 0.5919397870699564
Epoch 3/20


100%|██████████| 47/47 [00:00<00:00, 175.77it/s, loss=0.338]


Validating model...


100%|██████████| 12/12 [00:00<00:00, 350.21it/s]


Validation loss: 0.5788721243540446
Epoch 4/20


100%|██████████| 47/47 [00:00<00:00, 178.82it/s, loss=0.261]


Validating model...


100%|██████████| 12/12 [00:00<00:00, 363.93it/s]


Validation loss: 0.5685970534880956
Epoch 5/20


100%|██████████| 47/47 [00:00<00:00, 180.10it/s, loss=0.214]


Validating model...


100%|██████████| 12/12 [00:00<00:00, 228.18it/s]


Validation loss: 0.5671936323245367
Epoch 6/20


100%|██████████| 47/47 [00:00<00:00, 182.27it/s, loss=0.153]


Validating model...


100%|██████████| 12/12 [00:00<00:00, 329.09it/s]


Validation loss: 0.5854220340649287
Epoch 7/20


100%|██████████| 47/47 [00:00<00:00, 185.95it/s, loss=0.138]


Validating model...


100%|██████████| 12/12 [00:00<00:00, 306.63it/s]


Validation loss: 0.5850677515069643
Epoch 8/20


100%|██████████| 47/47 [00:00<00:00, 193.49it/s, loss=0.101]


Validating model...


100%|██████████| 12/12 [00:00<00:00, 355.24it/s]


Validation loss: 0.591073289513588
Epoch 9/20


100%|██████████| 47/47 [00:00<00:00, 204.12it/s, loss=0.0952]


Validating model...


100%|██████████| 12/12 [00:00<00:00, 369.79it/s]


Validation loss: 0.5916304091612498
Epoch 10/20


100%|██████████| 47/47 [00:00<00:00, 199.35it/s, loss=0.0803]


Validating model...


100%|██████████| 12/12 [00:00<00:00, 127.83it/s]


Validation loss: 0.6021973838408788
Epoch 11/20


100%|██████████| 47/47 [00:00<00:00, 204.48it/s, loss=0.0729]


Validating model...


100%|██████████| 12/12 [00:00<00:00, 368.76it/s]


Validation loss: 0.6034549375375112
Epoch 12/20


100%|██████████| 47/47 [00:00<00:00, 185.09it/s, loss=0.0698]


Validating model...


100%|██████████| 12/12 [00:00<00:00, 227.08it/s]


Validation loss: 0.6048173010349274
Epoch 13/20


100%|██████████| 47/47 [00:00<00:00, 199.73it/s, loss=0.0595]


Validating model...


100%|██████████| 12/12 [00:00<00:00, 340.70it/s]


Validation loss: 0.6065603246291479
Epoch 14/20


100%|██████████| 47/47 [00:00<00:00, 192.29it/s, loss=0.0588]


Validating model...


100%|██████████| 12/12 [00:00<00:00, 286.15it/s]


Validation loss: 0.6083969473838806
Epoch 15/20


100%|██████████| 47/47 [00:00<00:00, 149.69it/s, loss=0.0587]


Validating model...


100%|██████████| 12/12 [00:00<00:00, 243.25it/s]


Validation loss: 0.610507975021998
Epoch 16/20


100%|██████████| 47/47 [00:00<00:00, 172.69it/s, loss=0.0527]


Validating model...


100%|██████████| 12/12 [00:00<00:00, 320.43it/s]


Validation loss: 0.6122093945741653
Epoch 17/20


100%|██████████| 47/47 [00:00<00:00, 183.43it/s, loss=0.0525]


Validating model...


100%|██████████| 12/12 [00:00<00:00, 302.97it/s]


Validation loss: 0.612433577577273
Epoch 18/20


100%|██████████| 47/47 [00:00<00:00, 173.03it/s, loss=0.0522]


Validating model...


100%|██████████| 12/12 [00:00<00:00, 304.30it/s]


Validation loss: 0.6134383231401443
Epoch 19/20


100%|██████████| 47/47 [00:00<00:00, 136.39it/s, loss=0.0531]


Validating model...


100%|██████████| 12/12 [00:00<00:00, 327.16it/s]


Validation loss: 0.6131944706042608
Epoch 20/20


100%|██████████| 47/47 [00:00<00:00, 182.98it/s, loss=0.0507]


Validating model...


100%|██████████| 12/12 [00:00<00:00, 267.61it/s]

Validation loss: 0.6133474260568619
Best model found at epoch 4 with val loss 0.5671936323245367



  self.model.load_state_dict(torch.load(self.validation_tmp_path))


In [11]:
weak_full_dataset = data.TensorDataset(
    torch.tensor(weak_embeddings.x_train),
    F.one_hot(
        torch.tensor(weak_embeddings.y_train), num_classes=weak_embeddings.num_classes
    ).to(torch.float32),
)

weak_inference = DatasetInference(
    weak_model,
    weak_full_dataset,
)
weak_logits = weak_inference.inference(batch_size=batch_size)
weak_labels = torch.softmax(weak_logits, dim=1)

Using model device for inference: cpu


100%|██████████| 196/196 [00:00<00:00, 300.72it/s]


In [12]:
weak_embeddings = None
weak_test_dataset = None
weak_train_dataset = None
weak_val_dataset = None
weak_trainer = None

In [13]:
# Load the data
strong_embeddings = load_embeddings(model_name=strong_model_id, dataset_name=dataset_id)

# # Regenerate the split to see if this changes things?
# validation_split = 0.2
# strong_gt_val_inds, strong_gt_train_inds = index_random_split(strong_embeddings.x_train.shape[0], split_sizes=[validation_split, 1-validation_split], random_state=np.random.default_rng(seed=get_new_seed()))

strong_train_dataset = data.TensorDataset(
    torch.tensor(strong_embeddings.x_train[st_from_wk_train_inds]),
    F.one_hot(
        torch.tensor(strong_embeddings.y_train[st_from_wk_train_inds]), num_classes=strong_embeddings.num_classes
    ).to(torch.float32),
)
strong_val_dataset = data.TensorDataset(
    torch.tensor(strong_embeddings.x_train[st_from_wk_val_inds]),
    F.one_hot(
        torch.tensor(strong_embeddings.y_train[st_from_wk_val_inds]), num_classes=strong_embeddings.num_classes
    ).to(torch.float32),
)

In [14]:
# Instantiate the strong 
num_inputs = strong_embeddings.x_train.shape[1]
strong_gt_model = LogisticRegression(num_inputs, strong_embeddings.num_classes)

In [15]:
# Train the strong ground-truth model
weight_decay = 0.0
lr = 1e-3
num_epochs = 200
batch_size=256

n_iter = num_epochs * len(range(0, len(strong_train_dataset), batch_size))

optimizer = optim.Adam(strong_gt_model.parameters(), weight_decay=weight_decay, lr=lr)
schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_iter)

strong_trainer = DatasetTrainer(
    model=strong_gt_model,
    optimizer=optimizer,
    loss_fn=cross_entropy,
    dataset=strong_train_dataset,
    val_dataset=strong_val_dataset,
    scheduler=schedule,
)
strong_trainer.train(num_epochs=num_epochs, batch_size=batch_size, log_every=10)

Using model device for training: cpu
Using model device for training: cpu
Epoch 1/200


100%|██████████| 110/110 [00:00<00:00, 337.75it/s, loss=1.08]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 863.50it/s]


Validation loss: 1.0642236002853938
Epoch 11/200


100%|██████████| 110/110 [00:00<00:00, 373.51it/s, loss=0.375]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 836.71it/s]


Validation loss: 0.4101451901452882
Epoch 21/200


100%|██████████| 110/110 [00:00<00:00, 293.42it/s, loss=0.291]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 452.03it/s]


Validation loss: 0.34208696122680393
Epoch 31/200


100%|██████████| 110/110 [00:00<00:00, 296.54it/s, loss=0.285]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 747.51it/s]


Validation loss: 0.3149549067020416
Epoch 41/200


100%|██████████| 110/110 [00:00<00:00, 335.24it/s, loss=0.241]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 766.72it/s]


Validation loss: 0.3011037659432207
Epoch 51/200


100%|██████████| 110/110 [00:00<00:00, 311.07it/s, loss=0.199]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 837.74it/s]


Validation loss: 0.2923602568251746
Epoch 61/200


100%|██████████| 110/110 [00:00<00:00, 316.89it/s, loss=0.197]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 854.67it/s]


Validation loss: 0.2880891114473343
Epoch 71/200


100%|██████████| 110/110 [00:00<00:00, 373.59it/s, loss=0.19]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 822.00it/s]


Validation loss: 0.28373211569019724
Epoch 81/200


100%|██████████| 110/110 [00:00<00:00, 355.44it/s, loss=0.165]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 814.92it/s]


Validation loss: 0.2810156600815909
Epoch 91/200


100%|██████████| 110/110 [00:00<00:00, 366.03it/s, loss=0.15]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 802.39it/s]


Validation loss: 0.2809166721999645
Epoch 101/200


100%|██████████| 110/110 [00:00<00:00, 364.41it/s, loss=0.137]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 836.23it/s]


Validation loss: 0.28089579939842224
Epoch 111/200


100%|██████████| 110/110 [00:00<00:00, 275.72it/s, loss=0.151]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 606.40it/s]


Validation loss: 0.2805795499256679
Epoch 121/200


100%|██████████| 110/110 [00:00<00:00, 328.32it/s, loss=0.141]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 664.05it/s]


Validation loss: 0.2801506370306015
Epoch 131/200


100%|██████████| 110/110 [00:00<00:00, 321.77it/s, loss=0.147]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 672.89it/s]


Validation loss: 0.2796568460762501
Epoch 141/200


100%|██████████| 110/110 [00:00<00:00, 338.91it/s, loss=0.149]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 759.75it/s]


Validation loss: 0.27978487259575296
Epoch 151/200


100%|██████████| 110/110 [00:00<00:00, 313.72it/s, loss=0.138]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 772.42it/s]


Validation loss: 0.27967248750584467
Epoch 161/200


100%|██████████| 110/110 [00:00<00:00, 335.50it/s, loss=0.136]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 814.46it/s]


Validation loss: 0.279873682984284
Epoch 171/200


100%|██████████| 110/110 [00:00<00:00, 349.60it/s, loss=0.142]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 765.45it/s]


Validation loss: 0.27996173075267244
Epoch 181/200


100%|██████████| 110/110 [00:00<00:00, 347.59it/s, loss=0.141]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 509.86it/s]


Validation loss: 0.27999495023063253
Epoch 191/200


100%|██████████| 110/110 [00:00<00:00, 371.03it/s, loss=0.147]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 541.74it/s]


Validation loss: 0.27998711221984457
Best model found at epoch 127 with val loss 0.27938142738171984


In [16]:
# validation_split = 0.2
# val_inds, train_inds = index_random_split(strong_embeddings.x_train.shape[0], split_size=validation_split, random_state=np.random.default_rng(seed=get_new_seed()))

weakly_labelled_train_dataset = data.TensorDataset(
    torch.tensor(strong_embeddings.x_train[st_from_wk_train_inds]),
    weak_labels[st_from_wk_train_inds],
    weak_logits[st_from_wk_train_inds],
)
weakly_labelled_val_dataset = data.TensorDataset(
    torch.tensor(strong_embeddings.x_train[st_from_wk_val_inds]),
    weak_labels[st_from_wk_val_inds],
    weak_logits[st_from_wk_val_inds],
)

In [17]:
# Instantiate the strong from weak
num_inputs = strong_embeddings.x_train.shape[1]
strong_from_wk_model = LogisticRegression(num_inputs, strong_embeddings.num_classes)

In [18]:
# Train the strong model from weak
weight_decay = 0.0
lr = 1e-3
num_epochs = 200
batch_size = 256

n_iter = num_epochs * len(range(0, len(weakly_labelled_train_dataset), batch_size))

optimizer = optim.Adam(
    strong_from_wk_model.parameters(), weight_decay=weight_decay, lr=lr
)
schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_iter)

strong_from_wk_trainer = DatasetTrainer(
    model=strong_from_wk_model,
    optimizer=optimizer,
    loss_fn=reverse_loss(logistic_bregman),
    dataset=weakly_labelled_train_dataset,
    val_dataset=weakly_labelled_val_dataset,
    scheduler=schedule,
    use_label_logits=True,
)
strong_from_wk_trainer.train(
    num_epochs=num_epochs, batch_size=batch_size, log_every=10
)

Using model device for training: cpu
Using model device for training: cpu
Epoch 1/200


100%|██████████| 110/110 [00:00<00:00, 253.34it/s, loss=2.03]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 261.30it/s]


Validation loss: 1.8740744037287576
Epoch 11/200


100%|██████████| 110/110 [00:00<00:00, 210.94it/s, loss=0.413]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 679.53it/s]


Validation loss: 0.4409416786261967
Epoch 21/200


100%|██████████| 110/110 [00:00<00:00, 288.08it/s, loss=0.339]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 580.88it/s]


Validation loss: 0.3806814860020365
Epoch 31/200


100%|██████████| 110/110 [00:00<00:00, 292.95it/s, loss=0.319]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 588.43it/s]


Validation loss: 0.3634028413466045
Epoch 41/200


100%|██████████| 110/110 [00:00<00:00, 289.21it/s, loss=0.302]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 579.70it/s]


Validation loss: 0.35642346101147787
Epoch 51/200


100%|██████████| 110/110 [00:00<00:00, 316.55it/s, loss=0.267]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 624.88it/s]


Validation loss: 0.352672387446676
Epoch 61/200


100%|██████████| 110/110 [00:00<00:00, 302.36it/s, loss=0.262]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 447.15it/s]


Validation loss: 0.351225702890328
Epoch 71/200


100%|██████████| 110/110 [00:00<00:00, 301.22it/s, loss=0.258]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 670.61it/s]


Validation loss: 0.3508391646402223
Epoch 81/200


100%|██████████| 110/110 [00:00<00:00, 296.00it/s, loss=0.252]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 631.05it/s]


Validation loss: 0.3505018172519548
Epoch 91/200


100%|██████████| 110/110 [00:00<00:00, 289.43it/s, loss=0.235]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 601.46it/s]


Validation loss: 0.3509417080453464
Epoch 101/200


100%|██████████| 110/110 [00:00<00:00, 254.98it/s, loss=0.234]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 661.42it/s]


Validation loss: 0.3522466591426304
Epoch 111/200


100%|██████████| 110/110 [00:00<00:00, 299.67it/s, loss=0.232]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 673.95it/s]


Validation loss: 0.35135123133659363
Epoch 121/200


100%|██████████| 110/110 [00:00<00:00, 283.98it/s, loss=0.221]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 669.18it/s]


Validation loss: 0.35259755487952915
Epoch 131/200


100%|██████████| 110/110 [00:00<00:00, 279.81it/s, loss=0.237]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 668.14it/s]


Validation loss: 0.35254899518830435
Epoch 141/200


100%|██████████| 110/110 [00:00<00:00, 309.41it/s, loss=0.218]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 681.48it/s]


Validation loss: 0.3529810405203274
Epoch 151/200


100%|██████████| 110/110 [00:00<00:00, 294.16it/s, loss=0.232]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 642.82it/s]


Validation loss: 0.3534585641963141
Epoch 161/200


100%|██████████| 110/110 [00:00<00:00, 291.31it/s, loss=0.227]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 638.10it/s]


Validation loss: 0.3535072239381926
Epoch 171/200


100%|██████████| 110/110 [00:00<00:00, 291.29it/s, loss=0.217]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 603.32it/s]


Validation loss: 0.35359590713466915
Epoch 181/200


100%|██████████| 110/110 [00:00<00:00, 291.08it/s, loss=0.223]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 422.76it/s]


Validation loss: 0.35366862054382053
Epoch 191/200


100%|██████████| 110/110 [00:00<00:00, 228.86it/s, loss=0.222]


Validating model...


100%|██████████| 28/28 [00:00<00:00, 631.29it/s]


Validation loss: 0.35368819002594265
Best model found at epoch 73 with val loss 0.3496598643915994


In [19]:
weak_embeddings = load_embeddings(model_name=weak_model_id, dataset_name=dataset_id)
strong_embeddings = load_embeddings(model_name=strong_model_id, dataset_name=dataset_id)
estimate_misfit_dataset = data.TensorDataset(
    torch.tensor(strong_embeddings.x_test),
    torch.tensor(weak_embeddings.x_test),
    F.one_hot(torch.tensor(weak_embeddings.y_test), num_classes=weak_embeddings.num_classes).to(torch.float32),
)

In [20]:
misfit_estimator = EstimateKLMisfit(
    strong_model=strong_from_wk_model,
    weak_model=weak_model,
    strong_gt_model=strong_gt_model,
    dataset=estimate_misfit_dataset,
)
result = misfit_estimator.estimate(
    batch_size=256, device=torch.device(DeviceType.CPU.value)
)
pprint(result)

100%|██████████| 40/40 [00:00<00:00, 128.24it/s, gt_to_st_xe=0.424, gt_to_wk_xe=0.571, stgt_to_st_xe=0.43, stgt_to_wk_xe=0.735, stgt_to_st=0.144, stgt_to_wk=0.449, st_to_wk=0.349, wk_to_st=0.396, gain_xe=0.148, misfit_xe_error=-0.202, stgt_gain_xe=0.306, stgt_misfit_xe_error=-0.0434, stgt_gain=0.306, stgt_misfit_error=-0.0434, st_ent=0.444, wk_ent=0.483, stgt_ent=0.286] 

{'gain_xe__err': 0.017217382788658142,
 'gain_xe__mean': 0.14752988517284393,
 'gain_xe__std': 0.8608691096305847,
 'gt_to_st_xe__err': 0.016732661053538322,
 'gt_to_st_xe__mean': 0.42387655377388,
 'gt_to_st_xe__std': 0.836633026599884,
 'gt_to_wk_xe__err': 0.021549470722675323,
 'gt_to_wk_xe__mean': 0.5714066624641418,
 'gt_to_wk_xe__std': 1.077473521232605,
 'misfit_xe_error__err': 0.01751331053674221,
 'misfit_xe_error__mean': -0.20164009928703308,
 'misfit_xe_error__std': 0.8756654858589172,
 'st_ent__err': 0.008803303353488445,
 'st_ent__mean': 0.4440767168998718,
 'st_ent__std': 0.4401651620864868,
 'st_to_wk__err': 0.011510785669088364,
 'st_to_wk__mean': 0.34917014837265015,
 'st_to_wk__std': 0.5755392909049988,
 'stgt_ent__err': 0.0075825853273272514,
 'stgt_ent__mean': 0.2859863340854645,
 'stgt_ent__std': 0.3791292607784271,
 'stgt_gain__err': 0.013827413320541382,
 'stgt_gain__mean': 0.3057228922843933,
 'stgt_gain__std': 0.6913706660270691,
 'stgt_gain_xe__err': 0.0138274


