# Poison Certified Training on UCI Datasets

In [1]:
%load_ext autoreload
%autoreload 2
import torch
import abstract_gradient_training as agt
from models.fully_connected import FullyConnected 
from datasets import uci

In [2]:
# configure the training parameters
batchsize = 20000
config = agt.AGTConfig(
    fragsize=20000,
    learning_rate=0.005,
    epsilon=0.01,
    k_poison=200,
    n_epochs=1,
    device="cuda:1",
    forward_bound="interval",
    backward_bound="interval",
    loss="mse",
    log_level="WARNING",
)
torch.manual_seed(0)

<torch._C.Generator at 0x7ab644bcfcb0>

In [3]:
# initialize the model and dataset
dl_train, dl_test = uci.get_dataloaders(batchsize, batchsize, "houseelectric")
model = FullyConnected(11, 1, 64, 1)  # network with 1 hidden layer of 64 neurons

houseelectric dataset, N=2049280, d=11


In [4]:
# train the model
param_l, param_n, param_u = agt.poison_certified_training(model, config, dl_train, dl_test)

[AGT] [INFO    ] [17:47:00] Starting Poison Certified Training
[AGT] [INFO    ] [17:47:01] Training batch 0: Network eval bounds=(0.18, 0.18, 0.18), W0 Bound=0.0 
[AGT] [INFO    ] [17:47:01] Training batch 1: Network eval bounds=(0.17, 0.17, 0.17), W0 Bound=3.63e-05 
[AGT] [INFO    ] [17:47:02] Training batch 2: Network eval bounds=(0.17, 0.17, 0.17), W0 Bound=7.25e-05 
[AGT] [INFO    ] [17:47:02] Training batch 3: Network eval bounds=(0.16, 0.16, 0.16), W0 Bound=0.000108 
[AGT] [INFO    ] [17:47:03] Training batch 4: Network eval bounds=(0.15, 0.15, 0.15), W0 Bound=0.000143 
[AGT] [INFO    ] [17:47:03] Training batch 5: Network eval bounds=(0.15, 0.15, 0.15), W0 Bound=0.000178 
[AGT] [INFO    ] [17:47:03] Training batch 6: Network eval bounds=(0.14, 0.14, 0.14), W0 Bound=0.000212 
[AGT] [INFO    ] [17:47:04] Training batch 7: Network eval bounds=(0.13, 0.13, 0.13), W0 Bound=0.000246 
[AGT] [INFO    ] [17:47:04] Training batch 8: Network eval bounds=(0.13, 0.13, 0.13), W0 Bound=0.00027

In [5]:
# evaluate the trained model
mse = agt.test_metrics.test_mse(param_n, param_l, param_u, dl_test)
print(f"Test MSE: nominal = {mse[1]:.4g}, certified upper bound = {mse[0]:.4g}, certified lower bound = {mse[2]:.4g}")

Test MSE: nominal = 0.03772, certified upper bound = 0.04172, certified lower bound = 0.03398
