In [47]:
import numpy as np
from pathlib import Path
import torch


def read_mnist_images(filename: Path, device: str = "cpu") -> torch.Tensor:
    with open(filename, "rb") as f:
        _, num_images, rows, cols = np.frombuffer(f.read(16), dtype=">u4")  # Big-endian
        images = np.frombuffer(f.read(), dtype=np.uint8).reshape(num_images, rows, cols)
    return torch.tensor(images, device=device)


def read_mnist_labels(filename: Path, device: str = "str") -> torch.Tensor:
    with open(filename, "rb") as f:
        _, _ = np.frombuffer(f.read(8), dtype=">u4")  # Big-endian
        labels = np.frombuffer(f.read(), dtype=np.uint8)  # Labels are 1-byte each
    return torch.tensor(labels, device=device)


mnist, device = Path.home() / "Downloads" / "mnist", "mps"
train_images = read_mnist_images(mnist / "train-images.idx3-ubyte", device=device)
train_labels = read_mnist_labels(mnist / "train-labels.idx1-ubyte", device=device)
test_images = read_mnist_images(mnist / "t10k-images.idx3-ubyte", device=device)
test_labels = read_mnist_labels(mnist / "t10k-labels.idx1-ubyte", device=device)


In [48]:
import os
import sys

sys.path.append(os.path.abspath("../src"))

import bitwise
from bitwise import bp
import torch
from typing import List


class DigitRecognizer:
    _layers: List[bp.Layer]

    def __init__(self, layers: List[bp.Layer], device="cpu"):
        self._layers = layers

    def eval(self, images: torch.Tensor) -> bitwise.Tensor:
        batch_size = images.shape[0]
        outputs = images.view(torch.int32).view(batch_size, 1, 196)
        for layer in self._layers:
            outputs = layer.eval(outputs)
        return outputs


def random_digit_recognizer(device="cpu") -> DigitRecognizer:
    layers = []

    dims = [196, 128, 64, 32, 10]
    for ins, outs in list(zip(dims, dims[1:])):
        weights = torch.randint(0, 2**32, (outs * 32, ins), device=device).to(
            dtype=torch.int32
        )
        bias = torch.randint(0, 2**32, (1, outs), device=device).to(dtype=torch.int32)
        layer = bp.Layer(weights, bias, train=True)
        layers.append(layer)

    return DigitRecognizer(layers)


digit_recognizer = random_digit_recognizer(device)

In [53]:
digit_recognizer.eval(train_images[:64]).shape

torch.Size([64, 1, 10])