In [None]:
import pandas as pd
import sys

sys.path.append('../')
from data.dicts.col_dict import *
from helpers.plots import *
plt.rcParams.update(rc)

def plot(df, target, cohort, group=None, max_vars=None, race_col='index', cov_col='cov', val='fraction', save_path=None):
    d = df[df["target"] == target]
    d['cov'] = d['cov']
    d = d[d["group"].isin(group)] if group else d
    
    fig, ax = plt.subplots(figsize=(6, 3), dpi=300)
    for race, g in d.groupby(race_col, sort=False):
        g = g.head(max_vars) if max_vars is not None else g
        x, y = g['cov'], g[val]
        ax.plot(x, y, marker="o", linewidth=2, color=r_colors[race], label=str(race).title())
        if val == 'coef':
            # add error bars
            ax.errorbar(x, y, yerr=g['std_err'], fmt='none', color=r_colors[race], # super thin gray line
                        ecolor='lightgray', elinewidth=1, capsize=0, 
                        zorder=1)
        ax.set_xticks(g['cov'].unique())
        ax.set_xticklabels(g['cov'].unique(), rotation=90, ha="right", fontsize=7)
    unit = 'mL' if val == 'coef' else '%'
    ax.set_ylabel(f"{target.upper()} Explained ({unit})", fontsize=7)
    ax.legend(ncols=len(race_dict.keys()), loc="upper center", bbox_to_anchor=(0.5, 1.35))
    ax.text(1.02, 0.5, cohort_dict[cohort],
        transform=ax.transAxes, rotation=-90,
        va="center", ha="left",
    )
    
    fig.tight_layout()
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0.1, format='pdf', dpi=300)
    plt.show()
    return

cohorts = ['nh3', 'nh4', 'nh', 'ukb']
targets = ["fev1", "fvc"]
units = ['coef', 'frac']
output_types = ['reset', 'cont']
output_dir = '../results/explain/figures'
data_dir = '../results/explain/tables/processed'

for cohort in cohorts:
    df = pd.read_csv(f"{data_dir}/{cohort}_cont.csv")
    for t in targets:
        for unit in ['coef', 'fraction']:
            save_path = f"{output_dir}/{t}_{cohort}_{unit}.png"
            plot(df, t, cohort=cohort, val=unit, save_path=save_path)
