In [1]:
import torch
import math
from torch import nn
import pandas as pd
import numpy as np
from sklearn import preprocessing
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data.sampler import SubsetRandomSampler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# define class for scale mixture gaussian prior
class ScaleMixtureGaussian:                               
    def __init__(self, mixture_weight, stddev_1, stddev_2):
        super().__init__()
        # mixture_weight is the weight for the first gaussian
        self.mixture_weight = mixture_weight
        # stddev_1 and stddev_2 are the standard deviations for the two gaussians
        self.stddev_1 = stddev_1
        self.stddev_2 = stddev_2
        # create two normal distributions with the specified standard deviations
        self.gaussian1 = torch.distributions.Normal(0,stddev_1)
        self.gaussian2 = torch.distributions.Normal(0,stddev_2)


    def log_prob(self, x):
        prob1 = torch.exp(self.gaussian1.log_prob(x))
        prob2 = torch.exp(self.gaussian2.log_prob(x))
        return (torch.log(self.mixture_weight * prob1 + (1-self.mixture_weight) * prob2)).sum()
    
# define class for gaussian node
class GaussianNode:
    def __init__(self, mean, rho_param):
        super().__init__()
        self.mean = mean
        self.rho_param = rho_param
        self.normal = torch.distributions.Normal(0,1)
    
    # Calculate the standard deviation from the rho parameter
    def sigma(self):
        return torch.log1p(torch.exp(self.rho_param))

    # Sample from the Gaussian node
    def sample(self):
        epsilon = self.normal.sample(self.rho_param.size()).cuda()
        return self.mean + self.sigma() * epsilon
    
    # Calculate the KL divergence between the prior and the variational posterior
    def log_prob(self, x):
        return (-math.log(math.sqrt(2 * math.pi)) - torch.log(self.sigma()) - ((x - self.mean) ** 2) / (2 * self.sigma() ** 2)).sum()

class BayesianLinear(nn.Module):
    def __init__(self, in_features, out_features, mu_init, rho_init, prior_init):
        super().__init__()

        # Initialize the parameters for the weights and biases
        self.weight_mean = nn.Parameter(torch.empty(out_features, in_features).uniform_(*mu_init))
        self.weight_rho_param = nn.Parameter(torch.empty(out_features, in_features).uniform_(*rho_init))
        self.weight = GaussianNode(self.weight_mean, self.weight_rho_param)

        self.bias_mean = nn.Parameter(torch.empty(out_features).uniform_(*mu_init))
        self.bias_rho_param = nn.Parameter(torch.empty(out_features).uniform_(*rho_init))
        self.bias = GaussianNode(self.bias_mean, self.bias_rho_param)
        
        self.weight_prior = ScaleMixtureGaussian(prior_init[0], math.exp(prior_init[1]), math.exp(prior_init[2]))
        self.bias_prior = ScaleMixtureGaussian(prior_init[0], math.exp(prior_init[1]), math.exp(prior_init[2]))

        self.log_prior = 0
        self.log_variational_posterior = 0

    def forward(self, x):
        weight = self.weight.sample()
        bias = self.bias.sample()

        return nn.functional.linear(x, weight, bias)

class BayesianNetwork(nn.Module):
    def __init__(self, model_params):
        super().__init__()
        self.input_shape = model_params['input_shape']
        self.classes = model_params['classes']
        self.batch_size = model_params['batch_size']
        self.hidden_units = model_params['hidden_units']
        self.experiment = model_params['experiment']
        self.mu_init = model_params['mu_init']
        self.rho_init = model_params['rho_init']
        self.prior_init = model_params['prior_init']

        self.fc1 = BayesianLinear(self.input_shape, self.hidden_units, self.mu_init, self.rho_init, self.prior_init)
        self.fc1_activation = nn.ReLU()
        self.fc2 = BayesianLinear(self.hidden_units, self.hidden_units, self.mu_init, self.rho_init, self.prior_init)
        self.fc2_activation = nn.ReLU()
        self.fc3 = BayesianLinear(self.hidden_units, self.classes, self.mu_init, self.rho_init, self.prior_init)
    
    def forward(self, x):
        if self.experiment == 'classification':
            x = x.view(-1, self.input_shape) # Flatten images
        x = self.fc1_activation(self.fc1(x))
        x = self.fc2_activation(self.fc2(x))
        x = self.fc3(x)
        return x

    def log_prior(self):
        return self.fc1.log_prior + self.fc2.log_prior + self.fc3.log_prior
    
    def log_variational_posterior(self):
        return self.fc1.log_variational_posterior + self.fc2.log_variational_posterior + self.fc3.log_variational_posterior


    def get_nll(self, outputs, target, sigma=1.):
        if self.experiment == 'regression': #  -(.5 * (target - outputs) ** 2).sum()
            nll = -torch.distributions.Normal(outputs, sigma).log_prob(target).sum()
        elif self.experiment == 'classification':
            nll = nn.CrossEntropyLoss(reduction='sum')(outputs, target)
        return nll

    def sample_elbo(self, x, target, beta, samples, sigma=1.):
        log_prior = torch.zeros(1).to(device)
        log_variational_posterior = torch.zeros(1).to(device)
        negative_log_likelihood = torch.zeros(1).to(device)

        for i in range(samples):
            output = self.forward(x)
            log_prior += self.log_prior()
            log_variational_posterior += self.log_variational_posterior()
            negative_log_likelihood += self.get_nll(output, target, sigma)

        log_prior = beta*(log_prior / samples)
        log_variational_posterior = beta*(log_variational_posterior / samples) 
        negative_log_likelihood = negative_log_likelihood / samples
        loss = log_variational_posterior - log_prior + negative_log_likelihood
        return loss, log_prior, log_variational_posterior, negative_log_likelihood    


In [3]:
class MLP(nn.Module):
    def __init__(self, model_params):
        super().__init__()
        self.input_shape = model_params['input_shape']
        self.classes = model_params['classes']
        self.batch_size = model_params['batch_size']
        self.hidden_units = model_params['hidden_units']
        self.experiment = model_params['experiment']

        self.net = nn.Sequential(
            nn.Linear(self.input_shape, self.hidden_units),
            nn.ReLU(),
            nn.Linear(self.hidden_units, self.hidden_units),
            nn.ReLU(),
            nn.Linear(self.hidden_units, self.classes))
    
    def forward(self, x):
        if self.experiment == 'classification':
            x = x.view(-1, self.input_shape) # Flatten images
        
        x = self.net(x)
        return x

# class MLP_Dropout(nn.Module):
#     def __init__(self, model_params):
#         super().__init__()
#         self.input_shape = model_params['input_shape']
#         self.classes = model_params['classes']
#         self.batch_size = model_params['batch_size']
#         self.hidden_units = model_params['hidden_units']
#         self.experiment = model_params['experiment']

#         self.net = nn.Sequential(
#             nn.Linear(self.input_shape, self.hidden_units),
#             nn.ReLU(),
#             nn.Dropout(0.5),
#             nn.Linear(self.hidden_units, self.hidden_units),
#             nn.ReLU(),
#             nn.Dropout(0.5),
#             nn.Linear(self.hidden_units, self.classes))
    
#     def forward(self, x):
#         if self.experiment == 'classification':
#             x = x.view(-1, self.input_shape) # Flatten images
       
#         x = self.net(x)
#         return x

#     def enable_dropout(self):
#         ''' Enable the dropout layers during test-time '''
#         for m in self.modules():
#             if m.__class__.__name__.startswith('Dropout'):
#                 m.train()

In [4]:
class RegConfig:
    # save_dir = './saved_models'
    train_size = 1024
    batch_size = 128
    lr = 1e-3
    epochs = 100 #1000
    train_samples = 5                   # number of train samples for MC gradients
    test_samples = 10                   # number of test samples for MC averaging
    num_test_points = 400               # number of test points
    experiment = 'regression'
    hidden_units = 400                  # number of hidden units
    noise_tolerance = .1                # log likelihood sigma
    mu_init = [-0.2, 0.2]               # range for mean 
    rho_init = [-5, -4]                 # range for rho_param
    prior_init = [0.5, -0, -6]        # mixture weight, log(stddev_1), log(stddev_2)
   

class RLConfig:
    data_dir = '/kaggle/input/mushroom/agaricus-lepiota.data' 
    batch_size = 64
    num_batches = 64
    buffer_size = batch_size * num_batches  # buffer to track latest batch of mushrooms
    lr = 1e-4
    training_steps = 5000 # 50000
    experiment = 'regression'
    hidden_units = 100                      # number of hidden units
    mu_init = [-0.2, 0.2]                   # range for mean 
    rho_init = [-5, -4]                     # range for rho_param
    prior_init = [0.5, -0, -6]              # mixture weight, log(stddev_1), log(stddev_2)

class ClassConfig:
    batch_size = 128
    lr = 1e-3 # 1e-5 fa schifo, 1e-4 parte da 8% errore, 1e-3 parte da 5%
    epochs = 1 #600
    hidden_units = 1200
    experiment = 'classification'
    dropout = False
    train_samples = 1 # 10 è troppo lento
    test_samples = 10
    x_shape = 28 * 28                       # x shape
    classes = 10                            # number of output classes
    mu_init = [-0.2, 0.2]                   # range for mean 
    rho_init = [-5, -4]                     # range for rho_param
    prior_init = [0.5, -0, -8]             # mixture weight, log(stddev_1), log(stddev_2)

In [5]:
class PrepareData(Dataset):
    def __init__(self, X, y):
        if not torch.is_tensor(X):
            self.X = torch.from_numpy(X)
        else:
            self.X = X
        if not torch.is_tensor(y):
            self.y = torch.from_numpy(y)
        else:
            self.y = y # vedere

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

def read_data_rl(data_dir):
    df = pd.read_csv(data_dir, sep=',', header=None)
    df.columns = ['class','cap-shape','cap-surface','cap-color','bruises','odor','gill-attachment',
         'gill-spacing','gill-size','gill-color','stalk-shape','stalk-root',
         'stalk-surf-above-ring','stalk-surf-below-ring','stalk-color-above-ring','stalk-color-below-ring',
         'veil-type','veil-color','ring-number','ring-type','spore-color','population','habitat']
    X = pd.DataFrame(df, columns=df.columns[1:len(df.columns)], index=df.index)
    Y = df['class']

    # transform to one-hot encoding
    label_encoder = preprocessing.LabelEncoder()
    label_encoder.fit(Y)
    Y_encoded = label_encoder.transform(Y)
    oh_X = X.copy()
    for feature in X.columns:
        label_encoder.fit(X[feature])
        oh_X[feature] = label_encoder.transform(X[feature])

    oh_encoder = preprocessing.OneHotEncoder()
    oh_encoder.fit(oh_X)
    oh_X = oh_encoder.transform(oh_X).toarray()

    return oh_X, Y_encoded

def create_data_reg(train_size):
    np.random.seed(0)
    xs = np.random.uniform(low=0., high=0.6, size=train_size)
    
    eps = np.random.normal(loc=0., scale=0.02, size=[train_size])

    ys = xs + 0.3 * np.sin(2*np.pi * (xs + eps)) + 0.3 * np.sin(4*np.pi * (xs + eps)) + eps

    xs = torch.from_numpy(xs).reshape(-1,1).float()
    ys = torch.from_numpy(ys).reshape(-1,1).float()

    return xs, ys

In [6]:
# def load_bnn_class_model(saved_model):
#     config = ClassConfig

#     model_params = {
#         'input_shape': config.x_shape,
#         'classes': config.classes,
#         'batch_size': config.batch_size,
#         'hidden_units': config.hidden_units,
#         'experiment': config.experiment,
#         'mu_init': config.mu_init,
#         'rho_init': config.rho_init,
#         'prior_init': config.prior_init
#     }
#     model = BayesianNetwork(model_params)
#     model.load_state_dict(torch.load(saved_model))

#     return model.eval()

# def load_mlp_class_model(saved_model):
#     config = ClassConfig
#     model_params = {
#         'input_shape': config.x_shape,
#         'classes': config.classes,
#         'batch_size': config.batch_size,
#         'hidden_units': config.hidden_units,
#         'experiment': config.experiment,
#     }
#     model = MLP(model_params)
#     model.load_state_dict(torch.load(saved_model))

#     return model.eval()

# def load_dropout_class_model(saved_model):
#     config = ClassConfig
#     model_params = {
#         'input_shape': config.x_shape,
#         'classes': config.classes,
#         'batch_size': config.batch_size,
#         'hidden_units': config.hidden_units,
#         'experiment': config.experiment,
#         'dropout': True
#     }
#     model = MLP_Dropout(model_params)
#     model.load_state_dict(torch.load(saved_model))

#     return model.eval()

In [7]:
def create_regression_plot(X_test, y_test, train_ds):
    fig = plt.figure(figsize=(9, 6))
    plt.plot(X_test, np.median(y_test, axis=0), label='Median Posterior Predictive')
    
    # Range
    plt.fill_between(
        X_test.reshape(-1), 
        np.percentile(y_test, 0, axis=0), 
        np.percentile(y_test, 100, axis=0), 
        alpha = 0.2, color='orange', label='Range') #color='blue',
    
    # interquartile range
    plt.fill_between(
        X_test.reshape(-1), 
        np.percentile(y_test, 25, axis=0), 
        np.percentile(y_test, 75, axis=0), 
        alpha = 0.4,  label='Interquartile Range') #color='red',
    
    plt.scatter(train_ds.dataset.X, train_ds.dataset.y, label='Training data', marker='x', alpha=0.5, color='k', s=2)
    plt.yticks(fontsize=20)
    plt.xticks(fontsize=20)
    plt.ylim([-1.5, 1.5])
    plt.xlim([-0.6, 1.4])

   

In [8]:
class BNN_Classification():
    def __init__(self, label, parameters):
        super().__init__()
        self.label = label
        self.lr = parameters['lr']
        self.hidden_units = parameters['hidden_units']
        self.experiment = parameters['experiment']
        self.batch_size = parameters['batch_size']
        self.num_batches = parameters['num_batches']
        self.n_samples = parameters['train_samples']
        self.test_samples = parameters['test_samples']
        self.x_shape = parameters['x_shape']
        self.classes = parameters['classes']
        self.mu_init = parameters['mu_init']
        self.rho_init = parameters['rho_init']
        self.prior_init = parameters['prior_init']
        self.best_acc = 0.
        self.init_net(parameters)
    
    def init_net(self, parameters):
        model_params = {
            'input_shape': self.x_shape,
            'classes': self.classes,
            'batch_size': self.batch_size,
            'hidden_units': self.hidden_units,
            'experiment': self.experiment,
            'mu_init': self.mu_init,
            'rho_init': self.rho_init,
            'prior_init': self.prior_init,
        }
        self.net = BayesianNetwork(model_params).to(device)
        self.optimiser = torch.optim.Adam(self.net.parameters(), lr=self.lr)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimiser, step_size=100, gamma=0.5)
        # print(f'Classification Task {self.label} Parameters: ')
        # print(f'number of samples: {self.n_samples}')
        # print("BNN Parameters: ")
        # print(f'batch size: {self.batch_size}, x shape: {model_params["input_shape"]}, hidden units: {model_params["hidden_units"]}, output shape: {model_params["classes"]}, lr: {self.lr}')

    def train_step(self, train_data):
        self.net.train()
        for idx, (x, y) in enumerate(tqdm(train_data)):
            beta = 2 ** (self.num_batches - (idx + 1)) / (2 ** self.num_batches - 1) 
            x, y = x.to(device), y.to(device)
            self.net.zero_grad()
            self.loss_info = self.net.sample_elbo(x, y, beta, self.n_samples)            
            net_loss = self.loss_info[0]
            net_loss.backward()
            self.optimiser.step()

    def predict(self, X):
        probs = torch.zeros(size=[self.batch_size, self.classes]).to(device)
        for _ in torch.arange(self.test_samples):
            out = torch.nn.Softmax(dim=1)(self.net(X))
            probs = probs + out / self.test_samples
        preds = torch.argmax(probs, dim=1)
        return preds, probs

    def evaluate(self, test_loader):
        self.net.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for data in tqdm(test_loader):
                X, y = data
                X, y = X.to(device), y.to(device)
                preds, _ = self.predict(X)
                total += self.batch_size
                correct += (preds == y).sum().item()
        self.acc = correct / total
        print(f'validation accuracy: {self.acc}')  
        return self.acc

In [9]:
class MLP_Classification():
    def __init__(self, label, parameters):
        super().__init__()
        self.label = label
        self.lr = parameters['lr']
        self.hidden_units = parameters['hidden_units']
        self.experiment = parameters['experiment']
        self.batch_size = parameters['batch_size']
        self.num_batches = parameters['num_batches']
        self.x_shape = parameters['x_shape']
        self.classes = parameters['classes']
        self.best_acc = 0.
        self.dropout = parameters['dropout']
        self.init_net(parameters)
    
    def init_net(self, parameters):
        model_params = {
            'input_shape': self.x_shape,
            'classes': self.classes,
            'batch_size': self.batch_size,
            'hidden_units': self.hidden_units,
            'experiment': self.experiment,
            'dropout': self.dropout,
        }
        if self.dropout:
            self.net = MLP_Dropout(model_params).to(device)
            print('MLP Dropout Parameters: ')
            print(f'batch size: {self.batch_size}, input shape: {model_params["input_shape"]}, hidden units: {model_params["hidden_units"]}, output shape: {model_params["classes"]}, lr: {self.lr}')
        else:
            self.net = MLP(model_params).to(device)
            print('MLP Parameters: ')
            print(f'batch size: {self.batch_size}, input shape: {model_params["input_shape"]}, hidden units: {model_params["hidden_units"]}, output shape: {model_params["classes"]}, lr: {self.lr}')
        self.optimiser = torch.optim.SGD(self.net.parameters(), lr=self.lr)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimiser, step_size=100, gamma=0.5)

    def train_step(self, train_data):
        self.net.train()
        for _, (x, y) in enumerate(tqdm(train_data)):
            x, y = x.to(device), y.to(device)
            self.net.zero_grad()
            self.loss_info = torch.nn.functional.cross_entropy(self.net(x), y, reduction='sum')
            self.loss_info.backward()
            self.optimiser.step()

    def predict(self, X):
        probs = torch.nn.Softmax(dim=1)(self.net(X))
        preds = torch.argmax(probs, dim=1)
        return preds, probs

    def evaluate(self, test_loader):
        self.net.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for data in tqdm(test_loader):
                X, y = data
                X, y = X.to(device), y.to(device)
                preds, _ = self.predict(X)
                total += self.batch_size
                correct += (preds == y).sum().item()
        self.acc = correct / total
        print(f'{self.label} validation accuracy: {self.acc}') 
        return self.acc

In [10]:
# def class_trainer():
#     config = ClassConfig
    
#     transform = transforms.Compose([
#             transforms.ToTensor(),
#             transforms.Lambda(lambda x: x * 255 / 126.),  # divide as in paper, * 255 gives better results
#         ])

#     train_data = datasets.MNIST(
#             root='data',
#             train=True,
#             download=True,
#             transform=transform)
#     test_data = datasets.MNIST(
#             root='data',
#             train=False,
#             download=True,
#             transform=transform)

#     valid_size = 1 / 6

#     num_train = len(train_data)
#     indices = list(range(num_train))
#     split = int(valid_size * num_train)
#     train_idx, valid_idx = indices[split:], indices[:split]

#     train_sampler = SubsetRandomSampler(train_idx)
#     valid_sampler = SubsetRandomSampler(valid_idx)


#     train_loader = torch.utils.data.DataLoader(
#             train_data,
#             batch_size=config.batch_size,
#             sampler=train_sampler,
#             drop_last=True)
#     valid_loader = torch.utils.data.DataLoader(
#             train_data,
#             batch_size=config.batch_size,
#             sampler=valid_sampler,
#             drop_last=True)
#     test_loader = torch.utils.data.DataLoader(
#             test_data,
#             batch_size=config.batch_size,
#             shuffle=False,
#             drop_last=True)

#     params = {
#         'lr': config.lr,
#         'hidden_units': config.hidden_units,
#         'experiment': config.experiment,
#         'dropout': config.dropout,
#         'batch_size': config.batch_size,
#         'epochs': config.epochs,
#         'x_shape': config.x_shape,
#         'classes': config.classes,
#         'num_batches': len(train_loader),
#         'train_samples': config.train_samples,
#         'test_samples': config.test_samples,
#         'mu_init': config.mu_init,
#         'rho_init': config.rho_init,
#         'prior_init': config.prior_init,
#     }

#     model = BNN_Classification('bnn_classification', {**params})
#     #model = MLP_Classification('mlp_classification', {**params})
    
#     epochs = config.epochs
#     for epoch in range(epochs):
#             print(f'Epoch {epoch+1}/{epochs}')
#             model.train_step(train_loader)
#             valid_acc = model.evaluate(valid_loader)
#             # test_acc = model.evaluate(test_loader)
#             print('Valid Error', round(100 * (1 - valid_acc), 3), '%',)
#             model.scheduler.step()
#             if model.acc > model.best_acc:
#                 model.best_acc = model.acc
                
# #class_trainer()




In [11]:
search_config = {
    'batch_size': [128],
    'lr': [1e-3, 1e-4],
    'epochs': [10],
    'hidden_units': [1200],
    'experiment': ['classification'],
    'dropout': [False],
    'train_samples': [1, 2, 5],
    'test_samples': [10],
    'x_shape': [28 * 28],
    'classes': [10],
    'mu_init': [[-0.2, 0.2]],
    'rho_init': [[-5, -4]],
    'prior_init': [
        [0.25, -0, -6],
        [0.25, -0, -7], 
        [0.25, -1, -6], 
        [0.25, -1, -7], 
        [0.75, -0, -6],
        [0.75, -0, -7], 
        [0.75, -1, -6], 
        [0.75, -1, -7],       
    ]
}

# search_config = {
#     'batch_size': [128],
#     'lr': [1e-3, 1e-4],
#     'epochs': [1], #10
#     'hidden_units': [1200],
#     'experiment': ['classification'],
#     'dropout': [False],
#     'train_samples': [1],
#     'test_samples': [10],
#     'x_shape': [28 * 28],
#     'classes': [10],
#     'mu_init': [[-0.2, 0.2]],
#     'rho_init': [[-5, -4]],
#     'prior_init': [
#         [0.25, -0, -6],        
#     ]
# }


import itertools
from copy import deepcopy

def generate_param_combinations(param_grid):
    keys = list(param_grid.keys())
    values = list(param_grid.values())
    for combo in itertools.product(*values):
        yield dict(zip(keys, combo))


def class_trainer(config):
    
    transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * 255 / 126.),  # divide as in paper, * 255 gives better results
        ])

    train_data = datasets.MNIST(
            root='data',
            train=True,
            download=True,
            transform=transform)
    # test_data = datasets.MNIST(
    #         root='data',
    #         train=False,
    #         download=True,
    #         transform=transform)

    valid_size = 1 / 6

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(valid_size * num_train)
    train_idx, valid_idx = indices[split:], indices[:split]

    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)


    train_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=config['batch_size'],
            sampler=train_sampler,
            drop_last=True)
    valid_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=config['batch_size'],
            sampler=valid_sampler,
            drop_last=True)
    # test_loader = torch.utils.data.DataLoader(
    #         test_data,
    #         batch_size=config.batch_size,
    #         shuffle=False,
    #         drop_last=True)

    params = deepcopy(config)
    params['num_batches'] = len(train_loader)

    model = BNN_Classification('bnn_classification', {**params})
    #model = MLP_Classification('mlp_classification', {**params})
    
    epochs = config['epochs']
    for epoch in range(epochs):
            print(f'Epoch {epoch+1}/{epochs}')
            model.train_step(train_loader)
            valid_acc = model.evaluate(valid_loader)
            print('Valid Error', round(100 * (1 - valid_acc), 3), '%',)
            model.scheduler.step()
            if model.acc > model.best_acc:
                model.best_acc = model.acc
                

    return model.best_acc, model


In [12]:
best_val_acc = 0.0
best_config = None
best_model = None

for config in generate_param_combinations(search_config):
    print(f"Trying config: {config}")
    val_acc, model = class_trainer(config)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_config = deepcopy(config)
        best_model = model

print("Best Config:")
print(best_config)
print(f"Best Validation Accuracy: {best_val_acc:.4f}")

Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 1, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, 0, -6]}


100%|██████████| 9.91M/9.91M [00:00<00:00, 55.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.63MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.3MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 10.6MB/s]


Epoch 1/10


100%|██████████| 390/390 [00:14<00:00, 26.20it/s]
100%|██████████| 78/78 [00:17<00:00,  4.39it/s]


validation accuracy: 0.9503205128205128
Valid Error 4.968 %
Epoch 2/10


100%|██████████| 390/390 [00:15<00:00, 24.58it/s]
100%|██████████| 78/78 [00:17<00:00,  4.40it/s]


validation accuracy: 0.9624399038461539
Valid Error 3.756 %
Epoch 3/10


100%|██████████| 390/390 [00:14<00:00, 26.72it/s]
100%|██████████| 78/78 [00:16<00:00,  4.76it/s]


validation accuracy: 0.9658453525641025
Valid Error 3.415 %
Epoch 4/10


100%|██████████| 390/390 [00:15<00:00, 24.50it/s]
100%|██████████| 78/78 [00:16<00:00,  4.68it/s]


validation accuracy: 0.9629407051282052
Valid Error 3.706 %
Epoch 5/10


100%|██████████| 390/390 [00:14<00:00, 27.15it/s]
100%|██████████| 78/78 [00:17<00:00,  4.54it/s]


validation accuracy: 0.9673477564102564
Valid Error 3.265 %
Epoch 6/10


100%|██████████| 390/390 [00:14<00:00, 27.46it/s]
100%|██████████| 78/78 [00:17<00:00,  4.39it/s]


validation accuracy: 0.9678485576923077
Valid Error 3.215 %
Epoch 7/10


100%|██████████| 390/390 [00:14<00:00, 26.16it/s]
100%|██████████| 78/78 [00:16<00:00,  4.66it/s]


validation accuracy: 0.9690504807692307
Valid Error 3.095 %
Epoch 8/10


100%|██████████| 390/390 [00:14<00:00, 27.14it/s]
100%|██████████| 78/78 [00:17<00:00,  4.43it/s]


validation accuracy: 0.9710536858974359
Valid Error 2.895 %
Epoch 9/10


100%|██████████| 390/390 [00:14<00:00, 27.01it/s]
100%|██████████| 78/78 [00:17<00:00,  4.58it/s]


validation accuracy: 0.9670472756410257
Valid Error 3.295 %
Epoch 10/10


100%|██████████| 390/390 [00:14<00:00, 26.83it/s]
100%|██████████| 78/78 [00:16<00:00,  4.74it/s]


validation accuracy: 0.97265625
Valid Error 2.734 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 1, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, 0, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:13<00:00, 27.96it/s]
100%|██████████| 78/78 [00:16<00:00,  4.79it/s]


validation accuracy: 0.9498197115384616
Valid Error 5.018 %
Epoch 2/10


100%|██████████| 390/390 [00:14<00:00, 26.05it/s]
100%|██████████| 78/78 [00:17<00:00,  4.52it/s]


validation accuracy: 0.9623397435897436
Valid Error 3.766 %
Epoch 3/10


100%|██████████| 390/390 [00:15<00:00, 26.00it/s]
100%|██████████| 78/78 [00:16<00:00,  4.72it/s]


validation accuracy: 0.9688501602564102
Valid Error 3.115 %
Epoch 4/10


100%|██████████| 390/390 [00:14<00:00, 26.19it/s]
100%|██████████| 78/78 [00:16<00:00,  4.75it/s]


validation accuracy: 0.9708533653846154
Valid Error 2.915 %
Epoch 5/10


100%|██████████| 390/390 [00:14<00:00, 27.23it/s]
100%|██████████| 78/78 [00:16<00:00,  4.77it/s]


validation accuracy: 0.9723557692307693
Valid Error 2.764 %
Epoch 6/10


100%|██████████| 390/390 [00:14<00:00, 26.40it/s]
100%|██████████| 78/78 [00:16<00:00,  4.72it/s]


validation accuracy: 0.9666466346153846
Valid Error 3.335 %
Epoch 7/10


100%|██████████| 390/390 [00:14<00:00, 27.81it/s]
100%|██████████| 78/78 [00:16<00:00,  4.76it/s]


validation accuracy: 0.9684495192307693
Valid Error 3.155 %
Epoch 8/10


100%|██████████| 390/390 [00:13<00:00, 27.92it/s]
100%|██████████| 78/78 [00:16<00:00,  4.78it/s]


validation accuracy: 0.9729567307692307
Valid Error 2.704 %
Epoch 9/10


100%|██████████| 390/390 [00:14<00:00, 27.46it/s]
100%|██████████| 78/78 [00:16<00:00,  4.87it/s]


validation accuracy: 0.9735576923076923
Valid Error 2.644 %
Epoch 10/10


100%|██████████| 390/390 [00:14<00:00, 27.86it/s]
100%|██████████| 78/78 [00:16<00:00,  4.76it/s]


validation accuracy: 0.9735576923076923
Valid Error 2.644 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 1, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, -1, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:14<00:00, 26.79it/s]
100%|██████████| 78/78 [00:16<00:00,  4.60it/s]


validation accuracy: 0.9416065705128205
Valid Error 5.839 %
Epoch 2/10


100%|██████████| 390/390 [00:14<00:00, 26.94it/s]
100%|██████████| 78/78 [00:16<00:00,  4.61it/s]


validation accuracy: 0.964042467948718
Valid Error 3.596 %
Epoch 3/10


100%|██████████| 390/390 [00:14<00:00, 26.45it/s]
100%|██████████| 78/78 [00:16<00:00,  4.63it/s]


validation accuracy: 0.9654447115384616
Valid Error 3.456 %
Epoch 4/10


100%|██████████| 390/390 [00:14<00:00, 26.70it/s]
100%|██████████| 78/78 [00:17<00:00,  4.59it/s]


validation accuracy: 0.9707532051282052
Valid Error 2.925 %
Epoch 5/10


100%|██████████| 390/390 [00:14<00:00, 26.35it/s]
100%|██████████| 78/78 [00:16<00:00,  4.63it/s]


validation accuracy: 0.9685496794871795
Valid Error 3.145 %
Epoch 6/10


100%|██████████| 390/390 [00:14<00:00, 26.64it/s]
100%|██████████| 78/78 [00:16<00:00,  4.64it/s]


validation accuracy: 0.9696514423076923
Valid Error 3.035 %
Epoch 7/10


100%|██████████| 390/390 [00:14<00:00, 26.88it/s]
100%|██████████| 78/78 [00:16<00:00,  4.65it/s]


validation accuracy: 0.9688501602564102
Valid Error 3.115 %
Epoch 8/10


100%|██████████| 390/390 [00:14<00:00, 27.51it/s]
100%|██████████| 78/78 [00:17<00:00,  4.50it/s]


validation accuracy: 0.965645032051282
Valid Error 3.435 %
Epoch 9/10


100%|██████████| 390/390 [00:14<00:00, 26.87it/s]
100%|██████████| 78/78 [00:16<00:00,  4.62it/s]


validation accuracy: 0.9719551282051282
Valid Error 2.804 %
Epoch 10/10


100%|██████████| 390/390 [00:14<00:00, 26.81it/s]
100%|██████████| 78/78 [00:17<00:00,  4.53it/s]


validation accuracy: 0.9680488782051282
Valid Error 3.195 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 1, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, -1, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:14<00:00, 27.45it/s]
100%|██████████| 78/78 [00:17<00:00,  4.47it/s]


validation accuracy: 0.9537259615384616
Valid Error 4.627 %
Epoch 2/10


100%|██████████| 390/390 [00:14<00:00, 27.00it/s]
100%|██████████| 78/78 [00:17<00:00,  4.58it/s]


validation accuracy: 0.9631410256410257
Valid Error 3.686 %
Epoch 3/10


100%|██████████| 390/390 [00:14<00:00, 27.37it/s]
100%|██████████| 78/78 [00:17<00:00,  4.58it/s]


validation accuracy: 0.9630408653846154
Valid Error 3.696 %
Epoch 4/10


100%|██████████| 390/390 [00:14<00:00, 27.21it/s]
100%|██████████| 78/78 [00:17<00:00,  4.50it/s]


validation accuracy: 0.969551282051282
Valid Error 3.045 %
Epoch 5/10


100%|██████████| 390/390 [00:14<00:00, 27.39it/s]
100%|██████████| 78/78 [00:17<00:00,  4.51it/s]


validation accuracy: 0.9683493589743589
Valid Error 3.165 %
Epoch 6/10


100%|██████████| 390/390 [00:14<00:00, 27.30it/s]
100%|██████████| 78/78 [00:17<00:00,  4.53it/s]


validation accuracy: 0.9688501602564102
Valid Error 3.115 %
Epoch 7/10


100%|██████████| 390/390 [00:14<00:00, 27.45it/s]
100%|██████████| 78/78 [00:16<00:00,  4.68it/s]


validation accuracy: 0.9663461538461539
Valid Error 3.365 %
Epoch 8/10


100%|██████████| 390/390 [00:14<00:00, 27.10it/s]
100%|██████████| 78/78 [00:17<00:00,  4.42it/s]


validation accuracy: 0.9671474358974359
Valid Error 3.285 %
Epoch 9/10


100%|██████████| 390/390 [00:14<00:00, 26.63it/s]
100%|██████████| 78/78 [00:17<00:00,  4.36it/s]


validation accuracy: 0.9717548076923077
Valid Error 2.825 %
Epoch 10/10


100%|██████████| 390/390 [00:14<00:00, 26.44it/s]
100%|██████████| 78/78 [00:17<00:00,  4.50it/s]


validation accuracy: 0.969551282051282
Valid Error 3.045 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 1, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, 0, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:14<00:00, 26.74it/s]
100%|██████████| 78/78 [00:17<00:00,  4.42it/s]


validation accuracy: 0.9583333333333334
Valid Error 4.167 %
Epoch 2/10


100%|██████████| 390/390 [00:14<00:00, 26.97it/s]
100%|██████████| 78/78 [00:17<00:00,  4.39it/s]


validation accuracy: 0.9586338141025641
Valid Error 4.137 %
Epoch 3/10


100%|██████████| 390/390 [00:14<00:00, 26.42it/s]
100%|██████████| 78/78 [00:17<00:00,  4.39it/s]


validation accuracy: 0.9685496794871795
Valid Error 3.145 %
Epoch 4/10


100%|██████████| 390/390 [00:14<00:00, 26.28it/s]
100%|██████████| 78/78 [00:17<00:00,  4.53it/s]


validation accuracy: 0.9714543269230769
Valid Error 2.855 %
Epoch 5/10


100%|██████████| 390/390 [00:14<00:00, 26.46it/s]
100%|██████████| 78/78 [00:17<00:00,  4.51it/s]


validation accuracy: 0.9705528846153846
Valid Error 2.945 %
Epoch 6/10


100%|██████████| 390/390 [00:14<00:00, 26.23it/s]
100%|██████████| 78/78 [00:17<00:00,  4.45it/s]


validation accuracy: 0.9683493589743589
Valid Error 3.165 %
Epoch 7/10


100%|██████████| 390/390 [00:14<00:00, 26.45it/s]
100%|██████████| 78/78 [00:17<00:00,  4.54it/s]


validation accuracy: 0.9719551282051282
Valid Error 2.804 %
Epoch 8/10


100%|██████████| 390/390 [00:15<00:00, 25.89it/s]
100%|██████████| 78/78 [00:17<00:00,  4.41it/s]


validation accuracy: 0.9691506410256411
Valid Error 3.085 %
Epoch 9/10


100%|██████████| 390/390 [00:14<00:00, 27.22it/s]
100%|██████████| 78/78 [00:17<00:00,  4.49it/s]


validation accuracy: 0.9701522435897436
Valid Error 2.985 %
Epoch 10/10


100%|██████████| 390/390 [00:14<00:00, 26.43it/s]
100%|██████████| 78/78 [00:17<00:00,  4.48it/s]


validation accuracy: 0.9708533653846154
Valid Error 2.915 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 1, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, 0, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:15<00:00, 25.56it/s]
100%|██████████| 78/78 [00:17<00:00,  4.49it/s]


validation accuracy: 0.9400040064102564
Valid Error 6.0 %
Epoch 2/10


100%|██████████| 390/390 [00:15<00:00, 25.85it/s]
100%|██████████| 78/78 [00:17<00:00,  4.42it/s]


validation accuracy: 0.9606370192307693
Valid Error 3.936 %
Epoch 3/10


100%|██████████| 390/390 [00:15<00:00, 25.42it/s]
100%|██████████| 78/78 [00:18<00:00,  4.32it/s]


validation accuracy: 0.9574318910256411
Valid Error 4.257 %
Epoch 4/10


100%|██████████| 390/390 [00:15<00:00, 25.49it/s]
100%|██████████| 78/78 [00:17<00:00,  4.40it/s]


validation accuracy: 0.9700520833333334
Valid Error 2.995 %
Epoch 5/10


100%|██████████| 390/390 [00:15<00:00, 25.52it/s]
100%|██████████| 78/78 [00:17<00:00,  4.49it/s]


validation accuracy: 0.9665464743589743
Valid Error 3.345 %
Epoch 6/10


100%|██████████| 390/390 [00:15<00:00, 25.89it/s]
100%|██████████| 78/78 [00:17<00:00,  4.37it/s]


validation accuracy: 0.9698517628205128
Valid Error 3.015 %
Epoch 7/10


100%|██████████| 390/390 [00:15<00:00, 24.85it/s]
100%|██████████| 78/78 [00:17<00:00,  4.34it/s]


validation accuracy: 0.9699519230769231
Valid Error 3.005 %
Epoch 8/10


100%|██████████| 390/390 [00:15<00:00, 24.96it/s]
100%|██████████| 78/78 [00:17<00:00,  4.43it/s]


validation accuracy: 0.9766626602564102
Valid Error 2.334 %
Epoch 9/10


100%|██████████| 390/390 [00:15<00:00, 24.62it/s]
100%|██████████| 78/78 [00:17<00:00,  4.40it/s]


validation accuracy: 0.9713541666666666
Valid Error 2.865 %
Epoch 10/10


100%|██████████| 390/390 [00:15<00:00, 25.26it/s]
100%|██████████| 78/78 [00:17<00:00,  4.39it/s]


validation accuracy: 0.9727564102564102
Valid Error 2.724 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 1, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, -1, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:15<00:00, 25.12it/s]
100%|██████████| 78/78 [00:17<00:00,  4.36it/s]


validation accuracy: 0.9425080128205128
Valid Error 5.749 %
Epoch 2/10


100%|██████████| 390/390 [00:15<00:00, 25.39it/s]
100%|██████████| 78/78 [00:18<00:00,  4.29it/s]


validation accuracy: 0.960136217948718
Valid Error 3.986 %
Epoch 3/10


100%|██████████| 390/390 [00:15<00:00, 25.54it/s]
100%|██████████| 78/78 [00:18<00:00,  4.31it/s]


validation accuracy: 0.961738782051282
Valid Error 3.826 %
Epoch 4/10


100%|██████████| 390/390 [00:15<00:00, 25.54it/s]
100%|██████████| 78/78 [00:17<00:00,  4.34it/s]


validation accuracy: 0.9716546474358975
Valid Error 2.835 %
Epoch 5/10


100%|██████████| 390/390 [00:15<00:00, 25.75it/s]
100%|██████████| 78/78 [00:17<00:00,  4.39it/s]


validation accuracy: 0.9684495192307693
Valid Error 3.155 %
Epoch 6/10


100%|██████████| 390/390 [00:15<00:00, 25.47it/s]
100%|██████████| 78/78 [00:17<00:00,  4.34it/s]


validation accuracy: 0.9665464743589743
Valid Error 3.345 %
Epoch 7/10


100%|██████████| 390/390 [00:15<00:00, 25.36it/s]
100%|██████████| 78/78 [00:17<00:00,  4.35it/s]


validation accuracy: 0.9698517628205128
Valid Error 3.015 %
Epoch 8/10


100%|██████████| 390/390 [00:14<00:00, 26.35it/s]
100%|██████████| 78/78 [00:17<00:00,  4.41it/s]


validation accuracy: 0.9688501602564102
Valid Error 3.115 %
Epoch 9/10


100%|██████████| 390/390 [00:14<00:00, 26.79it/s]
100%|██████████| 78/78 [00:18<00:00,  4.27it/s]


validation accuracy: 0.9671474358974359
Valid Error 3.285 %
Epoch 10/10


100%|██████████| 390/390 [00:15<00:00, 25.35it/s]
100%|██████████| 78/78 [00:18<00:00,  4.33it/s]


validation accuracy: 0.97265625
Valid Error 2.734 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 1, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, -1, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:15<00:00, 25.31it/s]
100%|██████████| 78/78 [00:17<00:00,  4.35it/s]


validation accuracy: 0.9515224358974359
Valid Error 4.848 %
Epoch 2/10


100%|██████████| 390/390 [00:15<00:00, 24.84it/s]
100%|██████████| 78/78 [00:18<00:00,  4.28it/s]


validation accuracy: 0.9596354166666666
Valid Error 4.036 %
Epoch 3/10


100%|██████████| 390/390 [00:15<00:00, 25.24it/s]
100%|██████████| 78/78 [00:17<00:00,  4.35it/s]


validation accuracy: 0.9654447115384616
Valid Error 3.456 %
Epoch 4/10


100%|██████████| 390/390 [00:15<00:00, 24.58it/s]
100%|██████████| 78/78 [00:17<00:00,  4.42it/s]


validation accuracy: 0.9650440705128205
Valid Error 3.496 %
Epoch 5/10


100%|██████████| 390/390 [00:15<00:00, 25.18it/s]
100%|██████████| 78/78 [00:17<00:00,  4.42it/s]


validation accuracy: 0.9711538461538461
Valid Error 2.885 %
Epoch 6/10


100%|██████████| 390/390 [00:15<00:00, 25.16it/s]
100%|██████████| 78/78 [00:17<00:00,  4.34it/s]


validation accuracy: 0.9665464743589743
Valid Error 3.345 %
Epoch 7/10


100%|██████████| 390/390 [00:15<00:00, 24.99it/s]
100%|██████████| 78/78 [00:17<00:00,  4.37it/s]


validation accuracy: 0.9736578525641025
Valid Error 2.634 %
Epoch 8/10


100%|██████████| 390/390 [00:15<00:00, 25.30it/s]
100%|██████████| 78/78 [00:17<00:00,  4.34it/s]


validation accuracy: 0.9747596153846154
Valid Error 2.524 %
Epoch 9/10


100%|██████████| 390/390 [00:15<00:00, 24.99it/s]
100%|██████████| 78/78 [00:17<00:00,  4.41it/s]


validation accuracy: 0.9664463141025641
Valid Error 3.355 %
Epoch 10/10


100%|██████████| 390/390 [00:15<00:00, 25.12it/s]
100%|██████████| 78/78 [00:18<00:00,  4.27it/s]


validation accuracy: 0.9691506410256411
Valid Error 3.085 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 2, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, 0, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:23<00:00, 16.72it/s]
100%|██████████| 78/78 [00:17<00:00,  4.41it/s]


validation accuracy: 0.9497195512820513
Valid Error 5.028 %
Epoch 2/10


100%|██████████| 390/390 [00:23<00:00, 16.43it/s]
100%|██████████| 78/78 [00:17<00:00,  4.39it/s]


validation accuracy: 0.9599358974358975
Valid Error 4.006 %
Epoch 3/10


100%|██████████| 390/390 [00:24<00:00, 16.15it/s]
100%|██████████| 78/78 [00:17<00:00,  4.37it/s]


validation accuracy: 0.9652443910256411
Valid Error 3.476 %
Epoch 4/10


100%|██████████| 390/390 [00:24<00:00, 15.76it/s]
100%|██████████| 78/78 [00:18<00:00,  4.23it/s]


validation accuracy: 0.9729567307692307
Valid Error 2.704 %
Epoch 5/10


100%|██████████| 390/390 [00:24<00:00, 16.08it/s]
100%|██████████| 78/78 [00:18<00:00,  4.29it/s]


validation accuracy: 0.9671474358974359
Valid Error 3.285 %
Epoch 6/10


100%|██████████| 390/390 [00:24<00:00, 16.13it/s]
100%|██████████| 78/78 [00:17<00:00,  4.34it/s]


validation accuracy: 0.9703525641025641
Valid Error 2.965 %
Epoch 7/10


100%|██████████| 390/390 [00:24<00:00, 15.82it/s]
100%|██████████| 78/78 [00:18<00:00,  4.33it/s]


validation accuracy: 0.9647435897435898
Valid Error 3.526 %
Epoch 8/10


100%|██████████| 390/390 [00:24<00:00, 16.13it/s]
100%|██████████| 78/78 [00:18<00:00,  4.33it/s]


validation accuracy: 0.9710536858974359
Valid Error 2.895 %
Epoch 9/10


100%|██████████| 390/390 [00:24<00:00, 15.94it/s]
100%|██████████| 78/78 [00:17<00:00,  4.46it/s]


validation accuracy: 0.9696514423076923
Valid Error 3.035 %
Epoch 10/10


100%|██████████| 390/390 [00:24<00:00, 15.97it/s]
100%|██████████| 78/78 [00:17<00:00,  4.43it/s]


validation accuracy: 0.9688501602564102
Valid Error 3.115 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 2, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, 0, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:23<00:00, 16.74it/s]
100%|██████████| 78/78 [00:17<00:00,  4.47it/s]


validation accuracy: 0.9540264423076923
Valid Error 4.597 %
Epoch 2/10


100%|██████████| 390/390 [00:24<00:00, 16.06it/s]
100%|██████████| 78/78 [00:17<00:00,  4.40it/s]


validation accuracy: 0.9628405448717948
Valid Error 3.716 %
Epoch 3/10


100%|██████████| 390/390 [00:24<00:00, 15.88it/s]
100%|██████████| 78/78 [00:17<00:00,  4.36it/s]


validation accuracy: 0.9697516025641025
Valid Error 3.025 %
Epoch 4/10


100%|██████████| 390/390 [00:24<00:00, 16.17it/s]
100%|██████████| 78/78 [00:17<00:00,  4.38it/s]


validation accuracy: 0.9652443910256411
Valid Error 3.476 %
Epoch 5/10


100%|██████████| 390/390 [00:24<00:00, 16.22it/s]
100%|██████████| 78/78 [00:17<00:00,  4.44it/s]


validation accuracy: 0.9645432692307693
Valid Error 3.546 %
Epoch 6/10


100%|██████████| 390/390 [00:24<00:00, 16.03it/s]
100%|██████████| 78/78 [00:18<00:00,  4.26it/s]


validation accuracy: 0.9697516025641025
Valid Error 3.025 %
Epoch 7/10


100%|██████████| 390/390 [00:23<00:00, 16.48it/s]
100%|██████████| 78/78 [00:18<00:00,  4.30it/s]


validation accuracy: 0.9720552884615384
Valid Error 2.794 %
Epoch 8/10


100%|██████████| 390/390 [00:24<00:00, 15.77it/s]
100%|██████████| 78/78 [00:17<00:00,  4.42it/s]


validation accuracy: 0.9677483974358975
Valid Error 3.225 %
Epoch 9/10


100%|██████████| 390/390 [00:24<00:00, 16.24it/s]
100%|██████████| 78/78 [00:18<00:00,  4.28it/s]


validation accuracy: 0.9698517628205128
Valid Error 3.015 %
Epoch 10/10


100%|██████████| 390/390 [00:23<00:00, 16.43it/s]
100%|██████████| 78/78 [00:17<00:00,  4.34it/s]


validation accuracy: 0.9728565705128205
Valid Error 2.714 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 2, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, -1, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:23<00:00, 16.54it/s]
100%|██████████| 78/78 [00:17<00:00,  4.59it/s]


validation accuracy: 0.9572315705128205
Valid Error 4.277 %
Epoch 2/10


100%|██████████| 390/390 [00:24<00:00, 16.16it/s]
100%|██████████| 78/78 [00:18<00:00,  4.26it/s]


validation accuracy: 0.9630408653846154
Valid Error 3.696 %
Epoch 3/10


100%|██████████| 390/390 [00:23<00:00, 16.44it/s]
100%|██████████| 78/78 [00:18<00:00,  4.26it/s]


validation accuracy: 0.9614383012820513
Valid Error 3.856 %
Epoch 4/10


100%|██████████| 390/390 [00:24<00:00, 15.92it/s]
100%|██████████| 78/78 [00:17<00:00,  4.41it/s]


validation accuracy: 0.9699519230769231
Valid Error 3.005 %
Epoch 5/10


100%|██████████| 390/390 [00:24<00:00, 15.79it/s]
100%|██████████| 78/78 [00:18<00:00,  4.26it/s]


validation accuracy: 0.9704527243589743
Valid Error 2.955 %
Epoch 6/10


100%|██████████| 390/390 [00:23<00:00, 16.46it/s]
100%|██████████| 78/78 [00:18<00:00,  4.29it/s]


validation accuracy: 0.9668469551282052
Valid Error 3.315 %
Epoch 7/10


100%|██████████| 390/390 [00:24<00:00, 15.67it/s]
100%|██████████| 78/78 [00:17<00:00,  4.44it/s]


validation accuracy: 0.9719551282051282
Valid Error 2.804 %
Epoch 8/10


100%|██████████| 390/390 [00:24<00:00, 16.01it/s]
100%|██████████| 78/78 [00:18<00:00,  4.25it/s]


validation accuracy: 0.9661458333333334
Valid Error 3.385 %
Epoch 9/10


100%|██████████| 390/390 [00:24<00:00, 16.20it/s]
100%|██████████| 78/78 [00:18<00:00,  4.25it/s]


validation accuracy: 0.9712540064102564
Valid Error 2.875 %
Epoch 10/10


100%|██████████| 390/390 [00:24<00:00, 16.02it/s]
100%|██████████| 78/78 [00:17<00:00,  4.41it/s]


validation accuracy: 0.9710536858974359
Valid Error 2.895 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 2, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, -1, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:23<00:00, 16.54it/s]
100%|██████████| 78/78 [00:17<00:00,  4.52it/s]


validation accuracy: 0.953125
Valid Error 4.688 %
Epoch 2/10


100%|██████████| 390/390 [00:23<00:00, 16.44it/s]
100%|██████████| 78/78 [00:18<00:00,  4.30it/s]


validation accuracy: 0.9646434294871795
Valid Error 3.536 %
Epoch 3/10


100%|██████████| 390/390 [00:24<00:00, 15.81it/s]
100%|██████████| 78/78 [00:17<00:00,  4.39it/s]


validation accuracy: 0.9686498397435898
Valid Error 3.135 %
Epoch 4/10


100%|██████████| 390/390 [00:24<00:00, 16.23it/s]
100%|██████████| 78/78 [00:17<00:00,  4.41it/s]


validation accuracy: 0.9665464743589743
Valid Error 3.345 %
Epoch 5/10


100%|██████████| 390/390 [00:24<00:00, 16.13it/s]
100%|██████████| 78/78 [00:18<00:00,  4.22it/s]


validation accuracy: 0.9714543269230769
Valid Error 2.855 %
Epoch 6/10


100%|██████████| 390/390 [00:24<00:00, 16.00it/s]
100%|██████████| 78/78 [00:17<00:00,  4.49it/s]


validation accuracy: 0.9657451923076923
Valid Error 3.425 %
Epoch 7/10


100%|██████████| 390/390 [00:23<00:00, 16.28it/s]
100%|██████████| 78/78 [00:17<00:00,  4.42it/s]


validation accuracy: 0.9691506410256411
Valid Error 3.085 %
Epoch 8/10


100%|██████████| 390/390 [00:24<00:00, 16.21it/s]
100%|██████████| 78/78 [00:18<00:00,  4.27it/s]


validation accuracy: 0.97265625
Valid Error 2.734 %
Epoch 9/10


100%|██████████| 390/390 [00:24<00:00, 16.20it/s]
100%|██████████| 78/78 [00:17<00:00,  4.52it/s]


validation accuracy: 0.9724559294871795
Valid Error 2.754 %
Epoch 10/10


100%|██████████| 390/390 [00:24<00:00, 15.91it/s]
100%|██████████| 78/78 [00:17<00:00,  4.38it/s]


validation accuracy: 0.9647435897435898
Valid Error 3.526 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 2, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, 0, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:23<00:00, 16.51it/s]
100%|██████████| 78/78 [00:17<00:00,  4.42it/s]


validation accuracy: 0.9503205128205128
Valid Error 4.968 %
Epoch 2/10


100%|██████████| 390/390 [00:23<00:00, 16.28it/s]
100%|██████████| 78/78 [00:17<00:00,  4.35it/s]


validation accuracy: 0.9671474358974359
Valid Error 3.285 %
Epoch 3/10


100%|██████████| 390/390 [00:24<00:00, 16.04it/s]
100%|██████████| 78/78 [00:17<00:00,  4.37it/s]


validation accuracy: 0.9664463141025641
Valid Error 3.355 %
Epoch 4/10


100%|██████████| 390/390 [00:24<00:00, 15.82it/s]
100%|██████████| 78/78 [00:18<00:00,  4.24it/s]


validation accuracy: 0.9707532051282052
Valid Error 2.925 %
Epoch 5/10


100%|██████████| 390/390 [00:23<00:00, 16.39it/s]
100%|██████████| 78/78 [00:18<00:00,  4.30it/s]


validation accuracy: 0.9682491987179487
Valid Error 3.175 %
Epoch 6/10


100%|██████████| 390/390 [00:24<00:00, 15.80it/s]
100%|██████████| 78/78 [00:18<00:00,  4.29it/s]


validation accuracy: 0.9692508012820513
Valid Error 3.075 %
Epoch 7/10


100%|██████████| 390/390 [00:24<00:00, 15.96it/s]
100%|██████████| 78/78 [00:18<00:00,  4.16it/s]


validation accuracy: 0.9704527243589743
Valid Error 2.955 %
Epoch 8/10


100%|██████████| 390/390 [00:24<00:00, 16.04it/s]
100%|██████████| 78/78 [00:17<00:00,  4.36it/s]


validation accuracy: 0.9712540064102564
Valid Error 2.875 %
Epoch 9/10


100%|██████████| 390/390 [00:24<00:00, 15.68it/s]
100%|██████████| 78/78 [00:17<00:00,  4.40it/s]


validation accuracy: 0.9742588141025641
Valid Error 2.574 %
Epoch 10/10


100%|██████████| 390/390 [00:24<00:00, 16.16it/s]
100%|██████████| 78/78 [00:18<00:00,  4.19it/s]


validation accuracy: 0.9683493589743589
Valid Error 3.165 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 2, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, 0, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:24<00:00, 16.21it/s]
100%|██████████| 78/78 [00:16<00:00,  4.63it/s]


validation accuracy: 0.9466145833333334
Valid Error 5.339 %
Epoch 2/10


100%|██████████| 390/390 [00:24<00:00, 16.10it/s]
100%|██████████| 78/78 [00:17<00:00,  4.54it/s]


validation accuracy: 0.9631410256410257
Valid Error 3.686 %
Epoch 3/10


100%|██████████| 390/390 [00:24<00:00, 16.17it/s]
100%|██████████| 78/78 [00:17<00:00,  4.46it/s]


validation accuracy: 0.9637419871794872
Valid Error 3.626 %
Epoch 4/10


100%|██████████| 390/390 [00:23<00:00, 16.44it/s]
100%|██████████| 78/78 [00:17<00:00,  4.46it/s]


validation accuracy: 0.9634415064102564
Valid Error 3.656 %
Epoch 5/10


100%|██████████| 390/390 [00:24<00:00, 15.90it/s]
100%|██████████| 78/78 [00:17<00:00,  4.52it/s]


validation accuracy: 0.9727564102564102
Valid Error 2.724 %
Epoch 6/10


100%|██████████| 390/390 [00:24<00:00, 15.89it/s]
100%|██████████| 78/78 [00:17<00:00,  4.37it/s]


validation accuracy: 0.9722556089743589
Valid Error 2.774 %
Epoch 7/10


100%|██████████| 390/390 [00:23<00:00, 16.37it/s]
100%|██████████| 78/78 [00:17<00:00,  4.37it/s]


validation accuracy: 0.9689503205128205
Valid Error 3.105 %
Epoch 8/10


100%|██████████| 390/390 [00:24<00:00, 16.09it/s]
100%|██████████| 78/78 [00:17<00:00,  4.51it/s]


validation accuracy: 0.9710536858974359
Valid Error 2.895 %
Epoch 9/10


100%|██████████| 390/390 [00:23<00:00, 16.30it/s]
100%|██████████| 78/78 [00:17<00:00,  4.35it/s]


validation accuracy: 0.9723557692307693
Valid Error 2.764 %
Epoch 10/10


100%|██████████| 390/390 [00:23<00:00, 16.42it/s]
100%|██████████| 78/78 [00:17<00:00,  4.49it/s]


validation accuracy: 0.9702524038461539
Valid Error 2.975 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 2, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, -1, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:22<00:00, 17.01it/s]
100%|██████████| 78/78 [00:16<00:00,  4.68it/s]


validation accuracy: 0.9525240384615384
Valid Error 4.748 %
Epoch 2/10


100%|██████████| 390/390 [00:23<00:00, 16.29it/s]
100%|██████████| 78/78 [00:17<00:00,  4.37it/s]


validation accuracy: 0.9602363782051282
Valid Error 3.976 %
Epoch 3/10


100%|██████████| 390/390 [00:23<00:00, 16.27it/s]
100%|██████████| 78/78 [00:18<00:00,  4.32it/s]


validation accuracy: 0.9691506410256411
Valid Error 3.085 %
Epoch 4/10


100%|██████████| 390/390 [00:24<00:00, 16.02it/s]
100%|██████████| 78/78 [00:17<00:00,  4.55it/s]


validation accuracy: 0.9654447115384616
Valid Error 3.456 %
Epoch 5/10


100%|██████████| 390/390 [00:23<00:00, 16.40it/s]
100%|██████████| 78/78 [00:17<00:00,  4.49it/s]


validation accuracy: 0.9698517628205128
Valid Error 3.015 %
Epoch 6/10


100%|██████████| 390/390 [00:24<00:00, 15.86it/s]
100%|██████████| 78/78 [00:18<00:00,  4.32it/s]


validation accuracy: 0.9645432692307693
Valid Error 3.546 %
Epoch 7/10


100%|██████████| 390/390 [00:23<00:00, 16.37it/s]
100%|██████████| 78/78 [00:18<00:00,  4.33it/s]


validation accuracy: 0.973457532051282
Valid Error 2.654 %
Epoch 8/10


100%|██████████| 390/390 [00:24<00:00, 15.82it/s]
100%|██████████| 78/78 [00:17<00:00,  4.51it/s]


validation accuracy: 0.9691506410256411
Valid Error 3.085 %
Epoch 9/10


100%|██████████| 390/390 [00:24<00:00, 16.02it/s]
100%|██████████| 78/78 [00:17<00:00,  4.34it/s]


validation accuracy: 0.9728565705128205
Valid Error 2.714 %
Epoch 10/10


100%|██████████| 390/390 [00:23<00:00, 16.38it/s]
100%|██████████| 78/78 [00:18<00:00,  4.32it/s]


validation accuracy: 0.9710536858974359
Valid Error 2.895 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 2, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, -1, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:23<00:00, 16.78it/s]
100%|██████████| 78/78 [00:16<00:00,  4.65it/s]


validation accuracy: 0.9454126602564102
Valid Error 5.459 %
Epoch 2/10


100%|██████████| 390/390 [00:23<00:00, 16.30it/s]
100%|██████████| 78/78 [00:17<00:00,  4.46it/s]


validation accuracy: 0.9589342948717948
Valid Error 4.107 %
Epoch 3/10


100%|██████████| 390/390 [00:23<00:00, 16.29it/s]
100%|██████████| 78/78 [00:18<00:00,  4.31it/s]


validation accuracy: 0.9669471153846154
Valid Error 3.305 %
Epoch 4/10


100%|██████████| 390/390 [00:24<00:00, 16.24it/s]
100%|██████████| 78/78 [00:17<00:00,  4.55it/s]


validation accuracy: 0.9728565705128205
Valid Error 2.714 %
Epoch 5/10


100%|██████████| 390/390 [00:24<00:00, 15.97it/s]
100%|██████████| 78/78 [00:17<00:00,  4.51it/s]


validation accuracy: 0.9672475961538461
Valid Error 3.275 %
Epoch 6/10


100%|██████████| 390/390 [00:23<00:00, 16.27it/s]
100%|██████████| 78/78 [00:17<00:00,  4.36it/s]


validation accuracy: 0.9724559294871795
Valid Error 2.754 %
Epoch 7/10


100%|██████████| 390/390 [00:24<00:00, 16.00it/s]
100%|██████████| 78/78 [00:17<00:00,  4.53it/s]


validation accuracy: 0.967948717948718
Valid Error 3.205 %
Epoch 8/10


100%|██████████| 390/390 [00:24<00:00, 16.10it/s]
100%|██████████| 78/78 [00:17<00:00,  4.43it/s]


validation accuracy: 0.9708533653846154
Valid Error 2.915 %
Epoch 9/10


100%|██████████| 390/390 [00:24<00:00, 16.19it/s]
100%|██████████| 78/78 [00:18<00:00,  4.24it/s]


validation accuracy: 0.9725560897435898
Valid Error 2.744 %
Epoch 10/10


100%|██████████| 390/390 [00:23<00:00, 16.54it/s]
100%|██████████| 78/78 [00:17<00:00,  4.48it/s]


validation accuracy: 0.9697516025641025
Valid Error 3.025 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 5, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, 0, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:47<00:00,  8.13it/s]
100%|██████████| 78/78 [00:17<00:00,  4.58it/s]


validation accuracy: 0.9504206730769231
Valid Error 4.958 %
Epoch 2/10


100%|██████████| 390/390 [00:49<00:00,  7.88it/s]
100%|██████████| 78/78 [00:17<00:00,  4.56it/s]


validation accuracy: 0.9638421474358975
Valid Error 3.616 %
Epoch 3/10


100%|██████████| 390/390 [00:49<00:00,  7.83it/s]
100%|██████████| 78/78 [00:17<00:00,  4.46it/s]


validation accuracy: 0.9697516025641025
Valid Error 3.025 %
Epoch 4/10


100%|██████████| 390/390 [00:49<00:00,  7.94it/s]
100%|██████████| 78/78 [00:16<00:00,  4.59it/s]


validation accuracy: 0.9716546474358975
Valid Error 2.835 %
Epoch 5/10


100%|██████████| 390/390 [00:50<00:00,  7.75it/s]
100%|██████████| 78/78 [00:17<00:00,  4.40it/s]


validation accuracy: 0.9677483974358975
Valid Error 3.225 %
Epoch 6/10


100%|██████████| 390/390 [00:49<00:00,  7.93it/s]
100%|██████████| 78/78 [00:17<00:00,  4.41it/s]


validation accuracy: 0.9709535256410257
Valid Error 2.905 %
Epoch 7/10


100%|██████████| 390/390 [00:48<00:00,  8.01it/s]
100%|██████████| 78/78 [00:17<00:00,  4.55it/s]


validation accuracy: 0.9673477564102564
Valid Error 3.265 %
Epoch 8/10


100%|██████████| 390/390 [00:49<00:00,  7.95it/s]
100%|██████████| 78/78 [00:17<00:00,  4.44it/s]


validation accuracy: 0.9667467948717948
Valid Error 3.325 %
Epoch 9/10


100%|██████████| 390/390 [00:49<00:00,  7.82it/s]
100%|██████████| 78/78 [00:17<00:00,  4.43it/s]


validation accuracy: 0.9746594551282052
Valid Error 2.534 %
Epoch 10/10


100%|██████████| 390/390 [00:49<00:00,  7.83it/s]
100%|██████████| 78/78 [00:17<00:00,  4.52it/s]


validation accuracy: 0.9742588141025641
Valid Error 2.574 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 5, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, 0, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:50<00:00,  7.76it/s]
100%|██████████| 78/78 [00:18<00:00,  4.27it/s]


validation accuracy: 0.9576322115384616
Valid Error 4.237 %
Epoch 2/10


100%|██████████| 390/390 [00:49<00:00,  7.82it/s]
100%|██████████| 78/78 [00:18<00:00,  4.22it/s]


validation accuracy: 0.9603365384615384
Valid Error 3.966 %
Epoch 3/10


100%|██████████| 390/390 [00:49<00:00,  7.82it/s]
100%|██████████| 78/78 [00:16<00:00,  4.67it/s]


validation accuracy: 0.9639423076923077
Valid Error 3.606 %
Epoch 4/10


100%|██████████| 390/390 [00:49<00:00,  7.85it/s]
100%|██████████| 78/78 [00:18<00:00,  4.15it/s]


validation accuracy: 0.9686498397435898
Valid Error 3.135 %
Epoch 5/10


100%|██████████| 390/390 [00:50<00:00,  7.75it/s]
100%|██████████| 78/78 [00:18<00:00,  4.23it/s]


validation accuracy: 0.96484375
Valid Error 3.516 %
Epoch 6/10


100%|██████████| 390/390 [00:50<00:00,  7.69it/s]
100%|██████████| 78/78 [00:17<00:00,  4.42it/s]


validation accuracy: 0.9682491987179487
Valid Error 3.175 %
Epoch 7/10


100%|██████████| 390/390 [00:51<00:00,  7.64it/s]
100%|██████████| 78/78 [00:17<00:00,  4.54it/s]


validation accuracy: 0.9706530448717948
Valid Error 2.935 %
Epoch 8/10


100%|██████████| 390/390 [00:51<00:00,  7.51it/s]
100%|██████████| 78/78 [00:17<00:00,  4.50it/s]


validation accuracy: 0.9651442307692307
Valid Error 3.486 %
Epoch 9/10


100%|██████████| 390/390 [00:51<00:00,  7.58it/s]
100%|██████████| 78/78 [00:16<00:00,  4.66it/s]


validation accuracy: 0.9706530448717948
Valid Error 2.935 %
Epoch 10/10


100%|██████████| 390/390 [00:50<00:00,  7.73it/s]
100%|██████████| 78/78 [00:17<00:00,  4.40it/s]


validation accuracy: 0.9758613782051282
Valid Error 2.414 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 5, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, -1, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:50<00:00,  7.77it/s]
100%|██████████| 78/78 [00:17<00:00,  4.49it/s]


validation accuracy: 0.9501201923076923
Valid Error 4.988 %
Epoch 2/10


100%|██████████| 390/390 [00:50<00:00,  7.72it/s]
100%|██████████| 78/78 [00:18<00:00,  4.30it/s]


validation accuracy: 0.9608373397435898
Valid Error 3.916 %
Epoch 3/10


100%|██████████| 390/390 [00:49<00:00,  7.82it/s]
100%|██████████| 78/78 [00:18<00:00,  4.25it/s]


validation accuracy: 0.9650440705128205
Valid Error 3.496 %
Epoch 4/10


100%|██████████| 390/390 [00:49<00:00,  7.80it/s]
100%|██████████| 78/78 [00:17<00:00,  4.51it/s]


validation accuracy: 0.9668469551282052
Valid Error 3.315 %
Epoch 5/10


100%|██████████| 390/390 [00:50<00:00,  7.70it/s]
100%|██████████| 78/78 [00:17<00:00,  4.48it/s]


validation accuracy: 0.9689503205128205
Valid Error 3.105 %
Epoch 6/10


100%|██████████| 390/390 [00:49<00:00,  7.92it/s]
100%|██████████| 78/78 [00:17<00:00,  4.53it/s]


validation accuracy: 0.9657451923076923
Valid Error 3.425 %
Epoch 7/10


100%|██████████| 390/390 [00:50<00:00,  7.78it/s]
100%|██████████| 78/78 [00:18<00:00,  4.23it/s]


validation accuracy: 0.9719551282051282
Valid Error 2.804 %
Epoch 8/10


100%|██████████| 390/390 [00:49<00:00,  7.81it/s]
100%|██████████| 78/78 [00:18<00:00,  4.22it/s]


validation accuracy: 0.9720552884615384
Valid Error 2.794 %
Epoch 9/10


100%|██████████| 390/390 [00:50<00:00,  7.68it/s]
100%|██████████| 78/78 [00:17<00:00,  4.36it/s]


validation accuracy: 0.971854967948718
Valid Error 2.815 %
Epoch 10/10


100%|██████████| 390/390 [00:50<00:00,  7.69it/s]
100%|██████████| 78/78 [00:17<00:00,  4.39it/s]


validation accuracy: 0.9749599358974359
Valid Error 2.504 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 5, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, -1, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:51<00:00,  7.59it/s]
100%|██████████| 78/78 [00:17<00:00,  4.48it/s]


validation accuracy: 0.9547275641025641
Valid Error 4.527 %
Epoch 2/10


100%|██████████| 390/390 [00:49<00:00,  7.86it/s]
100%|██████████| 78/78 [00:17<00:00,  4.46it/s]


validation accuracy: 0.9632411858974359
Valid Error 3.676 %
Epoch 3/10


100%|██████████| 390/390 [00:50<00:00,  7.77it/s]
100%|██████████| 78/78 [00:17<00:00,  4.50it/s]


validation accuracy: 0.9691506410256411
Valid Error 3.085 %
Epoch 4/10


100%|██████████| 390/390 [00:50<00:00,  7.77it/s]
100%|██████████| 78/78 [00:17<00:00,  4.37it/s]


validation accuracy: 0.96875
Valid Error 3.125 %
Epoch 5/10


100%|██████████| 390/390 [00:49<00:00,  7.82it/s]
100%|██████████| 78/78 [00:17<00:00,  4.42it/s]


validation accuracy: 0.9698517628205128
Valid Error 3.015 %
Epoch 6/10


100%|██████████| 390/390 [00:50<00:00,  7.80it/s]
100%|██████████| 78/78 [00:17<00:00,  4.38it/s]


validation accuracy: 0.9688501602564102
Valid Error 3.115 %
Epoch 7/10


100%|██████████| 390/390 [00:50<00:00,  7.77it/s]
100%|██████████| 78/78 [00:18<00:00,  4.30it/s]


validation accuracy: 0.9719551282051282
Valid Error 2.804 %
Epoch 8/10


100%|██████████| 390/390 [00:49<00:00,  7.96it/s]
100%|██████████| 78/78 [00:17<00:00,  4.44it/s]


validation accuracy: 0.9699519230769231
Valid Error 3.005 %
Epoch 9/10


100%|██████████| 390/390 [00:50<00:00,  7.72it/s]
100%|██████████| 78/78 [00:17<00:00,  4.50it/s]


validation accuracy: 0.9692508012820513
Valid Error 3.075 %
Epoch 10/10


100%|██████████| 390/390 [00:49<00:00,  7.83it/s]
100%|██████████| 78/78 [00:17<00:00,  4.48it/s]


validation accuracy: 0.9759615384615384
Valid Error 2.404 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 5, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, 0, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:50<00:00,  7.79it/s]
100%|██████████| 78/78 [00:17<00:00,  4.51it/s]


validation accuracy: 0.9543269230769231
Valid Error 4.567 %
Epoch 2/10


100%|██████████| 390/390 [00:49<00:00,  7.88it/s]
100%|██████████| 78/78 [00:17<00:00,  4.51it/s]


validation accuracy: 0.9647435897435898
Valid Error 3.526 %
Epoch 3/10


100%|██████████| 390/390 [00:50<00:00,  7.75it/s]
100%|██████████| 78/78 [00:17<00:00,  4.50it/s]


validation accuracy: 0.9686498397435898
Valid Error 3.135 %
Epoch 4/10


100%|██████████| 390/390 [00:49<00:00,  7.92it/s]
100%|██████████| 78/78 [00:17<00:00,  4.52it/s]


validation accuracy: 0.9666466346153846
Valid Error 3.335 %
Epoch 5/10


100%|██████████| 390/390 [00:50<00:00,  7.72it/s]
100%|██████████| 78/78 [00:17<00:00,  4.43it/s]


validation accuracy: 0.9721554487179487
Valid Error 2.784 %
Epoch 6/10


100%|██████████| 390/390 [00:49<00:00,  7.82it/s]
100%|██████████| 78/78 [00:17<00:00,  4.35it/s]


validation accuracy: 0.9677483974358975
Valid Error 3.225 %
Epoch 7/10


100%|██████████| 390/390 [00:50<00:00,  7.79it/s]
100%|██████████| 78/78 [00:19<00:00,  4.07it/s]


validation accuracy: 0.9735576923076923
Valid Error 2.644 %
Epoch 8/10


100%|██████████| 390/390 [00:49<00:00,  7.83it/s]
100%|██████████| 78/78 [00:17<00:00,  4.45it/s]


validation accuracy: 0.9722556089743589
Valid Error 2.774 %
Epoch 9/10


100%|██████████| 390/390 [00:49<00:00,  7.91it/s]
100%|██████████| 78/78 [00:18<00:00,  4.27it/s]


validation accuracy: 0.9667467948717948
Valid Error 3.325 %
Epoch 10/10


100%|██████████| 390/390 [00:49<00:00,  7.82it/s]
100%|██████████| 78/78 [00:17<00:00,  4.54it/s]


validation accuracy: 0.9710536858974359
Valid Error 2.895 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 5, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, 0, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:49<00:00,  7.93it/s]
100%|██████████| 78/78 [00:17<00:00,  4.43it/s]


validation accuracy: 0.9546274038461539
Valid Error 4.537 %
Epoch 2/10


100%|██████████| 390/390 [00:49<00:00,  7.82it/s]
100%|██████████| 78/78 [00:18<00:00,  4.22it/s]


validation accuracy: 0.9671474358974359
Valid Error 3.285 %
Epoch 3/10


100%|██████████| 390/390 [00:49<00:00,  7.88it/s]
100%|██████████| 78/78 [00:17<00:00,  4.45it/s]


validation accuracy: 0.9693509615384616
Valid Error 3.065 %
Epoch 4/10


100%|██████████| 390/390 [00:50<00:00,  7.79it/s]
100%|██████████| 78/78 [00:17<00:00,  4.43it/s]


validation accuracy: 0.9688501602564102
Valid Error 3.115 %
Epoch 5/10


100%|██████████| 390/390 [00:50<00:00,  7.78it/s]
100%|██████████| 78/78 [00:17<00:00,  4.43it/s]


validation accuracy: 0.9717548076923077
Valid Error 2.825 %
Epoch 6/10


100%|██████████| 390/390 [00:50<00:00,  7.74it/s]
100%|██████████| 78/78 [00:18<00:00,  4.29it/s]


validation accuracy: 0.9728565705128205
Valid Error 2.714 %
Epoch 7/10


100%|██████████| 390/390 [00:50<00:00,  7.80it/s]
100%|██████████| 78/78 [00:18<00:00,  4.17it/s]


validation accuracy: 0.9678485576923077
Valid Error 3.215 %
Epoch 8/10


100%|██████████| 390/390 [00:49<00:00,  7.85it/s]
100%|██████████| 78/78 [00:18<00:00,  4.16it/s]


validation accuracy: 0.9651442307692307
Valid Error 3.486 %
Epoch 9/10


100%|██████████| 390/390 [00:51<00:00,  7.64it/s]
100%|██████████| 78/78 [00:17<00:00,  4.40it/s]


validation accuracy: 0.9748597756410257
Valid Error 2.514 %
Epoch 10/10


100%|██████████| 390/390 [00:50<00:00,  7.71it/s]
100%|██████████| 78/78 [00:18<00:00,  4.30it/s]


validation accuracy: 0.971854967948718
Valid Error 2.815 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 5, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, -1, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:50<00:00,  7.76it/s]
100%|██████████| 78/78 [00:18<00:00,  4.27it/s]


validation accuracy: 0.9557291666666666
Valid Error 4.427 %
Epoch 2/10


100%|██████████| 390/390 [00:50<00:00,  7.78it/s]
100%|██████████| 78/78 [00:17<00:00,  4.49it/s]


validation accuracy: 0.9704527243589743
Valid Error 2.955 %
Epoch 3/10


100%|██████████| 390/390 [00:49<00:00,  7.94it/s]
100%|██████████| 78/78 [00:18<00:00,  4.33it/s]


validation accuracy: 0.9649439102564102
Valid Error 3.506 %
Epoch 4/10


100%|██████████| 390/390 [00:50<00:00,  7.77it/s]
100%|██████████| 78/78 [00:18<00:00,  4.21it/s]


validation accuracy: 0.9661458333333334
Valid Error 3.385 %
Epoch 5/10


100%|██████████| 390/390 [00:50<00:00,  7.77it/s]
100%|██████████| 78/78 [00:17<00:00,  4.34it/s]


validation accuracy: 0.9691506410256411
Valid Error 3.085 %
Epoch 6/10


100%|██████████| 390/390 [00:49<00:00,  7.82it/s]
100%|██████████| 78/78 [00:17<00:00,  4.54it/s]


validation accuracy: 0.9702524038461539
Valid Error 2.975 %
Epoch 7/10


100%|██████████| 390/390 [00:49<00:00,  7.82it/s]
100%|██████████| 78/78 [00:18<00:00,  4.32it/s]


validation accuracy: 0.9669471153846154
Valid Error 3.305 %
Epoch 8/10


100%|██████████| 390/390 [00:49<00:00,  7.84it/s]
100%|██████████| 78/78 [00:19<00:00,  4.04it/s]


validation accuracy: 0.969551282051282
Valid Error 3.045 %
Epoch 9/10


100%|██████████| 390/390 [00:48<00:00,  7.98it/s]
100%|██████████| 78/78 [00:18<00:00,  4.23it/s]


validation accuracy: 0.9762620192307693
Valid Error 2.374 %
Epoch 10/10


100%|██████████| 390/390 [00:49<00:00,  7.88it/s]
100%|██████████| 78/78 [00:17<00:00,  4.41it/s]


validation accuracy: 0.9736578525641025
Valid Error 2.634 %
Trying config: {'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 5, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, -1, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:49<00:00,  7.95it/s]
100%|██████████| 78/78 [00:17<00:00,  4.38it/s]


validation accuracy: 0.950020032051282
Valid Error 4.998 %
Epoch 2/10


100%|██████████| 390/390 [00:50<00:00,  7.75it/s]
100%|██████████| 78/78 [00:17<00:00,  4.44it/s]


validation accuracy: 0.9644431089743589
Valid Error 3.556 %
Epoch 3/10


100%|██████████| 390/390 [00:50<00:00,  7.80it/s]
100%|██████████| 78/78 [00:17<00:00,  4.52it/s]


validation accuracy: 0.9704527243589743
Valid Error 2.955 %
Epoch 4/10


100%|██████████| 390/390 [00:50<00:00,  7.77it/s]
100%|██████████| 78/78 [00:16<00:00,  4.59it/s]


validation accuracy: 0.9691506410256411
Valid Error 3.085 %
Epoch 5/10


100%|██████████| 390/390 [00:49<00:00,  7.94it/s]
100%|██████████| 78/78 [00:17<00:00,  4.42it/s]


validation accuracy: 0.9675480769230769
Valid Error 3.245 %
Epoch 6/10


100%|██████████| 390/390 [00:50<00:00,  7.79it/s]
100%|██████████| 78/78 [00:17<00:00,  4.54it/s]


validation accuracy: 0.9691506410256411
Valid Error 3.085 %
Epoch 7/10


100%|██████████| 390/390 [00:49<00:00,  7.93it/s]
100%|██████████| 78/78 [00:17<00:00,  4.39it/s]


validation accuracy: 0.9732572115384616
Valid Error 2.674 %
Epoch 8/10


100%|██████████| 390/390 [00:49<00:00,  7.81it/s]
100%|██████████| 78/78 [00:17<00:00,  4.46it/s]


validation accuracy: 0.9647435897435898
Valid Error 3.526 %
Epoch 9/10


100%|██████████| 390/390 [00:49<00:00,  7.91it/s]
100%|██████████| 78/78 [00:17<00:00,  4.38it/s]


validation accuracy: 0.9708533653846154
Valid Error 2.915 %
Epoch 10/10


100%|██████████| 390/390 [00:50<00:00,  7.78it/s]
100%|██████████| 78/78 [00:17<00:00,  4.41it/s]


validation accuracy: 0.9729567307692307
Valid Error 2.704 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 1, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, 0, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:15<00:00, 25.54it/s]
100%|██████████| 78/78 [00:17<00:00,  4.56it/s]


validation accuracy: 0.9122596153846154
Valid Error 8.774 %
Epoch 2/10


100%|██████████| 390/390 [00:14<00:00, 26.43it/s]
100%|██████████| 78/78 [00:17<00:00,  4.58it/s]


validation accuracy: 0.9290865384615384
Valid Error 7.091 %
Epoch 3/10


100%|██████████| 390/390 [00:15<00:00, 24.58it/s]
100%|██████████| 78/78 [00:17<00:00,  4.41it/s]


validation accuracy: 0.9459134615384616
Valid Error 5.409 %
Epoch 4/10


100%|██████████| 390/390 [00:15<00:00, 25.47it/s]
100%|██████████| 78/78 [00:17<00:00,  4.46it/s]


validation accuracy: 0.9525240384615384
Valid Error 4.748 %
Epoch 5/10


100%|██████████| 390/390 [00:15<00:00, 24.65it/s]
100%|██████████| 78/78 [00:17<00:00,  4.59it/s]


validation accuracy: 0.9565304487179487
Valid Error 4.347 %
Epoch 6/10


100%|██████████| 390/390 [00:15<00:00, 25.57it/s]
100%|██████████| 78/78 [00:16<00:00,  4.64it/s]


validation accuracy: 0.9595352564102564
Valid Error 4.046 %
Epoch 7/10


100%|██████████| 390/390 [00:16<00:00, 24.37it/s]
100%|██████████| 78/78 [00:17<00:00,  4.58it/s]


validation accuracy: 0.9633413461538461
Valid Error 3.666 %
Epoch 8/10


100%|██████████| 390/390 [00:16<00:00, 24.13it/s]
100%|██████████| 78/78 [00:17<00:00,  4.58it/s]


validation accuracy: 0.9636418269230769
Valid Error 3.636 %
Epoch 9/10


100%|██████████| 390/390 [00:16<00:00, 24.31it/s]
100%|██████████| 78/78 [00:17<00:00,  4.59it/s]


validation accuracy: 0.9614383012820513
Valid Error 3.856 %
Epoch 10/10


100%|██████████| 390/390 [00:15<00:00, 25.20it/s]
100%|██████████| 78/78 [00:17<00:00,  4.58it/s]


validation accuracy: 0.9659455128205128
Valid Error 3.405 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 1, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, 0, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:15<00:00, 25.11it/s]
100%|██████████| 78/78 [00:17<00:00,  4.54it/s]


validation accuracy: 0.9163661858974359
Valid Error 8.363 %
Epoch 2/10


100%|██████████| 390/390 [00:15<00:00, 25.55it/s]
100%|██████████| 78/78 [00:17<00:00,  4.48it/s]


validation accuracy: 0.9356971153846154
Valid Error 6.43 %
Epoch 3/10


100%|██████████| 390/390 [00:15<00:00, 24.88it/s]
100%|██████████| 78/78 [00:18<00:00,  4.26it/s]


validation accuracy: 0.9428084935897436
Valid Error 5.719 %
Epoch 4/10


100%|██████████| 390/390 [00:15<00:00, 25.76it/s]
100%|██████████| 78/78 [00:17<00:00,  4.48it/s]


validation accuracy: 0.9522235576923077
Valid Error 4.778 %
Epoch 5/10


100%|██████████| 390/390 [00:15<00:00, 25.92it/s]
100%|██████████| 78/78 [00:17<00:00,  4.49it/s]


validation accuracy: 0.9585336538461539
Valid Error 4.147 %
Epoch 6/10


100%|██████████| 390/390 [00:15<00:00, 25.12it/s]
100%|██████████| 78/78 [00:17<00:00,  4.46it/s]


validation accuracy: 0.9611378205128205
Valid Error 3.886 %
Epoch 7/10


100%|██████████| 390/390 [00:15<00:00, 25.18it/s]
100%|██████████| 78/78 [00:17<00:00,  4.34it/s]


validation accuracy: 0.9621394230769231
Valid Error 3.786 %
Epoch 8/10


100%|██████████| 390/390 [00:15<00:00, 25.61it/s]
100%|██████████| 78/78 [00:18<00:00,  4.14it/s]


validation accuracy: 0.9630408653846154
Valid Error 3.696 %
Epoch 9/10


100%|██████████| 390/390 [00:15<00:00, 25.76it/s]
100%|██████████| 78/78 [00:18<00:00,  4.23it/s]


validation accuracy: 0.9646434294871795
Valid Error 3.536 %
Epoch 10/10


100%|██████████| 390/390 [00:15<00:00, 25.60it/s]
100%|██████████| 78/78 [00:17<00:00,  4.46it/s]


validation accuracy: 0.9659455128205128
Valid Error 3.405 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 1, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, -1, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:15<00:00, 25.84it/s]
100%|██████████| 78/78 [00:17<00:00,  4.49it/s]


validation accuracy: 0.9081530448717948
Valid Error 9.185 %
Epoch 2/10


100%|██████████| 390/390 [00:15<00:00, 25.49it/s]
100%|██████████| 78/78 [00:17<00:00,  4.53it/s]


validation accuracy: 0.9336939102564102
Valid Error 6.631 %
Epoch 3/10


100%|██████████| 390/390 [00:15<00:00, 24.50it/s]
100%|██████████| 78/78 [00:17<00:00,  4.41it/s]


validation accuracy: 0.9417067307692307
Valid Error 5.829 %
Epoch 4/10


100%|██████████| 390/390 [00:15<00:00, 24.39it/s]
100%|██████████| 78/78 [00:17<00:00,  4.45it/s]


validation accuracy: 0.9476161858974359
Valid Error 5.238 %
Epoch 5/10


100%|██████████| 390/390 [00:15<00:00, 25.90it/s]
100%|██████████| 78/78 [00:17<00:00,  4.46it/s]


validation accuracy: 0.953125
Valid Error 4.688 %
Epoch 6/10


100%|██████████| 390/390 [00:15<00:00, 24.47it/s]
100%|██████████| 78/78 [00:17<00:00,  4.55it/s]


validation accuracy: 0.9557291666666666
Valid Error 4.427 %
Epoch 7/10


100%|██████████| 390/390 [00:14<00:00, 26.30it/s]
100%|██████████| 78/78 [00:17<00:00,  4.53it/s]


validation accuracy: 0.9565304487179487
Valid Error 4.347 %
Epoch 8/10


100%|██████████| 390/390 [00:14<00:00, 26.38it/s]
100%|██████████| 78/78 [00:17<00:00,  4.47it/s]


validation accuracy: 0.961738782051282
Valid Error 3.826 %
Epoch 9/10


100%|██████████| 390/390 [00:15<00:00, 25.52it/s]
100%|██████████| 78/78 [00:17<00:00,  4.48it/s]


validation accuracy: 0.9622395833333334
Valid Error 3.776 %
Epoch 10/10


100%|██████████| 390/390 [00:15<00:00, 25.66it/s]
100%|██████████| 78/78 [00:17<00:00,  4.47it/s]


validation accuracy: 0.9644431089743589
Valid Error 3.556 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 1, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, -1, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:15<00:00, 25.51it/s]
100%|██████████| 78/78 [00:17<00:00,  4.38it/s]


validation accuracy: 0.9156650641025641
Valid Error 8.433 %
Epoch 2/10


100%|██████████| 390/390 [00:15<00:00, 25.74it/s]
100%|██████████| 78/78 [00:18<00:00,  4.21it/s]


validation accuracy: 0.9378004807692307
Valid Error 6.22 %
Epoch 3/10


100%|██████████| 390/390 [00:15<00:00, 25.63it/s]
100%|██████████| 78/78 [00:17<00:00,  4.55it/s]


validation accuracy: 0.9449118589743589
Valid Error 5.509 %
Epoch 4/10


100%|██████████| 390/390 [00:15<00:00, 25.38it/s]
100%|██████████| 78/78 [00:17<00:00,  4.49it/s]


validation accuracy: 0.953926282051282
Valid Error 4.607 %
Epoch 5/10


100%|██████████| 390/390 [00:15<00:00, 25.21it/s]
100%|██████████| 78/78 [00:18<00:00,  4.26it/s]


validation accuracy: 0.9576322115384616
Valid Error 4.237 %
Epoch 6/10


100%|██████████| 390/390 [00:15<00:00, 25.62it/s]
100%|██████████| 78/78 [00:18<00:00,  4.23it/s]


validation accuracy: 0.9605368589743589
Valid Error 3.946 %
Epoch 7/10


100%|██████████| 390/390 [00:15<00:00, 25.07it/s]
100%|██████████| 78/78 [00:18<00:00,  4.30it/s]


validation accuracy: 0.960136217948718
Valid Error 3.986 %
Epoch 8/10


100%|██████████| 390/390 [00:15<00:00, 25.19it/s]
100%|██████████| 78/78 [00:18<00:00,  4.19it/s]


validation accuracy: 0.9639423076923077
Valid Error 3.606 %
Epoch 9/10


100%|██████████| 390/390 [00:15<00:00, 24.97it/s]
100%|██████████| 78/78 [00:16<00:00,  4.68it/s]


validation accuracy: 0.9644431089743589
Valid Error 3.556 %
Epoch 10/10


100%|██████████| 390/390 [00:15<00:00, 25.55it/s]
100%|██████████| 78/78 [00:18<00:00,  4.29it/s]


validation accuracy: 0.9657451923076923
Valid Error 3.425 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 1, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, 0, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:15<00:00, 24.59it/s]
100%|██████████| 78/78 [00:18<00:00,  4.29it/s]


validation accuracy: 0.9142628205128205
Valid Error 8.574 %
Epoch 2/10


100%|██████████| 390/390 [00:15<00:00, 25.14it/s]
100%|██████████| 78/78 [00:17<00:00,  4.37it/s]


validation accuracy: 0.9390024038461539
Valid Error 6.1 %
Epoch 3/10


100%|██████████| 390/390 [00:15<00:00, 24.95it/s]
100%|██████████| 78/78 [00:18<00:00,  4.29it/s]


validation accuracy: 0.9474158653846154
Valid Error 5.258 %
Epoch 4/10


100%|██████████| 390/390 [00:15<00:00, 24.55it/s]
100%|██████████| 78/78 [00:17<00:00,  4.48it/s]


validation accuracy: 0.9516225961538461
Valid Error 4.838 %
Epoch 5/10


100%|██████████| 390/390 [00:15<00:00, 25.08it/s]
100%|██████████| 78/78 [00:17<00:00,  4.40it/s]


validation accuracy: 0.9573317307692307
Valid Error 4.267 %
Epoch 6/10


100%|██████████| 390/390 [00:15<00:00, 25.37it/s]
100%|██████████| 78/78 [00:17<00:00,  4.44it/s]


validation accuracy: 0.9568309294871795
Valid Error 4.317 %
Epoch 7/10


100%|██████████| 390/390 [00:16<00:00, 24.20it/s]
100%|██████████| 78/78 [00:17<00:00,  4.50it/s]


validation accuracy: 0.9575320512820513
Valid Error 4.247 %
Epoch 8/10


100%|██████████| 390/390 [00:15<00:00, 25.00it/s]
100%|██████████| 78/78 [00:17<00:00,  4.47it/s]


validation accuracy: 0.9616386217948718
Valid Error 3.836 %
Epoch 9/10


100%|██████████| 390/390 [00:15<00:00, 25.31it/s]
100%|██████████| 78/78 [00:17<00:00,  4.39it/s]


validation accuracy: 0.9646434294871795
Valid Error 3.536 %
Epoch 10/10


100%|██████████| 390/390 [00:16<00:00, 23.74it/s]
100%|██████████| 78/78 [00:17<00:00,  4.44it/s]


validation accuracy: 0.9620392628205128
Valid Error 3.796 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 1, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, 0, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:15<00:00, 25.19it/s]
100%|██████████| 78/78 [00:17<00:00,  4.53it/s]


validation accuracy: 0.9156650641025641
Valid Error 8.433 %
Epoch 2/10


100%|██████████| 390/390 [00:15<00:00, 24.41it/s]
100%|██████████| 78/78 [00:16<00:00,  4.60it/s]


validation accuracy: 0.9351963141025641
Valid Error 6.48 %
Epoch 3/10


100%|██████████| 390/390 [00:15<00:00, 24.53it/s]
100%|██████████| 78/78 [00:17<00:00,  4.50it/s]


validation accuracy: 0.9437099358974359
Valid Error 5.629 %
Epoch 4/10


100%|██████████| 390/390 [00:16<00:00, 24.19it/s]
100%|██████████| 78/78 [00:17<00:00,  4.41it/s]


validation accuracy: 0.9509214743589743
Valid Error 4.908 %
Epoch 5/10


100%|██████████| 390/390 [00:15<00:00, 24.61it/s]
100%|██████████| 78/78 [00:17<00:00,  4.44it/s]


validation accuracy: 0.9532251602564102
Valid Error 4.677 %
Epoch 6/10


100%|██████████| 390/390 [00:15<00:00, 24.92it/s]
100%|██████████| 78/78 [00:17<00:00,  4.43it/s]


validation accuracy: 0.9575320512820513
Valid Error 4.247 %
Epoch 7/10


100%|██████████| 390/390 [00:15<00:00, 24.80it/s]
100%|██████████| 78/78 [00:17<00:00,  4.39it/s]


validation accuracy: 0.9591346153846154
Valid Error 4.087 %
Epoch 8/10


100%|██████████| 390/390 [00:15<00:00, 24.99it/s]
100%|██████████| 78/78 [00:17<00:00,  4.35it/s]


validation accuracy: 0.9595352564102564
Valid Error 4.046 %
Epoch 9/10


100%|██████████| 390/390 [00:15<00:00, 25.15it/s]
100%|██████████| 78/78 [00:17<00:00,  4.49it/s]


validation accuracy: 0.9607371794871795
Valid Error 3.926 %
Epoch 10/10


100%|██████████| 390/390 [00:15<00:00, 24.94it/s]
100%|██████████| 78/78 [00:17<00:00,  4.42it/s]


validation accuracy: 0.9646434294871795
Valid Error 3.536 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 1, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, -1, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:15<00:00, 24.76it/s]
100%|██████████| 78/78 [00:17<00:00,  4.38it/s]


validation accuracy: 0.9139623397435898
Valid Error 8.604 %
Epoch 2/10


100%|██████████| 390/390 [00:15<00:00, 25.43it/s]
100%|██████████| 78/78 [00:18<00:00,  4.27it/s]


validation accuracy: 0.9316907051282052
Valid Error 6.831 %
Epoch 3/10


100%|██████████| 390/390 [00:15<00:00, 24.92it/s]
100%|██████████| 78/78 [00:18<00:00,  4.11it/s]


validation accuracy: 0.9437099358974359
Valid Error 5.629 %
Epoch 4/10


100%|██████████| 390/390 [00:15<00:00, 25.36it/s]
100%|██████████| 78/78 [00:17<00:00,  4.49it/s]


validation accuracy: 0.94921875
Valid Error 5.078 %
Epoch 5/10


100%|██████████| 390/390 [00:15<00:00, 25.88it/s]
100%|██████████| 78/78 [00:18<00:00,  4.23it/s]


validation accuracy: 0.9561298076923077
Valid Error 4.387 %
Epoch 6/10


100%|██████████| 390/390 [00:15<00:00, 25.32it/s]
100%|██████████| 78/78 [00:17<00:00,  4.51it/s]


validation accuracy: 0.9573317307692307
Valid Error 4.267 %
Epoch 7/10


100%|██████████| 390/390 [00:15<00:00, 25.18it/s]
100%|██████████| 78/78 [00:17<00:00,  4.50it/s]


validation accuracy: 0.961738782051282
Valid Error 3.826 %
Epoch 8/10


100%|██████████| 390/390 [00:16<00:00, 24.15it/s]
100%|██████████| 78/78 [00:17<00:00,  4.50it/s]


validation accuracy: 0.9618389423076923
Valid Error 3.816 %
Epoch 9/10


100%|██████████| 390/390 [00:16<00:00, 23.93it/s]
100%|██████████| 78/78 [00:17<00:00,  4.46it/s]


validation accuracy: 0.9613381410256411
Valid Error 3.866 %
Epoch 10/10


100%|██████████| 390/390 [00:16<00:00, 24.31it/s]
100%|██████████| 78/78 [00:17<00:00,  4.53it/s]


validation accuracy: 0.9657451923076923
Valid Error 3.425 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 1, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, -1, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:16<00:00, 24.13it/s]
100%|██████████| 78/78 [00:16<00:00,  4.60it/s]


validation accuracy: 0.9138621794871795
Valid Error 8.614 %
Epoch 2/10


100%|██████████| 390/390 [00:16<00:00, 24.36it/s]
100%|██████████| 78/78 [00:17<00:00,  4.52it/s]


validation accuracy: 0.9318910256410257
Valid Error 6.811 %
Epoch 3/10


100%|██████████| 390/390 [00:16<00:00, 23.75it/s]
100%|██████████| 78/78 [00:17<00:00,  4.56it/s]


validation accuracy: 0.9442107371794872
Valid Error 5.579 %
Epoch 4/10


100%|██████████| 390/390 [00:16<00:00, 24.11it/s]
100%|██████████| 78/78 [00:17<00:00,  4.53it/s]


validation accuracy: 0.9491185897435898
Valid Error 5.088 %
Epoch 5/10


100%|██████████| 390/390 [00:15<00:00, 25.06it/s]
100%|██████████| 78/78 [00:17<00:00,  4.57it/s]


validation accuracy: 0.9529246794871795
Valid Error 4.708 %
Epoch 6/10


100%|██████████| 390/390 [00:16<00:00, 24.26it/s]
100%|██████████| 78/78 [00:17<00:00,  4.52it/s]


validation accuracy: 0.953926282051282
Valid Error 4.607 %
Epoch 7/10


100%|██████████| 390/390 [00:15<00:00, 25.97it/s]
100%|██████████| 78/78 [00:17<00:00,  4.40it/s]


validation accuracy: 0.9547275641025641
Valid Error 4.527 %
Epoch 8/10


100%|██████████| 390/390 [00:15<00:00, 25.09it/s]
100%|██████████| 78/78 [00:17<00:00,  4.50it/s]


validation accuracy: 0.9606370192307693
Valid Error 3.936 %
Epoch 9/10


100%|██████████| 390/390 [00:15<00:00, 24.95it/s]
100%|██████████| 78/78 [00:18<00:00,  4.23it/s]


validation accuracy: 0.9638421474358975
Valid Error 3.616 %
Epoch 10/10


100%|██████████| 390/390 [00:15<00:00, 25.02it/s]
100%|██████████| 78/78 [00:18<00:00,  4.22it/s]


validation accuracy: 0.9644431089743589
Valid Error 3.556 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 2, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, 0, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:23<00:00, 16.62it/s]
100%|██████████| 78/78 [00:17<00:00,  4.51it/s]


validation accuracy: 0.9092548076923077
Valid Error 9.075 %
Epoch 2/10


100%|██████████| 390/390 [00:24<00:00, 15.65it/s]
100%|██████████| 78/78 [00:17<00:00,  4.54it/s]


validation accuracy: 0.9308894230769231
Valid Error 6.911 %
Epoch 3/10


100%|██████████| 390/390 [00:24<00:00, 16.04it/s]
100%|██████████| 78/78 [00:17<00:00,  4.56it/s]


validation accuracy: 0.9442107371794872
Valid Error 5.579 %
Epoch 4/10


100%|██████████| 390/390 [00:24<00:00, 15.80it/s]
100%|██████████| 78/78 [00:17<00:00,  4.52it/s]


validation accuracy: 0.9485176282051282
Valid Error 5.148 %
Epoch 5/10


100%|██████████| 390/390 [00:24<00:00, 16.21it/s]
100%|██████████| 78/78 [00:17<00:00,  4.51it/s]


validation accuracy: 0.956229967948718
Valid Error 4.377 %
Epoch 6/10


100%|██████████| 390/390 [00:23<00:00, 16.43it/s]
100%|██████████| 78/78 [00:17<00:00,  4.46it/s]


validation accuracy: 0.9588341346153846
Valid Error 4.117 %
Epoch 7/10


100%|██████████| 390/390 [00:24<00:00, 15.61it/s]
100%|██████████| 78/78 [00:17<00:00,  4.54it/s]


validation accuracy: 0.9621394230769231
Valid Error 3.786 %
Epoch 8/10


100%|██████████| 390/390 [00:23<00:00, 16.91it/s]
100%|██████████| 78/78 [00:17<00:00,  4.58it/s]


validation accuracy: 0.9620392628205128
Valid Error 3.796 %
Epoch 9/10


100%|██████████| 390/390 [00:23<00:00, 16.74it/s]
100%|██████████| 78/78 [00:16<00:00,  4.74it/s]


validation accuracy: 0.9649439102564102
Valid Error 3.506 %
Epoch 10/10


100%|██████████| 390/390 [00:23<00:00, 16.26it/s]
100%|██████████| 78/78 [00:16<00:00,  4.64it/s]


validation accuracy: 0.9650440705128205
Valid Error 3.496 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 2, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, 0, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:22<00:00, 17.29it/s]
100%|██████████| 78/78 [00:16<00:00,  4.73it/s]


validation accuracy: 0.9185697115384616
Valid Error 8.143 %
Epoch 2/10


100%|██████████| 390/390 [00:22<00:00, 17.68it/s]
100%|██████████| 78/78 [00:17<00:00,  4.46it/s]


validation accuracy: 0.9359975961538461
Valid Error 6.4 %
Epoch 3/10


100%|██████████| 390/390 [00:22<00:00, 17.36it/s]
100%|██████████| 78/78 [00:16<00:00,  4.80it/s]


validation accuracy: 0.9444110576923077
Valid Error 5.559 %
Epoch 4/10


100%|██████████| 390/390 [00:22<00:00, 17.08it/s]
100%|██████████| 78/78 [00:17<00:00,  4.58it/s]


validation accuracy: 0.9526241987179487
Valid Error 4.738 %
Epoch 5/10


100%|██████████| 390/390 [00:22<00:00, 17.33it/s]
100%|██████████| 78/78 [00:16<00:00,  4.82it/s]


validation accuracy: 0.9566306089743589
Valid Error 4.337 %
Epoch 6/10


100%|██████████| 390/390 [00:23<00:00, 16.92it/s]
100%|██████████| 78/78 [00:17<00:00,  4.34it/s]


validation accuracy: 0.9592347756410257
Valid Error 4.077 %
Epoch 7/10


100%|██████████| 390/390 [00:24<00:00, 15.77it/s]
100%|██████████| 78/78 [00:18<00:00,  4.19it/s]


validation accuracy: 0.9631410256410257
Valid Error 3.686 %
Epoch 8/10


100%|██████████| 390/390 [00:26<00:00, 14.78it/s]
100%|██████████| 78/78 [00:19<00:00,  4.07it/s]


validation accuracy: 0.9631410256410257
Valid Error 3.686 %
Epoch 9/10


100%|██████████| 390/390 [00:24<00:00, 16.16it/s]
100%|██████████| 78/78 [00:18<00:00,  4.30it/s]


validation accuracy: 0.9660456730769231
Valid Error 3.395 %
Epoch 10/10


100%|██████████| 390/390 [00:25<00:00, 15.46it/s]
100%|██████████| 78/78 [00:16<00:00,  4.67it/s]


validation accuracy: 0.96484375
Valid Error 3.516 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 2, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, -1, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:24<00:00, 15.72it/s]
100%|██████████| 78/78 [00:17<00:00,  4.56it/s]


validation accuracy: 0.9199719551282052
Valid Error 8.003 %
Epoch 2/10


100%|██████████| 390/390 [00:23<00:00, 16.56it/s]
100%|██████████| 78/78 [00:16<00:00,  4.60it/s]


validation accuracy: 0.9397035256410257
Valid Error 6.03 %
Epoch 3/10


100%|██████████| 390/390 [00:23<00:00, 16.30it/s]
100%|██████████| 78/78 [00:16<00:00,  4.67it/s]


validation accuracy: 0.9495192307692307
Valid Error 5.048 %
Epoch 4/10


100%|██████████| 390/390 [00:23<00:00, 16.68it/s]
100%|██████████| 78/78 [00:16<00:00,  4.63it/s]


validation accuracy: 0.9545272435897436
Valid Error 4.547 %
Epoch 5/10


100%|██████████| 390/390 [00:22<00:00, 16.99it/s]
100%|██████████| 78/78 [00:17<00:00,  4.41it/s]


validation accuracy: 0.95703125
Valid Error 4.297 %
Epoch 6/10


100%|██████████| 390/390 [00:23<00:00, 16.38it/s]
100%|██████████| 78/78 [00:16<00:00,  4.71it/s]


validation accuracy: 0.9644431089743589
Valid Error 3.556 %
Epoch 7/10


100%|██████████| 390/390 [00:22<00:00, 17.08it/s]
100%|██████████| 78/78 [00:16<00:00,  4.73it/s]


validation accuracy: 0.9655448717948718
Valid Error 3.446 %
Epoch 8/10


100%|██████████| 390/390 [00:23<00:00, 16.48it/s]
100%|██████████| 78/78 [00:16<00:00,  4.76it/s]


validation accuracy: 0.9662459935897436
Valid Error 3.375 %
Epoch 9/10


100%|██████████| 390/390 [00:23<00:00, 16.65it/s]
100%|██████████| 78/78 [00:16<00:00,  4.67it/s]


validation accuracy: 0.9678485576923077
Valid Error 3.215 %
Epoch 10/10


100%|██████████| 390/390 [00:23<00:00, 16.57it/s]
100%|██████████| 78/78 [00:16<00:00,  4.61it/s]


validation accuracy: 0.9664463141025641
Valid Error 3.355 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 2, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, -1, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:22<00:00, 17.03it/s]
100%|██████████| 78/78 [00:19<00:00,  3.94it/s]


validation accuracy: 0.9164663461538461
Valid Error 8.353 %
Epoch 2/10


100%|██████████| 390/390 [00:23<00:00, 16.91it/s]
100%|██████████| 78/78 [00:19<00:00,  4.04it/s]


validation accuracy: 0.9326923076923077
Valid Error 6.731 %
Epoch 3/10


100%|██████████| 390/390 [00:25<00:00, 15.52it/s]
100%|██████████| 78/78 [00:17<00:00,  4.57it/s]


validation accuracy: 0.9466145833333334
Valid Error 5.339 %
Epoch 4/10


100%|██████████| 390/390 [00:23<00:00, 16.41it/s]
100%|██████████| 78/78 [00:17<00:00,  4.52it/s]


validation accuracy: 0.9524238782051282
Valid Error 4.758 %
Epoch 5/10


100%|██████████| 390/390 [00:24<00:00, 16.19it/s]
100%|██████████| 78/78 [00:17<00:00,  4.51it/s]


validation accuracy: 0.9566306089743589
Valid Error 4.337 %
Epoch 6/10


100%|██████████| 390/390 [00:23<00:00, 16.68it/s]
100%|██████████| 78/78 [00:16<00:00,  4.64it/s]


validation accuracy: 0.9593349358974359
Valid Error 4.067 %
Epoch 7/10


100%|██████████| 390/390 [00:23<00:00, 16.87it/s]
100%|██████████| 78/78 [00:16<00:00,  4.64it/s]


validation accuracy: 0.9589342948717948
Valid Error 4.107 %
Epoch 8/10


100%|██████████| 390/390 [00:22<00:00, 16.99it/s]
100%|██████████| 78/78 [00:19<00:00,  3.95it/s]


validation accuracy: 0.9613381410256411
Valid Error 3.866 %
Epoch 9/10


100%|██████████| 390/390 [00:22<00:00, 16.99it/s]
100%|██████████| 78/78 [00:16<00:00,  4.63it/s]


validation accuracy: 0.9626402243589743
Valid Error 3.736 %
Epoch 10/10


100%|██████████| 390/390 [00:24<00:00, 16.05it/s]
100%|██████████| 78/78 [00:16<00:00,  4.64it/s]


validation accuracy: 0.964042467948718
Valid Error 3.596 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 2, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, 0, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:23<00:00, 16.88it/s]
100%|██████████| 78/78 [00:16<00:00,  4.72it/s]


validation accuracy: 0.9137620192307693
Valid Error 8.624 %
Epoch 2/10


100%|██████████| 390/390 [00:23<00:00, 16.95it/s]
100%|██████████| 78/78 [00:16<00:00,  4.69it/s]


validation accuracy: 0.9328926282051282
Valid Error 6.711 %
Epoch 3/10


100%|██████████| 390/390 [00:22<00:00, 17.05it/s]
100%|██████████| 78/78 [00:16<00:00,  4.71it/s]


validation accuracy: 0.9444110576923077
Valid Error 5.559 %
Epoch 4/10


100%|██████████| 390/390 [00:22<00:00, 17.02it/s]
100%|██████████| 78/78 [00:16<00:00,  4.68it/s]


validation accuracy: 0.9471153846153846
Valid Error 5.288 %
Epoch 5/10


100%|██████████| 390/390 [00:24<00:00, 16.17it/s]
100%|██████████| 78/78 [00:16<00:00,  4.74it/s]


validation accuracy: 0.9548277243589743
Valid Error 4.517 %
Epoch 6/10


100%|██████████| 390/390 [00:22<00:00, 16.98it/s]
100%|██████████| 78/78 [00:16<00:00,  4.72it/s]


validation accuracy: 0.9581330128205128
Valid Error 4.187 %
Epoch 7/10


100%|██████████| 390/390 [00:23<00:00, 16.64it/s]
100%|██████████| 78/78 [00:16<00:00,  4.67it/s]


validation accuracy: 0.9604366987179487
Valid Error 3.956 %
Epoch 8/10


100%|██████████| 390/390 [00:24<00:00, 16.20it/s]
100%|██████████| 78/78 [00:18<00:00,  4.19it/s]


validation accuracy: 0.9600360576923077
Valid Error 3.996 %
Epoch 9/10


100%|██████████| 390/390 [00:24<00:00, 16.21it/s]
100%|██████████| 78/78 [00:17<00:00,  4.44it/s]


validation accuracy: 0.9647435897435898
Valid Error 3.526 %
Epoch 10/10


100%|██████████| 390/390 [00:23<00:00, 16.39it/s]
100%|██████████| 78/78 [00:16<00:00,  4.60it/s]


validation accuracy: 0.9629407051282052
Valid Error 3.706 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 2, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, 0, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:23<00:00, 16.54it/s]
100%|██████████| 78/78 [00:17<00:00,  4.41it/s]


validation accuracy: 0.9160657051282052
Valid Error 8.393 %
Epoch 2/10


100%|██████████| 390/390 [00:22<00:00, 16.99it/s]
100%|██████████| 78/78 [00:16<00:00,  4.62it/s]


validation accuracy: 0.9340945512820513
Valid Error 6.591 %
Epoch 3/10


100%|██████████| 390/390 [00:24<00:00, 16.23it/s]
100%|██████████| 78/78 [00:16<00:00,  4.80it/s]


validation accuracy: 0.9468149038461539
Valid Error 5.319 %
Epoch 4/10


100%|██████████| 390/390 [00:23<00:00, 16.77it/s]
100%|██████████| 78/78 [00:16<00:00,  4.76it/s]


validation accuracy: 0.9501201923076923
Valid Error 4.988 %
Epoch 5/10


100%|██████████| 390/390 [00:23<00:00, 16.91it/s]
100%|██████████| 78/78 [00:16<00:00,  4.63it/s]


validation accuracy: 0.9563301282051282
Valid Error 4.367 %
Epoch 6/10


100%|██████████| 390/390 [00:22<00:00, 17.14it/s]
100%|██████████| 78/78 [00:18<00:00,  4.32it/s]


validation accuracy: 0.9597355769230769
Valid Error 4.026 %
Epoch 7/10


100%|██████████| 390/390 [00:22<00:00, 17.22it/s]
100%|██████████| 78/78 [00:16<00:00,  4.76it/s]


validation accuracy: 0.9628405448717948
Valid Error 3.716 %
Epoch 8/10


100%|██████████| 390/390 [00:23<00:00, 16.84it/s]
100%|██████████| 78/78 [00:16<00:00,  4.74it/s]


validation accuracy: 0.9635416666666666
Valid Error 3.646 %
Epoch 9/10


100%|██████████| 390/390 [00:22<00:00, 17.10it/s]
100%|██████████| 78/78 [00:16<00:00,  4.63it/s]


validation accuracy: 0.9642427884615384
Valid Error 3.576 %
Epoch 10/10


100%|██████████| 390/390 [00:23<00:00, 16.78it/s]
100%|██████████| 78/78 [00:17<00:00,  4.58it/s]


validation accuracy: 0.96484375
Valid Error 3.516 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 2, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, -1, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:22<00:00, 17.06it/s]
100%|██████████| 78/78 [00:17<00:00,  4.40it/s]


validation accuracy: 0.9202724358974359
Valid Error 7.973 %
Epoch 2/10


100%|██████████| 390/390 [00:23<00:00, 16.68it/s]
100%|██████████| 78/78 [00:16<00:00,  4.72it/s]


validation accuracy: 0.9390024038461539
Valid Error 6.1 %
Epoch 3/10


100%|██████████| 390/390 [00:24<00:00, 15.88it/s]
100%|██████████| 78/78 [00:17<00:00,  4.57it/s]


validation accuracy: 0.9501201923076923
Valid Error 4.988 %
Epoch 4/10


100%|██████████| 390/390 [00:23<00:00, 16.91it/s]
100%|██████████| 78/78 [00:17<00:00,  4.40it/s]


validation accuracy: 0.9518229166666666
Valid Error 4.818 %
Epoch 5/10


100%|██████████| 390/390 [00:23<00:00, 16.85it/s]
100%|██████████| 78/78 [00:16<00:00,  4.70it/s]


validation accuracy: 0.9576322115384616
Valid Error 4.237 %
Epoch 6/10


100%|██████████| 390/390 [00:22<00:00, 17.06it/s]
100%|██████████| 78/78 [00:17<00:00,  4.56it/s]


validation accuracy: 0.9597355769230769
Valid Error 4.026 %
Epoch 7/10


100%|██████████| 390/390 [00:22<00:00, 16.97it/s]
100%|██████████| 78/78 [00:17<00:00,  4.44it/s]


validation accuracy: 0.9598357371794872
Valid Error 4.016 %
Epoch 8/10


100%|██████████| 390/390 [00:24<00:00, 16.14it/s]
100%|██████████| 78/78 [00:17<00:00,  4.52it/s]


validation accuracy: 0.9624399038461539
Valid Error 3.756 %
Epoch 9/10


100%|██████████| 390/390 [00:23<00:00, 16.95it/s]
100%|██████████| 78/78 [00:17<00:00,  4.35it/s]


validation accuracy: 0.9636418269230769
Valid Error 3.636 %
Epoch 10/10


100%|██████████| 390/390 [00:22<00:00, 17.23it/s]
100%|██████████| 78/78 [00:17<00:00,  4.58it/s]


validation accuracy: 0.9664463141025641
Valid Error 3.355 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 2, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, -1, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:22<00:00, 17.04it/s]
100%|██████████| 78/78 [00:16<00:00,  4.63it/s]


validation accuracy: 0.9160657051282052
Valid Error 8.393 %
Epoch 2/10


100%|██████████| 390/390 [00:23<00:00, 16.55it/s]
100%|██████████| 78/78 [00:16<00:00,  4.70it/s]


validation accuracy: 0.9387019230769231
Valid Error 6.13 %
Epoch 3/10


100%|██████████| 390/390 [00:23<00:00, 16.68it/s]
100%|██████████| 78/78 [00:16<00:00,  4.68it/s]


validation accuracy: 0.9464142628205128
Valid Error 5.359 %
Epoch 4/10


100%|██████████| 390/390 [00:23<00:00, 16.76it/s]
100%|██████████| 78/78 [00:18<00:00,  4.26it/s]


validation accuracy: 0.9521233974358975
Valid Error 4.788 %
Epoch 5/10


100%|██████████| 390/390 [00:23<00:00, 16.84it/s]
100%|██████████| 78/78 [00:16<00:00,  4.60it/s]


validation accuracy: 0.9556290064102564
Valid Error 4.437 %
Epoch 6/10


100%|██████████| 390/390 [00:24<00:00, 16.11it/s]
100%|██████████| 78/78 [00:16<00:00,  4.70it/s]


validation accuracy: 0.9603365384615384
Valid Error 3.966 %
Epoch 7/10


100%|██████████| 390/390 [00:23<00:00, 16.84it/s]
100%|██████████| 78/78 [00:16<00:00,  4.62it/s]


validation accuracy: 0.9637419871794872
Valid Error 3.626 %
Epoch 8/10


100%|██████████| 390/390 [00:24<00:00, 15.99it/s]
100%|██████████| 78/78 [00:17<00:00,  4.45it/s]


validation accuracy: 0.9619391025641025
Valid Error 3.806 %
Epoch 9/10


100%|██████████| 390/390 [00:24<00:00, 16.15it/s]
100%|██████████| 78/78 [00:17<00:00,  4.41it/s]


validation accuracy: 0.9620392628205128
Valid Error 3.796 %
Epoch 10/10


100%|██████████| 390/390 [00:23<00:00, 16.67it/s]
100%|██████████| 78/78 [00:16<00:00,  4.65it/s]


validation accuracy: 0.9655448717948718
Valid Error 3.446 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 5, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, 0, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:49<00:00,  7.92it/s]
100%|██████████| 78/78 [00:16<00:00,  4.80it/s]


validation accuracy: 0.9092548076923077
Valid Error 9.075 %
Epoch 2/10


100%|██████████| 390/390 [00:48<00:00,  8.05it/s]
100%|██████████| 78/78 [00:17<00:00,  4.52it/s]


validation accuracy: 0.9326923076923077
Valid Error 6.731 %
Epoch 3/10


100%|██████████| 390/390 [00:47<00:00,  8.26it/s]
100%|██████████| 78/78 [00:16<00:00,  4.76it/s]


validation accuracy: 0.9469150641025641
Valid Error 5.308 %
Epoch 4/10


100%|██████████| 390/390 [00:47<00:00,  8.27it/s]
100%|██████████| 78/78 [00:17<00:00,  4.49it/s]


validation accuracy: 0.9528245192307693
Valid Error 4.718 %
Epoch 5/10


100%|██████████| 390/390 [00:47<00:00,  8.21it/s]
100%|██████████| 78/78 [00:16<00:00,  4.71it/s]


validation accuracy: 0.9584334935897436
Valid Error 4.157 %
Epoch 6/10


100%|██████████| 390/390 [00:48<00:00,  8.02it/s]
100%|██████████| 78/78 [00:16<00:00,  4.72it/s]


validation accuracy: 0.9608373397435898
Valid Error 3.916 %
Epoch 7/10


100%|██████████| 390/390 [00:47<00:00,  8.22it/s]
100%|██████████| 78/78 [00:17<00:00,  4.51it/s]


validation accuracy: 0.9650440705128205
Valid Error 3.496 %
Epoch 8/10


100%|██████████| 390/390 [00:47<00:00,  8.15it/s]
100%|██████████| 78/78 [00:16<00:00,  4.71it/s]


validation accuracy: 0.9657451923076923
Valid Error 3.425 %
Epoch 9/10


100%|██████████| 390/390 [00:47<00:00,  8.18it/s]
100%|██████████| 78/78 [00:16<00:00,  4.63it/s]


validation accuracy: 0.9669471153846154
Valid Error 3.305 %
Epoch 10/10


100%|██████████| 390/390 [00:48<00:00,  8.08it/s]
100%|██████████| 78/78 [00:16<00:00,  4.67it/s]


validation accuracy: 0.9688501602564102
Valid Error 3.115 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 5, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, 0, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:47<00:00,  8.25it/s]
100%|██████████| 78/78 [00:16<00:00,  4.67it/s]


validation accuracy: 0.9172676282051282
Valid Error 8.273 %
Epoch 2/10


100%|██████████| 390/390 [00:47<00:00,  8.18it/s]
100%|██████████| 78/78 [00:16<00:00,  4.65it/s]


validation accuracy: 0.9413060897435898
Valid Error 5.869 %
Epoch 3/10


100%|██████████| 390/390 [00:49<00:00,  7.88it/s]
100%|██████████| 78/78 [00:17<00:00,  4.59it/s]


validation accuracy: 0.9508213141025641
Valid Error 4.918 %
Epoch 4/10


100%|██████████| 390/390 [00:48<00:00,  8.09it/s]
100%|██████████| 78/78 [00:16<00:00,  4.74it/s]


validation accuracy: 0.9554286858974359
Valid Error 4.457 %
Epoch 5/10


100%|██████████| 390/390 [00:49<00:00,  7.89it/s]
100%|██████████| 78/78 [00:16<00:00,  4.76it/s]


validation accuracy: 0.9588341346153846
Valid Error 4.117 %
Epoch 6/10


100%|██████████| 390/390 [00:47<00:00,  8.14it/s]
100%|██████████| 78/78 [00:16<00:00,  4.67it/s]


validation accuracy: 0.9628405448717948
Valid Error 3.716 %
Epoch 7/10


100%|██████████| 390/390 [00:48<00:00,  8.12it/s]
100%|██████████| 78/78 [00:16<00:00,  4.73it/s]


validation accuracy: 0.9638421474358975
Valid Error 3.616 %
Epoch 8/10


100%|██████████| 390/390 [00:48<00:00,  8.12it/s]
100%|██████████| 78/78 [00:16<00:00,  4.77it/s]


validation accuracy: 0.9665464743589743
Valid Error 3.345 %
Epoch 9/10


100%|██████████| 390/390 [00:48<00:00,  8.08it/s]
100%|██████████| 78/78 [00:16<00:00,  4.60it/s]


validation accuracy: 0.9668469551282052
Valid Error 3.315 %
Epoch 10/10


100%|██████████| 390/390 [00:48<00:00,  8.08it/s]
100%|██████████| 78/78 [00:16<00:00,  4.78it/s]


validation accuracy: 0.9689503205128205
Valid Error 3.105 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 5, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, -1, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:49<00:00,  7.93it/s]
100%|██████████| 78/78 [00:17<00:00,  4.57it/s]


validation accuracy: 0.9173677884615384
Valid Error 8.263 %
Epoch 2/10


100%|██████████| 390/390 [00:48<00:00,  8.12it/s]
100%|██████████| 78/78 [00:16<00:00,  4.65it/s]


validation accuracy: 0.9415064102564102
Valid Error 5.849 %
Epoch 3/10


100%|██████████| 390/390 [00:47<00:00,  8.15it/s]
100%|██████████| 78/78 [00:17<00:00,  4.50it/s]


validation accuracy: 0.9504206730769231
Valid Error 4.958 %
Epoch 4/10


100%|██████████| 390/390 [00:47<00:00,  8.21it/s]
100%|██████████| 78/78 [00:18<00:00,  4.27it/s]


validation accuracy: 0.9571314102564102
Valid Error 4.287 %
Epoch 5/10


100%|██████████| 390/390 [00:47<00:00,  8.13it/s]
100%|██████████| 78/78 [00:17<00:00,  4.49it/s]


validation accuracy: 0.957832532051282
Valid Error 4.217 %
Epoch 6/10


100%|██████████| 390/390 [00:48<00:00,  7.98it/s]
100%|██████████| 78/78 [00:17<00:00,  4.40it/s]


validation accuracy: 0.9610376602564102
Valid Error 3.896 %
Epoch 7/10


100%|██████████| 390/390 [00:48<00:00,  8.09it/s]
100%|██████████| 78/78 [00:19<00:00,  3.97it/s]


validation accuracy: 0.9623397435897436
Valid Error 3.766 %
Epoch 8/10


100%|██████████| 390/390 [00:47<00:00,  8.25it/s]
100%|██████████| 78/78 [00:18<00:00,  4.29it/s]


validation accuracy: 0.96484375
Valid Error 3.516 %
Epoch 9/10


100%|██████████| 390/390 [00:47<00:00,  8.19it/s]
100%|██████████| 78/78 [00:17<00:00,  4.51it/s]


validation accuracy: 0.9650440705128205
Valid Error 3.496 %
Epoch 10/10


100%|██████████| 390/390 [00:47<00:00,  8.22it/s]
100%|██████████| 78/78 [00:16<00:00,  4.63it/s]


validation accuracy: 0.9676482371794872
Valid Error 3.235 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 5, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.25, -1, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:47<00:00,  8.20it/s]
100%|██████████| 78/78 [00:18<00:00,  4.19it/s]


validation accuracy: 0.9157652243589743
Valid Error 8.423 %
Epoch 2/10


100%|██████████| 390/390 [00:47<00:00,  8.14it/s]
100%|██████████| 78/78 [00:16<00:00,  4.68it/s]


validation accuracy: 0.9368990384615384
Valid Error 6.31 %
Epoch 3/10


100%|██████████| 390/390 [00:47<00:00,  8.17it/s]
100%|██████████| 78/78 [00:16<00:00,  4.71it/s]


validation accuracy: 0.9490184294871795
Valid Error 5.098 %
Epoch 4/10


100%|██████████| 390/390 [00:47<00:00,  8.14it/s]
100%|██████████| 78/78 [00:16<00:00,  4.64it/s]


validation accuracy: 0.9540264423076923
Valid Error 4.597 %
Epoch 5/10


100%|██████████| 390/390 [00:48<00:00,  8.08it/s]
100%|██████████| 78/78 [00:16<00:00,  4.71it/s]


validation accuracy: 0.9594350961538461
Valid Error 4.056 %
Epoch 6/10


100%|██████████| 390/390 [00:48<00:00,  8.04it/s]
100%|██████████| 78/78 [00:16<00:00,  4.78it/s]


validation accuracy: 0.9606370192307693
Valid Error 3.936 %
Epoch 7/10


100%|██████████| 390/390 [00:49<00:00,  7.87it/s]
100%|██████████| 78/78 [00:16<00:00,  4.78it/s]


validation accuracy: 0.9619391025641025
Valid Error 3.806 %
Epoch 8/10


100%|██████████| 390/390 [00:48<00:00,  8.07it/s]
100%|██████████| 78/78 [00:16<00:00,  4.70it/s]


validation accuracy: 0.9658453525641025
Valid Error 3.415 %
Epoch 9/10


100%|██████████| 390/390 [00:48<00:00,  8.01it/s]
100%|██████████| 78/78 [00:16<00:00,  4.69it/s]


validation accuracy: 0.9658453525641025
Valid Error 3.415 %
Epoch 10/10


100%|██████████| 390/390 [00:49<00:00,  7.84it/s]
100%|██████████| 78/78 [00:16<00:00,  4.68it/s]


validation accuracy: 0.9693509615384616
Valid Error 3.065 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 5, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, 0, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:49<00:00,  7.85it/s]
100%|██████████| 78/78 [00:16<00:00,  4.70it/s]


validation accuracy: 0.9114583333333334
Valid Error 8.854 %
Epoch 2/10


100%|██████████| 390/390 [00:49<00:00,  7.92it/s]
100%|██████████| 78/78 [00:16<00:00,  4.59it/s]


validation accuracy: 0.9372996794871795
Valid Error 6.27 %
Epoch 3/10


100%|██████████| 390/390 [00:49<00:00,  7.90it/s]
100%|██████████| 78/78 [00:16<00:00,  4.60it/s]


validation accuracy: 0.9478165064102564
Valid Error 5.218 %
Epoch 4/10


100%|██████████| 390/390 [00:49<00:00,  7.95it/s]
100%|██████████| 78/78 [00:16<00:00,  4.64it/s]


validation accuracy: 0.9527243589743589
Valid Error 4.728 %
Epoch 5/10


100%|██████████| 390/390 [00:48<00:00,  8.02it/s]
100%|██████████| 78/78 [00:17<00:00,  4.54it/s]


validation accuracy: 0.960136217948718
Valid Error 3.986 %
Epoch 6/10


100%|██████████| 390/390 [00:48<00:00,  8.05it/s]
100%|██████████| 78/78 [00:16<00:00,  4.63it/s]


validation accuracy: 0.9589342948717948
Valid Error 4.107 %
Epoch 7/10


100%|██████████| 390/390 [00:48<00:00,  8.10it/s]
100%|██████████| 78/78 [00:16<00:00,  4.64it/s]


validation accuracy: 0.9626402243589743
Valid Error 3.736 %
Epoch 8/10


100%|██████████| 390/390 [00:48<00:00,  8.02it/s]
100%|██████████| 78/78 [00:16<00:00,  4.70it/s]


validation accuracy: 0.9664463141025641
Valid Error 3.355 %
Epoch 9/10


100%|██████████| 390/390 [00:49<00:00,  7.82it/s]
100%|██████████| 78/78 [00:17<00:00,  4.50it/s]


validation accuracy: 0.9666466346153846
Valid Error 3.335 %
Epoch 10/10


100%|██████████| 390/390 [00:49<00:00,  7.82it/s]
100%|██████████| 78/78 [00:17<00:00,  4.55it/s]


validation accuracy: 0.9681490384615384
Valid Error 3.185 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 5, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, 0, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:49<00:00,  7.95it/s]
100%|██████████| 78/78 [00:16<00:00,  4.74it/s]


validation accuracy: 0.9173677884615384
Valid Error 8.263 %
Epoch 2/10


100%|██████████| 390/390 [00:48<00:00,  8.05it/s]
100%|██████████| 78/78 [00:16<00:00,  4.70it/s]


validation accuracy: 0.9388020833333334
Valid Error 6.12 %
Epoch 3/10


100%|██████████| 390/390 [00:49<00:00,  7.92it/s]
100%|██████████| 78/78 [00:16<00:00,  4.69it/s]


validation accuracy: 0.9450120192307693
Valid Error 5.499 %
Epoch 4/10


100%|██████████| 390/390 [00:48<00:00,  7.97it/s]
100%|██████████| 78/78 [00:16<00:00,  4.63it/s]


validation accuracy: 0.9511217948717948
Valid Error 4.888 %
Epoch 5/10


100%|██████████| 390/390 [00:47<00:00,  8.21it/s]
100%|██████████| 78/78 [00:16<00:00,  4.69it/s]


validation accuracy: 0.9559294871794872
Valid Error 4.407 %
Epoch 6/10


100%|██████████| 390/390 [00:48<00:00,  8.05it/s]
100%|██████████| 78/78 [00:16<00:00,  4.77it/s]


validation accuracy: 0.9600360576923077
Valid Error 3.996 %
Epoch 7/10


100%|██████████| 390/390 [00:47<00:00,  8.16it/s]
100%|██████████| 78/78 [00:16<00:00,  4.70it/s]


validation accuracy: 0.9609375
Valid Error 3.906 %
Epoch 8/10


100%|██████████| 390/390 [00:46<00:00,  8.44it/s]
100%|██████████| 78/78 [00:16<00:00,  4.74it/s]


validation accuracy: 0.9637419871794872
Valid Error 3.626 %
Epoch 9/10


100%|██████████| 390/390 [00:46<00:00,  8.33it/s]
100%|██████████| 78/78 [00:16<00:00,  4.67it/s]


validation accuracy: 0.9614383012820513
Valid Error 3.856 %
Epoch 10/10


100%|██████████| 390/390 [00:46<00:00,  8.34it/s]
100%|██████████| 78/78 [00:17<00:00,  4.44it/s]


validation accuracy: 0.9646434294871795
Valid Error 3.536 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 5, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, -1, -6]}
Epoch 1/10


100%|██████████| 390/390 [00:46<00:00,  8.31it/s]
100%|██████████| 78/78 [00:16<00:00,  4.62it/s]


validation accuracy: 0.9197716346153846
Valid Error 8.023 %
Epoch 2/10


100%|██████████| 390/390 [00:46<00:00,  8.39it/s]
100%|██████████| 78/78 [00:16<00:00,  4.68it/s]


validation accuracy: 0.9318910256410257
Valid Error 6.811 %
Epoch 3/10


100%|██████████| 390/390 [00:47<00:00,  8.27it/s]
100%|██████████| 78/78 [00:16<00:00,  4.73it/s]


validation accuracy: 0.950020032051282
Valid Error 4.998 %
Epoch 4/10


100%|██████████| 390/390 [00:48<00:00,  8.07it/s]
100%|██████████| 78/78 [00:16<00:00,  4.70it/s]


validation accuracy: 0.9525240384615384
Valid Error 4.748 %
Epoch 5/10


100%|██████████| 390/390 [00:47<00:00,  8.24it/s]
100%|██████████| 78/78 [00:16<00:00,  4.70it/s]


validation accuracy: 0.9577323717948718
Valid Error 4.227 %
Epoch 6/10


100%|██████████| 390/390 [00:48<00:00,  8.01it/s]
100%|██████████| 78/78 [00:16<00:00,  4.67it/s]


validation accuracy: 0.9603365384615384
Valid Error 3.966 %
Epoch 7/10


100%|██████████| 390/390 [00:48<00:00,  8.01it/s]
100%|██████████| 78/78 [00:16<00:00,  4.68it/s]


validation accuracy: 0.9629407051282052
Valid Error 3.706 %
Epoch 8/10


100%|██████████| 390/390 [00:47<00:00,  8.21it/s]
100%|██████████| 78/78 [00:16<00:00,  4.75it/s]


validation accuracy: 0.9662459935897436
Valid Error 3.375 %
Epoch 9/10


100%|██████████| 390/390 [00:48<00:00,  8.03it/s]
100%|██████████| 78/78 [00:16<00:00,  4.72it/s]


validation accuracy: 0.9638421474358975
Valid Error 3.616 %
Epoch 10/10


100%|██████████| 390/390 [00:48<00:00,  7.99it/s]
100%|██████████| 78/78 [00:16<00:00,  4.81it/s]


validation accuracy: 0.9659455128205128
Valid Error 3.405 %
Trying config: {'batch_size': 128, 'lr': 0.0001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 5, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, -1, -7]}
Epoch 1/10


100%|██████████| 390/390 [00:46<00:00,  8.31it/s]
100%|██████████| 78/78 [00:16<00:00,  4.78it/s]


validation accuracy: 0.9110576923076923
Valid Error 8.894 %
Epoch 2/10


100%|██████████| 390/390 [00:47<00:00,  8.28it/s]
100%|██████████| 78/78 [00:16<00:00,  4.81it/s]


validation accuracy: 0.9385016025641025
Valid Error 6.15 %
Epoch 3/10


100%|██████████| 390/390 [00:47<00:00,  8.14it/s]
100%|██████████| 78/78 [00:16<00:00,  4.75it/s]


validation accuracy: 0.9459134615384616
Valid Error 5.409 %
Epoch 4/10


100%|██████████| 390/390 [00:48<00:00,  7.96it/s]
100%|██████████| 78/78 [00:16<00:00,  4.77it/s]


validation accuracy: 0.9521233974358975
Valid Error 4.788 %
Epoch 5/10


100%|██████████| 390/390 [00:47<00:00,  8.16it/s]
100%|██████████| 78/78 [00:16<00:00,  4.83it/s]


validation accuracy: 0.9598357371794872
Valid Error 4.016 %
Epoch 6/10


100%|██████████| 390/390 [00:48<00:00,  8.12it/s]
100%|██████████| 78/78 [00:16<00:00,  4.76it/s]


validation accuracy: 0.9588341346153846
Valid Error 4.117 %
Epoch 7/10


100%|██████████| 390/390 [00:47<00:00,  8.14it/s]
100%|██████████| 78/78 [00:16<00:00,  4.75it/s]


validation accuracy: 0.9635416666666666
Valid Error 3.646 %
Epoch 8/10


100%|██████████| 390/390 [00:47<00:00,  8.21it/s]
100%|██████████| 78/78 [00:19<00:00,  3.92it/s]


validation accuracy: 0.9660456730769231
Valid Error 3.395 %
Epoch 9/10


100%|██████████| 390/390 [00:46<00:00,  8.33it/s]
100%|██████████| 78/78 [00:17<00:00,  4.47it/s]


validation accuracy: 0.9646434294871795
Valid Error 3.536 %
Epoch 10/10


100%|██████████| 390/390 [00:47<00:00,  8.18it/s]
100%|██████████| 78/78 [00:16<00:00,  4.74it/s]

validation accuracy: 0.9691506410256411
Valid Error 3.085 %
Best Config:
{'batch_size': 128, 'lr': 0.001, 'epochs': 10, 'hidden_units': 1200, 'experiment': 'classification', 'dropout': False, 'train_samples': 1, 'test_samples': 10, 'x_shape': 784, 'classes': 10, 'mu_init': [-0.2, 0.2], 'rho_init': [-5, -4], 'prior_init': [0.75, 0, -7]}
Best Validation Accuracy: 0.9767





In [13]:
# from tqdm import tqdm
# import numpy as np
# import torch
# from torch.utils.data import SubsetRandomSampler
# from torchvision import datasets, transforms
# from itertools import product

# def class_trainer(config):

#     transform = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Lambda(lambda x: x * 255. / 126.),
#     ])

#     train_data = datasets.MNIST(
#         root='data',
#         train=True,
#         download=True,
#         transform=transform)
#     test_data = datasets.MNIST(
#         root='data',
#         train=False,
#         download=True,
#         transform=transform)

#     valid_size = 1 / 6
#     num_train = len(train_data)
#     indices = list(range(num_train))
#     split = int(valid_size * num_train)
#     train_idx, valid_idx = indices[split:], indices[:split]

#     train_sampler = SubsetRandomSampler(train_idx)
#     valid_sampler = SubsetRandomSampler(valid_idx)

#     train_loader = torch.utils.data.DataLoader(
#         train_data,
#         batch_size=config['batch_size'],
#         sampler=train_sampler,
#         drop_last=True)
#     valid_loader = torch.utils.data.DataLoader(
#         train_data,
#         batch_size=config['batch_size'],
#         sampler=valid_sampler,
#         drop_last=True)
#     test_loader = torch.utils.data.DataLoader(
#         test_data,
#         batch_size=config['batch_size'],
#         shuffle=False,
#         drop_last=True)

#     params = {
#         'lr': config['lr'],
#         'hidden_units': config['hidden_units'],
#         'experiment': config['experiment'],
#         'batch_size': config['batch_size'],
#         'epochs': config['epochs'],
#         'x_shape': config['x_shape'],
#         'classes': config['classes'],
#         'num_batches': len(train_loader),
#         'train_samples': config['train_samples'],
#         'test_samples': config['test_samples'],
#         'mu_init': config['mu_init'],
#         'rho_init': config['rho_init'],
#         'prior_init': config['prior_init'],
#     }

#     model = BNN_Classification('bnn_classification', {**params, 'dropout': False})

#     best_val_acc = 0
#     for epoch in range(config['epochs']):
#         print(f'Epoch {epoch + 1}/{config["epochs"]}')
#         model.train_step(train_loader)
#         valid_acc = model.evaluate(valid_loader)
#         print('Valid Error', round(100 * (1 - valid_acc), 3), '%')
#         model.scheduler.step()
#         if model.acc > model.best_acc:
#             model.best_acc = model.acc
#             best_val_acc = valid_acc
#             # torch.save(model.net.state_dict(), "best_model.pt")

#     return best_val_acc


In [14]:
# def run_grid_search():
#     # Define the grid
#     param_grid = {
#         'lr': [1e-3, 1e-4],
#         'batch_size': [64, 128],
#         'hidden_units': [128, 256],
#     }

#     # Fixed config options
#     base_config = {
#         'experiment': 'grid_search',
#         'epochs': 5,
#         'x_shape': (1, 28, 28),
#         'classes': 10,
#         'train_samples': 60000,
#         'test_samples': 10000,
#         'mu_init': 0,
#         'rho_init': -3,
#         'prior_init': 0,
#     }

#     best_acc = 0
#     best_params = None

#     # Iterate over all combinations
#     for lr, batch_size, hidden_units in product(param_grid['lr'], param_grid['batch_size'], param_grid['hidden_units']):
#         config = {
#             **base_config,
#             'lr': lr,
#             'batch_size': batch_size,
#             'hidden_units': hidden_units
#         }

#         print(f"\nRunning with config: {config}")
#         acc = class_trainer(config)

#         if acc > best_acc:
#             best_acc = acc
#             best_params = config

#     print(f"Best Validation Accuracy: {best_acc:.4f}")
#     print(f"Best Config: {best_params}")

# run_grid_search()

In [15]:
class BNN_Regression():
    def __init__(self, label, parameters):
        super().__init__()
        self.label = label
        self.batch_size = parameters['batch_size']
        self.num_batches = parameters['num_batches']
        self.n_samples = parameters['train_samples']
        self.test_samples = parameters['test_samples']
        self.x_shape = parameters['x_shape']
        self.y_shape = parameters['y_shape']
        self.noise_tol = parameters['noise_tolerance']
        self.lr = parameters['lr']
        self.best_loss = np.inf
        self.init_net(parameters)
    
    def init_net(self, parameters):
        model_params = {
            'input_shape': self.x_shape,
            'classes': self.y_shape,
            'batch_size': self.batch_size,
            'hidden_units': parameters['hidden_units'],
            'experiment': parameters['experiment'],
            'mu_init': parameters['mu_init'],
            'rho_init': parameters['rho_init'],
            'prior_init': parameters['prior_init']
        }
        self.net = BayesianNetwork(model_params).to(device)
        self.optimiser = torch.optim.Adam(self.net.parameters(), lr=self.lr)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimiser, step_size=500, gamma=0.5)
        # print(f'Regression Task {self.label} Parameters: ')
        # print(f'number of samples: {self.n_samples}, noise tolerance: {self.noise_tol}')
        print("BNN Parameters: ")
        print(f'batch size: {self.batch_size}, x shape: {model_params["input_shape"]}, hidden units: {model_params["hidden_units"]}, output shape: {model_params["classes"]}, mu_init: {parameters["mu_init"]}, rho_init: {parameters["rho_init"]}, prior_init: {parameters["prior_init"]}, lr: {self.lr}')

    def train_step(self, train_data):
        self.net.train()
        for idx, (x, y) in enumerate(train_data):
            beta = 2 ** (self.num_batches - (idx + 1)) / (2 ** self.num_batches - 1) 
            x, y = x.to(device), y.to(device)
            self.net.zero_grad()
            self.loss_info = self.net.sample_elbo(x, y, beta, self.n_samples, sigma=self.noise_tol)
            net_loss = self.loss_info[0]
            net_loss.backward()
            self.optimiser.step()
        self.epoch_loss = net_loss.item()

    def evaluate(self, x_test):
        self.net.eval()
        with torch.no_grad():
            y_test = np.zeros((self.test_samples, x_test.shape[0]))
            for s in range(self.test_samples):
                tmp = self.net(x_test.to(device)).detach().cpu().numpy()
                y_test[s,:] = tmp.reshape(-1)
            return y_test

In [16]:
class MLP_Regression():
    def __init__(self, label, parameters):
        super().__init__()
        self.label = label
        self.lr = parameters['lr']
        self.hidden_units = parameters['hidden_units']
        self.experiment = parameters['experiment']
        self.batch_size = parameters['batch_size']
        self.num_batches = parameters['num_batches']
        self.x_shape = parameters['x_shape']
        self.y_shape = parameters['y_shape']
        self.best_loss = np.inf
        self.init_net(parameters)
    
    def init_net(self, parameters):
        model_params = {
            'input_shape': self.x_shape,
            'classes': self.y_shape,
            'batch_size': self.batch_size,
            'hidden_units': self.hidden_units,
            'experiment': self.experiment
        }
        self.net = MLP(model_params).to(device)
        self.optimiser = torch.optim.Adam(self.net.parameters(), lr=self.lr)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimiser, step_size=5000, gamma=0.5)
        print("MLP Parameters: ")
        print(f'batch size: {self.batch_size}, input shape: {model_params["input_shape"]}, hidden units: {model_params["hidden_units"]}, output shape: {model_params["classes"]}, lr: {self.lr}')

    def train_step(self, train_data):
        self.net.train()
        for _, (x, y) in enumerate(train_data):
            x, y = x.to(device), y.to(device)
            self.net.zero_grad()
            self.loss_info = torch.nn.functional.mse_loss(self.net(x), y, reduction='sum')
            self.loss_info.backward()
            self.optimiser.step()

        self.epoch_loss = self.loss_info.item()

    def evaluate(self, x_test):
        self.net.eval()
        with torch.no_grad():
            y_test = self.net(x_test.to(device)).detach().cpu().numpy()
            return y_test

In [17]:
def reg_trainer():
    config = RegConfig
    X, Y = create_data_reg(train_size=config.train_size)
    train_loader = PrepareData(X, Y)
    train_loader = DataLoader(train_loader, batch_size=config.batch_size, shuffle=True)

    params = {
        'lr': config.lr,
        'hidden_units': config.hidden_units,
        'experiment': config.experiment,
        'batch_size': config.batch_size,
        'num_batches': len(train_loader),
        'x_shape': X.shape[1],
        'y_shape': Y.shape[1],
        'train_samples': config.train_samples,
        'test_samples': config.test_samples,
        'noise_tolerance': config.noise_tolerance,
        'mu_init': config.mu_init,
        'rho_init': config.rho_init,
        'prior_init': config.prior_init,
    }

    model = BNN_Regression('bnn_regression', {**params})
    #model = MLP_Regression('mlp_regression', {**params})

    epochs = config.epochs
    print(f"Initialising training on {device}...")

    # training loop
    for epoch in tqdm(range(epochs)):
    
        model.train_step(train_loader)
        model.scheduler.step()
        # save best model
        if model.epoch_loss < model.best_loss:
            model.best_loss = model.epoch_loss
            # torch.save(model.net.state_dict(), model.save_model_path)

    # evaluate
    print("Evaluating and generating plots...")
    x_test = torch.linspace(-2., 2, config.num_test_points).reshape(-1, 1)
    
    # model.net.load_state_dict(torch.load(model.save_model_path, map_location=torch.device(device)))
    y_test = model.evaluate(x_test)
    
    #create_regression_plot(x_test.cpu().numpy(), y_test.reshape(1, -1), train_loader) #per mlp regression
   
    create_regression_plot(x_test.cpu().numpy(), y_test, train_loader)


#reg_trainer()

In [18]:
class Bandit():
    def __init__(self, label, bandit_params, x, y):
        self.n_samples = bandit_params['n_samples']
        self.buffer_size = bandit_params['buffer_size']
        self.batch_size = bandit_params['batch_size']
        self.num_batches = bandit_params['num_batches']
        self.lr = bandit_params['lr']
        self.epsilon = bandit_params['epsilon']
        self.cumulative_regrets = [0]
        self.buffer_x, self.buffer_y = [], []
        self.x, self.y = x, y
        self.label = label
        self.init_net(bandit_params)
        self.tp, self.tn, self.fp, self.fn = 0, 0, 0, 0

    def get_agent_reward(self, eaten, edible):
        if not eaten:
            return 0
        if eaten and edible:
            return 5
        elif eaten and not edible:
            return 5 if np.random.rand() > 0.5 else -35

    def get_oracle_reward(self, edible):
        return 5*edible 

    def take_action(self, mushroom):
        context, edible = self.x[mushroom], self.y[mushroom]
        eat_tuple = torch.FloatTensor(np.concatenate((context, [1, 0]))).unsqueeze(0).to(device)
        reject_tuple = torch.FloatTensor(np.concatenate((context, [0, 1]))).unsqueeze(0).to(device)

        # evaluate reward for actions
        with torch.no_grad():
            self.net.eval()
            reward_eat = sum([self.net(eat_tuple) for _ in range(self.n_samples)]).item()
            reward_reject = sum([self.net(reject_tuple) for _ in range(self.n_samples)]).item()

        eat = reward_eat > reward_reject
        # epsilon-greedy agent
        if np.random.rand() < self.epsilon:
            eat = (np.random.rand() < 0.5)
        agent_reward = self.get_agent_reward(eat, edible)

        # record bandit action
        if edible and eat:
            self.tp += 1
        elif edible and not eat:
            self.fn += 1
        elif not edible and eat:
            self.fp += 1
        else:
            self.tn += 1

        # record context, action, reward
        action = torch.Tensor([1, 0] if eat else [0, 1])
        self.buffer_x.append(np.concatenate((context, action)))
        self.buffer_y.append(agent_reward)

        # calculate regret
        regret = self.get_oracle_reward(edible) - agent_reward
        self.cumulative_regrets.append(self.cumulative_regrets[-1]+regret)

    def update(self, mushroom):
        self.take_action(mushroom)
        l = len(self.buffer_x)

        if l <= self.batch_size:
            idx_pool = int(self.batch_size//l + 1)*list(range(l))
            idx_pool = np.random.permutation(idx_pool[-self.batch_size:])
        elif l > self.batch_size and l < self.buffer_size:
            idx_pool = int(l//self.batch_size)*self.batch_size
            idx_pool = np.random.permutation(list(range(l))[-idx_pool:])
        else:
            idx_pool = np.random.permutation(list(range(l))[-self.buffer_size:])

        context_pool = torch.Tensor([self.buffer_x[i] for i in idx_pool]).to(device)
        value_pool = torch.Tensor([self.buffer_y[i] for i in idx_pool]).to(device)
        
        for i in range(0, len(idx_pool), self.batch_size):
            self.loss_info = self.loss_step(context_pool[i:i+self.batch_size], value_pool[i:i+self.batch_size], i//self.batch_size)

In [19]:
class BNN_Bandit(Bandit):
    def __init__(self, label, *args):
        super().__init__(label, *args)
    
    def init_net(self, parameters):
        model_params = {
            'input_shape': self.x.shape[1]+2,
            'classes': 1 if len(self.y.shape)==1 else self.y.shape[1],
            'batch_size': self.batch_size,
            'hidden_units': parameters['hidden_units'],
            'experiment': parameters['experiment'],
            'mu_init': parameters['mu_init'],
            'rho_init': parameters['rho_init'],
            'prior_init': parameters['prior_init']
        }
        self.net = BayesianNetwork(model_params).to(device)
        self.optimiser = torch.optim.Adam(self.net.parameters(), lr=self.lr)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimiser, step_size=5000, gamma=0.5)
        print("BNN Parameters: ")
        print(f'x shape: {model_params["input_shape"]}, hidden units: {model_params["hidden_units"]}, output shape: {model_params["classes"]}, lr: {self.lr}')

    def loss_step(self, x, y, batch_id):
        beta = 2 ** (self.num_batches - (batch_id + 1)) / (2 ** self.num_batches - 1) 
        self.net.train()
        self.net.zero_grad()
        loss_info = self.net.sample_elbo(x, y, beta, self.n_samples)
        net_loss = loss_info[0]
        net_loss.backward()
        self.optimiser.step()
        return loss_info

class Greedy_Bandit(Bandit):
    def __init__(self, label, *args):
        super().__init__(label, *args)
        self.writer = SummaryWriter(comment=f"_{label}_training"),
    
    def init_net(self, parameters):
        model_params = {
            'input_shape': self.x.shape[1]+2,
            'classes': 1 if len(self.y.shape)==1 else self.y.shape[1],
            'batch_size': self.batch_size,
            'hidden_units': parameters['hidden_units'],
            'experiment': parameters['experiment']
        }
        self.net = MLP(model_params).to(device)
        self.optimiser = torch.optim.Adam(self.net.parameters(), lr=self.lr)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimiser, step_size=5000, gamma=0.5)
        print(f'Bandit {self.label} Parameters: ')
        print(f'buffer_size: {self.buffer_size}, batch size: {self.batch_size}, number of samples: {self.n_samples}, epsilon: {self.epsilon}')
        print("MLP Parameters: ")
        print(f'x shape: {model_params["input_shape"]}, hidden units: {model_params["hidden_units"]}, output shape: {model_params["classes"]}, lr: {self.lr}')

    def loss_step(self, x, y, batch_id):
        self.net.train()
        self.net.zero_grad()
        net_loss = torch.nn.functional.mse_loss(self.net(x).squeeze(), y, reduction='sum')
        net_loss.backward()
        self.optimiser.step()
        return net_loss

In [20]:
import matplotlib.pyplot as plt

def rl_trainer():
    config = RLConfig
    X, Y = read_data_rl(config.data_dir)

    params = {
        'buffer_size': config.buffer_size,
        'batch_size': config.batch_size,
        'num_batches': config.num_batches,
        'lr': config.lr,
        'hidden_units': config.hidden_units,
        'experiment': config.experiment,
        'mu_init': config.mu_init,
        'rho_init': config.rho_init,
        'prior_init': config.prior_init
    }

    bandit = BNN_Bandit('bnn_bandit', {**params, 'n_samples':2, 'epsilon':0}, X, Y)
    
    training_steps = config.training_steps
    print(f"Initialising training on {device}...")
    training_data_len = len(X)
    for step in tqdm(range(training_steps)):
        mushroom = np.random.randint(training_data_len)
        bandit.update(mushroom)
        bandit.scheduler.step()

    # Plot cumulative regret
    plt.figure(figsize=(10, 6))
    plt.plot(bandit.cumulative_regrets, label='Cumulative Regret')
    plt.xlabel('Steps')
    plt.ylabel('Cumulative Regret')
    plt.title('Cumulative Regret over Time')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()


#rl_trainer()