In [None]:
# TO FIX agg_A_fit (no more)

def extract_2statemelt_global(fits, sites, base, ss_site, disp_name, agg_kadd_fit, label = None):
    #

    store_results = []


    for site in sites:
        # perform fits
        fits_filt, melt_result, smooth_x, smooth_y, lnkobs_20C, lnkobs_20C_err, lnkadd, lnkadd_err, kadd_curve_x, kadd_curve_y = fit_melting_curve(fits, disp_name, site, agg_kadd_fit)

        tm = melt_result.params['g'].value
        tm_err = melt_result.params['g'].stderr
        dH = melt_result.params['f'].value # in kJ/mol
        dH_err = melt_result.params['f'].stderr

        dH_kcal = dH / 4.184 # convert to kcal/mol
        dH_kcal_err = dH_err / 4.184

        # store results
        store_results.append([site, disp_name, tm, tm_err, dH_kcal, dH_kcal_err])

    result_df = pd.DataFrame(store_results, columns = ['site', 'disp_name', 'tm', 'tm_err', 'dH_kcal', 'dH_kcal_err'])

    return result_df


def plot_2statemelt_global(result_df, disp_name, base, label=None):
    fig, ax = plt.subplots(1, 1, figsize=(2.25, 3))

    # Grouped bar setup
    constructs = result_df['disp_name'].unique()
    sites = result_df['site'].unique()
    bar_width = 0.8 / len(sites)
    x = np.arange(len(constructs))

    # Plot each site's bars
    for i, site in enumerate(sites):
        subset = result_df[result_df['site'] == site]
        # Ensure same order of constructs
        subset = subset.set_index('disp_name').reindex(constructs).reset_index()
        heights = subset['tm']
        errors = subset['tm_err']

        ax.bar(
            x + i * bar_width,
            heights,
            yerr=errors,
            width=bar_width,
            label=str(site),
            capsize=2,
            linewidth=0.7,
            edgecolor='white',
        )

    ax.set_xticks(x + bar_width * (len(sites) - 1) / 2)
    ax.set_xticklabels(['WT', 'A8C'])
    ax.set_ylabel(r'Melting temp., $T_m$ (°C)')
    ax.set_xlabel('Construct')
    #ax.legend(title='Site')

    filename = f'{disp_name}_{base}_2statemelt_globalTM'
    if label:
        filename += f'_{label}'
    # legend outside
    # hide legend
    ax.legend().set_visible(False)
    #ax.legend(
    #    bbox_to_anchor=(1.05, 1), loc='upper left'
    #)
    plt.tight_layout()

    plt.savefig(f'{filename}.pdf')

    plt.show()

    return 0


def plot_2statemelt_global_dH(result_df, disp_name, base, label=None):
    fig, ax = plt.subplots(1, 1, figsize=(4, 3))

    # Grouped bar setup
    constructs = result_df['disp_name'].unique()
    sites = result_df['site'].unique()
    bar_width = 0.8 / len(sites)
    x = np.arange(len(constructs))

    # Plot each site's bars
    for i, site in enumerate(sites):
        subset = result_df[result_df['site'] == site]
        # Ensure same order of constructs
        subset = subset.set_index('disp_name').reindex(constructs).reset_index()
        heights = subset['dH_kcal']
        errors = subset['dH_kcal_err']

        ax.bar(
            x + i * bar_width,
            heights,
            yerr=errors,
            width=bar_width,
            label=str(site),
            capsize=3,
            linewidth=1,
            edgecolor='white',
        )

    ax.set_xticks(x + bar_width * (len(sites) - 1) / 2)
    ax.set_xticklabels(['WT', 'A8C'])
    ax.set_ylabel(r'Global melting temp., $T_m$ (°C)')
    ax.set_xlabel('Construct')
    ax.legend(title='Site')
    plt.tight_layout()

    filename = f'{disp_name}_{base}_2statemelt_globaldH'
    if label:
        filename += f'_{label}'
    # no legend
    #ax.legend().set_visible(False)
    plt.savefig(f'{filename}.pdf')

    plt.show()

    return 0


wt_tms = extract_2statemelt_global(fits, [7, 8, 15, 16, 26, 29], 'A', 18, '4U_wt', arrhenius_aggfits['4U_A'])
a8c_tms = extract_2statemelt_global(fits, [7, 8, 15, 16, 26, 29], 'A', 18, '4U_a8c', arrhenius_aggfits['4U_A'])

# barplot

plot_2statemelt_global(pd.concat([wt_tms, a8c_tms], axis = 0), '4U', 'A')
plot_2statemelt_global_dH(pd.concat([wt_tms, a8c_tms], axis = 0), '4U', 'A')



# calculate mean dH and dH_err 
mean_wt_dH = wt_tms['dH_kcal'].mean()
mean_wt_dH_err = wt_tms['dH_kcal_err'].mean()
mean_a8c_dH = a8c_tms['dH_kcal'].mean()
mean_a8c_dH_err = a8c_tms['dH_kcal_err'].mean()
