将RSA拉成一列的数据反推成一个矩阵，之后做成图片并保存

In [16]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [17]:
def reconstruct_matrix(A):
    A = A[1:]
    colnum = len(A)
    num = int((1 + np.sqrt(1 + 8 * colnum)) / 2)  # 求解 num*(num-1)/2 = colnum
    if num * (num - 1) // 2 != colnum:
        raise ValueError(f"输入向量维度 {colnum} 无法满足 num*(num-1)/2 公式。")
    B = np.zeros((num, num))
    rows, cols = np.tril_indices(num, k=-1)
    B[rows, cols] = A
    B = B + B.T
    np.fill_diagonal(B, 0)
    return B

def import_data_from_csv(file_path):
    data = pd.read_csv(file_path, header=None)
    return data.values.flatten()

def save_matrix_to_csv(matrix, file_path):
    np.savetxt(file_path, matrix, delimiter=',', fmt='%.4f')

def plot_heatmap(matrix, file_path):
    file_name = os.path.splitext(os.path.basename(file_path))[0]
    plt.figure(figsize=(8, 6))
    plt.imshow(matrix, cmap='coolwarm', interpolation='nearest')
    plt.colorbar()
    plt.xticks([])
    plt.yticks([])
    plt.xlabel('')
    plt.ylabel('') 
    plt.title(f'Heatmap of {file_name}')
    plt.savefig(file_path, format='png')
    plt.close()


def main(csv_input_file, output_matrix_file, output_heatmap_file):
    A = import_data_from_csv(csv_input_file)
    B = reconstruct_matrix(A)
    save_matrix_to_csv(B, output_matrix_file)
    plot_heatmap(B, output_heatmap_file)

In [12]:
#无循环
csv_input_file = 'RSA_GPT_MASQ-DS.csv'
output_matrix_file = 'output_matrix.csv' 
output_heatmap_file = 'heatmap_output.png'
main(csv_input_file, output_matrix_file, output_heatmap_file)

In [None]:
#循环文件
folder_path = 'E:/personal/Documents/RSA-num/'
result_path =  'E:/personal/Documents/result/'
for filename in os.listdir(folder_path):
        if filename.endswith('.csv'):
            csv_input_file = os.path.join(folder_path, filename)
            print (csv_input_file)
            base_filename = os.path.splitext(filename)[0]
            output_matrix_file = os.path.join(result_path, f'{base_filename}_matrix.csv')
            output_heatmap_file = os.path.join(result_path, f'{base_filename}_heatmap.png')
            main(csv_input_file, output_matrix_file, output_heatmap_file)