In [797]:
import warnings
warnings.filterwarnings("ignore")
from abc import ABC, abstractmethod
import numpy as np
import reedsolo
from sklearn.decomposition import PCA
import time

class baseLinLayer(ABC):
    def __init__(self, weights):
        self.weights = weights

    @abstractmethod
    def info(self):
        pass

    @abstractmethod
    def bitflip(self, k: int):
        pass

    @abstractmethod
    def forward(self):
        pass

class regLinLayer(baseLinLayer):
    def __init__(self, weights):
        super().__init__(weights=weights)
        self.time_ = []
    
    def info(self):
        pass
    
    def bitflip(self, k: int):
        # k in [0,...,15]
        rand_ind = np.random.randint(0, self.weights.ravel().size)
        source_number = self.weights.ravel()[rand_ind]
        flipped = np.frombuffer(source_number.tobytes(), dtype=np.uint16)[0] ^ (1 << k)
        corr_number = np.frombuffer(np.uint16(flipped).tobytes(), dtype=np.float16)[0]
        self.weights.ravel()[rand_ind] = corr_number
        print("source number: {}, corr number: {}".format(source_number, corr_number))

    def forward(self, X):
        start = time.time()
        res = X @ self.weights
        self.time_.append(time.time() - start)
        return res
    
class safeLinLayer(baseLinLayer):
    def __init__(self, weights, rate, nsym: int = 2):
        super().__init__(weights=weights)
        self.nsym = nsym
        self.original_len = weights.shape[1]
        self.rs = reedsolo.RSCodec(self.nsym)
        self.pca = PCA()
        # print("[INFO] weights: {}".format(weights))
        self.weights_pca_k = self.pca_encode(weights, rate)
        self.weights_pca_k = self.encode_RS_rows(self.weights_pca_k)
        self.time_ = []

    def info(self):
        pass

    def encode_RS_rows(self, matrix: np.ndarray):
        encoded = []
        for row in matrix:
            bytes_row = row.astype(np.float16).tobytes()
            encoded_bytes = self.rs.encode(bytes_row)
            encoded.append(np.frombuffer(encoded_bytes, dtype=np.uint8))
        return np.array(encoded)
    
    def decode_RS_rows(self, matrix_bytes: np.ndarray):
        decoded = []
        for row in matrix_bytes:
            try:
                decoded_bytes = self.rs.decode(row.tobytes())[0]
                decoded.append(np.frombuffer(decoded_bytes, dtype=np.float16))
            except reedsolo.ReedSolomonError:
                decoded.append(np.full(self.original_len, np.nan))
        return np.stack(decoded)

    def pca_encode(self, weights, rate):
        weights_pca_k = self.pca.fit_transform(weights)
        return weights_pca_k

    def bitflip(self, k: int):
        # k in [0,...,7]
        rand_ind = np.random.randint(0, self.weights_uint8_pca.ravel().size)
        source_number = self.weights_pca_k.ravel()[rand_ind]
        flipped = source_number ^ (1 << k)
        corr_number = np.frombuffer(flipped.tobytes(), dtype=np.uint8)[0]
        self.weights_pca_k.ravel()[rand_ind] = corr_number
        print("source number: {}, corr number: {}".format(source_number, corr_number))
    
    def forward(self, X):
        start = time.time()
        weights_uint = self.decode_RS_rows(self.weights_pca_k)
        restored = self.pca.inverse_transform(weights_uint)
        # print("[INFO] reatored weights: {}".format(restored))
        res = X @ restored
        self.time_.append(time.time() - start)
        return res

In [798]:
from tqdm import tqdm

n = 128
m = 128

X = np.random.randn(4, 5).astype(np.float32)
W = np.random.randn(5, 4).astype(np.float32)

# X = np.random.randn(n, m).astype(np.float16)
# W = np.random.randn(m, n).astype(np.float16) #* 1e-5

def test_time(layer_reg: baseLinLayer, layer_safe: baseLinLayer, max_time: int):
    for _ in tqdm(range(max_time), desc="Processing"):
        _ = layer_reg.forward(X)
        _ = layer_safe.forward(X)
    print("[TIME TEST] degradation: {}%".format(((np.mean(layer_safe.time_) / np.mean(layer_reg.time_)) - 1) * 100))

def test_accuracy(layer_reg: baseLinLayer, layer_safe: baseLinLayer):
    reg_out = layer_reg.forward(X)
    safe_out = layer_safe.forward(X)
    # print("reg_out: {}\n".format(reg_out))
    # print("safe_out: {}\n".format(safe_out))
    print("Diff between reg matmul and safe matmul: {}".format(np.linalg.norm(reg_out - safe_out)))

def test_bitflip(layer_reg: baseLinLayer, layer_safe: baseLinLayer, reg_pos: int, safe_pos: int):
    reg_out = layer_reg.forward(X)
    layer_reg.bitflip(reg_pos)
    reg_bitflip = layer_reg.forward(X)
    print("🔴 Without protection: {}".format(np.linalg.norm(reg_out - reg_bitflip)))

    safe_out = layer_safe.forward(X)
    layer_safe.bitflip(safe_pos)
    safe_bitflip = layer_safe.forward(X)
    print("🟢 With RS protection:".format(np.linalg.norm(safe_out - safe_bitflip)))

In [799]:
reg = regLinLayer(W)
safe = safeLinLayer(W, 0.99, 2)

In [800]:
#  -------------------------------------- Accuracy testing --------------------------------------
test_accuracy(reg, safe)

Diff between reg matmul and safe matmul: 0.0011076020309701562


In [801]:
# #  -------------------------------------- Info about memory --------------------------------------
# reg.info()
# safe.info()

In [802]:
# #  -------------------------------------- Performance testing --------------------------------------
# test_time(reg, safe, 1)

In [803]:
# #  -------------------------------------- Bitflip testing --------------------------------------
# test_bitflip(reg, safe, 13, 7)