In [1]:
import os
import sys

sys.path.append(os.path.abspath("../src"))

In [2]:
import numpy as np
from pathlib import Path
import torch


def read_mnist_images(filename: Path, device: str = "cpu") -> torch.Tensor:
    with open(filename, "rb") as f:
        _, num_images, rows, cols = np.frombuffer(f.read(16), dtype=">u4")  # Big-endian
        images = np.frombuffer(f.read(), dtype=np.uint8).reshape(num_images, rows, cols)
    return torch.tensor(images, device=device)


def read_mnist_labels(filename: Path, device: str = "str") -> torch.Tensor:
    with open(filename, "rb") as f:
        _, _ = np.frombuffer(f.read(8), dtype=">u4")  # Big-endian
        labels = np.frombuffer(f.read(), dtype=np.uint8)  # Labels are 1-byte each
    return torch.tensor(labels, device=device)


mnist, device = Path.home() / "Downloads" / "mnist", "mps"
train_images = read_mnist_images(mnist / "train-images.idx3-ubyte", device=device)
train_labels = read_mnist_labels(mnist / "train-labels.idx1-ubyte", device=device)
test_images = read_mnist_images(mnist / "t10k-images.idx3-ubyte", device=device)
test_labels = read_mnist_labels(mnist / "t10k-labels.idx1-ubyte", device=device)


In [3]:
import bitwise
from bitwise import bp
import torch
from typing import List


class DigitRecognizer:
    _layers: List[bp.Layer]

    def __init__(self, layers: List[bp.Layer], device="cpu"):
        self._layers = layers

    def eval(self, images: torch.Tensor) -> bitwise.Tensor:
        batch_size = images.shape[0]
        outputs = images.view(torch.int32).view(batch_size, 1, 196)
        for layer in self._layers:
            outputs = layer.eval(outputs)
        return outputs

    def update(self, errors: bitwise.Tensor):
        for layer in reversed(self._layers):
            errors = layer.update(errors)


def random_digit_recognizer(device="cpu") -> DigitRecognizer:
    layers = []

    dims = [196, 128, 64, 32, 10]
    for ins, outs in list(zip(dims, dims[1:])):
        weights = torch.randint(0, 2**32, (outs * 32, ins), device=device).to(
            dtype=torch.int32
        )
        bias = torch.randint(0, 2**32, (1, outs), device=device).to(dtype=torch.int32)
        layer = bp.Layer(weights, bias, train=True)
        layers.append(layer)

    return DigitRecognizer(layers)


digit_recognizer = random_digit_recognizer(device)

In [4]:
import bitwise
import torch
from torch.nn import functional as F


def class_probs(outputs: bitwise.Tensor) -> torch.Tensor:
    # Count set bits using Hamming weight algorithm.
    x = outputs[:, 0].clone()
    x = x - ((x >> 1) & 0x55555555)
    x = (x & 0x33333333) + ((x >> 2) & 0x33333333)
    x = (x + (x >> 4)) & 0x0F0F0F0F
    x = x + (x >> 8)
    x = x + (x >> 16)
    return (x & 0x3F).to(torch.float32) / 32.0


def class_loss(outputs: bitwise.Tensor, labels: torch.Tensor) -> float:
    return F.cross_entropy(class_probs(outputs).log(), labels).item()


def output_errors(outputs: bitwise.Tensor, labels: torch.Tensor) -> bitwise.Tensor:
    expected = (
        F.one_hot(labels.to(dtype=torch.long), num_classes=10)
        .to(dtype=torch.int32)
        .unsqueeze_(1)
        * -1
    )
    return expected.bitwise_xor_(outputs)

In [30]:
batch_size = 16
outputs = digit_recognizer.eval(train_images[:batch_size])
print(class_loss(outputs, train_labels[:batch_size]))
errors = output_errors(outputs, train_labels[:batch_size])
digit_recognizer.update(errors)

2.3084163665771484
