In [1]:
import os
from requests import get
import torch
from crossval_result_loader import get_results, Result, get_reduction_factor

dataset = "CelebA"
project_name = "MTVAEs_05.01-cross-val"
directory = f"assets/results/raw/{project_name}/{dataset}/"

def group_by_result_lambda(results, lambda_function):
    res = {}
    for result in results:
        key = lambda_function(result)
        if key not in res:
            res[key] = []
        res[key].append(result)
    return res

def find_min_loss(results: list[Result], num = 0):
    min_loss = float("inf")
    min_result = None
    for result in results:
        loss = result.get_unconditioned_losses()[num].average
        if loss < min_loss:
            min_loss = loss
            min_result = result
            
    return min_loss

def get_results_as_latex_to_file(directory: str, file_path: str):
    results: list[Result] = get_results(directory)
    
    with open(file_path, "w") as file:
        file.write("\\centering\n")
        file.write("\\scriptsize\n")
        file.write("\\begin{tabular}{||c|c|c|c|c|c||}\n")
        file.write("\\hline\n")
        file.write(" & Config. & Method & Parameters & Reconstruction loss & VQ/KL loss \\\\\n")
        file.write("\\hline\n")
        
        
        grouped_by_vq_vae = group_by_result_lambda(results, lambda x: x.get_model_name())
        
        for method_name, method_res in grouped_by_vq_vae.items():
            display_method_name = "\multirow{"+ str(len(method_res)) +"}{*}{\\rotatebox[origin=c]{90}{"+ method_name +"}}"
            
            method_res = sorted(method_res, key=lambda x: x.get_config_number())
            
            grouped_by_config = group_by_result_lambda(method_res, lambda x: x.get_config_number())
            
            for config, config_res in grouped_by_config.items():
                
                display_config = "\multirow{"+ str(len(config_res)) +"}{*}{" + str(config) + "}"
                
                # sort by get_method and then by get_parameters
                config_res = sorted(config_res, key=lambda x: (x.get_method(), x.get_parameters()))
                
                grouped_by_method = group_by_result_lambda(config_res, lambda x: x.get_method())
                
                for method, method_res in grouped_by_method.items():
                    
                    display_method = "\multirow{"+ str(len(method_res)) +"}{*}{" + str(method) + "}"
                    
                    for i, result in enumerate(method_res):
                        loses = result.get_unconditioned_losses()
                        
                        output_line = display_method_name + " & " + display_config + " & " + display_method + " & " + result.get_parameters() + " & "
                        
                        for i, loss in enumerate(loses):
                            is_min = find_min_loss(config_res, i) == loss.average
                            
                            average = loss.average
                            std = loss.std
                            
                            
                            
                            if method_name != "VQ-VAE":
                                reduction_factor = get_reduction_factor(file_path)
                                average = average / reduction_factor
                                std = std / reduction_factor

                            if is_min:
                                output_line += "\\textbf{" + "{:.4f}".format(average) + " +- " + "{:.1e}".format(std) + "} & "
                            else:  
                                output_line += "{:.4f}".format(average) + " +- " + "{:.1e}".format(std) + " & "
                            
                        file.write(output_line[:-2] + "\\\\\n")

                            
                        display_method_name = ""
                        display_config = ""
                        display_method = ""
                        file.write("\\cline{4-6}\n")
                    file.write("\\cline{3-6}\n")
                file.write("\\cline{2-6}\n")
            file.write("\\hline\n")
        
                        
        file.write("\\hline\n")           
        file.write ("\\end{tabular}\n")
        
        
file_path = f"../paper/figures/tables/{dataset}.tex"
get_results_as_latex_to_file(directory, file_path)

Loading results...
Loaded:  34
Done.


In [5]:
import os
from requests import get
import torch
from crossval_result_loader import get_results, Result, get_reduction_factor

dataset = "CIFAR10"
project_name = "MTVAEs_05.01-cross-val"
directory = f"assets/results/raw/{project_name}/{dataset}/"

results: list[Result] = get_results(directory)

def print_multidecoder_gaussian_results_table(results: list[Result]):
    for result in results:
        if "VQ-VAE" not in result.get_model_name() and result.get_config_number() == 2:
            if "Multi" in result.get_method() or "-" in result.get_method():
                print(result.get_model_name())
                print(result.get_config_number())
                print(result.get_method())
                print(result.get_parameters())
                losses = result.get_unconditioned_losses()
                for loss in losses:
                    print(loss.average)
                    print(loss.std)
                print("")
    

print_multidecoder_gaussian_results_table(results)



Loading results...
Loaded:  45
Done.
Gaussian VAE
2
Multi Decoder
Exact sampling, SoftAdapt
5389.62431640625
293.91963315963744
4506.19443359375
4468.218908824921

Gaussian VAE
2
Single Decoder
Uniform sampling, Exponent=60
6820.4607421875
1015.3607862091064
2160.360302734375
515.2576882457732

Gaussian VAE
2
Multi Decoder
Uniform sampling
5557.90830078125
1029.1149960899352
4166.536181640625
2621.225898370743

Gaussian VAE
2
-
-
6673.1736328125
1843.2074376678465
2796.838330078125
985.1821710920334

Gaussian VAE
2
Single Decoder
Gaussian sampling, Exponent=60
6767.97978515625
1311.2787030792235
2228.211376953125
544.1192553377152

Gaussian VAE
2
Multi Decoder
Uniform sampling, SoftAdapt
5358.391796875
11509.937326202391
4498.68779296875
31821.090268001553

Gaussian VAE
2
Single Decoder
Uniform sampling, Exponent=40
6790.29794921875
1586.920136871338
2166.042626953125
242.6016093873978

Gaussian VAE
2
Single Decoder
Gaussian sampling, Exponent=40
6761.11611328125
2394.94
2215.037207031