# Tensorflow.keras inference implementation of one-bit-per-weight CNN for CIFAR 10 
##  https://arxiv.org/abs/1802.08530
## M. D. McDonnell, 
## Training wide residual networks for deployment using a single bit for each weight
## ICLR, 2018

In [1]:
# select a GPU
import os
os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

import numpy as np
import matplotlib.pyplot as plt
from scipy.io import loadmat
import warnings
warnings.filterwarnings('ignore',category=FutureWarning)

import tensorflow
print('Tensorflow version = ',tensorflow.__version__)
from tensorflow.keras.optimizers import SGD
from ResNetModel import resnet

Tensorflow version =  1.13.1


In [2]:
#params
WhichDataSet = 'CIFAR10'
#WhichDataSet = 'CIFAR100'
resnet_width = 10
resnet_depth = 20
ModelsPath = 'TrainedModels/Tensorflow.keras/'

In [3]:
#load and prepare data
if WhichDataSet == 'CIFAR10':
    _, (x_test, y_test) = tensorflow.keras.datasets.cifar10.load_data()
else:
    _, (x_test, y_test) = tensorflow.keras.datasets.cifar100.load_data()
num_classes = np.unique(y_test).shape[0]
input_shape = x_test.shape[1:]
x_test = x_test.astype('float32')
y_test = tensorflow.keras.utils.to_categorical(y_test, num_classes)


In [4]:
#create an inference model that uses regular conv2d layers
inference_model = resnet(UseBinaryWeights=False,input_shape=input_shape, depth=resnet_depth, num_classes=num_classes,wd=0.0,width=resnet_width)
#need to compile: we are forced to set the optimizr etc, even though it won't be use
inference_model.compile(loss='categorical_crossentropy' ,optimizer = SGD(lr=0.0,decay=0.0, momentum=0.9, nesterov=False), metrics = ['accuracy'])


Instructions for updating:
Colocations handled automatically by placer.


In [5]:
#verify each layer can take only two values:
#verify get the right results from compressed boolean storage
AllParamsDict_loaded=loadmat(ModelsPath+WhichDataSet+'_allparams.mat')
conv_names=[m for m in list(AllParamsDict_loaded.keys()) if any(s in m for s in ['conv2d'])]
bn_names=[m for m in list(AllParamsDict_loaded.keys()) if any(s in m for s in ['batch'])]

c1=0
c2=0
for layer in inference_model.layers:
    if 'conv2d' in layer.name:
        ww=AllParamsDict_loaded[conv_names[c1]].astype('float32')*2.0-1.0
        ww=ww*np.sqrt(2.0/np.prod(ww[0].shape[0:3]))
        layer.set_weights([ww[0]])
        uw = np.unique(ww)
        print('conv layer ',c1,' has ', len(uw),' unique weights')
        c1=c1+1
    elif 'batch_normalization' in layer.name:
        ww=AllParamsDict_loaded[bn_names[c2]]
        layer.set_weights(ww)
        c2=c2+1

#get accuracy:
y_pred = inference_model.predict(x_test)
print('One-bit-per-weight Test accuracy (%):', 100*sum(np.argmax(y_pred,-1)==np.argmax(y_test,-1))/y_test.shape[0])

conv layer  0  has  2  unique weights
conv layer  1  has  2  unique weights
conv layer  2  has  2  unique weights
conv layer  3  has  2  unique weights
conv layer  4  has  2  unique weights
conv layer  5  has  2  unique weights
conv layer  6  has  2  unique weights
conv layer  7  has  2  unique weights
conv layer  8  has  2  unique weights
conv layer  9  has  2  unique weights
conv layer  10  has  2  unique weights
conv layer  11  has  2  unique weights
conv layer  12  has  2  unique weights
conv layer  13  has  2  unique weights
conv layer  14  has  2  unique weights
conv layer  15  has  2  unique weights
conv layer  16  has  2  unique weights
conv layer  17  has  2  unique weights
conv layer  18  has  2  unique weights
conv layer  19  has  2  unique weights
One-bit-per-weight Test accuracy (%): 96.25


In [6]:
#Compare with full precision model
inference_model = resnet(UseBinaryWeights=False,input_shape=input_shape, depth=resnet_depth, num_classes=num_classes,wd=0.0,width=resnet_width)
inference_model.compile(loss='categorical_crossentropy' ,optimizer = SGD(lr=0.0,decay=0.0, momentum=0.9, nesterov=False), metrics = ['accuracy'])
inference_model.load_weights(ModelsPath+'Final_weights_'+WhichDataSet+'_32bit_model_v2.h5')
y_pred = inference_model.predict(x_test)
print('Full Precision Test accuracy (%):', 100*sum(np.argmax(y_pred,-1)==np.argmax(y_test,-1))/y_test.shape[0])

Full Precision Test accuracy (%): 96.65
