In [1]:
from pprint import pprint

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import Adam, SGD
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as T

from src.params import DatasetName, DEFAULT_DEVICE, CPU_GENERATOR
from src.datasets import load_dataset, TRANSFORM_RGB, TRANSFORM_BASE
from src.models import MLP, TruncatedMLP, PrimalThreshold, HomogenousMixtureModel, BagOfDecisionBoundaries
from src.datasets import SyntheticNormalDataSampler, SyntheticDatasetDataSampler
from src.engine import DatasetClassificationTrainer, SamplingClassificationTrainer, DatasetInference
from src.metrics import Accuracy, KLDivergence, CrossEntropy
from src.measurements import EstimatedLabeledModelLosses, LossSpec

print(DEFAULT_DEVICE)

cpu


  return torch._C._cuda_getDeviceCount() > 0


In [2]:
dataset_name = DatasetName.CIFAR10

transform = T.Compose([T.ToTensor(), T.Lambda(lambda x: torch.flatten(x)), T.Lambda(lambda x: (x - x.mean()) / x.std()),])

# Use below to test multi-classification problems
num_classes = 10
target_transform = lambda x: F.one_hot(torch.tensor(x), num_classes=10).to(torch.float32)

# Use below to test binary-classification problems
# num_classes = 2
# target_transform = lambda x: torch.tensor([x <= 4, x > 4]).to(torch.float32)

train_dataset = load_dataset(dataset_name, split="train", transform=transform, target_transform=target_transform)

Files already downloaded and verified


In [3]:
# Get some stats on the mnist dataset 
print(f"Number of samples: {len(train_dataset)}")
print(f"Number of classes in dataset: {len(train_dataset.classes)}")

tmp_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
sample = next(iter(tmp_loader))

data_sample_dim = sample[0].shape[1:]
print(f"Sample dim: {data_sample_dim}")
print(f"Target shape: {sample[1].shape[1:]}")
print(f"Sample Mean: {sample[0].mean()}")
print(f"Sample Std: {sample[0].std()}")

Number of samples: 50000
Number of classes in dataset: 10
Sample dim: torch.Size([3072])
Target shape: torch.Size([10])
Sample Mean: -1.738468853318409e-07
Sample Std: 1.0


In [4]:
input_dim = data_sample_dim[0]
hidden_dim = 50
target_num_layers = 2
model = MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim= num_classes, num_layers=target_num_layers)
representation = model.representation

In [5]:
features_optimizer = Adam(model.parameters(), lr=1e-3)
loss_fn = CrossEntropy(output_logits=True, label_logits=False)
accuracy_fn = Accuracy(output_logits=True, label_logits=False, hard=True)
features_trainer = DatasetClassificationTrainer(model, optimizer=features_optimizer, loss_fn=loss_fn, metrics=[accuracy_fn], dataset=train_dataset)

print(f"Initializing the ground truth model representation on {dataset_name}")
features_trainer.train(num_epochs=5, batch_size=256, average_window=10)

Using model device for training: cpu
Using model device for training: cpu
Initializing the ground truth model representation on cifar10


  0%|          | 0/5 [00:04<?, ?it/s, loss=1.61, up_norm=0.0956, accuracy=0.429]


KeyboardInterrupt: 

# Test the Primal Model by minimizing Cross-Entropy

In [6]:
# Now train primal model
features_model = BagOfDecisionBoundaries(
    input_dim=hidden_dim, 
    output_dim=hidden_dim * 1000, 
    num_states=10
).prepend(representation, input_dim=input_dim, output_dim=hidden_dim)

In [7]:

# features_model = PrimalThreshold(thresholds).prepend(gt_representation, input_dim=input_dim, output_dim=hidden_dim)
primal_model = HomogenousMixtureModel(
    features_model=features_model,
    output_dim=num_classes,
    no_features_grad=True,
    use_dual_weights=True
)


In [8]:
# Test training the primal model on the actual data
primal_optimizer = Adam(primal_model.mixture_layer.parameters(), lr=1e-1)
loss_fn = CrossEntropy(output_logits=False, label_logits=False)
accuracy_fn = Accuracy(output_logits=False, label_logits=False)
primal_trainer = DatasetClassificationTrainer(primal_model, optimizer=primal_optimizer, loss_fn=loss_fn, metrics=[accuracy_fn], dataset=train_dataset)

primal_trainer.train(num_epochs=1, batch_size=256, average_window=10, update_pbar_every=1)

Using model device for training: cpu
Using model device for training: cpu


  0%|          | 0/1 [00:04<?, ?it/s, loss=2.23, up_norm=136, accuracy=0.301]


KeyboardInterrupt: 

# Test the M-Projection of the Model onto the Primal Model

In [9]:
# var = 1.0
model.finetune_scale = 1.
model_data_sampler = SyntheticDatasetDataSampler(model=model, dataset=train_dataset, input_dim=input_dim, output_dim=num_classes)
model_data_sampler.sample(10)

Using model device for synthetic data generation: cpu


(tensor([[ 0.3920,  0.5539,  0.5308,  ...,  1.1438,  1.1091,  1.1785],
         [ 1.3750,  1.3224,  1.2961,  ..., -1.6500, -1.8604, -2.0708],
         [-0.5653, -0.7325, -0.6025,  ...,  0.8095,  0.8653,  1.1997],
         ...,
         [-0.6156, -0.6317, -0.6479,  ...,  1.9398,  1.6325,  1.3899],
         [ 0.4133,  0.4395,  0.4395,  ...,  0.3870,  0.3870,  0.0979],
         [ 1.9600,  1.9749,  1.9155,  ..., -1.7519, -1.7519, -1.7519]]),
 tensor([[2.2103e-03, 1.7470e-01, 6.2018e-03, 6.5452e-02, 3.0843e-03, 3.9486e-02,
          9.0126e-02, 1.0146e-02, 4.3374e-03, 6.0425e-01],
         [5.3553e-03, 1.4979e-03, 1.5126e-01, 2.9942e-01, 2.1223e-01, 5.5480e-02,
          2.6282e-01, 7.6951e-03, 7.0772e-04, 3.5404e-03],
         [5.2640e-02, 3.8634e-01, 1.9198e-02, 5.2352e-02, 3.6389e-02, 2.8441e-02,
          5.3117e-02, 1.1704e-02, 1.2137e-01, 2.3845e-01],
         [2.3949e-03, 1.3427e-04, 9.3040e-02, 5.8745e-03, 6.9912e-01, 2.5159e-02,
          1.6857e-02, 1.5727e-01, 2.3971e-06, 1.5518e

In [10]:
# Testing on the original model
primal_mimic_model = HomogenousMixtureModel(
    features_model=features_model,
    output_dim=num_classes,
    no_features_grad=True,
    use_dual_weights=True
)

# Test training the primal model on the actual data
primal_mimic_optimizer = Adam(primal_mimic_model.mixture_layer.parameters(), lr=1e-1)
loss_fn = KLDivergence(output_logits=False, label_logits=True)
accuracy_fn = Accuracy(output_logits=False, label_logits=True, hard=True)
primal_mimic_trainer = SamplingClassificationTrainer(primal_mimic_model, optimizer=primal_mimic_optimizer, loss_fn=loss_fn, metrics=[accuracy_fn], data_sampler=model_data_sampler)

primal_mimic_trainer.train(num_samples=158720, batch_size=1024, average_window=10, update_pbar_every=1)


Using model device for training: cpu


  2%|▏         | 3/155 [00:08<07:04,  2.79s/it, loss=1.05, up_norm=160, accuracy=0.209] 


KeyboardInterrupt: 

In [11]:

primal_mimic_model_2 = HomogenousMixtureModel(
    features_model=features_model,
    output_dim=num_classes,
    no_features_grad=True,
    use_dual_weights=False
)
primal_mimic_model_2.mixture_layer = primal_mimic_model.mixture_layer.to_primal()

# Test training the primal model on the actual data
primal_mimic_optimizer_2 = SGD(primal_mimic_model_2.mixture_layer.parameters(), lr=1e-2)
loss_fn = KLDivergence(output_logits=False, label_logits=True)
accuracy_fn = Accuracy(output_logits=False, label_logits=True, hard=True)
primal_mimic_trainer_2 = SamplingClassificationTrainer(primal_mimic_model_2, optimizer=primal_mimic_optimizer_2, loss_fn=loss_fn, metrics=[accuracy_fn], data_sampler=model_data_sampler)

primal_mimic_trainer_2.train(num_samples=411648, batch_size=1024, average_window=10, update_pbar_every=1)

Using model device for training: cpu


  0%|          | 0/402 [00:00<?, ?it/s]

  0%|          | 2/402 [00:10<34:34,  5.19s/it, loss=0.649, up_norm=0.00531, accuracy=0.313]


KeyboardInterrupt: 

In [11]:
models_dict = {
    "dual": model,
    "primal": primal_model,
    "primal_proj": primal_mimic_model,
}
losses_list = [
    LossSpec(name1="dual", name2=EstimatedLabeledModelLosses.GT, loss_fn=CrossEntropy(output_logits=True, label_logits=False)),
    LossSpec(name1="primal", name2=EstimatedLabeledModelLosses.GT, loss_fn=CrossEntropy(output_logits=False, label_logits=False)),
    LossSpec(name1="primal_proj", name2=EstimatedLabeledModelLosses.GT, loss_fn=CrossEntropy(output_logits=False, label_logits=False)),
    LossSpec(name1="dual", name2=EstimatedLabeledModelLosses.GT, loss_fn=Accuracy(output_logits=True, label_logits=False, hard=True)),
    LossSpec(name1="primal", name2=EstimatedLabeledModelLosses.GT, loss_fn=Accuracy(output_logits=False, label_logits=False, hard=True)),
    LossSpec(name1="primal_proj", name2=EstimatedLabeledModelLosses.GT, loss_fn=Accuracy(output_logits=False, label_logits=False, hard=True)),
    LossSpec(name1="primal_proj", name2="dual", loss_fn=KLDivergence(output_logits=False, label_logits=True)),
    LossSpec(name1="primal", name2="primal_proj", loss_fn=KLDivergence(output_logits=False, label_logits=False)),
    LossSpec(name1="primal", name2="dual", loss_fn=KLDivergence(output_logits=False, label_logits=True)),
    LossSpec(name1="primal", name2="dual", loss_fn=Accuracy(output_logits=False, label_logits=True, hard=True)),
    LossSpec(name1="primal", name2="primal_proj", loss_fn=Accuracy(output_logits=False, label_logits=False, hard=True)),
    LossSpec(name1="primal_proj", name2="dual", loss_fn=Accuracy(output_logits=False, label_logits=True, hard=True)),
]

estimator = EstimatedLabeledModelLosses(
    dataset=train_dataset,
    models=models_dict,
    losses=losses_list,
)

results = estimator.estimate(batch_size=1024)

pprint(results)

Using device: cpu for estimating losses between models


 29%|██▊       | 14/49 [00:40<01:42,  2.92s/it, dual<-gt_cross_entropy=1.52, primal<-gt_cross_entropy=2.11, primal_proj<-gt_cross_entropy=2.2, dual<-gt_accuracy=0.46, primal<-gt_accuracy=0.39, primal_proj<-gt_accuracy=0.239, primal_proj<-dual_kl_divergence=0.922, primal<-primal_proj_kl_divergence=0.00644, primal<-dual_kl_divergence=0.833, primal<-dual_accuracy=0.602, primal<-primal_proj_accuracy=0.373, primal_proj<-dual_accuracy=0.317]  


KeyboardInterrupt: 

In [25]:
a = model_data_sampler.sample(1)
print(model(a[0]))
print(primal_model(a[0]))
print(primal_mimic_model(a[0]))

tensor([[-1.0523, -0.7977, -0.1137,  0.6376,  0.2728,  1.7905,  0.2125, -0.0516,
         -1.6128, -2.2975]], grad_fn=<MulBackward0>)
tensor([[0.0851, 0.0958, 0.1121, 0.1139, 0.1042, 0.1082, 0.0997, 0.1009, 0.0909,
         0.0892]], grad_fn=<ViewBackward0>)
tensor([[0.0950, 0.0927, 0.1072, 0.1080, 0.1057, 0.1062, 0.1014, 0.1010, 0.0913,
         0.0916]], grad_fn=<ViewBackward0>)


In [28]:
loss_fn = KLDivergence(output_logits=False, label_logits=True)

data = model_data_sampler.sample(10)
print(len(data))
print(data[0])

loss_fn(primal_mimic_model(data[0]), model(data[0]))

3
tensor([[-0.4317, -0.4317, -0.4317,  ..., -0.4317, -0.4317, -0.4317],
        [-0.3781, -0.3781, -0.3781,  ..., -0.3781, -0.3781, -0.3781],
        [-0.5074, -0.5074, -0.5074,  ..., -0.5074, -0.5074, -0.5074],
        ...,
        [-0.3869, -0.3869, -0.3869,  ..., -0.3869, -0.3869, -0.3869],
        [-0.4537, -0.4537, -0.4537,  ..., -0.4537, -0.4537, -0.4537],
        [-0.2878, -0.2878, -0.2878,  ..., -0.2878, -0.2878, -0.2878]])


tensor([5.9647, 4.2375, 5.4469, 7.0639, 9.5008, 7.3328, 5.4672, 8.2371, 7.4951,
        7.1459], grad_fn=<SubBackward0>)

In [29]:
primal_mimic_model(data[0])

tensor([[0.1004, 0.0890, 0.1463, 0.1173, 0.1151, 0.0707, 0.1007, 0.0834, 0.0971,
         0.0801],
        [0.0933, 0.1242, 0.1263, 0.0854, 0.0869, 0.0933, 0.0928, 0.1153, 0.1019,
         0.0806],
        [0.0880, 0.0825, 0.0952, 0.1465, 0.1193, 0.1032, 0.0990, 0.0855, 0.1182,
         0.0627],
        [0.0903, 0.0868, 0.0914, 0.0894, 0.0786, 0.0990, 0.0982, 0.0912, 0.1875,
         0.0877],
        [0.0843, 0.0798, 0.0828, 0.0755, 0.0656, 0.0824, 0.0893, 0.2350, 0.1017,
         0.1036],
        [0.0848, 0.0861, 0.0909, 0.0955, 0.1256, 0.0894, 0.0857, 0.0914, 0.0891,
         0.1614],
        [0.0849, 0.2301, 0.0873, 0.0879, 0.0827, 0.0889, 0.0869, 0.0785, 0.0856,
         0.0870],
        [0.0831, 0.1210, 0.1709, 0.1054, 0.0873, 0.0788, 0.0866, 0.0834, 0.0983,
         0.0852],
        [0.0870, 0.0862, 0.0878, 0.0827, 0.0994, 0.0912, 0.0886, 0.0897, 0.0994,
         0.1881],
        [0.0849, 0.2301, 0.0873, 0.0879, 0.0827, 0.0889, 0.0869, 0.0785, 0.0856,
         0.0870]], grad_fn=<

In [31]:
torch.softmax(model(data[0]),dim=1)

tensor([[9.7405e-04, 1.9393e-06, 9.8800e-01, 8.1596e-03, 2.5603e-04, 9.4962e-06,
         2.5499e-03, 5.5805e-07, 4.7720e-05, 1.1988e-07],
        [1.2151e-06, 8.9325e-01, 2.4465e-02, 2.6820e-03, 3.2318e-07, 1.4375e-02,
         1.4313e-02, 1.3506e-04, 5.0749e-02, 2.6700e-05],
        [9.7844e-03, 6.6299e-06, 2.6526e-02, 9.6238e-01, 3.2791e-06, 8.9500e-04,
         2.0263e-04, 1.3326e-06, 1.9317e-04, 7.1826e-06],
        [5.7568e-07, 3.9869e-05, 8.0569e-05, 5.7558e-05, 1.4264e-06, 6.2862e-05,
         1.5019e-04, 1.2253e-07, 9.9960e-01, 4.0663e-06],
        [2.9142e-06, 3.7642e-08, 3.4367e-07, 1.3106e-06, 7.9239e-09, 1.5432e-07,
         5.3949e-11, 9.9942e-01, 1.1298e-07, 5.7806e-04],
        [8.7001e-07, 3.8244e-09, 1.1557e-07, 1.6726e-04, 1.2015e-03, 8.0729e-05,
         1.9151e-07, 6.0092e-04, 1.5028e-04, 9.9780e-01],
        [8.8856e-07, 9.9739e-01, 5.8819e-04, 7.5598e-05, 4.0313e-06, 2.8456e-05,
         3.5206e-04, 4.2377e-05, 1.5094e-03, 6.9105e-06],
        [7.9036e-08, 5.2223