In [1]:
import os
from requests import get
import torch
from crossval_result_loader import get_results, Result, get_reduction_factor

dataset = "CelebA"
project_name = "MTVAEs_05.01-cross-val"
directory = f"assets/results/raw/{project_name}/{dataset}/"

def group_by_result_lambda(results: list[Result], lambda_function) -> dict[str, list[Result]]:
    res = {}
    for result in results:
        key = lambda_function(result)
        if key not in res:
            res[key] = []
        res[key].append(result)
    return res

def find_min_loss(results: list[Result], num = 0):
    min_loss = float("inf")
    min_result = None
    for result in results:
        loss = result.get_unconditioned_losses()[num].average
        if loss < min_loss:
            min_loss = loss
            min_result = result
            
    return min_loss

def get_results_as_latex_to_file(directory: str, file_path: str):
    results: list[Result] = get_results(directory)
    
    with open(file_path, "w") as file:
        file.write("\\centering\n")
        file.write("\\scriptsize\n")
        file.write("\\begin{tabular}{||c|c|c|c|c|c||}\n")
        file.write("\\hline\n")
        file.write(" & Config. & Method & Parameters & Reconstruction loss & VQ/KL loss \\\\\n")
        file.write("\\hline\n")
        
        
        grouped_by_vq_vae = group_by_result_lambda(results, lambda x: x.get_model_name())
        
        for method_name, method_res in grouped_by_vq_vae.items():
            display_method_name = "\multirow{"+ str(len(method_res)) +"}{*}{\\rotatebox[origin=c]{90}{"+ method_name +"}}"
            
            method_res = sorted(method_res, key=lambda x: x.get_config_number())
            
            grouped_by_config = group_by_result_lambda(method_res, lambda x: x.get_config_number())
            
            for config, config_res in grouped_by_config.items():
                
                display_config = "\multirow{"+ str(len(config_res)) +"}{*}{" + str(config) + "}"
                
                # sort by get_method and then by get_parameters
                config_res = sorted(config_res, key=lambda x: (x.get_method(), x.get_parameters()))
                
                grouped_by_method = group_by_result_lambda(config_res, lambda x: x.get_method())
                
                for method, method_res in grouped_by_method.items():
                    
                    display_method = "\multirow{"+ str(len(method_res)) +"}{*}{" + str(method) + "}"
                    
                    for i, result in enumerate(method_res):
                        loses = result.get_unconditioned_losses()
                        
                        output_line = display_method_name + " & " + display_config + " & " + display_method + " & " + result.get_parameters() + " & "
                        
                        for i, loss in enumerate(loses):
                            is_min = find_min_loss(config_res, i) == loss.average
                            
                            reduction_factor = get_reduction_factor(result.filename, i == 0)
                            average = loss.average / reduction_factor
                            std = loss.std / reduction_factor

                            if is_min:
                                output_line += "\\textbf{" + "{:.4f}".format(average) + " +- " + "{:.1e}".format(std) + "} & "
                            else:  
                                output_line += "{:.4f}".format(average) + " +- " + "{:.1e}".format(std) + " & "
                            
                        file.write(output_line[:-2] + "\\\\\n")

                            
                        display_method_name = ""
                        display_config = ""
                        display_method = ""
                        file.write("\\cline{4-6}\n")
                    file.write("\\cline{3-6}\n")
                file.write("\\cline{2-6}\n")
            file.write("\\hline\n")
        
                        
        file.write("\\hline\n")           
        file.write ("\\end{tabular}\n")
        
        
file_path = f"../paper/figures/tables/{dataset}.tex"
get_results_as_latex_to_file(directory, file_path)

Loading results...
Loaded:  42
Done.


In [1]:
import os
import re
from requests import get
import torch
from crossval_result_loader import get_results, Result, get_reduction_factor

dataset = "CIFAR10"
project_name = "MTVAEs_05.01-cross-val"
directory = f"assets/results/raw/{project_name}/{dataset}/"

results: list[Result] = get_results(directory)

def output_results_table(results: list[Result], output_path: str):
    # sort 
    results = sorted(results, key=lambda x: (x.get_method(), x.get_parameters()))
    
    with open(output_path, "w") as file:
        file.write("\\centering\n")
        file.write("\\scriptsize\n")
        file.write("\\begin{tabular}{||c|c|c|c||}\n")
        file.write("\\hline\n")
        file.write(" Method & Parameters & Reconstruction loss & KL loss \\\\\n")
        file.write("\\hline\n")
        
        for result in results:
            
            method = result.get_method() if "-" not in result.get_method() else "\\textit{Baseline}"
            output_line = method + " & " + result.get_parameters() + " & "
            losses = result.get_unconditioned_losses()
            for i, loss in enumerate(losses):
                reduction_factor = get_reduction_factor(result.filename, i == 0)
                average = loss.average / reduction_factor
                std = loss.std / reduction_factor
                output_line += "{:.4f}".format(average) + " +- " + "{:.1e}".format(std) + " & "
                
                if results[0].get_unconditioned_losses()[i].average < loss.average:
                    # add arrow up
                    output_line = output_line[:-2] + " $\\uparrow$ & "
                elif results[0].get_unconditioned_losses()[i].average > loss.average:
                    # add arrow down
                    output_line = output_line[:-2] + " $\\downarrow$ & "
                
            file.write(output_line[:-2] + "\\\\\n")
            file.write("\\hline\n")
                   
        file.write ("\\end{tabular}\n")



def print_multidecoder_gaussian_results_table(results: list[Result]):
    output_path = "../paper/figures/tables/scvae2d.tex"
    # filter list instead
    res = list(filter(lambda x: "VQ-VAE" not in x.get_model_name() and x.get_config_number() == 2, results))
    res = list(filter(lambda x: "Multi" in x.get_method() or "-" in x.get_method(), res))
    output_results_table(res, output_path)
    
def print_multidecoder_vqvae_results_table(results: list[Result]):
    output_path = "../paper/figures/tables/scvqvae2d.tex"
    # filter list instead
    res = list(filter(lambda x: "VQ-VAE" in x.get_model_name() and x.get_config_number() == 3, results))
    res = list(filter(lambda x: "Multi" in x.get_method() or "-" in x.get_method(), res))
    output_results_table(res, output_path)
    
def print_singledecoder_gaussian_results_table(results: list[Result]):
    output_path = "../paper/figures/tables/scvae1d.tex"
    # filter list instead
    res = list(filter(lambda x: "VQ-VAE" not in x.get_model_name() and x.get_config_number() == 1, results))
    res = list(filter(lambda x: "Single" in x.get_method() or "-" in x.get_method(), res))
    output_results_table(res, output_path)
    
def print_singledecoder_vqvae_results_table(results: list[Result]):
    output_path = "../paper/figures/tables/scvqvae1d.tex"
    # filter list instead
    res = list(filter(lambda x: "VQ-VAE" in x.get_model_name() and x.get_config_number() == 2, results))
    res = list(filter(lambda x: "Single" in x.get_method() or "-" in x.get_method(), res))
    output_results_table(res, output_path)

print_multidecoder_gaussian_results_table(results)
print_multidecoder_vqvae_results_table(results)

print_singledecoder_gaussian_results_table(results)
print_singledecoder_vqvae_results_table(results)




Loading results...
Loaded:  45
Done.
