# Matrix-Scaling: A Bayesian approach

In [34]:
import os
import sys
import time
import importlib
import collections
sys.path.append('..')

import numpy as np
import pandas as pd
from scipy.special import softmax

from utils.data import get_cifar10, load_logits
from utils.ops import onehot_encode
from utils.metrics import neg_log_likelihood, expected_calibration_error

In [35]:
data_path = '../cifar-10'

In [36]:
cifar10, ix2label = get_cifar10(data_path, test=True)
print("Number of samples in the test set: {:d}".format(cifar10["test_labels"].shape[0]))

Number of samples in the test set: 10000


In [37]:
target = onehot_encode(cifar10['test_labels'])

# val/test split
random_split = np.random.permutation(10000)

val_target = target[random_split[:5000], :]
test_target = target[random_split[5000:], :]

In [38]:
resnet_path = '../pretrained-models'
net = 'resnet56_v2'

In [39]:
_, logits = load_logits(os.path.join(resnet_path, net))

# val/test split
val_logits = logits[random_split[:5000], :]
test_logits = logits[random_split[5000:], :]

Initialize params

In [97]:
w_mean = np.random.randn(10, 10)
w_log_var = np.random.randn(10, 10)

b_mean = np.random.randn(10)
b_log_var = np.random.randn(10)

params = {
    'w_mean': w_mean,
    'w_log_var': w_log_var,
    'b_mean': b_mean,
    'b_log_var': b_log_var
}

## PRIOR
prior_w_mean = np.eye(10)
prior_b_mean = np.zeros(10)

In [98]:
def sample(params):
    Z_w = np.random.randn(10, 10)
    Z_b = np.random.randn(10)
    
    w = Z_w*np.exp(0.5 * params['w_log_var']) + params['w_mean']
    b = Z_b*np.exp(0.5 * params['b_log_var']) + params['b_mean']

    return w, b, Z_w, Z_b

In [99]:
def predict(X, params, K=100):
    y = np.zeros(X.shape)
    for i in range(K):
        w, b, _, _ = sample(params)
        y += X @ w + b
        
    y /= K
    return softmax(y, axis=1)

In [100]:
def ELBO(pred, target, params):
    LL = -neg_log_likelihood(pred, target)
    mean_w_KL = 0.5 * np.mean(np.exp(params['w_log_var'])**2 + (prior_w_mean - params['w_mean'])**2 - 1 - 2*params['w_log_var'])
    mean_b_KL = 0.5 * np.mean(np.exp(params['b_log_var'])**2 + (prior_b_mean - params['b_mean'])**2 - 1 - 2*params['b_log_var'])
    
    ELBO = LL - (mean_w_KL + mean_b_KL)/2.
    
    return ELBO

In [106]:
def fit(X, y, params, batch_size=10, epochs=1000, lambd=0.5, learning_rate=0.0001):
    n_steps = X.shape[0]//batch_size + (X.shape[0]%batch_size > 0)
    for e in range(epochs):
        for j in range(n_steps):
            
            log_batch = X[j*batch_size:min((j+1)*batch_size, X.shape[0]), :]
            y_batch = y[j*batch_size:min((j+1)*batch_size, X.shape[0]), :]
            
            # Forward pass
            w, b, Z_w, Z_b = sample(params)
            logits = log_batch @ w + b
            
            probs = softmax(logits, axis=1)
            
            # Likelihood backpropagation
            dW = np.mean((probs-y_batch).reshape([-1, 10, 1])
                         @ log_batch.reshape([-1, 1, 10]), axis=0).T
            dw_var = dW * 2*(-(Z_w + params['w_mean'])/np.exp(params['w_log_var']))
            dw_log_var = -0.5 * dw_var
            dw_mean = 1./np.exp(0.5*params['w_log_var'])
            
            db = np.mean((probs-y_batch), axis=0)
            db_var = db * 2*(-(Z_b + params['b_mean'])/np.exp(params['b_log_var']))
            db_log_var = -0.5 * db_var
            db_mean = 1./np.exp(0.5*params['b_log_var'])
            
            # KL backpropagation
            dw_mean_KL = -(prior_w_mean - params['w_mean'])
            dw_log_var_KL = 0.5*np.exp(params['w_log_var']) - 1
            
            db_mean_KL = -(prior_b_mean - params['b_mean'])
            db_log_var_KL = 0.5*np.exp(params['b_log_var']) - 1
            """
            print('dw_mean: {}'.format(dw_mean))
            print('dw_log_var: {}'.format(dw_log_var))
            print('db_mean: {}'.format(db_mean))
            print('db_log_var: {}'.format(db_log_var))
            
            print('dw_mean_KL: {}'.format(dw_mean_KL))
            print('dw_log_var_KL: {}'.format(dw_log_var_KL))
            print('db_mean_KL: {}'.format(db_mean_KL))
            print('db_log_var_KL: {}'.format(db_log_var_KL))
            
            # Update parameters
            params['w_mean'] -= learning_rate*(dw_mean - lambd*dw_mean_KL)
            params['w_log_var'] -= learning_rate*(dw_log_var - lambd*dw_log_var_KL)
            params['b_mean'] -= learning_rate*(db_mean - lambd*db_mean_KL)
            params['b_log_var'] -= learning_rate*(db_log_var - lambd*db_log_var_KL)
            
            print(params)
            """
        if e%100 == 0:
            elbo = ELBO(predict(X, params), y, params)
            print("End of epoch {:d}, ELBO: {:.3e}".format(e, elbo))
    return params 
    

In [107]:
params = fit(val_logits, val_target, params, epochs=1000)

End of epoch 0, ELBO: -1.661e+01
End of epoch 100, ELBO: -1.630e+01
End of epoch 200, ELBO: -1.624e+01
End of epoch 300, ELBO: -1.646e+01
End of epoch 400, ELBO: -1.633e+01
End of epoch 500, ELBO: -1.648e+01
End of epoch 600, ELBO: -1.649e+01
End of epoch 700, ELBO: -1.650e+01
End of epoch 800, ELBO: -1.635e+01
End of epoch 900, ELBO: -1.638e+01


KeyboardInterrupt: 