## Simple Matrix Bitwise XOR

In [2]:
import numpy as np

def bitwise_xor(a, b):
    """Scalar XOR: a ^ b = a + b - 2 * (a & b)"""
    return a + b - 2 * (a & b)

def matrix_xor(M_a, M_b):
    """
    Computes element-wise a XOR b for numpy arrays a and b using the arithmetic rearrangement:
    M_a ^ M_b = M_a + b - 2 * (M_a & M_b)
    Assumes a and b are numpy arrays of non-negative integers with compatible shapes.
    """
    return M_a + M_b - 2 * np.bitwise_and(M_a, M_b)

def x_times_matrix_xor(x, M_a, M_b):
    """
    Computes x @ (M_a XOR M_b) where XOR is element-wise, and @ is matrix multiplication.
    Assumes x is a numpy array (matrix), and a, b are numpy arrays (matrices) of non-negative integers
    with shapes such that matrix multiplication is valid (x.shape[1] == a.shape[0], and a.shape == b.shape).
    """
    xor_ab = matrix_xor(M_a, M_b)
    return np.matmul(x, xor_ab)

In [4]:
tests_scalar = [
    (1, 2, 3),
    (3, 1, 2),
    (0, 0, 0),
    (7, 7, 0),
    (12345, 67890, 80139)
]
for a, b, expected in tests_scalar:
    impl = bitwise_xor(a, b)
    builtin = np.bitwise_xor(a, b)
    print(f"a={a}, b={b}: impl={impl} == expected={expected} == builtin={builtin}")

a=1, b=2: impl=3 == expected=3 == builtin=3
a=3, b=1: impl=2 == expected=2 == builtin=2
a=0, b=0: impl=0 == expected=0 == builtin=0
a=7, b=7: impl=0 == expected=0 == builtin=0
a=12345, b=67890: impl=80139 == expected=80139 == builtin=80139


In [5]:
a_vec = np.array([1, 3])
b_vec = np.array([2, 1])
print(matrix_xor(a_vec, b_vec))  # [3 2]
print(np.bitwise_xor(a_vec, b_vec))  # [3 2]

a_zero = np.array([0, 0])
print(matrix_xor(a_zero, a_zero))  # [0 0]

a_mixed = np.array([5, 7])
b_mixed = np.array([3, 7])
print(matrix_xor(a_mixed, b_mixed))  # [6 0]

[3 2]
[3 2]
[0 0]
[6 0]


## Transformer XOR Operation

In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

def bitwise_xor(a, b):
    """Scalar XOR: a ^ b = a + b - 2 * (a & b)"""
    return a + b - 2 * (a & b)

def matrix_xor(M_a, M_b):
    """
    Computes element-wise a XOR b for numpy arrays a and b using the arithmetic rearrangement.
    Assumes a and b are numpy arrays of non-negative integers with compatible shapes.
    """
    return M_a + M_b - 2 * np.bitwise_and(M_a, M_b)

def x_times_matrix_xor(x, M_a, M_b):
    """
    Computes x @ (M_a XOR M_b) where XOR is element-wise, and @ is matrix multiplication.
    Assumes x is a numpy array (matrix), and a, b are numpy arrays with compatible shapes.
    """
    xor_ab = matrix_xor(M_a, M_b)
    return np.matmul(x, xor_ab)

def apply_xor_to_weights_and_evaluate(transformer, input_data, target_data, scale_factor=1000):
    """
    Applies bitwise XOR transformation to transformer weight matrices and evaluates performance.
    Args:
        transformer: TinyTransformerLayer instance
        input_data: Input tensor (batch_size, seq_len, d_model)
        target_data: Target tensor for comparison
        scale_factor: Factor to convert float weights to integers for bitwise ops
    Returns:
        original_output: Output with original weights
        modified_output: Output with XOR-modified weights
        mse_loss: Mean squared error between original and modified outputs
    """
    # Store original weights
    original_weights = {}
    for name, param in transformer.named_parameters():
        if 'weight' in name:
            original_weights[name] = param.data.clone()

    # Apply XOR transformation
    for name, param in transformer.named_parameters():
        if 'weight' in name:
            # Convert to integer range for bitwise ops
            weight_np = param.data.cpu().numpy()
            scaled_weight = np.round(weight_np * scale_factor).astype(int)
            # Apply XOR with a reference matrix (e.g., itself or a constant)
            xor_weight = matrix_xor(scaled_weight, scaled_weight)  # Self-XOR as example
            # Scale back and clip to original range
            param.data.copy_(torch.tensor((xor_weight / scale_factor).clip(min=weight_np.min(), max=weight_np.max()), device=param.device))

    # Forward pass with modified weights
    transformer.eval()
    with torch.no_grad():
        modified_output = transformer(input_data)

    # Restore original weights
    for name, param in transformer.named_parameters():
        if 'weight' in name:
            param.data.copy_(original_weights[name])

    # Forward pass with original weights
    original_output = transformer(input_data)

    # Compute MSE loss
    mse_loss = F.mse_loss(modified_output, original_output).item()

    return original_output, modified_output, mse_loss

In [2]:
# Define TinyTransformerLayer (as provided)
class TinyTransformerLayer(nn.Module):
    """Small block with MHA + FFN, sufficient for demonstration."""
    def __init__(self, d_model=64, nhead=4, dim_feedforward=128, dropout=0.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True, dropout=dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.GELU(),
            nn.Linear(dim_feedforward, d_model),
        )
    def forward(self, x, attn_mask=None):
        h = self.ln1(x)
        a, _ = self.self_attn(h, h, h, attn_mask=attn_mask)
        x = x + a
        h2 = self.ln2(x)
        y = self.ff(h2)
        return x + y

# Example usage
if __name__ == "__main__":
    # Initialize transformer
    transformer = TinyTransformerLayer(d_model=64, nhead=4, dim_feedforward=128)
    transformer.eval()

    # Dummy input and target data
    batch_size, seq_len, d_model = 2, 10, 64
    input_data = torch.randn(batch_size, seq_len, d_model)
    target_data = torch.randn(batch_size, seq_len, d_model)

    # Apply XOR and evaluate
    orig_out, mod_out, mse = apply_xor_to_weights_and_evaluate(transformer, input_data, target_data, scale_factor=1000)
    print(f"Original output shape: {orig_out.shape}, Modified output shape: {mod_out.shape}")
    print(f"MSE between original and modified outputs: {mse}")

Original output shape: torch.Size([2, 10, 64]), Modified output shape: torch.Size([2, 10, 64])
MSE between original and modified outputs: 0.06542350351810455
