In [1]:
import os
import torch

import pyvdirs.dirs as dirs
import sys
sys.path.insert(0, dirs.SYSTEM_HOME)
sys.path.insert(0, os.path.join(dirs.SYSTEM_HOME, "ToyExample"))
from socket import gethostname

from ToyExample.toy_example import do_test
import pyvtools.text as vtext

In [2]:
test_batch_size = 2**14
test_seed = 7

# series = ["18_Statistics", "19_ACIDParams", "21_Repetitions", "23_NormalizedLogits"]
series = ["23_NormalizedLogits"]

get_path = lambda series : os.path.join(dirs.MODELS_HOME, "ToyExample", series)

In [3]:
host_id = gethostname()
other_hosts = vtext.filter_by_string_must(list(dirs.check_directories_file().keys()), [host_id,"else"], must=False)

In [4]:
test_results = {}

for s in series:

    series_path = get_path(s)
    contents = os.listdir(series_path)
    folders = [c for c in contents if os.path.isdir(os.path.join(series_path, c))]
    folders = vtext.filter_by_string_must(folders, ["Failed", "Old"], must=False)
    log_files = ["log_"+f+".txt" for f in folders]
    assert all([os.path.isfile(os.path.join(series_path, f)) for f in log_files]), "Some logs have not been found"

    test_results[s] = {}
    for folder, log_file in zip(folders, log_files):

        log_filepath = os.path.join(series_path, log_file)

        files = os.listdir(os.path.join(series_path, folder))
        net_file = vtext.filter_by_string_must(files, "learner")[0]
        EMA_file = "".join(net_file.split("learner"))

        net_filepath = os.path.join(series_path, folder, net_file)
        EMA_filepath = os.path.join(series_path, folder, EMA_file)

        with open(log_filepath, "r") as f:
            acid = False
            for i, line in enumerate(f):
                if "ACID = True" in line:
                    acid = True
                if "Guide model loaded from" in line or i>70: 
                    break
        if "Guide model loaded from" in line:
            guide_line = line
            guide_filepath = guide_line.split("Guide model loaded from ")[-1].split("\n")[0]
            for h in other_hosts:
                guide_filepath = guide_filepath.replace(dirs.check_directories_file()[h]["models_home"], dirs.MODELS_HOME)
        else:
            guide_filepath = None

        folder_results = do_test(
            net_filepath, ema_path=EMA_filepath, guide_path=guide_filepath, acid=acid, 
            classes='A', P_mean=-2.3, P_std=1.5, sigma_max=5, depth_sep=5,
            n_samples=test_batch_size, batch_size=test_batch_size, test_outer=True,
            guidance_weight=3,
            seed=test_seed, generator=None,
            device=torch.device('cuda'))
        
        test_results[s][folder] = folder_results

2025-05-14 14:16:11 | [32mINFO    [0m | Seed = 7[0m
2025-05-14 14:16:11 | [32mINFO    [0m | Number of test epochs = 1[0m
2025-05-14 14:16:11 | [32mINFO    [0m | Test batch size = 16384[0m
2025-05-14 14:16:11 | [32mINFO    [0m | Number of test samples = 16384[0m
2025-05-14 14:16:11 | [32mINFO    [0m | Guidance weight = 3[0m
100%|██████████| 1/1 [00:07<00:00,  7.88s/it]
2025-05-14 14:16:19 | [32mINFO    [0m | Average Test Learner Loss = 0.011694928631186485[0m
2025-05-14 14:16:19 | [32mINFO    [0m | Average Test EMA Loss = 0.010441781021654606[0m
2025-05-14 14:16:19 | [32mINFO    [0m | Average Test Guide Loss = 0.050932493060827255[0m
2025-05-14 14:16:19 | [32mINFO    [0m | Average Test Ref Loss = 0.050932493060827255[0m
2025-05-14 14:16:19 | [32mINFO    [0m | Average Outer Test Learner Loss = 0.22515948116779327[0m
2025-05-14 14:16:19 | [32mINFO    [0m | Average Outer Test EMA Loss = 0.22126927971839905[0m
2025-05-14 14:16:19 | [32mINFO    [0m | Averag

In [20]:
test_keys = set()
for s in series:
    for folder in folders:
        test_keys.update( list(test_results[s][folder].keys()) )

for test_key in test_keys:
    print(test_key)
    for folder_key, vals in test_results[s].items():
        print(folder_key, "\t", vals[test_key])
    print("")

learner_loss
ACID 	 0.011694928631186485
ACIDNonInverted 	 0.012531672604382038
LateACID 	 0.011993270367383957
LateACIDNonInverted 	 0.013276759535074234

ema_guided_out_L2_metric
ACID 	 0.4991648197174072
ACIDNonInverted 	 0.5023188591003418
LateACID 	 0.5064122080802917
LateACIDNonInverted 	 0.5105265378952026

ema_L2_metric
ACID 	 0.01700439676642418
ACIDNonInverted 	 0.014938662759959698
LateACID 	 0.01851225271821022
LateACIDNonInverted 	 0.016909707337617874

learner_out_L2_metric
ACID 	 0.47990307211875916
ACIDNonInverted 	 0.4901620149612427
LateACID 	 0.4804249107837677
LateACIDNonInverted 	 0.47617873549461365

learner_out_loss
ACID 	 0.22515948116779327
ACIDNonInverted 	 0.22926221787929535
LateACID 	 0.22476956248283386
LateACIDNonInverted 	 0.22652536630630493

ema_loss
ACID 	 0.010441781021654606
ACIDNonInverted 	 0.011781025677919388
LateACID 	 0.010537711903452873
LateACIDNonInverted 	 0.011399950832128525

learner_guided_L2_metric
ACID 	 0.09331624209880829
ACIDNonInv

In [32]:
test_keys = set()
for s in series:
    for folder in folders:
        test_keys.update( list(test_results[s][folder].keys()) )

sorted_results = {test_key: {s: {f: test_results[s][f][test_key] for f in test_results[s].keys()} for s in test_results.keys()} for test_key in test_keys}

In [34]:
folders = list(test_results[s].keys())
print("\t", "\t", "\t".join(folders))
for test_key in test_keys:
    print(test_key, "\t", "\t".join([str(sorted_results[test_key][s][folder]) for folder in folders]))

	 	 ACID	ACIDNonInverted	LateACID	LateACIDNonInverted
learner_loss 	 0.011694928631186485	0.012531672604382038	0.011993270367383957	0.013276759535074234
ema_guided_out_L2_metric 	 0.4991648197174072	0.5023188591003418	0.5064122080802917	0.5105265378952026
ema_L2_metric 	 0.01700439676642418	0.014938662759959698	0.01851225271821022	0.016909707337617874
learner_out_L2_metric 	 0.47990307211875916	0.4901620149612427	0.4804249107837677	0.47617873549461365
learner_out_loss 	 0.22515948116779327	0.22926221787929535	0.22476956248283386	0.22652536630630493
ema_loss 	 0.010441781021654606	0.011781025677919388	0.010537711903452873	0.011399950832128525
learner_guided_L2_metric 	 0.09331624209880829	0.09843237698078156	0.10282757878303528	0.11313240975141525
ema_out_loss 	 0.22126927971839905	0.22625857591629028	0.22273573279380798	0.2249699831008911
ema_out_L2_metric 	 0.4804104268550873	0.4805939793586731	0.481892853975296	0.4827585220336914
guide_out_loss 	 0.27971529960632324	0.279715299606323