In [2]:
import matplotlib.pyplot as plt
import re
import os
import numpy as np
from plots import plot_train_compare2, plot_train_compare

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def get_hp_args_from_txt(path):
    with open(path) as f:
        data_string = f.read()

        # Step 2: Define regex patterns to extract the values
        poison_lr_pattern = r"'poison_lr'\s*:\s*([0-9.]+)"
        iterations_pattern = r"'iterations'\s*:\s*(\d+)"
        epsilon_pattern = r"'epsilon'\s*:\s*([0-9.]+)"
        poison_start_epoch_pattern = r"'poison_start_epoch'\s*:\s*(\d+)"

        # Step 3: Extract the values using regex
        poison_lr = re.search(poison_lr_pattern, data_string).group(1)
        iterations = re.search(iterations_pattern, data_string).group(1)
        epsilon = re.search(epsilon_pattern, data_string).group(1)
        poison_start_epoch = re.search(poison_start_epoch_pattern, data_string).group(1)

        # Print the extracted values
        return (poison_lr, iterations, epsilon, poison_start_epoch)

def get_tests_vals_from_txt(path):
    vals = []
    tests = []
    vals_p = []
    tests_p = []
    with open(path) as f:
        highest_val = -1
        highest_val_p = -1
        highest_test = -1
        highest_test_p = -1
        i = 0
        for l in f.readlines():
            i += 1
            val_acc = float(re.findall(r"Val_accuracy:(\d+\.\d+)", l)[0])
            test_acc = float(re.findall(r"Test_accuracy:(\d+\.\d+)", l)[0])
            if i >= 200:
                break
            elif i >= 100:
                if val_acc > highest_val_p:
                    highest_val_p = val_acc
                    highest_test_p = test_acc
                vals_p.append(val_acc)
                tests_p.append(test_acc)
            elif i < 100:
                if val_acc > highest_val:
                    highest_val = val_acc
                    highest_test = test_acc
                vals.append(val_acc)
                tests.append(test_acc)
    return (vals, vals_p, highest_val, highest_val_p, tests, tests_p, highest_test, highest_test_p)
    
def construct_run_dicts(base_path):
    ids = next(os.walk(base_path))[1]
    runs = {}
    for id in ids:
        d = {}
        path = f"{base_path}/{id}/metrics.txt"
        vals, vals_p, highest_val, highest_val_p, tests, tests_p, highest_test, highest_test_p = get_tests_vals_from_txt(path)
        
        d["vals"] = vals
        d["vals_p"] = vals_p
        d["highest_val"] = highest_val
        d["highest_val_p"] = highest_val_p
        
        d["tests"] = tests
        d["tests_p"] = tests_p
        d["highest_test"] = highest_test
        d["highest_test_p"] = highest_test_p

        d["ratio"] = d["highest_val_p"]/d["highest_val"]

        runs[id] = d
    return runs

runs_poison = construct_run_dicts('experiment_results_from_eddie/fmnist_res_net_18/poison_final')
runs_baseline = construct_run_dicts('experiment_results_from_eddie/fmnist_res_net_18/baseline_final')

In [17]:
highest_tests_sam_list = [runs_poison[d]["highest_test"] for d in runs_poison]
highest_tests_p_list = [runs_poison[d]["highest_test_p"] for d in runs_poison]
highest_tests_baseline_list = [runs_baseline[d]["highest_test"] for d in runs_baseline]

print("SAM mean test:", np.mean(highest_tests_sam_list), "+-", np.std(highest_tests_sam_list))
print("Poison mean test:", np.mean(highest_tests_p_list), "+-", np.std(highest_tests_p_list))
print("Baseline mean test:", np.mean(highest_tests_baseline_list), "+-", np.std(highest_tests_baseline_list))

SAM mean test: 91.98999999999998 +- 0.16451950239004176
Poison mean test: 91.99333333333333 +- 0.15003703246569375
Baseline mean test: 91.64333333333335 +- 0.44629337635436456
