In [1]:
import matplotlib.pyplot as plt
from deap import gp
from deap.tools import Logbook
from typing import Callable, TypeVar, Iterable
import pickle, os, numpy as np
from run_gp import RunInfo
from shared_tools.make_datasets import x_train, y_train, x_validation, y_validation, x_test, y_test
from simple_pred.function_set import create_pset
from shared_tools.toolbox import create_toolbox
from functools import partial
from dataclasses import dataclass
from main import parser
from tree import Tree
from IPython.display import Image
import matplotlib

In [2]:
parameters = parser.parse_args([])

datasets = {
    "train": (x_train, y_train),
    "validation": (x_validation, y_validation),
    "test": (x_test, y_test)
}

pset = create_pset(*x_train[0].shape)

toolbox = create_toolbox(datasets, pset, parameters)

In [3]:
def plot(logbook: Logbook) -> None:
    gen = logbook.select("gen")

    size_avgs = logbook.chapters["size"].select("avg")

    fig, axs = plt.subplots(2, 2)
    combinations = zip(["min", "max", "avg"], ["red", "green", "blue"])

    for ax, (setting, color) in zip(axs.flat, combinations):
        ax.plot(gen, logbook.chapters["fitness"].select(setting),
                color=color, label=f"fitness_{setting}")
        ax.set_title(f"Fitness {setting}")
        ax.set(xlabel='x-label', ylabel='y-label')
    axs[1,1].plot(gen, size_avgs, color="orange")
    axs[1,1].set_title("Average Size")
    # Hide x labels and tick labels for top plots and y ticks for right plots.
    for ax in axs.flat:
        ax.set(xlabel="generations", ylabel="fitness")
        # ax.label_outer()
    axs[1,1].set(ylabel="size")
    plt.show()


In [4]:
T = TypeVar('T')
def retrieve_from_files(extractor: Callable[[RunInfo], T], files: Iterable[str],  exclude_zero: bool=True) -> list[T]:
    parameters = parser.parse_args([])
    
    datasets = {
        "train": (x_train, y_train),
        "validation": (x_validation, y_validation),
        "test": (x_test, y_test)
    }
    
    pset = create_pset(*x_train[0].shape)
    
    toolbox = create_toolbox(datasets, pset, parameters)
    result = []
    for file_path in files:
        with open(file_path, 'rb') as file:
            run_info = pickle.load(file)
            if exclude_zero and run_info.parameters.seed == 0:
                continue
            result.append(extractor(run_info))
    return result

def get_files(model: str) -> list[str]:
    return [f"{model}/data/{name}" for name in os.listdir(f"{model}/data")]
    
def plot_across_all_models(model: str, *extractors: tuple[Callable[[RunInfo], T], str], title: str, exclude_zero: bool=True) -> None:
    for extractor, name in extractors:
        results = np.array(retrieve_from_files(
            extractor,
            (f"{model}/data/{name}" for name in os.listdir(f"{model}/data")),
            exclude_zero=exclude_zero
        ))
        for result in results:
            plt.plot(result, alpha=0.3, label='_nolegend_')
        plt.plot(results.mean(axis=0), linewidth=3,  label=f"{name} average")
    plt.title(title)
    plt.legend()
    plt.show()



In [5]:

plot_across_all_models('simple_pred', (lambda i: i.log.select("fit_min"), "fitness"), (lambda i: i.log.select("val_min"), "validation"), title="Minimum Test and Validation error over every run")
#plot_across_all_models('simple_pred', lambda i: i.log.select("val_min"), title="minimum validation error simple pred")


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


AttributeError: Can't get attribute 'X' on <module 'deap.gp' from '/home/pork/.venv/lib/python3.13/site-packages/deap/gp.py'>

In [None]:
best_individual = min(retrieve_from_files(lambda r: r, get_files('simple_pred')), key=lambda r: r.test_error).best_individual
print(best_individual)

In [None]:
def show(a_v_pairs: list[tuple[float, float]]) -> None:
    aro, val = sum(a for a, _ in a_v_pairs) / len(a_v_pairs), sum(v for _, v in a_v_pairs) / len(a_v_pairs)
    plt.scatter([a for a, _ in a_v_pairs], [v for _, v in a_v_pairs],  s=20, color=(0.1, 0.1, 1, 0.7))
    plt.scatter([aro], [val], s=100, color=(1, 0, 0, 0.7))
    plt.xlim((-1, 1))
    plt.ylim((-1, 1))
    plt.show()


In [None]:
show(y_validation)

In [None]:
predictor = toolbox.compile(best_individual)
predictions = [predictor(img) for img in x_validation]
show(predictions)

In [None]:
f = Tree.of(best_individual, pset)
for i, img in enumerate(x_train[:5]):
    f.save_graph(f"models/best_model{i}.png", img)
    display(Image(f"models/best_model{i}.png"))

In [None]:
np.array([img.std() for img in x_train]).mean(), np.array([img.std() for img in x_train]).max()

In [None]:
img_stds = np.array([img.std() for img in x_train])
matplotlib.pyplot.hist((img_stds - img_stds.mean()) * 2 / img_stds.max())

In [None]:
print(img_stds.mean(), 2 / img_stds.max())