In [150]:
import numpy as np
import reedsolo

def encode_RS_rows(matrix: np.ndarray, nsym: int):
    rs = reedsolo.RSCodec(nsym)
    encoded = []
    for row in matrix:
        bytes_row = row.astype(np.float16).tobytes()
        encoded_bytes = rs.encode(bytes_row)
        encoded.append(np.frombuffer(encoded_bytes, dtype=np.uint8))
    return np.array(encoded)

def decode_RS_rows(matrix_bytes: np.ndarray, nsym: int, original_len: int):
    rs = reedsolo.RSCodec(nsym)
    decoded = []
    for row in matrix_bytes:
        try:
            decoded_bytes = rs.decode(row.tobytes())[0]
            decoded.append(np.frombuffer(decoded_bytes, dtype=np.float16))
        except reedsolo.ReedSolomonError:
            decoded.append(np.full(original_len, np.nan))  # ошибка — заполним NaN
    return np.stack(decoded)

In [151]:
def flip_random_bit(arr, k):
    arr_copy = arr.copy()
    arr_uint32 = arr_copy.view(np.int8)
    if arr_uint32.size == 0:
        return arr_copy  # Пустой массив
    random_index = np.random.randint(0, arr_uint32.size)
    number = arr_uint32.flat[random_index]
    arr_uint32.flat[random_index] ^= np.int8(1 << k)
    print("[DECODED] source number: {}, flipped number: {}".format(number, arr_uint32.flat[random_index]))
    return arr_copy

def float_flip_random_bit(arr, k):
    arr_copy = arr.copy()
    arr_uint32 = arr_copy.view(np.int16)
    if arr_uint32.size == 0:
        return arr_copy  # Пустой массив
    random_index = np.random.randint(0, arr_uint32.size)
    number = arr_uint32.flat[random_index]
    arr_uint32.flat[random_index] ^= np.uint16(1 << k)
    print("[NO DECODED] source number: {}, flipped number: {}".format(number, arr_uint32.flat[random_index]))
    return arr_copy

In [152]:
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

def plot_hist(arr, name, min, max):
    plt.figure(figsize=(10, 5))
    plt.hist(arr.flatten(), bins=range(min, max), edgecolor='black')
    plt.title(name)
    plt.xlabel("Значение")
    plt.ylabel("Частота")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def pca_method(array, rate, info):
    pca = PCA()
    pca.fit(array)
    cumulative_variance = np.cumsum(pca.explained_variance_ratio_)
    k = np.argmax(cumulative_variance >= rate) + 1
    print("[{}] for rate {} k = {}".format(info, rate, k))
    # return pca.transform(array)[:, :k], pca.components_[:k]
    C = pca.components_[:k]
    mu = pca.mean_
    Z = (array - mu) @ C.T
    return Z, C, mu

def regular_matmul(X, W, bit):
    # bit = 13
    Y_correct = X @ W
    X_raw_corrupted = X.copy()
    X_raw_corrupted = float_flip_random_bit(X_raw_corrupted, 13)
    Y_broken = X_raw_corrupted @ W
    print("🔴 Без защиты:", np.linalg.norm(Y_broken - Y_correct))

def rs_matmul(X, W, bit, nsym, info):
    # bit = 7
    # plt.plot(np.sort(X.flatten()))
    X_encoded = encode_RS_rows(X, nsym=nsym)
    print("[INFO] array: {}".format(X_encoded))
    # plot_hist(X_encoded, "Encoded X", min=0, max=300)
    Y_correct = X @ W
    X_encoded_corrupted = X_encoded.copy()
    X_encoded_corrupted = flip_random_bit(X_encoded_corrupted, bit)
    X_restored = decode_RS_rows(X_encoded_corrupted, nsym=nsym, original_len=X.shape[1])
    Y_rs = X_restored @ W
    print("🟢 С RS-восстановлением:", np.linalg.norm(Y_rs - Y_correct))
    if info:
        print("type: {}, size: {}, bytes: {}".format(X.dtype, X.size, 16 * X.size))
        print("type: {}, size: {}, bytes: {}".format(X_encoded.dtype, X_encoded.size, 8 * X_encoded.size))
        print("type: {}, size: {}, bytes: {}".format(X_restored.dtype, X_restored.size, 16 * X_restored.size))
        print("rate of bytes: {}".format((8 * X_encoded.size) / (16 * X.size)))


In [153]:
import time

nsym = 2
n = 256
m = 256

# X = np.random.randn(n, m).astype(np.float16)
# W = np.random.randn(m, n).astype(np.float16)

X = np.random.randn(4, 5).astype(np.float16)
W = np.random.randn(5, 4).astype(np.float16)

start = time.time()
regular_matmul(X, W, 13)
total_regular = time.time() - start
print("-----------------[regular_matmul] Time: {} -----------------", total_regular)

start = time.time()
rs_matmul(X, W, 6, nsym, True)
total_rs = time.time() - start
print("-----------------[rs_matmul] Time: {} -----------------", total_rs)
print("Degradation: {}".format(total_rs / total_regular))

[NO DECODED] source number: -18423, flipped number: -26615
🔴 Без защиты: 1.008
-----------------[regular_matmul] Time: {} ----------------- 0.0002593994140625
[INFO] array: [[ 52  59   9 184 164  52 135  48 210  60 137 254]
 [238 180 148 187 109 183 228 184 141  61 230 165]
 [ 14  42 240 189  79 191 191  55 179 182 136 156]
 [132 190   6  51 247 176  23  57  16  60 241 187]]
[DECODED] source number: -120, flipped number: -56
🟢 С RS-восстановлением: 0.0
type: float16, size: 20, bytes: 320
type: uint8, size: 48, bytes: 384
type: float16, size: 20, bytes: 320
rate of bytes: 1.2
-----------------[rs_matmul] Time: {} ----------------- 0.0008800029754638672
Degradation: 3.3924632352941178
