In [None]:
from google.colab import files
_=files.upload() # upload mnist2_train.txt and mnist10_train.txt

In [None]:
!pip install larq

In [4]:
import numpy as np
from tqdm import tqdm
import tensorflow as tf
import larq as lq

In [35]:
def train(BITS, CLASSES, validation_split=0.3, epochs=25, learning_rate=1e-4, patience=3, **kwargs):
    X, y = [], []
    with open(f'mnist{CLASSES}_train.txt', 'r') as f:
        for line in f.readlines():
            *a, b = map(int, line.split())
            X.append(a), y.append(b)
    X, y = map(np.array, (X, y))
    def kinit(shape, dtype=None):
        arr = np.zeros(shape)
        for i in range(shape[0]): arr[i][i//15] = 1
        return tf.convert_to_tensor(arr, dtype=dtype)
    model = tf.keras.models.Sequential()
    model.add(
        lq.layers.QuantDense(
            15*CLASSES,
            input_quantizer=lq.quantizers.SteSign(clip_value=1.0),
            #input_quantizer=lq.quantizers.NoOp(precision=1),
            kernel_quantizer=lq.quantizers.SteSign(clip_value=1.0),
            #kernel_quantizer=lq.quantizers.NoOp(precision=1),
            kernel_constraint=lq.constraints.WeightClip(clip_value=1),
            input_shape=(BITS,),
            use_bias=False
        )
    )
    model.add(tf.keras.layers.Dense(CLASSES, kernel_initializer=kinit, use_bias=False, trainable=False))
    model.add(tf.keras.layers.BatchNormalization(scale=False))
    model.add(tf.keras.layers.Activation("softmax"))
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    model.fit(X, y, validation_split=validation_split, epochs=epochs, callbacks=[tf.keras.callbacks.EarlyStopping(monitor='accuracy', patience=patience)], **kwargs)
    t = [[int(s) for s in r] for r in lq.math.sign(model.weights[0])]
    w = [[t[i][j] for i in range(BITS)] for j in range(15*CLASSES)]
    print(f"for r in{str(w).replace(' ','').replace('[','(').replace(']',')')}:print(*r)")

In [41]:
# MNIST 10 class
train(BITS=51, CLASSES=10, validation_split=0.0, epochs=250, learning_rate=9e-4, patience=17, steps_per_epoch=800)

Epoch 1/250
Epoch 2/250
Epoch 3/250
Epoch 4/250
Epoch 5/250
Epoch 6/250
Epoch 7/250
Epoch 8/250
Epoch 9/250
Epoch 10/250
Epoch 11/250
Epoch 12/250
Epoch 13/250
Epoch 14/250
Epoch 15/250
Epoch 16/250
Epoch 17/250
Epoch 18/250
Epoch 19/250
Epoch 20/250
Epoch 21/250
Epoch 22/250
Epoch 23/250
Epoch 24/250
Epoch 25/250
Epoch 26/250
Epoch 27/250
Epoch 28/250
Epoch 29/250
for r in((1,1,1,-1,1,-1,-1,1,1,1,1,1,-1,-1,1,-1,-1,1,1,-1,1,1,1,1,-1,1,-1,1,1,-1,1,-1,1,-1,1,-1,-1,1,1,-1,1,-1,-1,-1,-1,1,-1,-1,1,1,-1),(1,1,1,-1,1,-1,-1,-1,1,1,1,1,-1,-1,1,1,-1,1,1,-1,1,1,1,-1,1,1,1,1,-1,-1,1,-1,1,1,1,-1,1,1,-1,1,1,-1,-1,1,1,-1,-1,1,-1,-1,-1),(-1,-1,1,-1,1,-1,-1,-1,1,1,-1,-1,1,1,1,1,1,-1,1,-1,1,1,1,-1,-1,1,1,1,-1,-1,-1,1,1,1,-1,1,1,1,1,-1,1,1,1,1,1,-1,-1,-1,1,1,1),(1,1,-1,-1,1,-1,-1,-1,1,-1,1,1,-1,-1,-1,-1,1,-1,1,-1,1,1,1,1,-1,1,1,-1,-1,-1,-1,1,1,1,1,1,1,1,-1,-1,1,1,1,1,1,-1,-1,-1,1,1,-1),(1,1,1,-1,-1,-1,-1,-1,1,1,1,1,-1,-1,1,-1,1,-1,1,-1,1,1,-1,-1,1,1,-1,-1,1,-1,1,-1,1,1,-1,-1,1,1,-1,-1,1,1,1,-1,-1,1,-1,-1

In [25]:
# MNIST 2 class
train(BITS=51, CLASSES=2, validation_split=0.0, epochs=250, learning_rate=1e-4, patience=10)

Epoch 1/250
Epoch 2/250
Epoch 3/250
Epoch 4/250
Epoch 5/250
Epoch 6/250
Epoch 7/250
Epoch 8/250
Epoch 9/250
Epoch 10/250
Epoch 11/250
Epoch 12/250
Epoch 13/250
Epoch 14/250
Epoch 15/250
Epoch 16/250
Epoch 17/250
Epoch 18/250
Epoch 19/250
Epoch 20/250
Epoch 21/250
Epoch 22/250
Epoch 23/250
Epoch 24/250
Epoch 25/250
Epoch 26/250
Epoch 27/250
Epoch 28/250
Epoch 29/250
Epoch 30/250
Epoch 31/250
Epoch 32/250
Epoch 33/250
Epoch 34/250
Epoch 35/250
for r in((1,-1,1,-1,1,-1,-1,-1,1,-1,-1,1,1,-1,-1,1,-1,1,1,-1,-1,-1,-1,1,1,-1,1,1,-1,-1,1,-1,-1,1,1,1,-1,1,1,-1,1,-1,-1,-1,1,1,1,1,-1,-1,-1),(1,-1,-1,-1,1,-1,-1,-1,1,1,1,1,1,-1,1,1,-1,1,1,-1,1,-1,-1,1,1,-1,1,-1,-1,-1,-1,1,-1,-1,-1,-1,1,1,1,1,-1,1,1,-1,1,-1,-1,-1,1,-1,-1),(1,1,1,-1,1,-1,-1,1,1,-1,-1,-1,1,1,-1,1,1,1,1,-1,1,1,1,-1,-1,1,-1,1,-1,1,-1,-1,1,1,1,-1,1,1,-1,-1,1,-1,-1,1,1,1,-1,-1,-1,1,-1),(1,1,1,-1,1,-1,-1,1,1,-1,-1,-1,-1,-1,1,-1,1,-1,1,-1,-1,1,1,-1,1,1,-1,1,1,1,1,-1,1,-1,-1,1,-1,-1,1,-1,1,-1,-1,1,1,-1,-1,1,-1,1,1),(-1,1,-1,-1,1,-1,-1,-1,-1,1