In [1]:
from pprint import pprint

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import AdamW
import torch.nn.functional as F
import torch.utils.data as data
from torchvision.datasets import MNIST, CIFAR10
import torchvision.transforms as T

from src.params import DATA_DIR
from src.models import MLP, TruncatedMLP
from src.datasets import SyntheticDataSampler
from src.engine import DatasetTrainer, SamplingTrainer
from src.metrics import binary_cross_entropy, logistic_bregman_binary, reverse_loss

# Initialize the hidden representations

In [2]:
dataset = CIFAR10

init_dataset = dataset(
    root=DATA_DIR,
    download=True,
    transform=T.Compose([T.ToTensor(), T.Lambda(lambda x: torch.flatten(x)), T.Lambda(lambda x: (x - x.mean()) / x.std()),]),
    target_transform=lambda x: F.one_hot(torch.tensor(x), num_classes=10).to(torch.float32),
)

Files already downloaded and verified


In [3]:
# Get some stats on the mnist dataset 
print(f"Number of samples: {len(init_dataset)}")
print(f"Number of classes: {len(init_dataset.classes)}")

tmp_loader = data.DataLoader(init_dataset, batch_size=1, shuffle=True)
sample = next(iter(tmp_loader))

data_sample_dim = sample[0].shape[1]
print(f"Sample dim: {data_sample_dim}")
print(f"Target shape: {sample[1].shape}")
print(f"Sample Mean: {sample[0].mean()}")
print(f"Sample Std: {sample[0].std()}")
# sample[0]

Number of samples: 50000
Number of classes: 10
Sample dim: 3072
Target shape: torch.Size([1, 10])
Sample Mean: -1.514951435410694e-07
Sample Std: 1.0


In [4]:
input_dim = data_sample_dim
hidden_dim = 50
target_num_layers = 2
init_model = MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=len(init_dataset.classes), num_layers=target_num_layers)



In [None]:
truncation_factor = 0.5
weak_num_layers = 2
weak_hidden_dim = 50
wk_init_model = TruncatedMLP(input_dim=input_dim, hidden_dim=weak_hidden_dim, output_dim=len(init_dataset.classes), num_layers=weak_num_layers, truncation_factor=truncation_factor)

In [5]:
gt_optimizer = AdamW(init_model.parameters(), lr=1e-3)
gt_trainer = DatasetTrainer(init_model, optimizer=gt_optimizer, loss_fn=binary_cross_entropy, dataset=init_dataset)

Using model device for training: cpu
Using model device for training: cpu
Using model device for training: cpu
Using model device for training: cpu


In [None]:
print(f"Initializing the ground truth model on {dataset.__name__}")
gt_trainer.train(num_epochs=2, batch_size=64, average_window=10)

In [None]:
wk_optimizer = AdamW(wk_init_model.parameters(), lr=1e-3)
wk_trainer = DatasetTrainer(wk_init_model, optimizer=wk_optimizer, loss_fn=binary_cross_entropy, dataset=init_dataset)

In [6]:
print(f"Initializing the weak model on {dataset.__name__}")
wk_trainer.train(num_epochs=2, batch_size=64, average_window=10)

Initializing the ground truth model on CIFAR10
Epoch 1/2


100%|██████████| 782/782 [00:06<00:00, 124.01it/s, loss=0.236]


Epoch 2/2


100%|██████████| 782/782 [00:06<00:00, 122.69it/s, loss=0.235]


Initializing the weak model on CIFAR10
Epoch 1/2


100%|██████████| 782/782 [00:06<00:00, 123.74it/s, loss=0.254]


Epoch 2/2


100%|██████████| 782/782 [00:06<00:00, 126.15it/s, loss=0.245]


# Do the weak to strong transfer

In [None]:
# Setup the ground truth model. We increase confidence via finetune_scale to lower ground-truth entropy
task_output_dim = len(init_dataset.classes)
gt_model = MLP(
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    output_dim=task_output_dim,
    num_layers=target_num_layers,
    finetune_scale=10.0
)
gt_model.load_state_dict(init_model.state_dict())

In [17]:
st_model = MLP(
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    output_dim=task_output_dim,
    num_layers=target_num_layers,
    representation_state_dict=init_model.representation.state_dict(),
)

In [18]:
# Some sanity-checks
assert (
    gt_model.representation.state_dict().keys()
    == st_model.representation.state_dict().keys()
)
assert all(
    [
        torch.all(
            gt_model.representation.state_dict()[k]
            == st_model.representation.state_dict()[k]
        )
        for k in gt_model.representation.state_dict().keys()
    ]
)
assert torch.any(gt_model.finetune.weight != st_model.finetune.weight)

In [12]:
# Initialize the weak model to use the trained weak representations
wk_model = TruncatedMLP(
    input_dim=input_dim,
    hidden_dim=weak_hidden_dim,
    output_dim=task_output_dim,
    num_layers=weak_num_layers,
    representation_state_dict=wk_init_model.representation.state_dict(),
    truncation_factor=wk_init_model.truncation_factor,
)

In [13]:
var = 1.0
gt_data_sampler = SyntheticDataSampler(model=gt_model, input_dim=input_dim, output_dim=task_output_dim, var=var)
gt_data_sampler.sample(10)

Using model device for synthetic data generation: cpu


DataBatch(x=tensor([[ 1.1158, -0.7585,  0.2394,  ...,  1.0146,  0.0405,  0.9326],
        [-0.2947, -1.6133,  0.7281,  ...,  1.4950,  0.4876,  1.7895],
        [ 0.3170, -0.2431,  1.2137,  ..., -0.9399, -0.0925, -0.3383],
        ...,
        [-0.8848, -0.8457,  1.9605,  ..., -1.4294, -0.3669,  0.3277],
        [ 1.6444, -1.0170,  2.4005,  ..., -0.9120, -1.2872,  0.3513],
        [-1.1767,  0.1469,  0.0938,  ..., -0.2114,  1.1745, -0.0937]]), y=tensor([[9.8536e-01, 4.2823e-03, 1.0059e-02, 3.1852e-07, 1.3077e-04, 7.5306e-05,
         1.2698e-10, 3.9368e-05, 3.3150e-06, 5.0069e-05],
        [5.9405e-01, 2.3071e-02, 1.6340e-02, 3.3485e-02, 1.3879e-02, 2.0578e-01,
         4.2260e-02, 7.1515e-04, 6.6332e-02, 4.0903e-03],
        [5.4928e-04, 2.5970e-06, 5.6081e-04, 3.3222e-05, 1.4451e-04, 1.8800e-05,
         5.5667e-06, 7.1527e-05, 9.9846e-01, 1.5498e-04],
        [1.1737e-02, 3.9084e-01, 1.1168e-05, 6.8631e-03, 9.0081e-06, 1.5769e-02,
         5.0055e-04, 1.3715e-08, 5.6841e-01, 5.8623e-

## Train the weak model

In [14]:
# Train the weak model
wk_train_num_samples = 1000000
optimizer = AdamW(wk_model.finetune.parameters(), lr=1e-3)
gt_to_wk_trainer = SamplingTrainer(model=wk_model, optimizer=optimizer, loss_fn=binary_cross_entropy, data_sampler=gt_data_sampler)
gt_to_wk_trainer.train(num_samples=wk_train_num_samples, batch_size=64, average_window=10)

Using model device for training: cpu


100%|██████████| 15625/15625 [00:40<00:00, 386.16it/s, loss=0.28] 


## Train the strong model using the weak

In [None]:
# Wk data sampler
wk_data_sampler = SyntheticDataSampler(model=wk_model, input_dim=input_dim, output_dim=task_output_dim, var=var)
wk_data_sampler.sample(10)

In [19]:


# Train the strong model - this uses the proper loss function for our misfit inequality
st_train_num_samples=2000000
optimizer = AdamW(st_model.finetune.parameters(), lr=1e-3)
st_to_wk_trainer = SamplingTrainer(model=st_model, optimizer=optimizer, loss_fn=reverse_loss(logistic_bregman_binary), data_sampler=wk_data_sampler, use_label_logits=True)
st_to_wk_trainer.train(num_samples=st_train_num_samples, batch_size=64, average_window=10)

# Train the strong model - this uses cross-entropy. It does not obey our misfit inequality, but still seems to work. Why?
# st_train_num_samples=2000000
# optimizer = AdamW(st_model.finetune.parameters(), lr=1e-3)
# st_to_wk_trainer = SamplingTrainer(model=st_model, optimizer=optimizer, loss_fn=binary_cross_entropy, data_sampler=wk_data_sampler)
# st_to_wk_trainer.train(num_samples=st_train_num_samples, batch_size=64, average_window=10)

Using model device for training: cpu


100%|██████████| 31250/31250 [01:29<00:00, 350.91it/s, loss=0.027] 


# Estimate misfit

In [20]:
gt_data_sampler = SyntheticDataSampler(model=gt_model, input_dim=input_dim, output_dim=task_output_dim, var=var)
gt_data_sampler.sample(10)

Using model device for synthetic data generation: cpu


DataBatch(x=tensor([[ 1.0738,  0.3794,  0.6472,  ...,  1.0327, -0.6142,  0.8089],
        [-1.1780, -0.4867,  0.8217,  ..., -0.2647,  0.0819, -1.3179],
        [-0.5802,  0.2627, -1.2674,  ...,  0.5013,  1.5323, -0.6210],
        ...,
        [ 0.2662, -1.5649,  1.1858,  ...,  0.3926,  0.5386, -0.6818],
        [ 0.6524, -0.5471, -0.8080,  ..., -0.5471, -0.9330, -0.4172],
        [-0.0048,  0.6548, -0.2040,  ..., -0.6937, -0.5767,  0.7211]]), y=tensor([[6.5902e-05, 3.0100e-01, 4.6389e-05, 9.9193e-05, 8.0595e-06, 1.7871e-04,
         9.9318e-05, 1.5810e-04, 5.9491e-04, 6.9775e-01],
        [6.6296e-04, 1.0963e-03, 1.9430e-04, 1.9886e-05, 2.6151e-04, 2.9914e-03,
         1.5754e-06, 9.6555e-01, 1.6108e-04, 2.9058e-02],
        [6.2568e-07, 5.7398e-01, 2.0274e-07, 6.0367e-07, 1.3372e-06, 5.4663e-06,
         2.8296e-10, 8.1377e-08, 8.7887e-04, 4.2513e-01],
        [1.2365e-03, 3.7693e-03, 6.1748e-01, 9.8452e-03, 4.8323e-02, 1.4312e-02,
         3.1745e-04, 8.2711e-02, 1.0323e-01, 1.1878e-

In [21]:
from src.measurements import EstimateKLMisfit

misfit_estimator = EstimateKLMisfit(strong_model=st_model, weak_model=wk_model, data_sampler=gt_data_sampler)
result = misfit_estimator.estimate(tolerance=1e-2, batch_size=64, max_samples=1000000)
pprint(result)

  0%|          | 1000000/5093263872 [01:20<113:58:24, 12410.93it/s, gt_to_st_xe=1.68, gt_to_wk_xe=1.92, st_to_wk=0.235, wk_to_st=0.218, misfit_xe=-0.00133, gt_to_st=0.954, gt_to_wk=1.19, gt_ent=0.731, st_ent=1.99, wk_ent=1.81, misfit=-0.00133]   

Reached max samples 1000000.
Errors: {'gt_to_st_xe': 0.0010214232606813312, 'gt_to_wk_xe': 0.0014456016942858696, 'st_to_wk': 0.00034318119287490845, 'wk_to_st': 0.0003097380104009062, 'misfit_xe': 0.000950468354858458, 'gt_to_st': 0.0008139380370266736, 'gt_to_wk': 0.0013162209652364254, 'gt_ent': 0.001011607819236815, 'st_ent': 0.00032130704494193196, 'wk_ent': 0.0005803710082545877, 'misfit': 0.000950468354858458}
{'gt_ent__err': 0.001011607819236815,
 'gt_ent__mean': 0.7306965589523315,
 'gt_ent__std': 0.5058038830757141,
 'gt_to_st__err': 0.0008139380370266736,
 'gt_to_st__mean': 0.9535191059112549,
 'gt_to_st__std': 0.40696901082992554,
 'gt_to_st_xe__err': 0.0010214232606813312,
 'gt_to_st_xe__mean': 1.6842186450958252,
 'gt_to_st_xe__std': 0.5107116103172302,
 'gt_to_wk__err': 0.0013162209652364254,
 'gt_to_wk__mean': 1.1871635913848877,
 'gt_to_wk__std': 0.6581104397773743,
 'gt_to_wk_xe__err': 0.0014456016942858696,
 'gt_to_wk_xe__mean': 1.9178581237792969,
 'gt_to_wk_xe__std


