In [1]:
import os
import torch
from crossval_result_loader import get_results, Result

dataset = "MNIST"
project_name = "MTVAEs_05.01-cross-val"
directory = f"assets/results/raw/{project_name}/{dataset}/"

def print_results(directory: str):
    results: list[Result] = get_results(directory)  
    for result in results: 
        print("*"*80)
        
        print (result.filename)
        
        print(result.get_display_model_name())
        
        losses = result.get_unconditioned_losses()
        
        for key, value in losses.items():
            print (f"{key}: {value.average} +- {value.std}")
            
print_results(directory)

Loading results...
Loaded:  45
Done.
********************************************************************************
VQVAE(128_16_[32, 64]_0)?dataset=MNIST&batch_size=128&max_epochs=100_crossval_results.pt
VQ-VAE(Conf. Nr.1)
Reconstruction loss: 0.0027908134274184705 +- 8.83935007784005e-09
VQ objective loss: 0.010065646283328534 +- 4.669239887607202e-07
********************************************************************************
VQVAE(128_32_[128, 256]_2)?dataset=MNIST&batch_size=128&max_epochs=100_crossval_results.pt
VQ-VAE(Conf. Nr.2)
Reconstruction loss: 0.0017001190455630422 +- 1.5770437576201424e-09
VQ objective loss: 0.002214624173939228 +- 1.9928815230617174e-08
********************************************************************************
VQVAE(256_64_[128, 256]_4)?dataset=MNIST&batch_size=128&max_epochs=100_crossval_results.pt
VQ-VAE(Conf. Nr.3)
Reconstruction loss: 0.001903897407464683 +- 1.3200110901672714e-08
VQ objective loss: 0.002362433634698391 +- 4.271572894880

In [18]:
def len_vq_vae(results: list[Result]) -> int:
    return len([result for result in results if "VQVAE" in result.filename])

def len_vae(results: list[Result]) -> int:
    return len([result for result in results if "VAE" in result.filename or "SCVAE1D" in result.filename or "SCVAE2D" in result.filename])

def len_config_vq(results: list[Result], config: int) -> int:
    return len([result for result in results if result.get_config_number() == config and "VQVAE" in result.filename])

def conf_output(results: list[Result], conf: int, previous_conf: int | None) -> str:
    if previous_conf == None:
        return "\multirow{"+ str(len_config_vq(results, conf)) +"}{*}{" + str(conf) + "}"
    elif previous_conf != conf:
        return "\multirow{"+ str(len_config_vq(results, conf)) +"}{*}{" + str(conf) + "}"
    else:
        return ""
    

def output_latex_table_of_results(directory: str):
    results : list[Result] = get_results(directory)
    print("*"*80)
    print("\\begin{center}")
    print("\\begin{tabular}{||c|c|c|c|c||}")
    print("\\hline")
    print(" Model name & Configuration & Method & Reconstruction loss & VQ loss \\\\")
    print("\\hline")
    
    
    # sort by config number
    results = sorted(results, key = lambda x: x.get_config_number())
    
    first_vq = "\multirow{"+ str(len_vq_vae(results)) +"}{*}{\\rotatebox[origin=c]{90}{VQVAE}}" 
    
    previous_conf = None
    
    for i, result in enumerate(results):
        if "VQVAE" in result.filename:
            losses = result.get_unconditioned_losses()
            conf = result.get_config_number()
            output_line = first_vq + " & " + conf_output(results, conf, previous_conf) + " & " + result.get_display_model_name() + " & "
            
            previous_conf = conf
            
            if first_vq != "":
                first_vq = ""
                
            
            for key, value in losses.items():
                output_line += str(value.average) + " +- " + str(value.std) + " & "
                
            print(output_line[:-2] + "\\\\")
            
            if i < len(results)-1:
                print("\\cline{3-4}")
            else:
                print("\\hline")
    
    print ("\\end{tabular}")
    print ("\\end{center}")
    print("*"*80)
    
dataset = "MNIST"
project_name = "MTVAEs_05.01-cross-val"
directory = f"assets/results/raw/{project_name}/{dataset}/"
    
output_latex_table_of_results(directory)


Loading results...
Loaded:  45
Done.
********************************************************************************
\begin{center}
\begin{tabular}{||c|c|c|c|c||}
\hline
 Model name & Configuration & Method & Reconstruction loss & VQ loss \\
\hline
\multirow{27}{*}{\rotatebox[origin=c]{90}{VQVAE}} & \multirow{9}{*}{1} & VQ-VAE(Conf. Nr.1) & 0.0027908134274184705 +- 8.83935007784005e-09 & 0.010065646283328534 +- 4.669239887607202e-07 \\
\cline{3-4}
 &  & VQ-VAE(Conf. Nr.1) with Multi Decoder method & 0.012305403221398592 +- 0.00038526916826302573 & 0.003475161537062377 +- 2.3461766616996344e-06 \\
\cline{3-4}
 &  & VQ-VAE(Conf. Nr.1) with Multi Decoder method & 0.00245369290933013 +- 6.4834449898945955e-09 & 0.008168273605406284 +- 1.8744889561730926e-07 \\
\cline{3-4}
 &  & VQ-VAE(Conf. Nr.1) with Multi Decoder method & 0.0024290796369314193 +- 3.5535142889724095e-08 & 0.004208790836855769 +- 2.38779463963357e-07 \\
\cline{3-4}
 &  & VQ-VAE(Conf. Nr.1) with Multi Decoder method & 0.00