In [None]:
import numpy as np


img = np.random.rand(1, 3, 3)

weights = np.random.randn(1, 3, 3)
scale = 2 ** 16


encoded_weights = encode(weights)
encoded_img = encode(img)

array([[[  9235, -12964,   4885],
        [  -860, -14927,  13023],
        [   361, -12094, -19212]]], dtype=int16)

In [31]:
import numpy as np

# ----------------------------
# Utility helpers
# ----------------------------
def split_share(x):
    """Additive split into two shares: returns (x0, x1) s.t. x0+x1 = x (over integers)."""
    r = np.random.randint(low=np.iinfo(np.int64).min//4, high=np.iinfo(np.int64).max//4, size=x.shape, dtype=np.int64)
    x0 = r
    x1 = (x.astype(np.int64) - r).astype(np.int64)
    return x0, x1

def reconstruct(share_pair):
    """Reconstruct from two shares."""
    a, b = share_pair
    return (a.astype(np.int64) + b.astype(np.int64)).astype(np.int64)

# ----------------------------
# Plaintext conv2d (naive)
# IN: (H, W, Cin)
# W: (k, k, Cin, Cout)
# returns O: (H, W, Cout)
# ----------------------------
def plain_conv2d(IN, W, padding=None, stride=1):
    IN = IN.astype(np.int64)
    W = W.astype(np.int64)
    k = W.shape[0]
    if padding is None:
        padding = (k - 1) // 2
    H, W_in, Cin = IN.shape
    Cout = W.shape[3]
    H_p = H + 2 * padding
    W_p = W_in + 2 * padding
    inp = np.pad(IN, ((padding,padding),(padding,padding),(0,0)), mode='constant', constant_values=0)
    out_h = (H_p - k)//stride + 1
    out_w = (W_p - k)//stride + 1
    O = np.zeros((out_h, out_w, Cout), dtype=np.int64)
    for i in range(out_h):
        for j in range(out_w):
            patch = inp[i*stride:i*stride+k, j*stride:j*stride+k, :]   # (k,k,Cin)
            # multiply with each output filter
            # result: sum_{u,v,c} patch[u,v,c] * W[u,v,c,m]
            # vectorized over Cout:
            # patch[..., None] * W -> (k,k,Cin, Cout)
            prod = patch[..., None] * W   # shape (k,k,Cin,Cout)
            O[i, j, :] = np.sum(prod, axis=(0,1,2))
    return O

# ----------------------------
# Beaver triple generator for convolution
# Inputs:
#   IN_shape = (H, W, Cin)
#   W_shape  = (k, k, Cin, Cout)
# Returns:
#   A_sh = (A0, A1) where A shape = IN_shape
#   B_sh = (B0, B1) where B shape = W_shape
#   C_sh = (C0, C1) where C = conv(A, B) shape = (H, W, Cout)
# ----------------------------
def generate_beaver_triple_conv(IN_shape, W_shape, rng=None):
    if rng is None:
        rng = np.random
    H, W_in, Cin = IN_shape
    k, k2, Cin_w, Cout = W_shape
    assert k == k2 and Cin == Cin_w, "kernel dims must match input channels"
    # Sample random A and B with the same shapes as X and Y
    A = rng.randint(low=-4, high=5, size=IN_shape).astype(np.int64)   # example small range
    B = rng.randint(low=-2, high=3, size=W_shape).astype(np.int64)
    # Compute C = conv(A, B)
    C = plain_conv2d(A, B)   # shape (H, W, Cout)
    # split into shares
    A_sh = split_share(A)
    B_sh = split_share(B)
    C_sh = split_share(C)
    return A_sh, B_sh, C_sh

# ----------------------------
# Two-party conv2d online protocol (simulated)
# Inputs:
#   IN_sh = (IN0, IN1)   each of shape (H,W,Cin)
#   W_sh  = (W0, W1)     each of shape (k,k,Cin,Cout)
#   triple_sh = (A_sh, B_sh, C_sh) as returned above
# Returns:
#   out_sh = (O0, O1) additive shares of the convolution output
#
# Implementation details:
#   - Each party computes d_i = IN_i - A_i, e_i = W_i - B_i
#   - They "open" d = d0 + d1 and e = e0 + e1 (simulated by reconstruct)
#   - Compute full Z = C + conv(d, B) + conv(A, e) + conv(d, e)
#   - Randomly split Z into shares Z0, Z1 and return them.
# Note: In a real protocol you would avoid reconstructing intermediate secrets except the masked diffs d and e;
# here this is a simulation of the full protocol.
# ----------------------------
def two_party_conv2d(IN_sh, W_sh, triple_sh):
    IN0, IN1 = IN_sh
    W0, W1 = W_sh
    A_sh, B_sh, C_sh = triple_sh
    A0, A1 = A_sh
    B0, B1 = B_sh
    C0, C1 = C_sh

    # compute local masked differences
    d0 = (IN0.astype(np.int64) - A0.astype(np.int64)).astype(np.int64)
    d1 = (IN1.astype(np.int64) - A1.astype(np.int64)).astype(np.int64)
    e0 = (W0.astype(np.int64) - B0.astype(np.int64)).astype(np.int64)
    e1 = (W1.astype(np.int64) - B1.astype(np.int64)).astype(np.int64)

    # "Open" d and e (in real protocol parties exchange and reconstruct these)
    d_open = (d0 + d1).astype(np.int64)
    e_open = (e0 + e1).astype(np.int64)
    
    # Beaver Triple Protocols

    # Compute the full Z (plaintext) using the formula:
    # Z = C + conv(d_open, B_full) + conv(A_full, e_open) + conv(d_open, e_open)
    B_full = (B0 + B1).astype(np.int64)
    A_full = (A0 + A1).astype(np.int64)
    C_full = (C0 + C1).astype(np.int64)

    conv_d_B = plain_conv2d(d_open, B_full)
    conv_A_e = plain_conv2d(A_full, e_open)
    conv_d_e = plain_conv2d(d_open, e_open)
    Z_full = (C_full + conv_d_B + conv_A_e + conv_d_e).astype(np.int64)

    # Randomly split Z_full into two shares and return them (simulation of producing local shares)
    Z0, Z1 = split_share(Z_full)
    return Z0.astype(np.int64), Z1.astype(np.int64)

# ----------------------------
# Test driver
# ----------------------------
if __name__ == "__main__":
    np.random.seed(123)

    # parameters
    H, W_in = 6, 7
    Cin = 3
    Cout = 4
    k = 3

    # random test input and weights (small ints to avoid overflow)
    IN = np.random.randint(-5, 6, size=(H, W_in, Cin)).astype(np.int64)
    W = np.random.randint(-3, 4, size=(k, k, Cin, Cout)).astype(np.int64)

    # plaintext conv
    plain_out = plain_conv2d(IN, W)

    # share input and weights
    IN0, IN1 = split_share(IN)
    W0, W1 = split_share(W)

    # generate Beaver triple for these shapes
    A_sh, B_sh, C_sh = generate_beaver_triple_conv(IN.shape, W.shape)

    # run two-party protocol (simulated)
    O0, O1 = two_party_conv2d((IN0, IN1), (W0, W1), (A_sh, B_sh, C_sh))

    # reconstruct and compare
    rec = (O0 + O1).astype(np.int64)
    print("plain_out shape:", plain_out.shape)
    print("rec shape      :", rec.shape)
    diff = plain_out - rec
    print("max abs error:", np.max(np.abs(diff)))
    if np.max(np.abs(diff)) == 0:
        print("SUCCESS: reconstructed output equals plaintext convolution.")
    else:
        print("Mismatch detected.")


plain_out shape: (6, 7, 4)
rec shape      : (6, 7, 4)
max abs error: 0
SUCCESS: reconstructed output equals plaintext convolution.


In [38]:
import torchvision
import torchvision.transforms as transforms

data = torchvision.datasets.CIFAR10("./data", train=True, transform=transforms.ToTensor(), download=True)
(data[0][0] * 255).long().numpy()

array([[[ 59,  43,  50, ..., 158, 152, 148],
        [ 16,   0,  18, ..., 123, 119, 122],
        [ 25,  16,  49, ..., 118, 120, 109],
        ...,
        [208, 201, 198, ..., 160,  56,  53],
        [180, 173, 186, ..., 184,  97,  83],
        [177, 168, 179, ..., 216, 151, 123]],

       [[ 62,  46,  48, ..., 132, 125, 124],
        [ 20,   0,   8, ...,  88,  83,  87],
        [ 24,   7,  27, ...,  84,  84,  73],
        ...,
        [170, 153, 161, ..., 133,  31,  34],
        [139, 123, 144, ..., 148,  62,  53],
        [144, 129, 142, ..., 184, 118,  92]],

       [[ 63,  45,  43, ..., 108, 102, 103],
        [ 20,   0,   0, ...,  55,  50,  57],
        [ 21,   0,   8, ...,  50,  50,  42],
        ...,
        [ 96,  34,  26, ...,  70,   7,  20],
        [ 96,  42,  30, ...,  94,  34,  34],
        [116,  94,  87, ..., 140,  84,  72]]], shape=(3, 32, 32))

ValueError: operands could not be broadcast together with shapes (10,100) (100,10) 

IndexError: index 1 is out of bounds for axis 2 with size 1