In [1]:
from functools import partial
from typing import Callable, Any, Final

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

In [3]:
from dataset_subset import sample_dataset_random, sample_dataset_jls_kmeans
from train_dataset import train_cfar, train_svhn, train_mnist

In [None]:
BASE_SEED: Final[int] = 206783441

In [4]:
DATASET_IMPLICIT_DIMS: dict[str, int] = {
    "MNIST": 11,
    "SVHN": 14,
    "CIFAR-10": 21
}

Calculate the different values for functions of $d_{ID}$

In [5]:
dims_df = pd.DataFrame.from_dict(DATASET_IMPLICIT_DIMS, orient='index', columns=['d'])
dims_df['2d'] = 2 * dims_df['d']
dims_df['5d'] = 5 * dims_df['d']
dims_df['d*ln(d)'] = dims_df['d'] * np.log(dims_df['d'])
dims_df['d*log2(d)'] = dims_df['d'] * np.log2(dims_df['d'])
dims_df['d^1.5*ln(d)'] = dims_df['d*ln(d)'] * np.sqrt(dims_df['d'])
dims_df['d^2'] = np.power(dims_df['d'], 2)
dims_df

Unnamed: 0,d,2d,5d,d*ln(d),d*log2(d),d^1.5*ln(d),d^2
MNIST,11,22,55,26.376848,38.053748,87.482108,121
SVHN,14,28,70,36.946803,53.302969,138.242277,196
CIFAR-10,21,42,105,63.934971,92.238666,292.986845,441


Round the dimensions since we can only use positive integers

In [6]:
dims_df = dims_df.round().astype(int)
dims_df

Unnamed: 0,d,2d,5d,d*ln(d),d*log2(d),d^1.5*ln(d),d^2
MNIST,11,22,55,26,38,87,121
SVHN,14,28,70,37,53,138,196
CIFAR-10,21,42,105,64,92,293,441


In [7]:
def train_model(model_name: str, sample_func: Callable[[np.ndarray[float]], np.ndarray[int]], base_seed: int, num_runs: int = 10) -> list[float]:
    train_func: Callable[[Callable[[np.ndarray[float]], np.ndarray[int]]], tuple[Any, float, float]]
    if model_name.upper() == 'MNIST':
        train_func = train_mnist
    elif model_name.upper() == 'SVHN':
        train_func = train_svhn
    else:
        train_func = train_cfar

    return [max(train_func(partial(sample_func, random_seed=base_seed + i), seed=base_seed+i)[1]) for i in tqdm(range(num_runs))]


In [8]:
train_results: dict[str, dict[str, list[float]]] = dict()

In [None]:
for model in dims_df.index:
    print('='*50)
    print(model.center(50, '='))
    print('='*50)
    print(f'{model} baseline:')
    train_results[model] = {'baseline': train_model(model, partial(sample_dataset_random, n_samples=100), base_seed=206783441)}
    for jls_dim in dims_df.columns:
        print(f'{model} with {jls_dim=}:')
        train_results[model][jls_dim] = train_model(model,
                                                    partial(sample_dataset_jls_kmeans, n_samples=100, jls_dim=dims_df.loc[model, jls_dim]),
                                                    base_seed=BASE_SEED)
    print()

Building MNIST data loader with 1 workers
Sequential(
  (fc1): Linear(in_features=784, out_features=256, bias=True)
  (relu1): ReLU()
  (drop1): Dropout(p=0.2, inplace=False)
  (fc2): Linear(in_features=256, out_features=256, bias=True)
  (relu2): ReLU()
  (drop2): Dropout(p=0.2, inplace=False)
  (out): Linear(in_features=256, out_features=10, bias=True)
)
Epoch #1 Elapsed 8.10s, 8.10 s/epoch, 1.16 s/batch, ets 801.70s
	Test set: Average loss: 2.0641, Accuracy: 3345/10000 (33%)
Epoch #2 Elapsed 28.96s, 14.48 s/epoch, 2.07 s/batch, ets 1418.82s
Epoch #3 Elapsed 49.05s, 16.35 s/epoch, 2.34 s/batch, ets 1586.04s
Epoch #4 Elapsed 68.92s, 17.23 s/epoch, 2.46 s/batch, ets 1653.98s
Epoch #5 Elapsed 91.48s, 18.30 s/epoch, 2.61 s/batch, ets 1738.20s
Epoch #6 Elapsed 116.17s, 19.36 s/epoch, 2.77 s/batch, ets 1820.03s
	Test set: Average loss: 1.0321, Accuracy: 6367/10000 (64%)
Epoch #7 Elapsed 139.50s, 19.93 s/epoch, 2.85 s/batch, ets 1853.34s
Epoch #8 Elapsed 159.95s, 19.99 s/epoch, 2.86 s/batch

In [None]:
train_results