In [None]:
import re
import os
import pandas as pd
from collections import defaultdict

bg_file = "/home2/s439906/project/CASP16/oligomer/stoich_bg_distribution.csv"


In [None]:
# 创建 DataFrame
df_corrected = pd.read_csv(bg_file)

# 添加链的数量列
df_corrected['chain_count'] = df_corrected['stoichiometry'].apply(lambda x: len(x.split(',')))

# 使用正则表达式提取并按降序排列数字部分
def extract_and_sort_numbers(stoichiometry):
    numbers = [int(num) for num in re.findall(r'\d+', stoichiometry)]
    sorted_numbers = sorted(numbers, reverse=True)
    return tuple(sorted_numbers)

# 应用该函数以创建新列 'sorted_numbers'
df_corrected['sorted_numbers'] = df_corrected['stoichiometry'].apply(extract_and_sort_numbers)

# 显示处理后的 DataFrame
# import ace_tools as tools; tools.display_dataframe_to_user(name="Stoichiometry with Sorted Numbers (Final)", dataframe=df_corrected)

df_corrected

In [None]:
# 按 chain_count 分组数据
unique_chain_counts = {chain_count: df_corrected[df_corrected['chain_count'] == chain_count] 
                       for chain_count in df_corrected['chain_count'].unique()}
unique_chain_counts

In [None]:
# 初始化字典来存储结果
chain_count_analysis = defaultdict(dict)

# 对每个 chain_count 分组，然后对 sorted_numbers 聚合
for chain_count, group in df_corrected.groupby('chain_count'):
    # 计算每种 unique sorted_numbers 的总和
    sorted_number_counts = group.groupby('sorted_numbers')['count'].sum()
    chain_count_analysis[chain_count] = sorted_number_counts.to_dict()
chain_count_analysis

In [None]:
chain_count_analysis = defaultdict(dict)
# 对每个 chain_count 分组
for chain_count, group in df_corrected.groupby('chain_count'):
    # 计算每种 unique sorted_numbers 的计数总和
    sorted_number_counts = group.groupby('sorted_numbers')['count'].sum()
    
    # 计算标准化比例
    total_count = sorted_number_counts.sum()
    normalized_counts = (sorted_number_counts / total_count).to_dict()
    
    chain_count_analysis[chain_count] = normalized_counts

chain_count_analysis

In [None]:
stoichiometry_dir = "/data/data1/conglab/jzhan6/CASP16/CASP16_scores/oligo_20240910/firstmodels/"
stoichiometry_files = [file for file in os.listdir(stoichiometry_dir) if file.endswith('.stoichiometry')]
stoichiometry_files.__len__()

In [None]:
def count_unique_chains(stoichiometry):
    # 提取所有字母（代表链的种类），并计算唯一字母的数量
    unique_chains = set(char for char in stoichiometry if char.isalpha())
    return len(unique_chains)

In [None]:
score_df_all = None
for file in stoichiometry_files:
    target = file.split('.')[0]
    df = pd.read_csv(os.path.join(stoichiometry_dir, file), sep='\t')
    df['group_number'] = df['model'].apply(lambda x: re.search(r'TS\d{3}', x).group(0) if re.search(r'TS\d{3}', x) else None)
    df['sorted_numbers'] = df['truth'].apply(extract_and_sort_numbers)
    df['chain_count'] = df['truth'].apply(count_unique_chains)
    first_sorted_numbers = df['sorted_numbers'].iloc[0]
    first_chain_count = df['chain_count'].iloc[0]
    bg_prob = chain_count_analysis[first_chain_count][first_sorted_numbers]
    score = 1 - bg_prob # if the probability is low, then the score awarded should be high
    df['score'] = df['match_status'].apply(
        lambda x: score if x == "yes" else 0
    )

    score_df = df[['group_number', 'score']]
    score_df.set_index('group_number', inplace=True)
    # rename the column to the target name
    score_df.columns = [target]

    if score_df_all is None:
        score_df_all = score_df
    else:
        score_df_all = pd.concat([score_df_all, score_df], axis=1)


df
score_df
score_df_all

In [None]:
mask = score_df_all.isna()
# sort columns alphabetically
score_df_all = score_df_all.reindex(sorted(score_df_all.columns), axis=1)
# impute missing values with 0
score_df_all.fillna(0, inplace=True)
# sum the scores for each row
score_df_all['total_score'] = score_df_all.sum(axis=1)
# sort the dataframe by total_score
score_df_all.sort_values('total_score', ascending=False, inplace=True)
score_df_all

In [None]:
# plot the total scores as bar plot
import matplotlib.pyplot as plt
plt.figure(figsize=(16, 6), dpi=300)
plt.bar(score_df_all.index, score_df_all['total_score'])
plt.xlabel('Group Number', fontsize=16)
plt.ylabel('Total Score', fontsize=16)
# rotate the x-axis labels by 90 degrees
plt.xticks(rotation=45, fontsize=10, ha='right')
plt.yticks(fontsize=10)
plt.title('Total Scores for stoichiometry for each Group', fontsize=20)
plt.show()



In [None]:
import seaborn as sns
from matplotlib.gridspec import GridSpec
from matplotlib.colors import ListedColormap
import numpy as np
# drop the total_score column
score_df_all.drop(columns='total_score', inplace=True)


In [None]:
sum = score_df_all.sum(axis=1)
sorted_indices = sum.sort_values(ascending=True).index
sorted_heatmap_data = score_df_all.loc[sorted_indices].reset_index(
    drop=True)
sorted_sum = sum.loc[sorted_indices].reset_index(drop=True)
sorted_mask = pd.DataFrame(
    mask, index=score_df_all.index).loc[sorted_indices].reset_index(drop=True)
# use mask to mask the data. will be used for heatmap
masked_data = sorted_heatmap_data.copy()
masked_data[sorted_mask] = np.nan
# set up the colormap
cmap = plt.cm.YlGn
cmap = ListedColormap(cmap(np.linspace(0, 1, 256)))
cmap.set_bad(color='gray')  # set the masked area to gray
# set up the figure and gridspec
fig = plt.figure(figsize=(24, 18), dpi=300)
gs = GridSpec(1, 2, width_ratios=[4, 1], wspace=0.3)
# plot the heatmap
ax0 = fig.add_subplot(gs[0])
sns.heatmap(masked_data, cmap=cmap, cbar=True, ax=ax0)
ax0.set_yticklabels(
    [f'{i}' for i in sorted_indices], rotation=0)  # use the same order as the row sum
ax0.set_xticklabels(sorted_heatmap_data.columns, rotation=90)
# set x tick font size
ax0.tick_params(axis='x', labelsize=16)
# set y tick font size
ax0.tick_params(axis='y', labelsize=16)
# set the font size of the colorbar
cbar = ax0.collections[0].colorbar
cbar.ax.tick_params(labelsize=16)

ax0.set_title(
    "Heatmap for scores of stoichiometry", fontsize=20)
# plot the row sum
ax1 = fig.add_subplot(gs[1], sharey=ax0)
y_pos = range(len(sorted_sum))
y_pos = [i+0.5 for i in y_pos]  # change the position of the bars
ax1.barh(y_pos, sorted_sum, color='tan')
# ax1.margins(y=0.5)
ax1.set_yticks(range(len(sorted_sum)))
ax1.set_yticklabels(
    [f'{i}' for i in sorted_indices], rotation=0)  # use the same order as the heatmap
# ax1.spines['bottom'].set_position(('outward', 10))  # 将 x 轴向下移动 10 点
# ymin, ymax = ax1.get_ylim()  # 获取当前的 y 轴范围
# ax1.set_ylim(ymin - 1, ymax-1)  # 为最底部条形预留空间
# set x tick font size
ax1.tick_params(axis='x', labelsize=16)
# set y tick font size
ax1.tick_params(axis='y', labelsize=16)
ax1.invert_yaxis()  # flip the y-axis
ax1.set_xlabel("Sum", fontsize=16)
ax1.set_title("Group sum scores", fontsize=20)