# Test models

## Initialization

In [29]:
import os
import pyvdirs.dirs as dirs
import sys
sys.path.insert(0, dirs.SYSTEM_HOME)
sys.path.insert(0, os.path.join(dirs.SYSTEM_HOME, "ToyExample"))
from socket import gethostname

import torch
import json

from ToyExample.toy_example import do_test
import pyvtools.text as vtext

### Parameters

In [32]:
test_batch_size = 2**14
test_seed = 7

series = ["18_Statistics", "19_ACIDParams", "21_Repetitions", "23_NormalizedLogits"]
# series = ["23_NormalizedLogits"]

results_filename = "TestResults.json"

### Auxiliary definitions

In [33]:
get_path = lambda series : os.path.join(dirs.MODELS_HOME, "ToyExample", series)

host_id = gethostname()
other_hosts = vtext.filter_by_string_must(list(dirs.check_directories_file().keys()), [host_id,"else"], must=False)

results_filepath = os.path.join(dirs.RESULTS_HOME, results_filename)

## Run test and collect results

### If needed, run

In [None]:
test_results = {}
series_folders = {}

for s in series:

    series_path = get_path(s)
    contents = os.listdir(series_path)
    folders = [c for c in contents if os.path.isdir(os.path.join(series_path, c))]
    folders = vtext.filter_by_string_must(folders, ["Failed", "Old"], must=False)
    series_folders[s] = folders

    log_files = ["log_"+f+".txt" for f in folders]
    assert all([os.path.isfile(os.path.join(series_path, f)) for f in log_files]), "Some logs have not been found"

    test_results[s] = {}
    for folder, log_file in zip(folders, log_files):

        log_filepath = os.path.join(series_path, log_file)

        files = os.listdir(os.path.join(series_path, folder))
        net_file = vtext.filter_by_string_must(files, "learner")[0]
        EMA_file = "".join(net_file.split("learner"))

        net_filepath = os.path.join(series_path, folder, net_file)
        EMA_filepath = os.path.join(series_path, folder, EMA_file)

        with open(log_filepath, "r") as f:
            acid = False
            for i, line in enumerate(f):
                if "ACID = True" in line:
                    acid = True
                if "Guide model loaded from" in line or i>70: 
                    break
        if "Guide model loaded from" in line:
            guide_line = line
            guide_filepath = guide_line.split("Guide model loaded from ")[-1].split("\n")[0]
            for h in other_hosts:
                guide_filepath = guide_filepath.replace(dirs.check_directories_file()[h]["models_home"], dirs.MODELS_HOME)
        else:
            guide_filepath = None

        folder_results = do_test(
            net_filepath, ema_path=EMA_filepath, guide_path=guide_filepath, acid=acid, 
            classes='A', P_mean=-2.3, P_std=1.5, sigma_max=5, depth_sep=5,
            n_samples=test_batch_size, batch_size=test_batch_size, 
            test_outer=True, test_mandala=True,
            guidance_weight=3,
            seed=test_seed, generator=None,
            log_filename=log_filepath,
            device=torch.device('cuda'))
        
        test_results[s][folder] = folder_results

2025-05-14 19:10 | INFO     | Seed = 7
2025-05-14 19:10 | INFO     | Number of test epochs = 1
2025-05-14 19:10 | INFO     | Test batch size = 16384
2025-05-14 19:10 | INFO     | Number of test samples = 16384
100%|██████████| 1/1 [00:07<00:00,  7.99s/it]
2025-05-14 19:10 | INFO     | Seed = 7
2025-05-14 19:10 | INFO     | Number of test epochs = 1
2025-05-14 19:10 | INFO     | Test batch size = 16384
2025-05-14 19:10 | INFO     | Number of test samples = 16384
100%|██████████| 1/1 [00:07<00:00,  7.94s/it]
2025-05-14 19:10 | INFO     | Seed = 7
2025-05-14 19:10 | INFO     | Number of test epochs = 1
2025-05-14 19:10 | INFO     | Test batch size = 16384
2025-05-14 19:10 | INFO     | Number of test samples = 16384
100%|██████████| 1/1 [00:08<00:00,  8.02s/it]
2025-05-14 19:10 | INFO     | Seed = 7
2025-05-14 19:10 | INFO     | Number of test epochs = 1
2025-05-14 19:10 | INFO     | Test batch size = 16384
2025-05-14 19:10 | INFO     | Number of test samples = 16384
100%|██████████| 1/1 [

In [37]:
with open(results_filepath, "w") as file:
    json.dump({"test_batch_size":test_batch_size,
               "test_seed":test_seed,
               **test_results}, 
               file)

### If not needed, then load

In [None]:
with open(results_filepath, "r") as file:
    test_results = json.load(file)
    test_batch_size = test_results.pop("test_batch_size")
    test_seed = test_results.pop("test_seed")

## Process data

### Get keys and fill gaps

In [24]:
test_keys = set()
for s in series:
    for folder in test_results[s].keys():
        test_keys.update( list(test_results[s][folder].keys()) )

In [25]:
for test_key in test_keys:
    for s in series:
        for folder_key, vals in test_results[s].items():
            if test_key not in vals.keys():
                test_results[s][folder_key][test_key] = None

In [26]:
sorted_results = {test_key: {s: {f: test_results[s][f][test_key] for f in test_results[s].keys()} for s in test_results.keys()} for test_key in test_keys}

In [45]:
test_names = {
    "ema_loss":"EMA's Average Loss",
    "ema_out_loss":"EMA's Outer Average Loss",
    "learner_loss":"Learner's Average Loss",
    "learner_out_loss":"Learner's Outer Average Loss",
    "ref_loss":"Reference's Average Loss",
    "ref_out_loss":"Reference's Outer Average Loss",
    "guide_loss":"Guide's Average Loss",
    "guide_out_loss":"Guide's Outer Average Loss",
    "ema_L2_metric":"EMA's Average L2 Distance",
    "ema_out_L2_metric":"EMA's Outer Average L2 Distance",
    "learner_L2_metric":"Learner's Average L2 Distance",
    "learner_out_L2_metric":"Learner's Outer Average L2 Distance",
    "ema_guided_L2_metric":"Guided EMA's Average L2 Distance",
    "ema_guided_out_L2_metric":"Guided EMA's Outer Average L2 Distance",
    "learner_guided_L2_metric":"Guided Learner's Average L2 Distance",
    "learner_guided_out_L2_metric":"Guided Learner's Outer Average L2 Distance",
    "ema_mandala_score":"EMA's Mandala Score",
    "learner_mandala_score":"Learner's Mandala Score",
    "ema_guided_mandala_score":"Guided EMA's Mandala Score",
    "learner_guided_mandala_score":"Guided Learner's Mandala Score",
    "ema_classification_score":"EMA's Classification Score",
    "learner_classification_score":"Learner's Classification Score",
    "ema_guided_classification_score":"Guided EMA's Classification Score",
    "learner_guided_classification_score":"Guided Learner's Classification Score",
}
assert all([test_key in test_names.keys() for test_key in test_keys]), "Missing key"

In [47]:
test_keys = list(test_names.keys())
sorted_results = {test_key:sorted_results[test_key] for test_key in test_keys}

In [68]:
modified_results = dict(test_results)
for s in series:
    modified_results[s] = {}
    for f in series_folders[s]:
        guided_results = {}
        guided_results["ema_L2_metric"] = test_results[s][f]["ema_guided_L2_metric"]
        guided_results["ema_out_L2_metric"] = test_results[s][f]["ema_guided_out_L2_metric"]
        guided_results["learner_L2_metric"] = test_results[s][f]["learner_guided_L2_metric"]
        guided_results["learner_out_L2_metric"] = test_results[s][f]["learner_guided_out_L2_metric"]
        guided_results["ema_mandala_score"] = test_results[s][f]["ema_guided_mandala_score"]
        guided_results["ema_classification_score"] = test_results[s][f]["ema_guided_classification_score"]
        guided_results["learner_mandala_score"] = test_results[s][f]["learner_guided_mandala_score"]
        guided_results["learner_classification_score"] = test_results[s][f]["learner_guided_classification_score"]
        modified_results[s]["Guided"+f] = guided_results

for test_key in test_keys:
    for s in series:
        for folder_key, vals in modified_results[s].items():
            if test_key not in vals.keys():
                modified_results[s][folder_key][test_key] = None

sorted_modified_results = {test_key: {s: {f: modified_results[s][f][test_key] for f in modified_results[s].keys()} for s in modified_results.keys()} for test_key in test_keys}

## Visualize data

In [63]:
# As a list
for s in series:
    print(">>>>", s, "<<<<<<<<<<<<<\n")
    for test_key in test_keys:
        print("######", test_names[test_key])
        for folder_key, vals in test_results[s].items():
            print(folder_key, "\t", vals[test_key])
        print("")

>>>> 18_Statistics <<<<<<<<<<<<<

###### EMA's Average Loss
NoACID_seed_073 	 0.008855066262185574
NoACID_seed_000 	 0.008971166796982288
NoACID_seed_172 	 0.008900010958313942
NoACID_seed_231 	 0.013386810198426247
NoACID_seed_357 	 0.015524894930422306
ACIDInterpol_seed_000 	 0.008971166796982288
ACIDInterpol_seed_073 	 0.008855066262185574
ACIDInterpol_seed_172 	 0.008900010958313942
ACIDInterpol_seed_231 	 0.013386810198426247
ACIDInterpol_seed_357 	 0.015524894930422306
ACIDTrickInterpol_seed_000 	 0.009761476889252663
ACIDTrick_seed_000 	 0.008490141481161118
ACIDTrick_seed_073 	 0.009044168516993523
ACIDTrick_seed_172 	 0.009004823863506317
ACIDTrick_seed_231 	 0.009113093838095665
ACIDTrick_seed_357 	 0.00889365840703249
ACID_seed_000 	 0.008678608573973179
ACID_seed_073 	 0.008925304748117924
ACID_seed_172 	 0.008490169420838356
ACID_seed_231 	 0.010667999275028706
ACID_seed_357 	 0.009201300330460072

###### EMA's Outer Average Loss
NoACID_seed_073 	 0.21643033623695374
NoACI

In [50]:
for s in series:
    print(">>>>", s, "<<<<<<<<<<<<<\n")
    folders = list(test_results[s].keys())
    print("\t", "\t".join(folders))
    for test_key in test_keys:
        print(test_names[test_key]+"\t", 
              "\t".join([str(sorted_results[test_key][s][folder]) for folder in folders]))
    print("")

>>>> 18_Statistics <<<<<<<<<<<<<

	 NoACID_seed_073	NoACID_seed_000	NoACID_seed_172	NoACID_seed_231	NoACID_seed_357	ACIDInterpol_seed_000	ACIDInterpol_seed_073	ACIDInterpol_seed_172	ACIDInterpol_seed_231	ACIDInterpol_seed_357	ACIDTrickInterpol_seed_000	ACIDTrick_seed_000	ACIDTrick_seed_073	ACIDTrick_seed_172	ACIDTrick_seed_231	ACIDTrick_seed_357	ACID_seed_000	ACID_seed_073	ACID_seed_172	ACID_seed_231	ACID_seed_357
EMA's Average Loss	 0.008855066262185574	0.008971166796982288	0.008900010958313942	0.013386810198426247	0.015524894930422306	0.008971166796982288	0.008855066262185574	0.008900010958313942	0.013386810198426247	0.015524894930422306	0.009761476889252663	0.008490141481161118	0.009044168516993523	0.009004823863506317	0.009113093838095665	0.00889365840703249	0.008678608573973179	0.008925304748117924	0.008490169420838356	0.010667999275028706	0.009201300330460072
EMA's Outer Average Loss	 0.21643033623695374	0.21810638904571533	0.21674653887748718	0.22926414012908936	0.23327693343162

In [None]:
# As a list
for s in series:
    print(">>>>", s, "<<<<<<<<<<<<<\n")
    for test_key in test_keys:
        print("######", test_names[test_key])
        for folder_key, vals in modified_results[s].items():
            print(folder_key, "\t", vals[test_key])
        print("")

>>>> 18_Statistics <<<<<<<<<<<<<

###### EMA's Average Loss
GuidedNoACID_seed_073 	 None
GuidedNoACID_seed_000 	 None
GuidedNoACID_seed_172 	 None
GuidedNoACID_seed_231 	 None
GuidedNoACID_seed_357 	 None
GuidedACIDInterpol_seed_000 	 None
GuidedACIDInterpol_seed_073 	 None
GuidedACIDInterpol_seed_172 	 None
GuidedACIDInterpol_seed_231 	 None
GuidedACIDInterpol_seed_357 	 None
GuidedACIDTrickInterpol_seed_000 	 None
GuidedACIDTrick_seed_000 	 None
GuidedACIDTrick_seed_073 	 None
GuidedACIDTrick_seed_172 	 None
GuidedACIDTrick_seed_231 	 None
GuidedACIDTrick_seed_357 	 None
GuidedACID_seed_000 	 None
GuidedACID_seed_073 	 None
GuidedACID_seed_172 	 None
GuidedACID_seed_231 	 None
GuidedACID_seed_357 	 None

###### EMA's Outer Average Loss
GuidedNoACID_seed_073 	 None
GuidedNoACID_seed_000 	 None
GuidedNoACID_seed_172 	 None
GuidedNoACID_seed_231 	 None
GuidedNoACID_seed_357 	 None
GuidedACIDInterpol_seed_000 	 None
GuidedACIDInterpol_seed_073 	 None
GuidedACIDInterpol_seed_172 	 None
Gu

In [73]:
for s in series:
    print(">>>>", s, "<<<<<<<<<<<<<\n")
    print("\t", "\t".join(series_folders[s]))
    for test_key in test_keys:
        print(test_names[test_key]+"\t", 
              "\t".join([str(sorted_modified_results[test_key][s][folder]) for folder in series_folders[s]]))
    print("")

>>>> 18_Statistics <<<<<<<<<<<<<

	 NoACID_seed_073	NoACID_seed_000	NoACID_seed_172	NoACID_seed_231	NoACID_seed_357	ACIDInterpol_seed_000	ACIDInterpol_seed_073	ACIDInterpol_seed_172	ACIDInterpol_seed_231	ACIDInterpol_seed_357	ACIDTrickInterpol_seed_000	ACIDTrick_seed_000	ACIDTrick_seed_073	ACIDTrick_seed_172	ACIDTrick_seed_231	ACIDTrick_seed_357	ACID_seed_000	ACID_seed_073	ACID_seed_172	ACID_seed_231	ACID_seed_357


KeyError: 'NoACID_seed_073'