In [1]:
from pprint import pprint

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import Adam, SGD
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import torchvision.transforms as T

from src.params import DatasetName, DEFAULT_DEVICE, CPU_GENERATOR
from src.datasets import load_dataset, TRANSFORM_RGB, TRANSFORM_BASE, ModelEmbeddingsLabeledDataset
from src.models import MLP, TruncatedMLP, PrimalThreshold, HomogenousMixtureModel, BagOfDecisionBoundaries, MeanThreshold
from src.datasets import SyntheticNormalDataSampler, SyntheticDatasetDataSampler
from src.engine import DatasetClassificationTrainer, SamplingClassificationTrainer, DatasetInference
from src.metrics import Accuracy, KLDivergence, CrossEntropy
from src.measurements import EstimatedLabeledModelLosses, EstimateWeakToStrong, LossSpec, StatSpec

print(DEFAULT_DEVICE)

cuda


In [2]:
dataset_name = DatasetName.CIFAR10

transform = T.Compose([T.ToTensor(), T.Lambda(lambda x: torch.flatten(x)), T.Lambda(lambda x: (x - x.mean()) / x.std()),])

num_classes = 10
target_transform = lambda x: F.one_hot(torch.tensor(x), num_classes=10).to(torch.float32)

train_dataset = load_dataset(dataset_name, split="train", transform=transform, target_transform=target_transform)

Files already downloaded and verified


In [3]:
# Get some stats on the mnist dataset 
print(f"Number of samples: {len(train_dataset)}")
print(f"Number of classes in dataset: {len(train_dataset.classes)}")

tmp_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
sample = next(iter(tmp_loader))

data_sample_dim = sample[0].shape[1:]
print(f"Sample dim: {data_sample_dim}")
print(f"Target shape: {sample[1].shape[1:]}")
print(f"Sample Mean: {sample[0].mean()}")
print(f"Sample Std: {sample[0].std()}")

Number of samples: 50000
Number of classes in dataset: 10
Sample dim: torch.Size([3072])
Target shape: torch.Size([10])
Sample Mean: -2.60770320892334e-08
Sample Std: 1.0


In [4]:
input_dim = data_sample_dim[0]
hidden_dim = 50
target_num_layers = 2
gt_init_model = MLP(
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    output_dim=num_classes,
    num_layers=target_num_layers,
)
representation = gt_init_model.representation

In [5]:
features_optimizer = Adam(gt_init_model.parameters(), lr=1e-3)
loss_fn = CrossEntropy(output_logits=True, label_logits=False)
accuracy_fn = Accuracy(output_logits=True, label_logits=False, hard=True)
features_trainer = DatasetClassificationTrainer(
    gt_init_model,
    optimizer=features_optimizer,
    loss_fn=loss_fn,
    metrics=[accuracy_fn],
    dataset=train_dataset,
)

print(f"Initializing the strong representation on {dataset_name}")
features_trainer.train(num_epochs=5, batch_size=256, average_window=10)

Using model device for training: cpu
Using model device for training: cpu
Initializing the strong representation on cifar10


100%|██████████| 5/5 [00:22<00:00,  4.42s/it, loss=1.28, up_norm=0.0937, grad_norm=1.47, accuracy=0.536]


In [12]:
from math import comb
num_combinations = 3
# Now train primal model
# features_model = BagOfDecisionBoundaries(
#     input_dim=hidden_dim, 
#     output_dim=comb(hidden_dim, num_combinations), 
#     num_states=2 ** num_combinations
# ).prepend(representation, input_dim=input_dim, output_dim=hidden_dim)
features_model = MeanThreshold(
    input_dim=hidden_dim,
    num_combinations=3
).prepend(representation, input_dim=input_dim, output_dim=hidden_dim)

In [13]:
# features_model = PrimalThreshold(thresholds).prepend(gt_representation, input_dim=input_dim, output_dim=hidden_dim)
primal_gt_model = HomogenousMixtureModel(
    features_model=features_model,
    output_dim=num_classes,
    no_features_grad=True,
    use_dual_weights=True
)

In [14]:
# Test training the primal model on the actual data
primal_gt_optimizer = Adam(primal_gt_model.mixture_layer.parameters(), lr=1e-1)
loss_fn = CrossEntropy(output_logits=False, label_logits=False)
accuracy_fn = Accuracy(output_logits=False, label_logits=False)
primal_trainer = DatasetClassificationTrainer(
    primal_gt_model,
    optimizer=primal_gt_optimizer,
    loss_fn=loss_fn,
    metrics=[accuracy_fn],
    dataset=train_dataset,
)

primal_trainer.train(
    num_epochs=2, batch_size=256, average_window=10, update_pbar_every=1
)

Using model device for training: cpu
Using model device for training: cpu


100%|██████████| 2/2 [01:38<00:00, 49.32s/it, loss=1.58, up_norm=1.72, grad_norm=0.0108, accuracy=0.464] 


# Do the weak to strong transfer

1. Train the weak model on the ground truth data
2. Train a primal strong model on the weak model's predictions
3. Also train a dual strong model for comparison

In [15]:
truncation_factor = 0.5
weak_num_layers = 2
weak_hidden_dim = 50
wk_model = TruncatedMLP(
    input_dim=input_dim,
    hidden_dim=weak_hidden_dim,
    output_dim=num_classes,
    num_layers=weak_num_layers,
    truncation_factor=truncation_factor,
)

In [16]:
wk_optimizer = Adam(wk_model.parameters(), lr=1e-3)
loss_fn = CrossEntropy(output_logits=True, label_logits=False)
accuracy_fn = Accuracy(output_logits=True, label_logits=False, hard=True)
wk_trainer = DatasetClassificationTrainer(
    wk_model,
    optimizer=wk_optimizer,
    loss_fn=loss_fn,
    metrics=[accuracy_fn],
    dataset=train_dataset,
)

print(f"Training the weak model on {dataset_name}")
wk_trainer.train(num_epochs=5, batch_size=256, average_window=10)

Using model device for training: cpu
Using model device for training: cpu
Training the weak model on cifar10


100%|██████████| 5/5 [00:20<00:00,  4.19s/it, loss=1.47, up_norm=0.0673, grad_norm=1.46, accuracy=0.475]


In [17]:
primal_st_model = HomogenousMixtureModel(
    features_model=features_model,
    output_dim=num_classes,
    no_features_grad=True,
    use_dual_weights=True,
)

dual_st_model = MLP(
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    output_dim=num_classes,
    num_layers=target_num_layers,
    representation_state_dict=gt_init_model.representation.state_dict(),
)

In [18]:
# Some sanity-checks
assert (
    primal_gt_model.features_model.state_dict().keys()
    == primal_st_model.features_model.state_dict().keys()
)
assert all(
    [
        torch.all(
            primal_gt_model.features_model.state_dict()[k]
            == primal_st_model.features_model.state_dict()[k]
        )
        for k in primal_st_model.features_model.state_dict().keys()
    ]
)

assert (
    dual_st_model.state_dict().keys()
    == gt_init_model.state_dict().keys()
)
assert all(
    [
        torch.all(
            dual_st_model.representation.state_dict()[k]
            == gt_init_model.representation.state_dict()[k]
        )
        for k in dual_st_model.representation.state_dict().keys()
    ]
)

In [19]:
# Make a data sampler for the weak model
wk_model_data_sampler = SyntheticDatasetDataSampler(model=wk_model, dataset=train_dataset, input_dim=input_dim, output_dim=num_classes)

Using model device for synthetic data generation: cpu


In [20]:
# Train the dual strong model on the weak model's outputs
dual_st_optimizer = Adam(dual_st_model.finetune.parameters(), lr=1e-3)
loss_fn = CrossEntropy(output_logits=True, label_logits=True)
accuracy_fn = Accuracy(output_logits=True, label_logits=True, hard=True)
dual_st_trainer = SamplingClassificationTrainer(
    dual_st_model,
    optimizer=dual_st_optimizer,
    loss_fn=loss_fn,
    metrics=[accuracy_fn],
    data_sampler=wk_model_data_sampler,
)

batch_size = 2 ** 10
dual_st_trainer.train(
    num_samples=1000 * batch_size, batch_size=batch_size, average_window=10, update_pbar_every=1
)

Using model device for training: cpu


100%|██████████| 1000/1000 [01:22<00:00, 12.17it/s, loss=1.63, up_norm=0.00283, grad_norm=0.207, accuracy=0.645]


In [21]:
# Train the primal strong model on the weak model's outputs
primal_st_optimizer = Adam(primal_st_model.mixture_layer.parameters(), lr=1e-1)
loss_fn = KLDivergence(output_logits=False, label_logits=True)
accuracy_fn = Accuracy(output_logits=False, label_logits=True, hard=True)
primal_st_trainer = SamplingClassificationTrainer(
    primal_st_model,
    optimizer=primal_st_optimizer,
    loss_fn=loss_fn,
    metrics=[accuracy_fn],
    data_sampler=wk_model_data_sampler,
)

batch_size = 2 ** 10
primal_st_trainer.train(
    num_samples=175 * batch_size, batch_size=batch_size, average_window=10, update_pbar_every=1
)

Using model device for training: cpu


100%|██████████| 175/175 [02:20<00:00,  1.24it/s, loss=0.428, up_norm=1.6, grad_norm=0.00206, accuracy=0.538] 


In [22]:
x, y, y_logits = wk_model_data_sampler.sample(2)
print(torch.sum(primal_gt_model(x),dim=1))
print(primal_gt_model.mixture_layer.num_states)

tensor([1.0000, 1.0000], grad_fn=<SumBackward1>)
8


In [24]:
a = torch.softmax(
    primal_gt_model.mixture_layer.conditional_weights, dim=0).view(
    primal_gt_model.mixture_layer.output_dim, primal_gt_model.mixture_layer.input_dim, primal_gt_model.mixture_layer.num_states
)
print(a.sum(dim=[0,1]))

tensor([1.0006, 1.0007, 1.0006, 1.0009, 1.0008, 1.0006, 1.0009, 1.0008],
       grad_fn=<SumBackward1>)


# Check the misfit amount

In [23]:
strong_models_dict = {
    "dual_gt": gt_init_model,
    "dual_st": dual_st_model,
    "primal_gt": primal_gt_model,
    "primal_st": primal_st_model,
}
weak_models_dict = {
    "wk": wk_model,
}
models_dict = {**strong_models_dict, **weak_models_dict}

losses_list = [
    # WRT the ground truth labels
    LossSpec(
        name1="dual_gt",
        name2=EstimateWeakToStrong.GT,
        loss_fn=CrossEntropy(output_logits=True, label_logits=False),
    ),
    LossSpec(
        name1="dual_st",
        name2=EstimateWeakToStrong.GT,
        loss_fn=CrossEntropy(output_logits=True, label_logits=False),
    ),
    LossSpec(
        name1="primal_gt",
        name2=EstimateWeakToStrong.GT,
        loss_fn=CrossEntropy(output_logits=False, label_logits=False),
    ),
    LossSpec(
        name1="primal_st",
        name2=EstimateWeakToStrong.GT,
        loss_fn=CrossEntropy(output_logits=False, label_logits=False),
    ),
    LossSpec(
        name1="wk",
        name2=EstimateWeakToStrong.GT,
        loss_fn=CrossEntropy(output_logits=True, label_logits=False),
    ),
    LossSpec(
        name1="dual_gt",
        name2=EstimateWeakToStrong.GT,
        loss_fn=Accuracy(output_logits=True, label_logits=False, hard=True),
    ),
    LossSpec(
        name1="dual_st",
        name2=EstimateWeakToStrong.GT,
        loss_fn=Accuracy(output_logits=True, label_logits=False, hard=True),
    ),
    LossSpec(
        name1="primal_gt",
        name2=EstimateWeakToStrong.GT,
        loss_fn=Accuracy(output_logits=False, label_logits=False, hard=True),
    ),
    LossSpec(
        name1="primal_st",
        name2=EstimateWeakToStrong.GT,
        loss_fn=Accuracy(output_logits=False, label_logits=False, hard=True),
    ),
    LossSpec(
        name1="wk",
        name2=EstimateWeakToStrong.GT,
        loss_fn=Accuracy(output_logits=True, label_logits=False, hard=True),
    ),
    # WRT the primal_gt model
    LossSpec( # A sanity check -- should be 0
        name1="primal_gt",
        name2="primal_gt",
        loss_fn=KLDivergence(output_logits=False, label_logits=False),
    ),
    LossSpec(
        name1="primal_gt",
        name2="primal_st",
        loss_fn=KLDivergence(output_logits=False, label_logits=False),
    ),
    LossSpec(
        name1="primal_gt",
        name2="wk",
        loss_fn=KLDivergence(output_logits=False, label_logits=True),
    ),
    LossSpec( # A sanity check -- should be 1
        name1="primal_gt",
        name2="primal_gt",
        loss_fn=Accuracy(output_logits=False, label_logits=False, hard=True),
    ),
    LossSpec(
        name1="primal_gt",
        name2="primal_st",
        loss_fn=Accuracy(output_logits=False, label_logits=False, hard=True),
    ),
    LossSpec(
        name1="primal_gt",
        name2="wk",
        loss_fn=Accuracy(output_logits=False, label_logits=True, hard=True),
    ),
    # WRT the dual_gt model
    LossSpec( # A sanity check -- should be 0
        name1="dual_gt",
        name2="dual_gt",
        loss_fn=KLDivergence(output_logits=True, label_logits=True),
    ),
    LossSpec(
        name1="dual_gt",
        name2="dual_st",
        loss_fn=KLDivergence(output_logits=True, label_logits=True),
    ),
    LossSpec(
        name1="dual_gt",
        name2="wk",
        loss_fn=KLDivergence(output_logits=True, label_logits=True),
    ),
    LossSpec( # A sanity check -- should be 1
        name1="dual_gt",
        name2="dual_gt",
        loss_fn=Accuracy(output_logits=True, label_logits=True, hard=True),
    ),
    LossSpec(
        name1="dual_gt",
        name2="dual_st",
        loss_fn=Accuracy(output_logits=True, label_logits=True, hard=True),
    ),
    LossSpec(
        name1="dual_gt",
        name2="wk",
        loss_fn=Accuracy(output_logits=True, label_logits=True, hard=True),
    ),
    # WRT the primal_st model
    LossSpec( # A sanity check -- should be 0
        name1="primal_st",
        name2="primal_st",
        loss_fn=KLDivergence(output_logits=False, label_logits=False),
    ),
    LossSpec(
        name1="primal_st",
        name2="wk",
        loss_fn=KLDivergence(output_logits=False, label_logits=True),
    ),
    LossSpec( # A sanity check -- should be 1
        name1="primal_st",
        name2="primal_st",
        loss_fn=Accuracy(output_logits=False, label_logits=False, hard=True),
    ),
    LossSpec(
        name1="primal_st",
        name2="wk",
        loss_fn=Accuracy(output_logits=False, label_logits=True, hard=True),
    ),
    # WRT the dual_st model
    LossSpec( # A sanity check -- should be 0
        name1="dual_st",
        name2="dual_st",
        loss_fn=KLDivergence(output_logits=True, label_logits=True),
    ),
    LossSpec(
        name1="dual_st",
        name2="wk",
        loss_fn=KLDivergence(output_logits=True, label_logits=True),
    ),
    LossSpec( # A sanity check -- should be 1
        name1="dual_st",
        name2="dual_st",
        loss_fn=Accuracy(output_logits=True, label_logits=True, hard=True),
    ),
    LossSpec(
        name1="dual_st",
        name2="wk",
        loss_fn=Accuracy(output_logits=True, label_logits=True, hard=True),
    ),    
]

estimator = EstimatedLabeledModelLosses(
    dataset=train_dataset,
    models=models_dict,
    losses=losses_list,
)

results = estimator.estimate(batch_size=2 ** 10)

pprint(results)

Using device: cpu for estimating losses between models


100%|██████████| 49/49 [00:46<00:00,  1.05it/s, dual_gt<-gt_cross_entropy=1.26, dual_st<-gt_cross_entropy=1.4, primal_gt<-gt_cross_entropy=1.58, primal_st<-gt_cross_entropy=1.63, wk<-gt_cross_entropy=1.41, dual_gt<-gt_accuracy=0.554, dual_st<-gt_accuracy=0.517, primal_gt<-gt_accuracy=0.472, primal_st<-gt_accuracy=0.436, wk<-gt_accuracy=0.501, primal_gt<-primal_gt_kl_divergence=9.27e-8, primal_gt<-primal_st_kl_divergence=0.0624, primal_gt<-wk_kl_divergence=0.496, primal_gt<-primal_gt_accuracy=1, primal_gt<-primal_st_accuracy=0.715, primal_gt<-wk_accuracy=0.524, dual_gt<-dual_gt_kl_divergence=2.98e-10, dual_gt<-dual_st_kl_divergence=0.144, dual_gt<-wk_kl_divergence=0.334, dual_gt<-dual_gt_accuracy=1, dual_gt<-dual_st_accuracy=0.75, dual_gt<-wk_accuracy=0.589, primal_st<-primal_st_kl_divergence=2.01e-7, primal_st<-wk_kl_divergence=0.43, primal_st<-primal_st_accuracy=1, primal_st<-wk_accuracy=0.54, dual_st<-dual_st_kl_divergence=-1.99e-10, dual_st<-wk_kl_divergence=0.23, dual_st<-dual_st_a

{'dual_gt<-dual_gt_accuracy__err': 0.0,
 'dual_gt<-dual_gt_accuracy__mean': 1.0,
 'dual_gt<-dual_gt_accuracy__std': 0.0,
 'dual_gt<-dual_gt_kl_divergence__err': 2.1212414047511174e-09,
 'dual_gt<-dual_gt_kl_divergence__mean': 2.983678304424586e-10,
 'dual_gt<-dual_gt_kl_divergence__std': 2.3716199848422548e-07,
 'dual_gt<-dual_st_accuracy__err': 0.003871781751513481,
 'dual_gt<-dual_st_accuracy__mean': 0.7502401471138,
 'dual_gt<-dual_st_accuracy__std': 0.4328783452510834,
 'dual_gt<-dual_st_kl_divergence__err': 0.0009704749099910259,
 'dual_gt<-dual_st_kl_divergence__mean': 0.1443948596715927,
 'dual_gt<-dual_st_kl_divergence__std': 0.10850239545106888,
 'dual_gt<-gt_accuracy__err': 0.004445788450539112,
 'dual_gt<-gt_accuracy__mean': 0.5542400479316711,
 'dual_gt<-gt_accuracy__std': 0.4970542788505554,
 'dual_gt<-gt_cross_entropy__err': 0.009310591965913773,
 'dual_gt<-gt_cross_entropy__mean': 1.2554808855056763,
 'dual_gt<-gt_cross_entropy__std': 1.0409557819366455,
 'dual_gt<-wk_ac




'primal_gt<-primal_st_kl_divergence__mean': 0.10186149179935455,
'primal_gt<-wk_kl_divergence__mean': 0.41966670751571655,
 'primal_st<-wk_kl_divergence__mean': 0.3058554232120514,

In [18]:
0.41966670751571655 - 0.3058554232120514

0.11381128430366516