In [1]:

%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import sys

sys.path.append("..")

from src.dvip import DVIP_Base
from src.layers import VIPLayer
from src.likelihood import Gaussian
from src.generative_functions import BayesianConvNN, BayesianNN, BayesLinear
from src.likelihood import BroadcastedLikelihood, MultiClass
from utils.dataset import Test_Dataset, Training_Dataset, MNIST_Dataset
from utils.metrics import MetricsRegression, MetricsClassification
from utils.process_flags import manage_experiment_configuration
from utils.pytorch_learning import fit, fit_with_metrics, score
from scripts.filename import create_file_name


In [2]:

mnist = MNIST_Dataset()
train_dataset, train_test_dataset, test_dataset = mnist.get_split()

Number of samples:  60000
Input dimension:  784
Label dimension:  1
Labels mean value:  0
Labels standard deviation:  1


In [3]:
batch_size = 100
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle = True)
train_test_loader = DataLoader(train_test_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [4]:
gen_f = BayesianConvNN(num_samples = 20, input_dim=(28, 28), output_dim = 1, activation = torch.nn.functional.relu)
layer = VIPLayer(gen_f, 20, 28*28, 1)

In [5]:
gen_f2 = BayesianNN([], num_samples = 20, input_dim=25, output_dim = 10, activation = torch.nn.functional.relu, layer_model=BayesLinear)
layer2 = VIPLayer(gen_f2, 20, 25, 10)

In [5]:
dvip = DVIP_Base(
    BroadcastedLikelihood(MultiClass(10, torch.float64, "cpu")),
    [layer],
    len(train_dataset),
    bb_alpha=0.0,
    num_samples=1,
    dtype=torch.float64,
    device="cpu",
)

In [6]:
dvip.print_variables()


---- MODEL PARAMETERS ----
 VIP_LAYERS
	 0
		 q_mu
			 torch.Size([20, 1])
			 [0. 0. ... 0. 0.]
		 q_sqrt_tri
			 torch.Size([210, 1])
			 [1. 0. ... 0. 1.]
		 GENERATIVE_FUNCTION
			 CONV1
				 W_mu
					 torch.Size([3, 3])
					 [0. 0. ... 0. 0.]
				 W_log_std
					 torch.Size([3, 3])
					 [0. 0. ... 0. 0.]
				 bias_mu
					 torch.Size([1])
					 [0.]
				 bias_log_std
					 torch.Size([1])
					 [0.]
			 CONV2
				 W_mu
					 torch.Size([3, 3])
					 [0. 0. ... 0. 0.]
				 W_log_std
					 torch.Size([3, 3])
					 [0. 0. ... 0. 0.]
				 bias_mu
					 torch.Size([1])
					 [0.]
				 bias_log_std
					 torch.Size([1])
					 [0.]
			 CONV3
				 W_mu
					 torch.Size([5, 5])
					 [0. 0. ... 0. 0.]
				 W_log_std
					 torch.Size([5, 5])
					 [0. 0. ... 0. 0.]
				 bias_mu
					 torch.Size([1])
					 [0.]
				 bias_log_std
					 torch.Size([1])
					 [0.]

---------------------------




In [12]:
opt = torch.optim.Adam(dvip.parameters(), lr=0.01)
metrics = MetricsClassification

# Perform training
train_hist, val_hist = fit_with_metrics(
    dvip,
    train_loader,
    opt,
    metrics,
    #val_generator=test_loader,
    epochs=20,
    device="cpu",
)



Training : 100%|██████████| 20/20 [17:40<00:00, 53.01s/epoch, loss_train=2.73e+5, nll_train=0.694, acc_train=0.0987]


ValueError: too many values to unpack (expected 2)

In [None]:
val_hist

[{'LOSS': 17164.55078125,
  'NLL': 5.4976372718811035,
  'ACC': 0.10589999705553055},
 {'LOSS': 15817.701171875,
  'NLL': 5.999475002288818,
  'ACC': 0.10289999842643738},
 {'LOSS': 15536.6572265625,
  'NLL': 6.026484489440918,
  'ACC': 0.10279999673366547},
 {'LOSS': 15454.3740234375,
  'NLL': 6.040866374969482,
  'ACC': 0.10279999673366547},
 {'LOSS': 15414.98828125,
  'NLL': 5.9806599617004395,
  'ACC': 0.10279999673366547},
 {'LOSS': 15377.451171875,
  'NLL': 5.988156318664551,
  'ACC': 0.10279999673366547},
 {'LOSS': 15356.697265625,
  'NLL': 5.973512172698975,
  'ACC': 0.10279999673366547},
 {'LOSS': 15362.892578125,
  'NLL': 5.985396385192871,
  'ACC': 0.10279999673366547},
 {'LOSS': 15326.65625, 'NLL': 5.955053806304932, 'ACC': 0.10279999673366547},
 {'LOSS': 15322.0498046875,
  'NLL': 5.951890468597412,
  'ACC': 0.10279999673366547},
 {'LOSS': 15309.2421875, 'NLL': 5.927739143371582, 'ACC': 0.10279999673366547},
 {'LOSS': 15304.8984375, 'NLL': 5.935015678405762, 'ACC': 0.10279