In [None]:
import pandas as pd
import re
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import math
from pathlib import Path


In [None]:
ROOT    = Path("/root/workspace/Evaluation/data/outputs")
OUTDIR  = ROOT.parent/ "fig_suplementry"
OUTDIR.mkdir(exist_ok=True)

# DATASETS = ["adultsample", "contraceptive", "credit", "imdb"]
DATASETS= ["australian"]
SEEDS    = [ ]
PCTS     = [5, 10,20,40]

PATTERNS = {
    "IPM_fixed":   "IPM_evaluations_fixed_{dat}_{seed}.csv",  
    "SENT-I":      "SENT-I_evaluations_{dat}_{seed}.csv",
    "IPM_retrain": "IPM_evaluations_Retraining_{dat}_{seed}.csv",
}

LABEL_MAP = {
    "IPM":     "IPM",
    "IPM_f":   r"$\mathsf{IPM}_{f}$",
    "SENTI":   "SENTI",
    "unknown": "unknown",
}
MARKER_MAP = {
    "IPM":             "o", 
    r"$\mathsf{IPM}_{f}$": "s",
    "SENTI":           "X", 
}
COLOR_MAP = {
    "IPM": "green",
    r"$\mathsf{IPM}_{f}$": "#ff7f0e",
    "SENTI": "Blue",

}
LEGEND_ORDER = ["SENTI", r"$\mathsf{IPM}_{f}$", "IPM"]
FONT_SIZE        = 25
LEGEND_FONT_SIZE = 22
plt.rcParams.update({
    "text.usetex": False,
    "mathtext.fontset": "cm",
    "font.family": "serif",
    "font.style": "normal",
    "font.size": FONT_SIZE,
    "axes.labelsize": FONT_SIZE,
    "axes.titlesize": FONT_SIZE,
    "xtick.labelsize": FONT_SIZE,
    "ytick.labelsize": FONT_SIZE,
    "legend.fontsize": LEGEND_FONT_SIZE,
    "axes.edgecolor": "black",
    "axes.linewidth": 1.5,
    "grid.color": "darkgray",
    "grid.linestyle": "-",
    "grid.linewidth": 1.0,
    "grid.alpha": 1.0
})


In [None]:

def plot_time_and_similarity(csv_paths, pct_nulls, dataset_name, outdir=OUTDIR):
    sns.set_context("talk")
    sns.set_style("whitegrid")

    # Load data
    frames = []
    for path in csv_paths:
        df = pd.read_csv(path)
        
#________________________________________for other Datasets____________________________________
        
        # if df.empty:
        #     continue
        # df["size"]  = df["end_index"] + 1
        # df["chunk"] = df["end_index"] - df["start_index"] + 1
        # seed = re.search(r"_(\d+)\.csv$", path.name).group(1)
        # method = (
        #     "IPM"   if "Retraining" in path.name else
        #     "IPM_f" if "fixed" in path.name else
        #     "SENTI" if "SENT-I" in path.name else

        #     "unknown"
        # )
        # df["method"], df["seed"] = method, seed
        # df = df[df["chunk"] > df["chunk"].min()]
        
#__________________________________________for australian___________________________________        
      
        df["dataset_file"] = path.stem
        df["size"]  = df["end_index"] + 1
        df["chunk"] = df["end_index"] - df["start_index"] + 1
        df['pct_nulls'] = df['pct_nulls'].astype(float).round()
        seed = re.search(r"_(\d+)\.csv$", path.name).group(1)
        if "SENT-I" in path.name:
            method = "SENTI"
        elif "Retraining" in path.name:
            method = "IPM"
        elif "fixed" in path.name:
            method = "IPM_f"
        else:
            method = "unknown"
        df["method"], df["seed"] = method, seed
        if df.empty:
            continue 
# ________________________________________________________________________________________________
        
        frames.append(df)

    if not frames:
        print(f"[INFO] No data for {dataset_name} — {pct_nulls}%")
        return

    all_df = pd.concat(frames, ignore_index=True)
    all_df = all_df[all_df["pct_nulls"] == pct_nulls]
    if all_df.empty:
        print(f"[WARN] Empty after pct filter for {dataset_name} — {pct_nulls}%")
        return

    sizes = sorted(all_df["size"].unique())
    ticks = [sizes[0], (sizes[0] + sizes[-1]) // 2, sizes[-1]]
    labels = ["τ=0", "τ=10", "τ=20"]

    # --- Semantic Similarity ---
    sim_long = (
        all_df.melt(
            id_vars=["size","method","seed"],
            value_vars=[c for c in all_df if c.startswith("avg_semantic_sim_")],
            var_name="metric", value_name="avg_semantic_sim"
        )
        .dropna(subset=["avg_semantic_sim"])
    )
    sim_long["method_label"] = sim_long["method"].map(LABEL_MAP)

    fig_sim, ax_sim = plt.subplots(figsize=(6,4))
    
    ##################################################################
    
    ax_sim.set_title(dataset_name, fontsize=FONT_SIZE, pad=12)
    
    sns.lineplot(
        data=sim_long,
        x="size", y="avg_semantic_sim",
        hue="method_label", style="method_label",
        hue_order=LEGEND_ORDER, style_order=LEGEND_ORDER,
        palette=COLOR_MAP,
        markers=MARKER_MAP, dashes=False, alpha=0.5, ci="sd",
        linewidth=3, markersize=10, ax=ax_sim
    )
    ax_sim.set_xlabel("", fontsize=FONT_SIZE)
    if dataset_name == "adultsample":
        ax_sim.set_ylabel("Accuracy", fontsize=FONT_SIZE)
    else:
        ax_sim.set_ylabel("")
    ax_sim.set_ylim(0.7,1)
    ax_sim.set_yticks([ 0.7, 0.8,0.9,1.0])
    ax_sim.yaxis.set_major_formatter(mticker.FormatStrFormatter('%.1f'))
    ax_sim.set_xticks(ticks)
    ax_sim.set_xticklabels(labels)
    ax_sim.tick_params(labelsize=FONT_SIZE)
    for lbl in ax_sim.get_xticklabels():
        if lbl.get_text()=="τ=0": lbl.set_ha('left')
        if lbl.get_text()=="τ=20": lbl.set_ha('right')
    ax_sim.set_xlim(sizes[0],sizes[-1])

    legend_sim = ax_sim.legend(
        loc='lower center', bbox_to_anchor=(0.5,0.02),
        ncol=3, frameon=True,
        borderpad=0.1, columnspacing=1.0,
        handlelength=2, handletextpad=0.3,
        labelspacing=0.1, fontsize=LEGEND_FONT_SIZE
    )
    fig_sim.subplots_adjust(left=0.10, right=0.95, bottom=0.15, top=0.85)
    fig_sim.savefig(outdir/f"{dataset_name}_{pct_nulls}_sim.png",dpi=300,bbox_inches='tight')
    plt.close(fig_sim)

    # --- Total Time (use total_time for first chunk, then imputation_time for IPM_f) ---
    tot_cols = [c for c in all_df if c.startswith("total_time_")]
    imp_cols = [c for c in all_df if c.startswith("imputation_time_")]

    total_df = (
        all_df.melt(
            id_vars=["size","method","seed"],
            value_vars=tot_cols,
            var_name="metric", value_name="total_time"
        )
        .dropna(subset=["total_time"])
    )
    imp_df = (
        all_df.melt(
            id_vars=["size","method","seed"],
            value_vars=imp_cols,
            var_name="metric", value_name="imputation_time"
        )
        .dropna(subset=["imputation_time"])
    )

    time_df = pd.merge(total_df, imp_df, on=["size","method","seed"], how="left")
    time_df["method_label"] = time_df["method"].map(LABEL_MAP)

    first_size = sizes[0]
    def choose_time(row):
        if row["method"] == "IPM_f":
            return row["total_time"] if row["size"] == first_size else row["imputation_time"]
        else:
            return row["total_time"]
    time_df["plot_time"] = time_df.apply(choose_time, axis=1)

    raw_max = time_df["plot_time"].max()
    if raw_max>210: y_max=250
    elif raw_max>150: y_max=200
    elif raw_max>130: y_max=150
    elif raw_max>50:  y_max=100
    elif raw_max>30:  y_max=50
    else:             y_max=math.ceil(raw_max)
    mid = math.ceil(y_max/2)

    fig_time, ax_time = plt.subplots(figsize=(6,4))
    ax_time.set_title(dataset_name, fontsize=FONT_SIZE, pad=12)

    
    # To ensure SENTI is plotted last (on top), sort accordingly
    time_df["zorder"] = time_df["method_label"].map({
    "SENTI": 3,              # draw on top
    r"$\mathsf{IPM}_{f}$": 2,
    "IPM": 1
})
# Draw each method separately in order
    for method_label in LEGEND_ORDER:
     subset = time_df[time_df["method_label"] == method_label]
     sns.lineplot(
         data=subset,
        x="size", y="plot_time",
        label=method_label,
        linestyle="solid",
        hue=None, style=None,
        color=COLOR_MAP[method_label],
        marker=MARKER_MAP[method_label],
        alpha=0.5, ci="sd",
        linewidth=3, markersize=10, ax=ax_time,
        zorder=subset["zorder"].iloc[0]
    )

    
    
    ax_time.set_xlabel("", fontsize=FONT_SIZE)
    if dataset_name=="adultsample":
        ax_time.set_ylabel("Run time (sec)",fontsize=FONT_SIZE)
    else:
        ax_time.set_ylabel("")
    ax_time.set_yticks([0,mid,y_max])
    ax_time.set_yticklabels(["0",str(mid),str(y_max)])
    ax_time.set_ylim(-5,y_max)
    ax_time.yaxis.set_major_formatter(mticker.FormatStrFormatter('%.0f'))
    ax_time.set_xticks(ticks)
    ax_time.set_xticklabels(labels)
    ax_time.tick_params(labelsize=FONT_SIZE)
    for lbl in ax_time.get_xticklabels():
        if lbl.get_text()=="τ=0": lbl.set_ha('left')
        if lbl.get_text()=="τ=20": lbl.set_ha('right')
    ax_time.set_xlim(sizes[0],sizes[-1])

    legend_time = ax_time.legend(
        loc='upper left', bbox_to_anchor=(0.01,0.99),
        ncol=1, frameon=True,
        borderpad=0.2, handlelength=2,
        handletextpad=0.2, labelspacing=0.2,
        fontsize=LEGEND_FONT_SIZE
    )

    fig_time.subplots_adjust(left=0.10, right=0.95, bottom=0.10, top=0.85)
    fig_time.savefig(outdir/f"{dataset_name}_{pct_nulls}_time.png",dpi=300,bbox_inches='tight')
    plt.close(fig_time)

# ----------------------------- MAIN ---------------------------------
if __name__=="__main__":
    for dat in DATASETS:
        for pct in PCTS:
            csvs=[]
            for seed in SEEDS:
                for pat in PATTERNS.values():
                    p=ROOT/pat.format(dat=dat,seed=seed)
                    if p.exists(): csvs.append(p)
            if csvs:
                plot_time_and_similarity(csvs,pct,dat)
