In [None]:
import numpy as np
from scipy.sparse import csr_matrix

# 生成一个随机的四维输入矩阵，例如N=2, C=3, H=4, W=5
C, H, W = 2, 4, 4
input_matrix = np.random.choice([0, 1], size=(C, H, W), p=[0.8, 0.2])
# input_matrix = np.transpose(input_matrix, (1, 2, 0))
print(input_matrix)

# 将四维矩阵转换为二维矩阵
flattened_matrix = input_matrix.reshape(-1, H*W)
print(flattened_matrix)

# 将二维矩阵转换为CSR格式
csr_matrix_form = csr_matrix(flattened_matrix)

# 输出CSR格式矩阵的一些信息
csr_info = {
    'data': csr_matrix_form.data,
    'indices': csr_matrix_form.indices,
    'indptr': csr_matrix_form.indptr,
    'shape': csr_matrix_form.shape
}
csr_info

In [None]:
def csr_img2col(csr_matrix, kernel_size, stride):
    H, W = csr_matrix.shape
    output_height = (H - kernel_size) // stride + 1
    output_width = (W - kernel_size) // stride + 1

    # 初始化列矩阵
    col_data = []
    col_indices = []
    col_indptr = [0]

    # 遍历每个非零元素，尝试构建3x3窗口
    for i in range(output_height):
        for j in range(output_width):
            window_data = []
            window_indices = []
            for ki in range(kernel_size):
                for kj in range(kernel_size):
                    row_index = i * stride + ki
                    col_index = j * stride + kj

                    # 检查当前位置是否有非零值
                    if row_index < H and col_index < W:
                        data_index = csr_matrix[row_index, col_index]
                        if data_index != 0:
                            window_data.append(data_index)
                            window_indices.append(ki * kernel_size + kj)

            col_data.extend(window_data)
            col_indices.extend(window_indices)
            col_indptr.append(len(col_data))

    # 转换为CSR格式
    from scipy.sparse import csr_matrix as csr
    col_csr_matrix = csr((col_data, col_indices, col_indptr), shape=(output_height * output_width, kernel_size**2))
    return col_csr_matrix


In [None]:
import numpy as np
from scipy.sparse import random

# 设置参数
H, W = 28, 28  # 图像的高度和宽度
kernel_size = 3  # 卷积核的大小
stride = 1  # 步长

# 生成一个28x28的稀疏矩阵，密度设置为0.1
sparse_matrix = random(H, W, density=0.1, format='csr', dtype=np.float32)


# 执行CSR格式的img2col操作
csr_col_matrix = csr_img2col(sparse_matrix, kernel_size, stride)

# 显示结果
csr_col_matrix.shape, csr_col_matrix.nnz
