In [2]:
import yaml


In [3]:
file = "results.yaml"

with open(file, "r") as f:
    results = yaml.load(f, Loader=yaml.FullLoader)

In [102]:
def multicol(name, space=2, bars=False):
    if bars:
        return r"\multicolumn{" +str(space)+r"}{|c|}{\textbf{" + name + r"}}"
    return r"\multicolumn{" +str(space)+r"}{c}{\textbf{" + name + r"}}"

def get_names(data):
    far_names = []
    near_names = []
    for _type, datasets in data.items():
        if _type == "farood":
            for dataset in datasets:
                far_names.append(dataset["dataset"])
        elif _type == "nearood":
            for dataset in datasets:
                near_names.append(dataset["dataset"])

    return far_names, near_names

def create_latex_table(result, method, dataset, encoders):
    data = result[method][encoders[0]][dataset]
    far_names, near_names = get_names(data)

    total_items = len(far_names) + len(near_names)

    header = r"\begin{table}[ht]"+ "\n"
    header += r"\caption{" +f"Result for {method} using {dataset}" + r"}"+ "\n"

    header += r"""\centering
\resizebox{\textwidth}{!}{% Resize table to fit within \textwidth horizontally
"""
    header += r"\begin{tabular}{@{}l*{" + str(total_items) + r"}{SS}@{}}" + "\n"
    header += r"\toprule" + "\n"

    description = r"\textbf{Encoder} & "  + " & ".join(multicol(name.replace("_", r"\_")) for name in far_names + near_names) + r" \\" + "\n"
    description +=  r" & {\footnotesize AUROC} $\uparrow$ & {\footnotesize FPR95} $\downarrow$ "*total_items + r" \\" + "\n"
    midrule = r"\midrule" + "\n"
    footer = r"\label{tab:" + f"{method}_{dataset}" + r"}" + "\n"
    footer += r"""
\bottomrule
\end{tabular}
}
\end{table}
"""+"\n"

    rows = [" & "   + multicol("Near OOD", len(near_names)*2, True) + " & " + multicol("Far OOD", len(far_names)*2, True) + r" \\" + "\n"]
    # rows = []

    lookup = {}
    max_data = {}
    for encoder in sorted(encoders):
        if encoder not in result[method]:
            continue
        if dataset not in result[method][encoder]:
            continue
        for name, data in result[method][encoder][dataset].items():
            if name not in ["nearood", "farood"]:
                continue
            lookup_tmp = {d["dataset"]: d["metrics"] for d in data}
            if encoder not in lookup:
                lookup[encoder] = {}
            lookup[encoder].update(lookup_tmp)
 
    for e in encoders:
        if e not in lookup:
            continue
        for data_name in lookup[e]:
            if data_name not in lookup[e]:
                continue
            if data_name not in max_data:
                max_data[data_name] = {
                    "AUC": 0,
                    "FPR_95": 1
                }
            if lookup[e][data_name]["AUC"] > max_data[data_name]["AUC"]:
                max_data[data_name]["AUC"] = lookup[e][data_name]["AUC"]
            if lookup[e][data_name]["FPR_95"] < max_data[data_name]["FPR_95"]:
                max_data[data_name]["FPR_95"] = lookup[e][data_name]["FPR_95"]

    for encoder in sorted(encoders):
        if encoder not in result[method]:
            continue
        if dataset not in result[method][encoder]:
            continue
        if "resnet18_" in encoder:
            row = [r"resnet18\_open\_ood"]
        elif "resnet50_" in encoder:
            row = [r"resnet50\_open\_ood"]
        else:
            row = [encoder]
        for data_names in far_names + near_names:
            data_res = max_data[data_names]
            max_auc = data_res["AUC"]
            min_fpr = data_res["FPR_95"]
            if data_names in lookup[encoder]:
                metrics = lookup[encoder][data_names]
                if metrics["AUC"] == max_auc:
                    row.append(r"\textbf{" + f"{metrics['AUC']*100:.2f}" + r"}")
                else:
                    row.append(f"{metrics['AUC']*100:.2f}")
                if metrics["FPR_95"] == min_fpr:
                    row.append(r"\textbf{" + f"{metrics['FPR_95']*100:.2f}" + r"}")
                else:
                    row.append(f"{metrics['FPR_95']*100:.2f}")
            else:
                row.append("-")
                row.append("-")
                    
        rows.append(" & ".join(row) + r" \\")

    return header + description + midrule + "\n".join(rows) + footer

# Generate the LaTeX table
datasets = ['imagenet200', 'cifar10', 'cifar100', 'covid', 'mnist']
encoders = ['repvgg', 'resnet50d', 'swin', 'deit', 'dino', 'dinov2', 'vit', 'clip'] + ['resnet18_32x32_cifar10_open_ood', 'resnet18_32x32_cifar100_open_ood', 'resnet18_224x224_imagenet200_open_ood', 'resnet50_224x224_imagenet_open_ood']
with open("out.txt", "w") as f:
    for method, result in results.items():
        for dataset in datasets:
            f.write(create_latex_table(results, method, dataset, encoders))

