In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import ast

In [2]:
def load_hex_weights(file_path):
    # Read the file as a string of hex values
    with open(file_path, 'r') as f:
        data = f.read().splitlines()
    
    # Convert each space-separated hex string to integer
    weights = []
    for line in data:
        for val in line.split():
            # Handle two's complement for negative numbers in int8 range
            int_val = int(val, 16)  # Convert hex string to integer
            if int_val > 127:  # If the value is above 127, it should be a negative number
                int_val -= 256  # Convert to signed 8-bit integer (two's complement)
            weights.append(int_val)
    
    return np.array(weights, dtype=np.int8)

In [3]:
# Use the function to load weights
conv1_weight = load_hex_weights('data/conv1_weight.txt')
conv2_weight = load_hex_weights('data/conv2_weight.txt')
fc1_weight = load_hex_weights('data/fc1_weight.txt')
fc2_weight = load_hex_weights('data/fc2_weight.txt')

# For conv2_weight, reshape the 9x10 matrix to 3x3x10x1
conv1_weight = conv1_weight.reshape(10, 1, 3, 3)  # Reshaping 9x10 to 3x3x10x1
conv2_weight = conv2_weight.reshape(1, 10, 3, 3)  # Reshaping 9x10 to 3x3x10x1
fc1_weight = fc1_weight.reshape(10, 132)  # Reshaping 9x10 to 3x3x10x1
fc2_weight = fc2_weight.reshape(1, 10)

In [4]:
print(conv1_weight)

[[[[-22  -1  16]
   [-10  -5 -18]
   [ -7  17  -9]]]


 [[[  3  10  -5]
   [  3  -6   2]
   [-21   6  -3]]]


 [[[-18  22   8]
   [ -6  14  -1]
   [ -5  -3  -8]]]


 [[[  5  -3   2]
   [-15  14   0]
   [-10  -3   6]]]


 [[[ -5  -1  -7]
   [  1   4  -9]
   [ -6  -8   4]]]


 [[[  0   3  19]
   [  4  -9  96]
   [ -5  -3  -4]]]


 [[[  7  -2  -4]
   [ 14   4 -13]
   [-20  10  -3]]]


 [[[  9  -2   7]
   [  7  -7  11]
   [  8   5   3]]]


 [[[-16   3   2]
   [ 13  22  14]
   [ -6 -13   2]]]


 [[[  3 -12  -2]
   [  1  -2   1]
   [ -6  14 -12]]]]


In [11]:
# 1. 读取文件
with open("data/sample_input.txt", "r") as f:
    # 读取所有数字（包括多空格换行的情况）
    data = f.read().split()

# 2. 转换为整数
data = np.array(list(map(int, data)), dtype=np.int32)

# 3. 检查总长度
print("总数据量:", len(data))  # 应该是 5*16*15 = 1200

# 4. 转换为 numpy 数组并 reshape
sample_input = data.reshape(5, 1, 16, 15)

# 5. 转换为 PyTorch 张量
# sample_input = torch.tensor(data, dtype=torch.int32)

# 6. 检查形状
# print("sample_input shape:", sample_input.shape)
# print(sample_input)
np.set_printoptions(threshold=np.inf)
print("Sample Input:", sample_input)


总数据量: 1200
Sample Input: [[[[ 28  70 206 147 230  53  57 187 179 178  20 125 222  60 191]
   [198  61  34   3 204 220 170 169  45 169 245 248  38 199 104]
   [148   0 250 248 157  40 118 131 241 182  38 245 241   5 167]
   [ 90  10  98 244 221 191  86 239 166 174 199  93 192  89 100]
   [253  74  24  42   9 172  62  27  42  54 137 225 208  28  17]
   [142  42  63   2 120 240  53 210  18 146 114 229 253 106 239]
   [ 54 110  49  17 139 133 201  44 128  39 132 144  73 128  61]
   [208 197  74 165 236 107 189 220 192  57 227 126   9  72  76]
   [197   2 196  66  52  14 182 254  61 149  45 197 110  85 176]
   [179  42  85  13 207 156  70 183  28  37 252 253  37  52 174]
   [ 67  88 173 146  98  48 201 219 173 229 126 194 232 147 146]
   [138 204 154  99  43 202 121 125 154 188 203  13 222  96 198]
   [ 53 182   3 193 206 183 135 206 123 127 104 187 225 121  72]
   [193 236 168 127  27  99 123 148  44 189  15 240 199 207  23]
   [212 225  62   4 193  72 201 152  58  78   0  40  91  45 196]
