In [149]:
import warnings
warnings.filterwarnings("ignore")
from abc import ABC, abstractmethod
import numpy as np
import reedsolo
import time

class baseLinLayer(ABC):
    def __init__(self, weights):
        self.weights = weights

    @abstractmethod
    def info(self):
        pass

    @abstractmethod
    def bitflip(self, k: int):
        pass

    @abstractmethod
    def forward(self):
        pass


class regLinLayer(baseLinLayer):
    def __init__(self, weights):
        super().__init__(weights=weights)
        self.time_ = []
    
    def info(self):
        pass
    
    def bitflip(self, k: int):
        # k in [0,...,15]
        rand_ind = np.random.randint(0, self.weights.ravel().size)
        source_number = self.weights.ravel()[rand_ind]
        flipped = np.frombuffer(source_number.tobytes(), dtype=np.uint16)[0] ^ (1 << k)
        corr_number = np.frombuffer(np.uint16(flipped).tobytes(), dtype=np.float16)[0]
        self.weights.ravel()[rand_ind] = corr_number
        print("source number: {}, corr number: {}".format(source_number, corr_number))

    def forward(self, X):
        start = time.time()
        res = X @ self.weights
        self.time_.append(time.time() - start)
        return res
    

class safeLinLayer(baseLinLayer):
    def __init__(self, weights, k=8, nsym=2):
        super().__init__(weights=weights)
        self.nsym = nsym
        self.rs = reedsolo.RSCodec(nsym)
        self.k = k
        self.Vt_k = None

        # Компрессия
        US_k, Vt_k = self.svd_compress(weights, k=k)
        self.Vt_k = Vt_k
        self.US_k_encoded = self.encode_RS_rows(US_k)
        self.time_ = []

    def info(self):
        pass

    def bitflip(self, k: int):
        pass

    def svd_compress(self, weights: np.ndarray, k: int):
        weights_f32 = weights.astype(np.float32)
        U, S, Vt = np.linalg.svd(weights_f32, full_matrices=False)
        U_k = U[:, :k]
        S_k = S[:k]
        Vt_k = Vt[:k, :]
        US_k = U_k @ np.diag(S_k)
        return US_k.astype(np.float16), Vt_k.astype(np.float16)

    def encode_RS_rows(self, matrix: np.ndarray):
        encoded = []
        for row in matrix:
            bytes_row = row.tobytes()
            encoded_bytes = self.rs.encode(bytes_row)
            encoded.append(np.frombuffer(encoded_bytes, dtype=np.uint8))
        return np.array(encoded)

    def decode_RS_rows(self, matrix_bytes: np.ndarray):
        decoded = []
        for row in matrix_bytes:
            try:
                decoded_bytes = self.rs.decode(row.tobytes())[0]
                decoded.append(np.frombuffer(decoded_bytes, dtype=np.float16))
            except reedsolo.ReedSolomonError:
                decoded.append(np.full(row.size // 2, np.nan, dtype=np.float16))  # float16 → 2 байта
        return np.stack(decoded)

    def forward(self, X):
        start = time.time()
        US_k = self.decode_RS_rows(self.US_k_encoded)#.astype(np.float32)
        Vt_k = self.Vt_k#.astype(np.float32)
        weights_approx = US_k @ Vt_k
        res = X @ weights_approx#.astype(np.float16)
        self.time_.append(time.time() - start)
        return res

In [150]:
from tqdm import tqdm

n = 128
m = 128

# X = np.random.randn(4, 5).astype(np.float32)
# W = np.random.randn(5, 4).astype(np.float32)

X = np.random.randn(n, m).astype(np.float16) #* 1e-5
W = np.random.randn(m, n).astype(np.float16) #* 1e-5


def test_accuracy(layer_reg: baseLinLayer, layer_safe: baseLinLayer):
    reg_out = layer_reg.forward(X)
    safe_out = layer_safe.forward(X)
    print("Diff between reg matmul and safe matmul: {}".format(np.linalg.norm(reg_out - safe_out)))

In [151]:
reg = regLinLayer(W)
safe = safeLinLayer(W, 128, 32)

In [152]:
#  -------------------------------------- Accuracy testing --------------------------------------
test_accuracy(reg, safe)

Diff between reg matmul and safe matmul: 0.583984375
