In [1]:
########## NN load in #############
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import copy

#learning_rate = 0
n_epochs = 200
batch_size_train = 64 # 64 for stochastic, 50000 for full gradients
batch_size_test = 1000
learning_rate = 0.001
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)
test_losses = []

### define train_loader and test_loader ###
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

### define a network ###
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return x

network = nn.Sequential(
            Flatten(),
            nn.Linear(784, 200),
            nn.ReLU(),
            nn.Linear(200, 10),
            nn.ReLU(),
            nn.LogSoftmax()
        )


### define optimizer (SGD momentum=0, lr = step size in HMC) ###
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                      momentum=0)

### import nn config ####
network_state_dict = torch.load("./results/model.pth")
network.load_state_dict(network_state_dict)


#### define test() to calculate test accuracy ###
def test(network):
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      output = network(data)
      test_loss += F.nll_loss(output, target, size_average=False).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))
  return correct

test(network)
x_net = copy.deepcopy(network)
train_loader_iter = train_loader.__iter__()

  input = module(input)



Test set: Avg. loss: 2.2936, Accuracy: 1268/10000 (12%)



In [2]:
########## HMC part ############
##### parameters ####
logm0 = 1;
sigmam = 1;
learning_rate = 0.3 # step size in HMC
L = 5
mu_m = 8e-4
T = 1e-12
d = 784*200+200+200*10+10
epoch = 1000 
zeta = 0.
mu = 1; lamb = 0.0001;
mag = 1;
mu *= mag; lamb *= mag;

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return x


q_net = nn.Sequential(
            Flatten(),
            nn.Linear(784, 200),
            nn.ReLU(),
            nn.Linear(200, 10),
            nn.ReLU(),
            nn.LogSoftmax()
        )


### x <- x + \eps M^{-1} q ###

def update_x_net(x_net, q_net, m):
    temp_x = copy.deepcopy(x_net)
    temp_q = copy.deepcopy(q_net)
    params1 = temp_x.state_dict()
    dict_params1 = dict(params1)
    params2 = temp_q.state_dict()
    dict_params2 = dict(params2)
    for name1 in params1:
        dict_params2[name1].data.copy_(dict_params1[name1].data + 1/m*dict_params2[name1].data*learning_rate)
    temp_x.load_state_dict(dict_params2)
    return temp_x

### q <- q - \eps \nabla U(x)  - \eps zeta q ###

def update_q_net(x_net, q_net, zeta, one):
    global train_loader_iter
    temp_q = copy.deepcopy(q_net)
    temp_x = copy.deepcopy(x_net)
    if one==False:
        opt_q = optim.SGD(temp_q.parameters(), lr=learning_rate/2,
                      momentum=0)
        opt_x = optim.SGD(temp_x.parameters(), lr=learning_rate/2,
                      momentum=0)
    else:
        opt_q = optim.SGD(temp_q.parameters(), lr=learning_rate,
                      momentum=0)
        opt_x = optim.SGD(temp_x.parameters(), lr=learning_rate,
                      momentum=0)
    try:
        data, target = train_loader_iter.__next__()
    except StopIteration:
        train_loader_iter = iter(train_loader)
        data, target = train_loader_iter.__next__()
    opt_x.zero_grad()
    output = temp_x(data)
    err_loss = mu*F.nll_loss(output, target)
    regularization_loss = 0
    for param in temp_x.parameters():
        regularization_loss += lamb*torch.sum((torch.abs(param)+1e-8)**0.5)
    loss = err_loss + regularization_loss;
    loss.backward()
    paras_x = temp_x.state_dict()
    paras_q = temp_q.state_dict()
    dict_params_x = dict(paras_x)
    dict_params_q = dict(paras_q)
    params_x = list(temp_x.parameters())
    i = 0
    for name1 in paras_x:
        dict_params_q[name1].data.copy_(dict_params_q[name1].data - zeta*dict_params_q[name1].data*learning_rate - params_x[i].grad*learning_rate)
        i = i + 1
    temp_q.load_state_dict(dict_params_q)
    x = list(temp_q.parameters())[0].data.flatten()
    y = list(q_net.parameters())[0].data.flatten()
    return temp_q


def leap(x_net, q_net, zeta):
    q_net1 = update_q_net(x_net, q_net, zeta, False)
    x_net1 = update_x_net(x_net, q_net1, m)
    for i in range(L-1):
        q_net1 = update_q_net(x_net1, q_net1, zeta, True)
        x_net1 = update_x_net(x_net1, q_net1, m)
    q_net1 = update_q_net(x_net1, q_net1, zeta, False)
    p2 = 0
    for param in q_net.parameters():
        p2 += torch.sum(torch.pow(param,2))
    p2 = p2.detach().numpy()
    zeta1 = zeta + learning_rate * 1/mu_m * (p2/m - d*T) # update zeta
    return x_net1,q_net1,zeta1


# calculate energy

def H(x_net, q_net):
    global train_loader_iter
    p2 = 0
    for param in q_net.parameters():
        p2 += torch.sum(torch.pow(param,2))
    p2 = p2.detach().numpy()/(2*m)
    try:
        data, target = train_loader_iter.__next__()
    except StopIteration:
        train_loader_iter = iter(train_loader)
        data, target = train_loader_iter.__next__()
    output = x_net(data)
    err_loss = mu*F.nll_loss(output, target)
    regularization_loss = 0
    for param in x_net.parameters():
        regularization_loss += lamb*torch.sum(torch.abs(param))
    loss = err_loss + regularization_loss;
    H = loss.item() + p2;
    #print(loss.item(),p2)
    return H

# Metropolis-Hastings (optional)

def mh(x_net, q_net, x_net1, q_net1):
    H0 = H(x_net, q_net)
    H1 = H(x_net1, q_net1)
    prob = np.e**(-(H0-H1)/T)
    print(H0,H1)
    acc = np.random.rand() < prob
    return acc


In [3]:
# count number of neurons smaller than wc in a network

def prune(x_net, wc):
    prune_num = 0
    temp_x = copy.deepcopy(x_net)
    for para in temp_x.parameters():
        t = para.data
        small_id = (t<wc)*(t>-wc)
        prune_num += torch.sum(small_id.flatten())
        t[small_id] = 0
    return prune_num, test(temp_x)


In [4]:
# Use SGNHT/QSGNHT to sample neural network

x_net = update_x_net(x_net, q_net, 1)
accs = []
prunnums = []
x_nets = []
q_nets = []
wc = 3e-2;

for i in range(epoch):
    m = pow(10,np.random.randn()*sigmam+logm0)
    if m<1:
        m = 1
    def init_weights(mat):
        if type(mat) == nn.Linear:
            torch.nn.init.normal_(mat.weight,mean=0.,std=np.sqrt(m*T))
            torch.nn.init.normal_(mat.bias,mean=0.,std=np.sqrt(m*T))
    q_net.apply(init_weights)
    [x_net1, q_net1, zeta1] = leap(x_net, q_net, zeta)
    acc = mh(x_net, q_net, x_net1, q_net1)
    if acc==1:
        x_net = x_net1
        q_net = q_net1
    prune_num, test_acc = prune(x_net, wc)
    prunnums.append(159010./(159010-prune_num.detach().numpy()))
    accs.append(test_acc)
    x_nets.append(x_net)
    q_nets.append(q_net)
    zeta = zeta1



2.593827327008454 3.083819209217722

Test set: Avg. loss: 1.6689, Accuracy: 5385/10000 (53%)

1.901601990116159 1.806777399815798

Test set: Avg. loss: 1.6689, Accuracy: 5385/10000 (53%)

1.78495164684067 2.315702568113223

Test set: Avg. loss: 1.3871, Accuracy: 6144/10000 (61%)

1.6109907229727283 1.3009442225652839

Test set: Avg. loss: 1.3871, Accuracy: 6144/10000 (61%)

1.6022579271167976 1.9631226894301494


KeyboardInterrupt: 