In [None]:
import os
import sys

sys.path.append(os.path.abspath("../src"))

In [None]:
import bitwise
from bitwise import bp
import torch
from typing import List


class Model:
    _layers: List[bp.Layer]

    def __init__(self, layers: List[bp.Layer]):
        self._layers = layers

    def eval(self, inputs: torch.Tensor) -> bitwise.Tensor:
        outputs = inputs
        for layer in self._layers:
            outputs = layer.eval(outputs)
        return outputs

    def update(self, errors: bitwise.Tensor):
        for i, layer in zip(
            range(len(self._layers) - 1, -1, -1), reversed(self._layers)
        ):
            errors = layer.update(errors)
            if torch.all(errors == 0) and i > 0:
                print(f"warning: no error propagated to layer {i}")
                break


def untrained_model(layer_widths: List[int], device="cpu") -> Model:
    layers = []

    for ins, outs in list(zip(layer_widths, layer_widths[1:])):
        weights = bitwise.identity_matrix(outs, ins).to(device=device)
        bias = torch.randint(
            -(2**31), 2**31, (1, (outs + 31) // 32), device=device, dtype=torch.int32
        )
        layer = bp.Layer(weights, bias, train=True)
        layers.append(layer)

    return Model(layers)

In [None]:
import numpy as np
from pathlib import Path
import torch


def read_mnist_images(filename: Path, device: str = "cpu") -> torch.Tensor:
    with open(filename, "rb") as f:
        _, num_images, rows, cols = np.frombuffer(f.read(16), dtype=">u4")  # Big-endian
        images = np.frombuffer(f.read(), dtype=np.uint8).reshape(num_images, rows, cols)
    return torch.tensor(images, device=device)


def read_mnist_labels(filename: Path, device: str = "str") -> torch.Tensor:
    with open(filename, "rb") as f:
        _, _ = np.frombuffer(f.read(8), dtype=">u4")  # Big-endian
        labels = np.frombuffer(f.read(), dtype=np.uint8)  # Labels are 1-byte each
    return torch.tensor(labels, device=device)


mnist, device = Path.home() / "Downloads" / "mnist", "cpu"
train_images = read_mnist_images(mnist / "train-images.idx3-ubyte", device=device)
train_labels = read_mnist_labels(mnist / "train-labels.idx1-ubyte", device=device)
test_images = read_mnist_images(mnist / "t10k-images.idx3-ubyte", device=device)
test_labels = read_mnist_labels(mnist / "t10k-labels.idx1-ubyte", device=device)


In [None]:
import bitwise
from torch.nn import functional as F


def class_probs(outputs: bitwise.Tensor) -> torch.Tensor:
    bit_counts = bitwise.bit_count_map(outputs).squeeze(dim=-2)
    return F.softmax(bit_counts.to(dtype=torch.float32), dim=1)


def class_loss(outputs: bitwise.Tensor, labels: torch.Tensor) -> float:
    bit_counts = bitwise.bit_count_map(outputs).squeeze(dim=-2)
    return F.cross_entropy(bit_counts.to(dtype=torch.float32), labels).item()


def output_errors(outputs: bitwise.Tensor, labels: torch.Tensor) -> bitwise.Tensor:
    expected = (
        F.one_hot(labels.to(dtype=torch.long), num_classes=10)
        .to(dtype=torch.int32)
        .unsqueeze_(1)
        * -1
    )
    return expected.bitwise_xor_(outputs)

In [None]:
import torch
from torch.utils.data import TensorDataset


def create_dataset(
    images: torch.Tensor, labels: torch.Tensor, batch_size: int
) -> TensorDataset:
    assert len(images) == len(labels)
    size = len(images) // batch_size * batch_size
    return TensorDataset(images[:size], labels[:size])

In [None]:
from torch.utils.data import TensorDataset, DataLoader


def compute_accuracy(model: Model, dataset: TensorDataset, batch_size: int) -> float:
    correct = 0
    dataloader = DataLoader(dataset, batch_size=batch_size)
    for images, labels in dataloader:
        inputs = images.view(batch_size, 1, 28 * 28).view(torch.int32)
        outputs = model.eval(inputs)
        predicted = class_probs(outputs).argmax(dim=-1)
        correct += (predicted == labels).sum()
    return (correct / len(dataset)).item()

In [None]:
from bitwise import bp

model = bp.untrained_model(
    [28 * 28 * 8, 4096, 4096, 4096, 10 * 32], device=device
)

In [None]:
batch_size = 8
train_dataset = create_dataset(train_images, train_labels, batch_size)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
for i, (images, labels) in enumerate(train_dataloader):
    inputs = images.view(batch_size, 1, 28 * 28).view(torch.int32)
    outputs = model.eval(inputs)
    loss_before = class_loss(outputs, labels)
    errors = output_errors(outputs, labels)
    model.update(errors)
    outputs = model.eval(inputs)
    loss_after = class_loss(outputs, labels)
    progress = ((i + 1) / len(train_dataloader)) * 100
    print(f"{progress:.3f}%: {loss_before} -> {loss_after}")

In [None]:
test_dataset = create_dataset(test_images, test_labels, batch_size)
compute_accuracy(model, test_dataset, batch_size)

In [None]:
compute_accuracy(model, train_dataset, batch_size)