In [None]:
from pprint import pprint

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import Adam, SGD
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as T

from src.params import DatasetName, DEFAULT_DEVICE, CPU_GENERATOR
from src.datasets import load_dataset, TRANSFORM_RGB, TRANSFORM_BASE
from src.models import MLP, TruncatedMLP, PrimalThreshold, HomogenousMixtureModel, BagOfDecisionBoundaries
from src.datasets import SyntheticDataSampler
from src.engine import DatasetTrainer, SamplingTrainer, DatasetInference
from src.metrics import cross_entropy_with_logits, logistic_bregman, reverse_loss, accuracy_with_logits, kl_divergence

print(DEFAULT_DEVICE)

cuda


# Initialize the hidden representations

In [2]:
dataset_name = DatasetName.MNIST

transform = T.Compose([T.ToTensor(), T.Lambda(lambda x: torch.flatten(x)), T.Lambda(lambda x: (x - x.mean()) / x.std()),])
num_classes = 10
target_transform = lambda x: F.one_hot(torch.tensor(x), num_classes=10).to(torch.float32)
# num_classes = 2
# target_transform = lambda x: torch.tensor([x <= 4, x > 4]).to(torch.float32)
train_dataset = load_dataset(dataset_name, split="train", transform=transform, target_transform=target_transform)

In [3]:
# Get some stats on the mnist dataset 
print(f"Number of samples: {len(train_dataset)}")
print(f"Number of classes: {len(train_dataset.classes)}")

tmp_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
sample = next(iter(tmp_loader))

data_sample_dim = sample[0].shape[1:]
print(f"Sample dim: {data_sample_dim}")
print(f"Target shape: {sample[1].shape[1:]}")
print(f"Sample Mean: {sample[0].mean()}")
print(f"Sample Std: {sample[0].std()}")
# sample[0]

Number of samples: 60000
Number of classes: 10
Sample dim: torch.Size([784])
Target shape: torch.Size([10])
Sample Mean: -8.089201486427555e-08
Sample Std: 0.9999999403953552


In [4]:
input_dim = data_sample_dim[0]
hidden_dim = 50
target_num_layers = 2
gt_model = MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim= num_classes, num_layers=target_num_layers)
gt_representation = gt_model.representation

In [None]:
gt_features_optimizer = Adam(gt_model.parameters(), lr=1e-3)
gt_features_trainer = DatasetTrainer(gt_model, optimizer=gt_features_optimizer, loss_fn=cross_entropy_with_logits, metrics=[accuracy_with_logits], dataset=train_dataset)

print(f"Initializing the ground truth model representation on {dataset_name}")
gt_features_trainer.train(num_epochs=5, batch_size=256, average_window=10)

Using model device for training: cpu
Using model device for training: cpu
Initializing the ground truth model representation on mnist


100%|██████████| 5/5 [00:44<00:00,  8.97s/it, loss=0.105, up_norm=0.0491, acc=0.967] 


In [None]:
var = 1.0
gt_model.finetune_scale = 1.
gt_data_sampler = SyntheticDataSampler(model=gt_model, input_dim=input_dim, output_dim=num_classes, var=var)
gt_data_sampler.sample(10)

Using model device for synthetic data generation: cpu


(tensor([[-0.2233,  0.5832,  1.9655,  ..., -0.1632,  0.4949, -1.5997],
         [ 0.6048,  1.5956,  0.1239,  ..., -0.4499, -0.4875,  0.5903],
         [ 0.0948,  1.1714,  0.2398,  ...,  0.8020, -0.3249,  0.4773],
         ...,
         [ 1.2762, -0.1854,  0.1299,  ...,  0.5097,  0.3059, -0.3169],
         [ 1.2841,  0.4792, -0.2305,  ..., -1.3562, -0.3889, -1.1014],
         [ 0.5049,  0.0780, -1.0979,  ..., -1.1325,  1.3271, -0.0626]]),
 tensor([[0.0449, 0.1271, 0.1839, 0.1913, 0.0346, 0.1404, 0.0289, 0.1373, 0.0579,
          0.0537],
         [0.0863, 0.0697, 0.1057, 0.3563, 0.0461, 0.1863, 0.0244, 0.0871, 0.0170,
          0.0211],
         [0.0027, 0.5544, 0.0182, 0.0245, 0.0402, 0.2186, 0.0192, 0.0811, 0.0113,
          0.0298],
         [0.0034, 0.1765, 0.0232, 0.3806, 0.0054, 0.2754, 0.0109, 0.0057, 0.1050,
          0.0140],
         [0.0171, 0.1994, 0.3994, 0.0759, 0.0533, 0.1144, 0.0329, 0.0232, 0.0532,
          0.0313],
         [0.0469, 0.0659, 0.1420, 0.1261, 0.0543, 0.3

In [21]:
# Now train primal model
features_model = BagOfDecisionBoundaries(input_dim=hidden_dim, output_dim=hidden_dim * 100, num_states=num_classes) \
    .prepend(gt_representation, input_dim=input_dim, output_dim=hidden_dim)
# features_model = PrimalThreshold(thresholds).prepend(gt_representation, input_dim=input_dim, output_dim=hidden_dim)
gt_primal_model = HomogenousMixtureModel(
    features_model=features_model,
    output_dim=num_classes,
    no_features_grad=True
)


In [None]:
# Test training the primal model on the actual data
gt_primal_optimizer = SGD(gt_primal_model.mixture_layer.parameters(), lr=1)
gt_primal_trainer = DatasetTrainer(gt_primal_model, optimizer=gt_primal_optimizer, loss_fn=cross_entropy_with_logits, metrics=[accuracy_with_logits], dataset=train_dataset)

gt_primal_trainer.train(num_epochs=10, batch_size=256, average_window=10)

Using model device for training: cpu
Using model device for training: cpu


 30%|███       | 3/10 [04:02<09:25, 80.83s/it, loss=1.98, up_norm=0.038, acc=0.473] 


KeyboardInterrupt: 

In [None]:
inference = DatasetInference(gt_primal_model, train_dataset)
res = inference.inference(batch_size=256)
accuracy_with_logits(res, train_dataset.targets)

In [None]:
# Train the primal model to match the ground truth model
gt_primal_optimizer = SGD(gt_primal_model.mixture_layer.parameters(), lr=1)
gt_primal_trainer = SamplingTrainer(gt_primal_model, optimizer=gt_primal_optimizer, loss_fn=reverse_loss(kl_divergence), metrics=[accuracy_with_logits], data_sampler=gt_data_sampler)

print(f"Initializing the ground truth mixture parameters on {dataset_name}")
gt_primal_trainer.train(num_samples=1000000, batch_size=256, average_window=10)

Using model device for training: cpu
Initializing the ground truth mixture parameters on mnist


  5%|▌         | 203/3907 [00:54<16:38,  3.71it/s, loss=0.206, up_norm=0.108] 


RuntimeError: index -1 is out of bounds for dimension 0 with size 5000

In [12]:
truncation_factor = 0.5
weak_num_layers = 2
weak_hidden_dim = 50
wk_init_model = TruncatedMLP(input_dim=input_dim, hidden_dim=weak_hidden_dim, output_dim=len(train_dataset.classes), num_layers=weak_num_layers, truncation_factor=truncation_factor)

In [None]:
wk_optimizer = Adam(wk_init_model.parameters(), lr=1e-3)
wk_trainer = DatasetTrainer(wk_init_model, optimizer=wk_optimizer, loss_fn=cross_entropy_with_logits, dataset=train_dataset, metrics=[accuracy_with_logits])

Using model device for training: cpu
Using model device for training: cpu


In [15]:
print(f"Initializing the weak model on {dataset_name}")
wk_trainer.train(num_epochs=2, batch_size=64, average_window=10)

Initializing the weak model on mnist


100%|██████████| 2/2 [00:08<00:00,  4.45s/it, loss=0.279, grad_norm=1.54, acc=0.925]


# Do the weak to strong transfer

In [None]:
# Setup the ground truth model. We increase confidence via finetune_scale to lower ground-truth entropy
with torch.no_grad():
    gt_primal_model.mixture_layer.conditional_weights.data *= 1e3

In [27]:
torch.softmax(gt_primal_model.mixture_layer.mixture_weights.data, dim=0)

tensor([0., 0., 0.,  ..., 0., 0., 0.])

In [17]:
st_primal_model = HomogenousMixtureModel(
    features_model=features_model,
    output_dim=num_classes,
    no_features_grad=True
)

In [18]:
# Initialize the weak model to use the trained weak representations
wk_model = TruncatedMLP(
    input_dim=input_dim,
    hidden_dim=weak_hidden_dim,
    output_dim=num_classes,
    num_layers=weak_num_layers,
    representation_state_dict=wk_init_model.representation.state_dict(),
    truncation_factor=wk_init_model.truncation_factor,
)

## Train the weak model

In [14]:
# Train the weak model
wk_train_num_samples = 1000000
optimizer = AdamW(wk_model.finetune.parameters(), lr=1e-3)
gt_to_wk_trainer = SamplingTrainer(model=wk_model, optimizer=optimizer, loss_fn=binary_cross_entropy, data_sampler=gt_data_sampler)
gt_to_wk_trainer.train(num_samples=wk_train_num_samples, batch_size=64, average_window=10)

Using model device for training: cpu


100%|██████████| 15625/15625 [00:40<00:00, 386.16it/s, loss=0.28] 


## Train the strong model using the weak

In [None]:
# Wk data sampler
wk_data_sampler = SyntheticDataSampler(model=wk_model, input_dim=input_dim, output_dim=task_output_dim, var=var)
wk_data_sampler.sample(10)

In [19]:


# Train the strong model - this uses the proper loss function for our misfit inequality
st_train_num_samples=2000000
optimizer = AdamW(st_model.finetune.parameters(), lr=1e-3)
st_to_wk_trainer = SamplingTrainer(model=st_model, optimizer=optimizer, loss_fn=reverse_loss(logistic_bregman_binary), data_sampler=wk_data_sampler, use_label_logits=True)
st_to_wk_trainer.train(num_samples=st_train_num_samples, batch_size=64, average_window=10)

# Train the strong model - this uses cross-entropy. It does not obey our misfit inequality, but still seems to work. Why?
# st_train_num_samples=2000000
# optimizer = AdamW(st_model.finetune.parameters(), lr=1e-3)
# st_to_wk_trainer = SamplingTrainer(model=st_model, optimizer=optimizer, loss_fn=binary_cross_entropy, data_sampler=wk_data_sampler)
# st_to_wk_trainer.train(num_samples=st_train_num_samples, batch_size=64, average_window=10)

Using model device for training: cpu


100%|██████████| 31250/31250 [01:29<00:00, 350.91it/s, loss=0.027] 


# Estimate misfit

In [20]:
gt_data_sampler = SyntheticDataSampler(model=gt_model, input_dim=input_dim, output_dim=task_output_dim, var=var)
gt_data_sampler.sample(10)

Using model device for synthetic data generation: cpu


DataBatch(x=tensor([[ 1.0738,  0.3794,  0.6472,  ...,  1.0327, -0.6142,  0.8089],
        [-1.1780, -0.4867,  0.8217,  ..., -0.2647,  0.0819, -1.3179],
        [-0.5802,  0.2627, -1.2674,  ...,  0.5013,  1.5323, -0.6210],
        ...,
        [ 0.2662, -1.5649,  1.1858,  ...,  0.3926,  0.5386, -0.6818],
        [ 0.6524, -0.5471, -0.8080,  ..., -0.5471, -0.9330, -0.4172],
        [-0.0048,  0.6548, -0.2040,  ..., -0.6937, -0.5767,  0.7211]]), y=tensor([[6.5902e-05, 3.0100e-01, 4.6389e-05, 9.9193e-05, 8.0595e-06, 1.7871e-04,
         9.9318e-05, 1.5810e-04, 5.9491e-04, 6.9775e-01],
        [6.6296e-04, 1.0963e-03, 1.9430e-04, 1.9886e-05, 2.6151e-04, 2.9914e-03,
         1.5754e-06, 9.6555e-01, 1.6108e-04, 2.9058e-02],
        [6.2568e-07, 5.7398e-01, 2.0274e-07, 6.0367e-07, 1.3372e-06, 5.4663e-06,
         2.8296e-10, 8.1377e-08, 8.7887e-04, 4.2513e-01],
        [1.2365e-03, 3.7693e-03, 6.1748e-01, 9.8452e-03, 4.8323e-02, 1.4312e-02,
         3.1745e-04, 8.2711e-02, 1.0323e-01, 1.1878e-

In [21]:
from src.measurements import EstimateKLMisfit

misfit_estimator = EstimateKLMisfit(strong_model=st_model, weak_model=wk_model, data_sampler=gt_data_sampler)
result = misfit_estimator.estimate(tolerance=1e-2, batch_size=64, max_samples=1000000)
pprint(result)

  0%|          | 1000000/5093263872 [01:20<113:58:24, 12410.93it/s, gt_to_st_xe=1.68, gt_to_wk_xe=1.92, st_to_wk=0.235, wk_to_st=0.218, misfit_xe=-0.00133, gt_to_st=0.954, gt_to_wk=1.19, gt_ent=0.731, st_ent=1.99, wk_ent=1.81, misfit=-0.00133]   

Reached max samples 1000000.
Errors: {'gt_to_st_xe': 0.0010214232606813312, 'gt_to_wk_xe': 0.0014456016942858696, 'st_to_wk': 0.00034318119287490845, 'wk_to_st': 0.0003097380104009062, 'misfit_xe': 0.000950468354858458, 'gt_to_st': 0.0008139380370266736, 'gt_to_wk': 0.0013162209652364254, 'gt_ent': 0.001011607819236815, 'st_ent': 0.00032130704494193196, 'wk_ent': 0.0005803710082545877, 'misfit': 0.000950468354858458}
{'gt_ent__err': 0.001011607819236815,
 'gt_ent__mean': 0.7306965589523315,
 'gt_ent__std': 0.5058038830757141,
 'gt_to_st__err': 0.0008139380370266736,
 'gt_to_st__mean': 0.9535191059112549,
 'gt_to_st__std': 0.40696901082992554,
 'gt_to_st_xe__err': 0.0010214232606813312,
 'gt_to_st_xe__mean': 1.6842186450958252,
 'gt_to_st_xe__std': 0.5107116103172302,
 'gt_to_wk__err': 0.0013162209652364254,
 'gt_to_wk__mean': 1.1871635913848877,
 'gt_to_wk__std': 0.6581104397773743,
 'gt_to_wk_xe__err': 0.0014456016942858696,
 'gt_to_wk_xe__mean': 1.9178581237792969,
 'gt_to_wk_xe__std


