In [None]:
import json
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
from Bio.PDB import MMCIFParser
from scipy.spatial.distance import cdist
from matplotlib import gridspec

# 设置全局绘图参数
plt.rcParams.update({
    'font.family': 'Arial',  # 设置字体
    'axes.titlesize': 36,  # 标题字体大小
    'axes.labelsize': 36,  # 坐标轴标签字体大小
    'xtick.labelsize': 36,  # X轴刻度字体大小
    'ytick.labelsize': 36,  # Y轴刻度字体大小
    'legend.fontsize': 36,  # 图例字体大小
    'figure.figsize': (10, 8),  # 图像大小
    'axes.linewidth': 4,  # 坐标轴线宽
    'xtick.major.size': 10,  # x 轴主刻度长度
    'ytick.major.size': 10,  # y 轴主刻度长度
    'xtick.minor.size': 5,  # x 轴次刻度长度
    'ytick.minor.size': 5,  # y 轴主刻度长度
    'xtick.major.width': 4,  # x 轴主刻度线宽
    'ytick.major.width': 4,  # y 轴主刻度线宽
    'xtick.minor.width': 4,  # x 轴次刻度线宽
    'ytick.minor.width': 4,  # y 轴主刻度线宽
    'axes.grid': False,  # 关闭背景网格
})

# 提取CA和P原子坐标
def extract_calpha_and_p_coords_from_cif(cif_file):
    parser = MMCIFParser(QUIET=True)
    structure = parser.get_structure("model", cif_file)
    coords = []
    atom_types = []
    for model in structure:
        for chain in model:
            for residue in chain:
                for atom in residue:
                    if atom.name in {"CA", "P"}:
                        coords.append(atom.coord)
                        atom_types.append(atom.name)
    if not coords:
        raise ValueError(f"No valid C-alpha or P atoms found in the CIF file: {cif_file}")
    return np.array(coords), atom_types

# 计算距离矩阵
def calculate_distance_matrix(coords1, coords2):
    return cdist(coords1, coords2)

# 绘制矩阵函数
def plot_matrix(ax, matrix, cmap, norm, title, show_x_labels, show_y_labels, y_ticks=None):
    cax = ax.imshow(matrix, cmap=cmap, norm=norm)
    #ax.set_title(title, fontsize=24, pad=10)
    if not show_x_labels:
        ax.set_xticklabels([])  # 隐藏x轴标签
    if not show_y_labels:
        ax.set_yticklabels([])  # 隐藏y轴标签
    if show_x_labels:
        ax.set_xlabel('Scored residue of Ycg1')
    if show_y_labels:
        ax.set_ylabel('Aligned residue of Brn1')
        # 设置自定义y轴刻度
        if y_ticks is not None:
            ax.set_yticks(y_ticks)
            ax.set_yticklabels([str(i + 383) for i in y_ticks])  # 设置y轴刻度标签
    # 设置刻度朝内
    ax.tick_params(axis='both', direction='in')
    # 在右上角添加文本,隐藏边框
    ax.text(0.95, 0.85, title, transform=ax.transAxes, 
            fontsize=36, color='black', ha='right', va='top')

# 数据处理
def process_pae_data(json_file, cif_file, range1, range2):
    with open(json_file, 'r') as f:
        data = json.load(f)
    pae_matrix = np.array(data['pae'])
    pae_submatrix = pae_matrix[range1, range2]
    coords, _ = extract_calpha_and_p_coords_from_cif(cif_file)
    dist_matrix = calculate_distance_matrix(coords, coords)
    dist_submatrix = dist_matrix[range1, range2]
    Si = np.divide(pae_submatrix, dist_submatrix, out=np.zeros_like(pae_submatrix), where=dist_submatrix != 0)
    return Si

# 定义切片范围
range1 = slice(876, 920)
range2 = slice(0, 871)

# 处理数据
Si1 = process_pae_data(
    '/Users/chenjinyu/Desktop/SMC/Results/PAE/withoutDNA/ycg1_brn1_pdb9/fold_ycg1_brn1_pdb9_full_data_0.json',
    '/Users/chenjinyu/Desktop/SMC/Results/PAE/withoutDNA/ycg1_brn1_pdb9/fold_ycg1_brn1_pdb9_model_0.cif',
    range1, range2
)
Si2 = process_pae_data(
    '/Users/chenjinyu/Desktop/SMC/Results/PAE/withDNA/ycg1_brn1_dnapdb3/fold_ycg1_brn1_dnapdb3_full_data_0.json',
    '/Users/chenjinyu/Desktop/SMC/Results/PAE/withDNA/ycg1_brn1_dnapdb3/fold_ycg1_brn1_dnapdb3_model_0.cif',
    range1, range2
)
Si = Si1 - Si2

# 计算三个矩阵的最小值和最大值
vmin = min(np.min(Si1), np.min(Si2), np.min(Si))
vmax = max(np.max(Si1), np.max(Si2), np.max(Si))

# 设置颜色映射和级别
levels = np.arange(-1.5, 1.6, 0.5)  # 按0.5间隔创建级别
cmap = plt.get_cmap('RdBu_r')  # 选择颜色映射
norm = BoundaryNorm(levels, ncolors=cmap.N)  # 使用levels范围来设置norm

# 创建gridspec布局
fig = plt.figure()
gs = gridspec.GridSpec(3, 1, height_ratios=[1, 1, 1], hspace=0.3)  # 3行1列

# 创建子图
ax1 = plt.subplot(gs[0])
ax2 = plt.subplot(gs[1])
ax3 = plt.subplot(gs[2])

# 设置统一的y轴刻度和范围
y_ticks = np.arange(0, 51, 20)  # 从 0 到 50 每隔 20 个显示一个
y_lim = (0, 43)  # 统一的y轴范围

# 绘制每个矩阵
plot_matrix(ax1, Si1, cmap, norm, r'$\text{ES}_\text{ij}\text{(US)}$', show_x_labels=False, show_y_labels=True, y_ticks=y_ticks)
plot_matrix(ax2, Si2, cmap, norm, r'$\text{ES}_\text{ij}\text{(BS)}$', show_x_labels=False, show_y_labels=True, y_ticks=y_ticks)
plot_matrix(ax3, Si, cmap, norm, r'$\text{ES}_\text{ij}\text{(US-BS)}$', show_x_labels=True, show_y_labels=True, y_ticks=y_ticks)

# 设置统一的y轴范围
#ax1.set_ylim(y_lim)
#ax2.set_ylim(y_lim)
#ax3.set_ylim(y_lim)

# 隐藏第一个和第三个子图的纵轴标签
ax1.set_ylabel('')  # 隐藏第一个子图的ylabel
ax3.set_ylabel('')  # 隐藏第三个子图的ylabel

# 调整纵轴比例
ax1.set_aspect('auto', adjustable='box')
ax2.set_aspect('auto', adjustable='box')
ax3.set_aspect('auto', adjustable='box')

# 添加颜色条
cbar = fig.colorbar(ax1.images[0], ax=[ax1, ax2, ax3], orientation='vertical', pad=0.02)
cbar.set_label(r'$\text{ES}_\text{ij}$')

# 显示图形
plt.tight_layout()
plt.savefig("Si_USBS.png", dpi=300, bbox_inches='tight')
plt.show()