In [1]:
import os
import pandas as pd
import re
os.chdir("../../") # set to llm_nicu_vitalsigns/script/
os.getcwd()


'/sfs/gpfs/tardis/home/jq2uw/llm_nicu_vitalsigns/script'

In [2]:
def bold_best_by_arrow(col, colname):
    try:
        numeric_vals = pd.to_numeric(col, errors='coerce')

        # Identify infinite values and mark them
        is_inf = numeric_vals == float('inf')
        numeric_vals[is_inf] = float('nan')

        # Skip if all are non-numeric
        if numeric_vals.isnull().all():
            return ['-' for _ in col]

        # Find best value depending on arrow direction
        if '↓' in colname:
            best_val = numeric_vals.min()
        elif '↑' in colname:
            best_val = numeric_vals.max()
        else:
            return ['-' if x == float('inf') else str(x) for x in numeric_vals]

        # Find index of the *first* best value
        first_idx = numeric_vals[numeric_vals == best_val].index[0]

        # Build string output with bolding
        out = []
        for i in col.index:
            val = col[i]
            if pd.isna(val) or val == float('inf'):
                out.append('-')
            elif i == first_idx:
                out.append(f"**{val:.2f}**")  # new → always 2-dp
            else:
                out.append(f"{val:.2f}")      # new → always 2-dp
        return out

    except Exception:
        return [str(v) for v in col]  # fallback to string values

    
def comparison_table(attr_suffix='_at', w=0.9, single_score=True):
    dataset_names = {
        'syn_gt': 'Synthetic (with ground truth)', 
        'syn': 'Synthetic',
        'air': 'Air Quality', 
        'nicu': 'NICU Heart Rate'
    }

    res_df_all = []

    for dataset_key, dataset_label in dataset_names.items():
        try:
            # isntructtime
            model_name = f"{dataset_key}{attr_suffix}_self"
            model_name_open = f"{dataset_key}{attr_suffix}_open"
            res_df_it = pd.read_csv(os.path.join("./VITAL/results", model_name, f'res_df_iqr{w}.csv'))
            res_df_it_open = pd.read_csv(os.path.join("./VITAL/results", model_name_open, f'res_df_iqr{w}.csv'))
            # tedit + tweaver
            if attr_suffix == '_at':
                te_folder = "tedit_lite"
            else:
                te_folder = "tedit_lite_tx"
            res_df_te = pd.read_csv(os.path.join(f"./{te_folder}/tedit_save/te", dataset_key, 'res_df_iqr.csv'))
            res_df_tw = pd.read_csv(os.path.join(f"./{te_folder}/tedit_save/tw", dataset_key, 'res_df_iqr.csv'))

            res_df_it['Model'] = 'InstructTime'
            res_df_it_open['Model'] = 'InstructTime (open-vocab)'#r'\makecell{InstructTime\\\footnotesize(open-vocab)}'
            res_df_te['Model'] = 'TEdit'
            res_df_tw['Model'] = 'Time Weaver'

            if attr_suffix == '_at':
                res_df = pd.concat([res_df_it, res_df_te, res_df_tw], axis=0, ignore_index=True)
            else:
                res_df = pd.concat([res_df_it, res_df_it_open, res_df_te, res_df_tw], axis=0, ignore_index=True)

            if single_score:
                for col in res_df.columns:
                    res_df[col] = res_df[col].map(lambda x: re.sub(r'\s*\[.*?\]', '', str(x)))

            cols = ['Model'] + [c for c in res_df.columns if c != 'Model']
            if 'LCSS similarity increase ↑' in cols:
                cols.remove('LCSS similarity increase ↑')
            res_df = res_df[cols]
            for col in [c for c in res_df.columns if c != 'Model']:
                res_df[col] = pd.to_numeric(res_df[col], errors='coerce')
            res_df.loc[res_df["RaTS ↑"] <= 0, "|RaTS (preserved)|↓"] = float('inf')

            res_df.set_index('Model', inplace=True)
            res_df.columns = pd.MultiIndex.from_product([[dataset_label], res_df.columns])

            res_df_all.append(res_df)

        except FileNotFoundError as e:
            print(f"Skipping {dataset_key}: {e}")
            continue

    if not res_df_all:
        raise ValueError("No valid datasets found.")

    res_df_all = pd.concat(res_df_all, axis=1)

    row_tag = r"\makecell{Attribute\\-based}" if attr_suffix == '_at' else r"\makecell{Instruction\\-based}"
    res_df_all.index = pd.MultiIndex.from_product([[row_tag], res_df_all.index])

    for col_group in res_df_all.columns:
        dataset_name, metric_name = col_group
        res_df_all[col_group] = bold_best_by_arrow(res_df_all[col_group], metric_name)

    return res_df_all


In [3]:
w = 0.9
df_attr = comparison_table(attr_suffix='_at', w=w)
df_inst = comparison_table(attr_suffix='', w=w)
df_all = pd.concat([df_inst, df_attr], axis=0)

Skipping nicu: [Errno 2] No such file or directory: './VITAL/results/nicu_at_open/res_df_iqr0.9.csv'
Skipping nicu: [Errno 2] No such file or directory: './tedit_lite_tx/tedit_save/te/nicu/res_df_iqr.csv'


In [4]:
df_all

Unnamed: 0_level_0,Unnamed: 1_level_0,Synthetic (with ground truth),Synthetic (with ground truth),Synthetic (with ground truth),Synthetic (with ground truth),Synthetic (with ground truth),Synthetic,Synthetic,Synthetic,Air Quality,Air Quality,Air Quality
Unnamed: 0_level_1,Unnamed: 1_level_1,Point-wise MSE ↓,Point-wise MAE ↓,DTW distance decrease ↓,RaTS ↑,|RaTS (preserved)|↓,DTW distance decrease ↓,RaTS ↑,|RaTS (preserved)|↓,DTW distance decrease ↓,RaTS ↑,|RaTS (preserved)|↓
Unnamed: 0_level_2,Model,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2
\makecell{Instruction\\-based},InstructTime,**2.44**,**1.24**,**-14.61**,**6.40**,**0.17**,**-13.04**,**5.68**,**0.04**,**-1.63**,**0.34**,**0.56**
\makecell{Instruction\\-based},InstructTime (open-vocab),3.41,1.50,-12.25,5.62,0.29,-11.96,4.78,0.10,-1.62,0.22,0.58
\makecell{Instruction\\-based},TEdit,7.05,2.21,9.52,3.90,4.91,-2.49,3.82,3.00,0.61,0.00,-
\makecell{Instruction\\-based},Time Weaver,6.09,2.04,4.28,3.85,4.98,-3.33,3.78,3.01,0.57,0.11,0.63
\makecell{Attribute\\-based},InstructTime,**2.42**,**1.24**,**-14.70**,**6.10**,0.06,**-13.13**,**5.96**,0.07,**-1.83**,**0.41**,**0.62**
\makecell{Attribute\\-based},TEdit,3.58,1.56,-11.85,5.96,**0.04**,-12.98,5.91,**0.00**,-1.63,0.26,0.64
\makecell{Attribute\\-based},Time Weaver,3.70,1.57,-8.79,5.94,0.09,-13.00,5.89,0.00,-1.69,0.35,0.68


In [7]:
import re
from collections import Counter

# Optional: rename long column labels to shorter ones
rename_map = {
    "Point-wise MSE ↓": "MSE ↓",
    "Point-wise MAE ↓": "MAE ↓",
    "DTW distance decrease ↓": "DTW ↓",
    "RaTS ↑": "RaTS ↑",
    "|RaTS (preserved)|↓": "|RaTS| ↓",
    "Synthetic (with ground truth)": "Synthetic w/ ground truth",
    "Synthetic": "Synthetic",
    "Air Quality": "Air quality",
    "NICU Heart Rate": "NICU heart rate",
}

df_all_clean = df_all.copy()
df_all_clean.columns = pd.MultiIndex.from_tuples([
    (rename_map.get(a, a), rename_map.get(b, b)) for a, b in df_all_clean.columns
])

# Convert markdown bold to LaTeX bold
df_all_latex = df_all_clean.map(lambda x: re.sub(r'\*\*(.*?)\*\*', r'\\textbf{\1}', str(x)))

# Build column format string with vertical bars
group_counts = Counter([a for a, b in df_all_latex.columns])
column_format = 'll' + ''.join(f"|{'c'*count}" for count in group_counts.values()) #+ '|'

# Generate LaTeX
latex_code = df_all_latex.to_latex(
    multicolumn=True,
    multirow=True,
    escape=False,
    column_format=column_format
)
latex_code = re.sub(
    r'\\multicolumn\{(\d+)\}\{c\}',      # old pattern
    r'\\multicolumn{\1}{|c}',            # new: {|c}  ← no trailing ‘|’
    latex_code
)
latex_code = re.sub(r'\\cline\{1-\d+\}', r'\\midrule', latex_code)
latex_code = re.sub(r'\\midrule\n\\bottomrule', r'\\bottomrule', latex_code)

# 1. Center align multicolumn headers
latex_code = re.sub(r'\\multicolumn\{(\d+)\}\{r\}', r'\\multicolumn{\1}{c}', latex_code)
latex_code = re.sub(r'\\multirow\[t\]', r'\\multirow[c]', latex_code)

# 2. Remove Model row if it appears
latex_code = re.sub(r'^\s*&\s*Model\s*&.*?\\\\\n', '', latex_code, flags=re.MULTILINE)

# 3. Add only a *leading* vertical bar in multicolumn headers
latex_code = re.sub(
    r'\\multicolumn\{(\d+)\}\{c\}',   # old pattern
    r'\\multicolumn{\1}{|c}',         # ← no trailing ‘|’
    latex_code
)

# 4. Bold dataset names in multicolumns (first header row)
latex_code = re.sub(
    r'(\\multicolumn\{\d+\}\{\|c\})\{(.*?)\}',
    lambda m: f"{m.group(1)}{{\\textbf{{{m.group(2)}}}}}",
    latex_code,
    count=len(group_counts)              # still only the first header row
)

# 5. Shrink metric row (second header row)
latex_code = re.sub(
    r'(?<=\\\\\n)(.*?&.*?)\\\\',
    lambda m: re.sub(r'([^&]+)', r'\\footnotesize \1', m.group(1)) + r'\\',
    latex_code,
    count=1  # Only apply to second header row
)

# Wrap in LaTeX table
wrapped_latex = rf"""
\begin{{table}}[htbp]
\centering
\resizebox{{\textwidth}}{{!}}{{%
{latex_code}}}
\caption{{Comparison of editing performance across datasets and models.}}
\label{{tab:comparison_table}}
\end{{table}}
"""

print(wrapped_latex)



\begin{table}[htbp]
\centering
\resizebox{\textwidth}{!}{%
\begin{tabular}{ll|ccccc|ccc|ccc}
\toprule
 &  & \multicolumn{5}{|c}{\textbf{Synthetic w/ ground truth}} & \multicolumn{3}{|c}{\textbf{Synthetic}} & \multicolumn{3}{|c}{\textbf{Air quality}} \\
\footnotesize  &\footnotesize   &\footnotesize  MSE ↓ &\footnotesize  MAE ↓ &\footnotesize  DTW ↓ &\footnotesize  RaTS ↑ &\footnotesize  |RaTS| ↓ &\footnotesize  DTW ↓ &\footnotesize  RaTS ↑ &\footnotesize  |RaTS| ↓ &\footnotesize  DTW ↓ &\footnotesize  RaTS ↑ &\footnotesize  |RaTS| ↓ \\
\midrule
\multirow[c]{4}{*}{\makecell{Instruction\\-based}} & InstructTime & \textbf{2.44} & \textbf{1.24} & \textbf{-14.61} & \textbf{6.40} & \textbf{0.17} & \textbf{-13.04} & \textbf{5.68} & \textbf{0.04} & \textbf{-1.63} & \textbf{0.34} & \textbf{0.56} \\
 & InstructTime (open-vocab) & 3.41 & 1.50 & -12.25 & 5.62 & 0.29 & -11.96 & 4.78 & 0.10 & -1.62 & 0.22 & 0.58 \\
 & TEdit & 7.05 & 2.21 & 9.52 & 3.90 & 4.91 & -2.49 & 3.82 & 3.00 & 0.61 & 0.00 & - 