In [2]:
import os
import sys

sys.path.append(os.path.abspath("../src"))

In [None]:
import torch


def bool_matmul(a, b):
    a = a.to(torch.uint8)
    b = b.to(torch.uint8)
    c = a @ b
    return c.to(torch.bool)

Sign of derivative can be derived from the weight value, so boolean results are enough.

In [None]:
def nor_and_eval(weights, inputs):
    weights = weights.to(torch.int32)
    inputs = inputs.to(torch.int32)
    or_and = (weights @ inputs).to(torch.bool)
    return ~or_and


def nor_and_grad(weights, inputs):
    weights = weights.to(torch.int32)
    inputs = inputs.transpose(-1, -2)
    prods = weights * inputs.to(torch.int32)
    residues = prods.sum(dim=1, keepdim=True) - prods
    return (~residues.to(torch.bool) & inputs).to(torch.bool)

In [None]:
_weights = torch.tensor([[False, True], [True, False]], dtype=torch.bool)
inputs = torch.tensor([[True], [True]], dtype=torch.bool)
nor_and_eval(_weights, inputs)
nor_and_grad(_weights, inputs)

tensor([[False,  True],
        [ True, False]])

In [None]:
class NorAndLinear:
    def __init__(self, weights):
        self._weights = weights

    def eval(self, inputs):
        weights = self._weights.to(torch.int32)
        inputs = inputs.to(torch.int32)
        or_and = (weights @ inputs).to(torch.bool)
        return ~or_and

    def grad(self, inputs):
        weights = self._weights.to(torch.int32)
        inputs = inputs.transpose(-1, -2)
        prods = weights * inputs.to(torch.int32)
        residues = prods.sum(dim=1, keepdim=True) - prods
        return (~residues.to(torch.bool) & inputs).to(torch.bool)

In [3]:
import torch
import bitwise


def transpose_bitpacked_32x1(matrix):
    """ Transposes a 32x1 bit-packed int32 matrix while maintaining correct bit order. """
    assert matrix.shape == (32,), "Input must be a 32-element tensor of int32"

    # Ensure matrix is int32
    matrix = matrix.to(torch.int32)

    # Generate bit masks safely within int32 limits
    bit_masks = (1 << torch.arange(32, dtype=torch.int64, device=matrix.device)).to(torch.int32)

    # Extract bits into a (32, 32) binary matrix
    bits = ((matrix.view(32, 1) & bit_masks) != 0).to(torch.int32)  # Shape (32, 32)

    # Transpose the bit matrix
    bits_T = bits.t()  # Shape (32, 32)

    # **Fix: Reverse both rows & columns to correct bit ordering**
    bits_T = torch.flip(bits_T, dims=[0, 1])  # Fix row and column flipping

    # Repack into int32 format (preserving column order correctly)
    transposed_matrix = torch.sum(bits_T * bit_masks, dim=1).to(torch.int32)

    return transposed_matrix



# Corrected random bit-packed matrix
matrix = torch.randint(0, 2**31, (32,), dtype=torch.int32)  # Ensure within int32 range
transposed = transpose_bitpacked_32x1(matrix)
print(bitwise.to_str(matrix.reshape(32,1,1)))
print(bitwise.to_str(transposed.reshape(32,1,1)))

[['01000011001101110100100010011000']
 ['01110000101011011000000110001000']
 ['00011100010000011100100110001011']
 ['01100000000011010111010010011000']
 ['01000101011100000111001011110110']
 ['00100000011010111001110100101011']
 ['00010100100110101010111101000101']
 ['01001010001000011001001101001000']
 ['00111110101101110100100100101100']
 ['00111000100101101001100111011111']
 ['01111010110001111011101010110100']
 ['01010111010010011001000001000111']
 ['00000000101101000011111000011001']
 ['01101100100101011010011011010001']
 ['01000110000001101011101010110111']
 ['01101101001100010100101100100000']
 ['00000110110000011111000000010000']
 ['00001000011111100000101111001101']
 ['00001001100000100001100010010001']
 ['01111000111011001001000000000101']
 ['01000100100111110001110011011000']
 ['01001001111100000000101001000000']
 ['01100100111001010100111110010000']
 ['00110000000011110011010111100101']
 ['00001010100101010000010101101011']
 ['00101011000010011101001101110100']
 ['011111110