In [None]:
from functools import partial
from typing import Callable, Any, Final

In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm
from pprint import pprint

In [None]:
from dataset_subset import sample_dataset_random, sample_dataset_jls_kmeans
from train_dataset import train_cfar, train_svhn, train_mnist

In [None]:
BASE_SEED: Final[int] = 206783441

In [None]:
DATASET_IMPLICIT_DIMS: dict[str, int] = {
    "MNIST": 11,
    "SVHN": 14,
    "CIFAR-10": 21
}

Calculate the different values for functions of $d_{ID}$

In [None]:
dims_df = pd.DataFrame.from_dict(DATASET_IMPLICIT_DIMS, orient='index', columns=['d'])
dims_df['2d'] = 2 * dims_df['d']
dims_df['5d'] = 5 * dims_df['d']
dims_df['d*ln(d)'] = dims_df['d'] * np.log(dims_df['d'])
dims_df['d*log2(d)'] = dims_df['d'] * np.log2(dims_df['d'])
dims_df['d^1.5*ln(d)'] = dims_df['d*ln(d)'] * np.sqrt(dims_df['d'])
dims_df['d^2'] = np.power(dims_df['d'], 2)
dims_df

Round the dimensions since we can only use positive integers

In [None]:
dims_df = dims_df.round().astype(int)
dims_df

In [None]:
def train_model(model_name: str, sample_func: Callable[[np.ndarray[float]], list[int]], base_seed: int, num_runs: int = 10) -> list[float]:
    train_func: Callable[[Callable[[np.ndarray[float]], list[int]]], tuple[Any, float, float]]
    if model_name.upper() == 'MNIST':
        train_func = train_mnist
    elif model_name.upper() == 'SVHN':
        train_func = train_svhn
    else:
        train_func = train_cfar

    return [max(train_func(partial(sample_func, random_seed=base_seed + i), seed=base_seed+i)[1]) for i in tqdm(range(num_runs))]


In [None]:
train_results: dict[str, dict[str, list[float]]] = dict()

In [None]:
for model in dims_df.index[::-1]:
    print('='*50)
    print(model.center(50, '='))
    print('='*50)
    print(f'{model} baseline:')
    train_results[model] = {'baseline': train_model(model, partial(sample_dataset_random, n_samples=100), base_seed=BASE_SEED)}
    for jls_dim in dims_df.columns:
        print(f'{model} with {jls_dim=}:')
        train_results[model][jls_dim] = train_model(model,
                                                    partial(sample_dataset_jls_kmeans, n_samples=100, jls_dim=dims_df.loc[model, jls_dim]),
                                                    base_seed=BASE_SEED)
    print(f"Results for {model=}:")
    pprint(pprint, indent=4)
    print()

In [None]:
train_results