In [None]:
import os
import json
import numpy as np
from Bio.PDB import MMCIFParser
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt

# 更新绘图风格
plt.rcParams.update({
    'font.family': 'Arial',  # 设置字体
    'axes.titlesize': 42,  # 标题字体大小
    'axes.labelsize': 42,  # 坐标轴标签字体大小
    'xtick.labelsize': 42,  # X轴刻度字体大小
    'ytick.labelsize': 42,  # Y轴刻度字体大小
    'legend.fontsize': 36,  # 图例字体大小
    'figure.figsize': (12, 6),  # 图像大小
    'axes.linewidth': 4,  # 坐标轴线宽
    'xtick.major.size': 10, # x 轴主刻度长度
    'ytick.major.size': 10, # y 轴主刻度长度
    'xtick.minor.size': 5, # x 轴次刻度长度
    'ytick.minor.size': 5, # y 轴主刻度长度
    'xtick.major.width': 4, # x 轴主刻度线宽
    'ytick.major.width': 4, # y 轴主刻度线宽
    'xtick.minor.width': 4, # x 轴次刻度线宽
    'ytick.minor.width': 4, # y 轴主刻度线宽
    'axes.grid': False,  # 关闭背景网格
})

# 定义函数：加载 PAE 矩阵
def load_pae_matrix(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
        pae_matrix = np.array(data['pae'])
    return pae_matrix

# 提取CA和P原子坐标
def extract_calpha_and_p_coords_from_cif(cif_file):
    parser = MMCIFParser(QUIET=True)
    structure = parser.get_structure("model", cif_file)
    coords = []
    atom_types = []
    for model in structure:
        for chain in model:
            for residue in chain:
                for atom in residue:
                    if atom.name == "CA" or atom.name == "P":
                        coords.append(atom.coord)
                        atom_types.append(atom.name)
    if not coords:
        raise ValueError(f"No valid C-alpha or P atoms found in the CIF file: {cif_file}")
    return np.array(coords), atom_types

# 计算距离矩阵
def calculate_distance_matrix(coords1, coords2):
    return cdist(coords1, coords2)

# 计算特定行范围的平均 q 值
def calculate_average_q_for_specific_rows(pae_matrix, dist_matrix, start_row, end_row, max_col=None, threshold=13):
    row_means = []
    for i in range(start_row, end_row + 1):
        valid_indices = np.where(
            (dist_matrix[i, :(max_col + 1) if max_col is not None else None] < threshold)
        )[0]
        if valid_indices.size > 0:
            q_values = pae_matrix[i, valid_indices] / dist_matrix[i, valid_indices]
            row_means.append(np.mean(q_values))
        else:
            row_means.append(0)  # 如果没有满足条件的值，均值为 0
    return row_means

# 计算多个文件夹的平均结果
def calculate_average_across_folders(base_path, start_row, end_row, max_col, threshold):
    all_row_means = []
    
    for folder_name in os.listdir(base_path):
        folder_path = os.path.join(base_path, folder_name)
        if not os.path.isdir(folder_path):
            continue

        # 查找 JSON 和 CIF 文件
        try:
            json_file = find_files_with_suffix(folder_path, "_full_data_0.json")
            cif_file = find_files_with_suffix(folder_path, "_model_0.cif")
        except FileNotFoundError:
            print(f"Skipping folder {folder_path} due to missing files.")
            continue

        # 加载 PAE 矩阵
        pae_matrix = load_pae_matrix(json_file)
        coords, _ = extract_calpha_and_p_coords_from_cif(cif_file)
        dist_matrix = calculate_distance_matrix(coords, coords)
        
        # 计算行范围的平均 q 值
        row_means = calculate_average_q_for_specific_rows(pae_matrix, dist_matrix, start_row, end_row, max_col, threshold)
        all_row_means.append(row_means)
    
    # 对所有文件夹结果进行平均
    all_row_means = np.array(all_row_means)
    return np.mean(all_row_means, axis=0)

# 绘制统计平均后的柱状图
def plot_average_comparison(avg_row_means1, avg_row_means2, start_row, end_row, label1, label2):
    plt.figure(figsize=(12, 6))
    x_labels = range(start_row -494, end_row + 1 - 494)
    bar_width = 0.45
    x_positions1 = np.arange(len(x_labels))
    x_positions2 = x_positions1 + bar_width

    # 绘制两组柱状图
    plt.bar(x_positions1, avg_row_means1, width=bar_width, label=label1, color='tab:red')
    plt.bar(x_positions2, avg_row_means2, width=bar_width, label=label2, color='tab:blue')

    # 设置横坐标标签，每隔10个显示一个
    step = 10
    x_ticks = range(start_row -494, end_row + 1 - 494, step)
    x_tick_positions = [x_positions1[i] for i in range(len(x_labels)) if x_labels[i] in x_ticks]
    plt.xticks(x_tick_positions, x_ticks, rotation=0)
    plt.yticks([0,1,2])

    # 设置刻度方向和宽度
    plt.tick_params(axis='x', which='both', direction='out')
    plt.tick_params(axis='y', which='both', direction='in')

    plt.xlabel("Residue index of Brn1")
    plt.ylabel(r"$\text{ES}_\text{i}$")
    plt.title(rf"13 Å")
    plt.legend(frameon=False)
    plt.tight_layout()
    plt.savefig("Si_USBS_ri_13A.png", dpi=300)
    plt.show()

# 查找特定后缀的文件
def find_files_with_suffix(base_path, suffix):
    for file_name in os.listdir(base_path):
        if file_name.endswith(suffix):
            return os.path.join(base_path, file_name)
    raise FileNotFoundError(f"No file with suffix '{suffix}' found in {base_path}")

# 主程序
def main():
    # 两个顶级路径
    paths = {
        "withoutDNA": "/Users/chenjinyu/Desktop/SMC/Results/PAE/withoutDNA",
        "withDNA": "/Users/chenjinyu/Desktop/SMC/Results/PAE/withDNA"
    }
    
    # 公共参数
    start_row, end_row = 877, 920
    max_col = 870  # 限制列范围 j ≤ max_col
    threshold = 13
    
    avg_results = {}

    # 计算每个顶级路径的平均结果
    for label, base_path in paths.items():
        avg_row_means = calculate_average_across_folders(base_path, start_row, end_row, max_col, threshold)
        avg_results[label] = avg_row_means
        print(f"Average results for {label}: {avg_row_means}")

    # 绘制平均比较柱状图
    plot_average_comparison(avg_results["withoutDNA"], avg_results["withDNA"], start_row, end_row, "US", "BS")

# 运行主程序
if __name__ == "__main__":
    main()