In [1]:
import npnet as tn
import ocr_data

import numpy as np

In [2]:
%matplotlib inline

import matplotlib
import matplotlib.pyplot as plt

matplotlib.rcParams['figure.figsize'] = [10,7]
matplotlib.rcParams['xtick.top'] = True
matplotlib.rcParams['xtick.direction'] = 'in'
matplotlib.rcParams['xtick.minor.visible'] = True
matplotlib.rcParams['ytick.right'] = True
matplotlib.rcParams['ytick.direction'] = 'in'
matplotlib.rcParams['ytick.minor.visible'] = True
matplotlib.rcParams['font.size'] = 19
matplotlib.rcParams['font.family'] = 'DejaVu Serif'
matplotlib.rcParams['mathtext.default'] = 'regular'
matplotlib.rcParams['errorbar.capsize'] = 3

In [3]:
input_shape = ocr_data.in_2d_shape
conv_kernels = [(5,5)] #[(3,3),(2,2),(3,3),(2,2),(3,3),(2,2),(3,3),(2,2),(2,2),(2,2)]
conv_stride = [(1,1)] #[(1,1),(2,2),(1,1),(2,2),(1,1),(2,2),(1,1),(2,2),(1,1),(2,2)]
conv_outs = [(5,)] #[(4,),(8,),(8,),(12,),(12,),(16,),(16,),(24,),(24,),(32,)]
conv_layers = len(conv_kernels)
hidden_shapes = []
hidden_layers = len(hidden_shapes)
output_shape = ocr_data.out_shape
print(input_shape,output_shape)

(32, 32) (26,)


In [4]:
in_layer = tn.Input(input_shape)()
last_layer = in_layer
print(in_layer)
for kernel_shape,out_shape,kernel_stride in zip(conv_kernels,conv_outs,conv_stride):
    last_layer = tn.Conv(kernel_shape,out_shape=out_shape,pad=True,kernel_stride=kernel_stride,activation=tn.ReLU())(last_layer)
    print(last_layer)
for hidden_shape in hidden_shapes:
    last_layer = tn.Dense(hidden_shape,activation=tn.ReLU())(last_layer)
    print(last_layer)
out_layer = tn.Dense(output_shape,activation=tn.ReLU())(last_layer)
print(out_layer)

s = tn.System(inputs=[in_layer],outputs=[out_layer])

Input :: [] -> [(32, 32)]
Conv :: [(32, 32)] -> [(32, 32, 5)]
Dense :: [(32, 32, 5)] -> [(26,)]
Conv :: [(32, 32)] -> [(32, 32, 5)] idx: 1 => Dense :: [(32, 32, 5)] -> [(26,)] idx: 0
Input :: [] -> [(32, 32)] idx: 2 => Conv :: [(32, 32)] -> [(32, 32, 5)] idx: 1


In [5]:
import multiprocessing
import functools

def batch(length=1000):    
    total,failures = 0,0
    state = None
    for true_out,input in ocr_data.tagged_2d_data(length):
        guess_out,state = s.guess([input],return_state=True)
        #print(guess_out[0]-true_out)
        if np.argmax(guess_out[0]) != np.argmax(true_out):
            failures += 1
        s.learn(state,[true_out],scale=1e-2,loss='quad')
        total += 1
    return total,failures

In [6]:
print('loading weights')
#s.load_weights('OCR_conv_network.h5')

loading weights


In [7]:
%prun batch(100)

 

In [None]:
try:
    batch_size = 1000
    while True:
        cases,failures = batch(batch_size)
        print('saving weights')
        s.save_weights('OCR_conv_network.h5')
        print('batch accuracy',cases-failures,'/',cases)
        if (cases-failures)/cases > 0.999:
            break
except KeyboardInterrupt:
    print('Stopped by user')

saving weights
batch accuracy 283 / 1000
saving weights
batch accuracy 560 / 1000
saving weights
batch accuracy 637 / 1000


In [None]:
for i,c in enumerate(range(ord('A'),ord('Z'))):
    for j in range(np.prod(conv_outs[-1])):
        print('neuron',chr(c),'layer',j)
        plt.imshow(s.parts[0].layer[0].weights[i].reshape(input_shape+conv_outs[-1])[:,:,j])
        plt.colorbar()
        plt.show()
        plt.close()

In [None]:
for i in [0]:
    conv_inst = s.parts[-i-2]
    print(conv_inst)
    conv_net = conv_inst.layer[0]
    print(conv_net.weights.shape)
    shape = conv_outs[-i]+conv_kernels[i]
    weights = conv_net.weights.reshape(shape)
    for j in range(conv_outs[i][0]):
        print('conv kernel',j)
        plt.imshow(weights[j])
        plt.colorbar()
        plt.show()
        plt.close()