In [4]:
import os
import torch
from crossval_result_loader import get_results, Result

dataset = "MNIST"
project_name = "MTVAEs_05.01-cross-val"
directory = f"assets/results/raw/{project_name}/{dataset}/"

def group_by_result_lambda(results, lambda_function):
    res = {}
    for result in results:
        key = lambda_function(result)
        if key not in res:
            res[key] = []
        res[key].append(result)
    return res

def find_min_loss(results: list[Result], num = 0):
    min_loss = float("inf")
    min_result = None
    for result in results:
        loss = result.get_unconditioned_losses()[num].average
        if loss < min_loss:
            min_loss = loss
            min_result = result
            
    return min_loss

def get_results_as_latex_to_file(directory: str, file_path: str):
    results: list[Result] = get_results(directory)
    
    with open(file_path, "w") as file:
        file.write("\\begin{center}\n")
        file.write("\\begin{table}\n")
        file.write("\\tiny\n")
        file.write("\\begin{tabular}{||c|c|c|c|c|c||}\n")
        file.write("\\hline\n")
        file.write(" & Config. & Method & Parameters & Reconstruction loss & VQ loss \\\\\n")
        file.write("\\hline\n")
        
        
        grouped_by_vq_vae = group_by_result_lambda(results, lambda x: x.get_model_name())
        
        for method_name, method_res in grouped_by_vq_vae.items():
            display_method_name = "\multirow{"+ str(len(method_res)) +"}{*}{\\rotatebox[origin=c]{90}{"+ method_name +"}}"
            
            grouped_by_config = group_by_result_lambda(method_res, lambda x: x.get_config_number())
            
            for config, config_res in grouped_by_config.items():
                
                display_config = "\multirow{"+ str(len(config_res)) +"}{*}{" + str(config) + "}"
                
                # sort by get_method and then by get_parameters
                config_res = sorted(config_res, key=lambda x: (x.get_method(), x.get_parameters()))
                
                grouped_by_method = group_by_result_lambda(config_res, lambda x: x.get_method())
                
                for method, method_res in grouped_by_method.items():
                    
                    display_method = "\multirow{"+ str(len(method_res)) +"}{*}{" + str(method) + "}"
                    
                    for i, result in enumerate(method_res):
                        loses = result.get_unconditioned_losses()
                        
                        output_line = display_method_name + " & " + display_config + " & " + display_method + " & " + result.get_parameters() + " & "
                        
                        for i, loss in enumerate(loses):
                            is_min = find_min_loss(config_res, i) == loss.average
                            
                            average = loss.average
                            std = loss.std
                            
                            
                            
                            if method_name != "VQ-VAE":
                                average = average / 128
                                std = std / 128

                            if is_min:
                                output_line += "\\textbf{" + "{:.4f}".format(average) + " +- " + "{:.1e}".format(std) + "} & "
                            else:  
                                output_line += "{:.4f}".format(average) + " +- " + "{:.1e}".format(std) + " & "
                            
                        file.write(output_line[:-2] + "\\\\\n")

                            
                        display_method_name = ""
                        display_config = ""
                        display_method = ""
                        file.write("\\cline{4-6}\n")
                    file.write("\\cline{3-6}\n")
                file.write("\\cline{2-6}\n")
            file.write("\\hline\n")
        
                        
        file.write("\\hline\n")           
        file.write ("\\end{tabular}\n")
        file.write("\\end{table}\n")
        file.write ("\\end{center}\n")
        
        
file_path = "../paper/figures/tables/MNIST.tex"
get_results_as_latex_to_file(directory, file_path)

Loading results...
Loaded:  45
Done.
