In [8]:
import torch
import torch.nn as nn
import os
# import matplotlib.pyplot as plt
# import seaborn as sns
# sns.set_theme(context='talk', style='whitegrid', palette='colorblind')
# from classes.layers import BayesLinearMixture
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

from torch.nn import Module, Parameter
import torch.nn.init as init
import torch.nn.functional as F

import math


In [9]:
batch_size = 128
# training_data = datasets.MNIST(root='data', train=True, download=True, transform=ToTensor())
# training_set, validation_set = torch.utils.data.random_split(training_data, [50000, 10000])
test_data = datasets.MNIST(root='data', train=False, download=True, transform=ToTensor())


# training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
# validation_loader = DataLoader(validation_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# print(type(test_data[0][0][0][0][0]))

In [10]:
class BayesLinearMixture(Module):
    r"""
    Applies Bayesian Linear
    Arguments:
        prior_mu (Float): mean of prior normal distribution.
        prior_sigma (Float): sigma of prior normal distribution.
    .. note:: other arguments are following linear of pytorch 1.2.0.
    https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py
    
    """
    __constants__ = ['prior_mu', 'prior_sigma', 'bias', 'in_features', 'out_features']

    def __init__(self, prior_mu1, prior_sigma1, prior_mu2, prior_sigma2, pi, in_features, out_features):
        super(BayesLinearMixture, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        self.prior_mu1 = prior_mu1
        self.prior_sigma1 = prior_sigma1
        self.prior_log_sigma1 = math.log(prior_sigma1)

        self.prior_mu2 = prior_mu2
        self.prior_sigma2 = prior_sigma2
        self.prior_log_sigma2 = math.log(prior_sigma2)

        self.pi = pi

        self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
#         self.weight_log_sigma = Parameter(torch.Tensor(out_features, in_features))
        self.weight_rho = Parameter(torch.Tensor(out_features, in_features))
        self.register_buffer('weight_eps', None)

#         if bias is None or bias is False :
#             self.bias = False
#         else :
#             self.bias = True

#         if self.bias:
#             self.bias_mu = Parameter(torch.Tensor(out_features))
#             self.bias_log_sigma = Parameter(torch.Tensor(out_features))
#             self.register_buffer('bias_eps', None)
#         else:
#             self.register_parameter('bias_mu', None)
#             self.register_parameter('bias_log_sigma', None)
#             self.register_buffer('bias_eps', None)
            
        self.reset_parameters()

    def reset_parameters(self):
        # Initialization method of Adv-BNN
        stdv = 1. / math.sqrt(self.weight_mu.size(1))
        self.weight_mu.data.uniform_(-stdv, stdv)
        self.weight_rho.data.fill_(self.prior_log_sigma1)
#         if self.bias :
#             self.bias_mu.data.uniform_(-stdv, stdv)
#             self.bias_log_sigma.data.fill_(self.prior_log_sigma1)
  
    def freeze(self) :
        self.weight_eps = torch.randn_like(self.weight_log_sigma)
#         if self.bias :
#             self.bias_eps = torch.randn_like(self.bias_log_sigma)
        
    def unfreeze(self) :
        self.weight_eps = None
#         if self.bias :
#             self.bias_eps = None 
            
    def forward(self, input):
        r"""
        Overriden.
        """
#         if self.weight_eps is None :
        eps = torch.torch.randn_like(self.weight_rho)
        self.weight = self.weight_mu + torch.log1p(torch.exp(self.weight_rho)) * eps
#         else :
#             self.weight = self.weight_mu + torch.exp(self.weight_log_sigma) * self.weight_eps
        
#         if self.bias:
#             if self.bias_eps is None :
#                 bias = self.bias_mu + torch.exp(self.bias_log_sigma) * torch.randn_like(self.bias_log_sigma)
#             else :
#                 bias = self.bias_mu + torch.exp(self.bias_log_sigma) * self.bias_eps                
#         else :
#             bias = None

        return F.linear(input, self.weight, bias=None)

    def extra_repr(self):
        r"""
        Overriden.
        """
        return 'prior_mu1={}, prior_sigma1={}, prior_mu2={}, prior_sigma2={}, in_features={}, out_features={}'.format(self.prior_mu1, self.prior_sigma1, self.prior_mu2, self.prior_sigma2, self.in_features, self.out_features)
  

In [11]:
input_dim = 784
hidden_dim = 1200
num_classes = 10
log_sigma1 = -1
log_sigma2 = -6
pi = 0.75
sigma1 = 10 ** log_sigma1
sigma2 = 10 ** log_sigma2
model = nn.Sequential(
    BayesLinearMixture(prior_mu1=0, prior_sigma1=sigma1, prior_mu2=0, prior_sigma2=sigma2, in_features=input_dim, out_features=hidden_dim, pi=pi),
    nn.ReLU(),
    BayesLinearMixture(prior_mu1=0, prior_sigma1=sigma1, prior_mu2=0, prior_sigma2=sigma2, in_features=hidden_dim, out_features=hidden_dim, pi=pi),
    nn.ReLU(),
    BayesLinearMixture(prior_mu1=0, prior_sigma1=sigma1, prior_mu2=0, prior_sigma2=sigma2, in_features=hidden_dim, out_features=num_classes, pi=pi),
)

In [27]:
signal_to_noise = []
directory = "./models/"
# filename = os.path.join(directory, "gaussian-mixture-prior", "averaging2", "pi0.5", "sigma21e-6", "model_20230408_123457_300")
filename = os.path.join(directory, "gaussian-mixture-prior", "model_20230411_090805_300")
model.load_state_dict(torch.load(filename))
model.train(False)

for m in model.modules():
    if isinstance(m, BayesLinearMixture):
        mu = m.weight_mu.data.numpy()
        rho = m.weight_rho.data.numpy()
        sigma = np.log1p(np.exp(rho))
        s2n = np.log(np.abs(mu) / sigma)
        s2n = list(s2n.flatten())
        signal_to_noise.extend(s2n)

print(len(signal_to_noise))
signal_to_noise = np.array(signal_to_noise)

2392800


In [28]:
signal_to_noise.sort()
index = int(len(signal_to_noise) * 0.98)
threshold_val = signal_to_noise[index]

num_weights = 0
model.train(False)

for m in model.modules():
    if isinstance(m, BayesLinearMixture):
        m.weight_mu.requires_grad_(False)
        m.weight_rho.requires_grad_(False)

        mu = m.weight_mu.data.numpy()
        rho = m.weight_rho.data.numpy()
        sigma = np.log1p(np.exp(rho) )
        s2n = np.log(np.abs(mu) / sigma + 1e-6)        
        mask = (s2n>=threshold_val)
        num_weights += np.sum(mask)
        m.weight_mu.data *= mask.astype(float)
        m.weight_rho.data += -1e10 * (1-mask).astype(float)
        m.weight_mu.data = m.weight_mu.data.float()
        m.weight_rho.data = m.weight_rho.data.float()
        
print(num_weights)

47856


In [29]:
##### model.train(False)

num_iters = 1
accuracy = 0.0
for _ in range(num_iters):
    correct = 0

    for i, tdata in enumerate(test_loader):
        tinputs, tlabels = tdata
        tinputs = tinputs.view(-1, 784)
        toutputs = model(tinputs)
        _, predicted = torch.max(toutputs, 1)
        correct += torch.sum(tlabels == predicted)

    accuracy += correct / len(test_data) / num_iters
    

print("Number of correct predictions {}".format(correct))

print("Error: {:.2f}".format(100 - 100 * accuracy))


Number of correct predictions 4534
Error: 54.66
