In [4]:
import numpy as np
import sys
sys.path.append("/yezhirui/evo_probe")
from src.find_contact import *
from src.gmm import *
from src.contact_space import *

In [5]:
print("=== 初始化数据和模型 ===")

# 1. 获取关键接触点
chemokine_pdb = "1j8i"  # 趋化因子折叠
alternate_pdb = "2jp1"  # 替代折叠
msa_seq = 'EVSDKRT-CVSLTTQRLPVSRIKTYTIT---EGSLRAVIFITKRGLKVCADPQATWVRDVVRSMDRKSNT'
results = calculate_contact_difference_msa_id(chemokine_pdb, alternate_pdb,msa_seq,threshold=10.0,remove_diag=5)
critical_contacts = results['critical_contacts']
print(f"关键接触点数量: {len(critical_contacts)}")

# 2. 加载MJ矩阵
mj_dict = load_mj_matrix("/yezhirui/evo_probe/data/mj_matrix.txt")

# 3. 生成contact embedding

contact_space = ContactSpace(critical_contacts, mj_dict)

# 批量添加节点
path_dir = "/yezhirui/evo_probe/data/sample"
node_configs = [("ANC0", f"{path_dir}/node499_anc0_samples.fasta"),("ANC1", f"{path_dir}/node500_anc1_samples.fasta"),
("ANC2", f"{path_dir}/node501_anc2_samples.fasta"), ("ANC3", f"{path_dir}/node502_anc3_samples.fasta"), ("ANC4", f"{path_dir}/node507_anc4_samples.fasta")]
for node_id, fasta_path in node_configs:
    contact_space.add_node_from_fasta(node_id, fasta_path)

contact_space.build_embeddings()

# anc0_contact_embedding = contact_space.get_node_embeddings("ANC0")
# anc1_contact_embedding = contact_space.get_node_embeddings("ANC1")
anc2_contact_embedding = contact_space.get_node_embeddings("ANC2")
anc3_contact_embedding = contact_space.get_node_embeddings("ANC3")
anc4_contact_embedding = contact_space.get_node_embeddings("ANC4")

=== 初始化数据和模型 ===
PDB坐标提取完成: 93 个残基 (链 A)
PDB坐标提取完成: 60 个残基 (链 A)
距离矩阵计算完成: 93 x 93
距离矩阵计算完成: 60 x 60
共同残基数量: 60 个 (原始: 93 vs 60)
从PDB提取序列完成: 93 个残基 (链 A)
从PDB提取序列完成: 60 个残基 (链 A)
自动检测到PDB起始偏移量: 3
警告: PDB序列比MSA序列长，有24个PDB残基未映射
PDB到MSA映射完成: 66 个残基 (匹配: 66, 不匹配: 0)
PDB序列长度: 93, MSA序列长度: 70
PDB起始偏移: 3 (PDB第4个残基对应MSA第0个位置)
自动检测到PDB起始偏移量: 3
警告: MSA序列比PDB序列长，MSA位置61(R)无对应PDB残基
PDB到MSA映射完成: 57 个残基 (匹配: 57, 不匹配: 0)
PDB序列长度: 60, MSA序列长度: 70
PDB起始偏移: 3 (PDB第4个残基对应MSA第0个位置)
接触对转换完成: 109 个有效接触对, 2 个跳过
接触对转换完成: 133 个有效接触对, 0 个跳过
接触对转换完成: 1354 个有效接触对, 172 个跳过
关键接触点数量: 242


In [6]:
from sklearn.mixture import GaussianMixture


# 联合聚类方法
def joint_clustering_analysis(datasets, k=2):
    """对所有数据集进行联合聚类分析"""
    # 合并所有数据
    all_data = []
    dataset_labels = []
    
    for name, data in datasets.items():
        all_data.append(data)
        dataset_labels.extend([name] * len(data))
    
    combined_data = np.vstack(all_data)
    
    # 在合并数据上进行聚类
    gmm = GaussianMixture(n_components=k, random_state=42)
    cluster_labels = gmm.fit_predict(combined_data)
    
    # 分析每个数据集在各cluster中的分布
    results = {}
    start_idx = 0
    for name, data in datasets.items():
        end_idx = start_idx + len(data)
        dataset_clusters = cluster_labels[start_idx:end_idx]
        results[name] = {
            'cluster_labels': dataset_clusters,
            'cluster_distribution': np.bincount(dataset_clusters)
        }
        start_idx = end_idx
    
    return results, gmm



# 比较不同数据集的聚类中心
def compare_cluster_centers(results_dict):
    """比较不同数据集的聚类中心相似性"""
    centers = {}
    for dataset_name, result in results_dict.items():
        gmm = result['gmm']
        centers[dataset_name] = gmm.means_
    
    # 计算中心间距离矩阵
    from scipy.spatial.distance import cdist
    
    for name1, center1 in centers.items():
        for name2, center2 in centers.items():
            if name1 != name2:
                # 计算所有可能的cluster配对距离
                distances = cdist(center1, center2)
                print(f"{name1} vs {name2}:")
                print(f"  最小距离配对: {np.min(distances)}")
                print(f"  距离矩阵:\n{distances}")


def analyze_conformational_similarity(datasets, all_results):
    """分析不同数据集cluster间的构象倾向性相似性"""
    
    # 对于每对数据集，比较其cluster的构象特征
    dataset_names = list(datasets.keys())
    
    for i, name1 in enumerate(dataset_names):
        for j, name2 in enumerate(dataset_names[i+1:], i+1):
            print(f"\n比较 {name1} 和 {name2}:")
            
            # 获取各cluster的平均构象特征
            data1 = datasets[name1]
            data2 = datasets[name2]
            labels1 = all_results[name1]['cluster_labels']
            labels2 = all_results[name2]['cluster_labels']
            
            # 计算每个cluster的中心
            center1_0 = data1[labels1 == 0].mean(axis=0)
            center1_1 = data1[labels1 == 1].mean(axis=0)
            center2_0 = data2[labels2 == 0].mean(axis=0)
            center2_1 = data2[labels2 == 1].mean(axis=0)
            
            # 计算交叉相似性
            sim_00 = np.corrcoef(center1_0, center2_0)[0,1]
            sim_01 = np.corrcoef(center1_0, center2_1)[0,1]
            sim_10 = np.corrcoef(center1_1, center2_0)[0,1]
            sim_11 = np.corrcoef(center1_1, center2_1)[0,1]
            
            print(f"  {name1}_cluster0 vs {name2}_cluster0: {sim_00:.3f}")
            print(f"  {name1}_cluster0 vs {name2}_cluster1: {sim_01:.3f}")
            print(f"  {name1}_cluster1 vs {name2}_cluster0: {sim_10:.3f}")
            print(f"  {name1}_cluster1 vs {name2}_cluster1: {sim_11:.3f}")
            
            # 判断最佳匹配
            if max(sim_00, sim_11) > max(sim_01, sim_10):
                print(f"  最佳匹配: 标签一致 (0-0, 1-1)")
            else:
                print(f"  最佳匹配: 标签交换 (0-1, 1-0)")

In [8]:
datasets = {
    'ANC2': anc2_contact_embedding,
    'ANC3': anc3_contact_embedding,
    'ANC4': anc4_contact_embedding
}


joint_results, joint_gmm = joint_clustering_analysis(datasets, k=2)

# 查看联合聚类结果（安全版本）
print("\n联合聚类结果:")
for dataset_name, result in joint_results.items():
    cluster_dist = result['cluster_distribution']
    total_samples = sum(cluster_dist)
    
    print(f"{dataset_name}:")
    print(f"  cluster分布: {cluster_dist}")
    print(f"  总样本数: {total_samples}")
    
    # 安全地访问cluster比例
    for i in range(len(cluster_dist)):
        ratio = cluster_dist[i] / total_samples
        print(f"  cluster{i}占比: {ratio:.3f} ({cluster_dist[i]}个样本)")
    
    # 检查是否所有数据都在一个cluster中
    if len(cluster_dist) == 1:
        print(f"  ⚠️  警告: {dataset_name}的所有数据都被分到了同一个cluster!")
    elif len(cluster_dist) < 2:
        print(f"  ⚠️  警告: {dataset_name}只有{len(cluster_dist)}个cluster，少于预期的2个")


联合聚类结果:
ANC2:
  cluster分布: [1000]
  总样本数: 1000
  cluster0占比: 1.000 (1000个样本)
  ⚠️  警告: ANC2的所有数据都被分到了同一个cluster!
ANC3:
  cluster分布: [ 21 979]
  总样本数: 1000
  cluster0占比: 0.021 (21个样本)
  cluster1占比: 0.979 (979个样本)
ANC4:
  cluster分布: [   0 1000]
  总样本数: 1000
  cluster0占比: 0.000 (0个样本)
  cluster1占比: 1.000 (1000个样本)
