In [208]:
# !pip install datasets
# !pip install gymnasium

In [209]:
# imports
from abc import (
    ABC,
    abstractmethod
)
from collections import defaultdict
from collections.abc import (
    Callable,
    Iterable
)
import itertools
import os
import torch
import tqdm
import datasets

from functions import (
    pbt_init,
    pbt_update,
    get_dataloader_random_reshuffle,
    to_ensembled,
    DictReLU,
    evaluate_model,
    normalize_features,
    Conv,
    get_accuracy,
    get_cross_entropy,
    Linear,
    Pool,
    LayerNorm,
    train_supervised,
    AdamW,
    Dropout
)


In [210]:
config = {
    "dataset_path": "uoft-cs/cifar10",
    "dataset_preprocessed_path": "data/cifar10.pt",
    "device": "cpu",
    "ensemble_shape": (16,),
    "float_dtype": torch.float32,
    "hyperparameter_raw_init_distributions": {
        "dropout_p": torch.distributions.Uniform(
            torch.tensor(0, device="cpu", dtype=torch.float32),
            torch.tensor(.5, device="cpu", dtype=torch.float32)
        ),
        "epsilon": torch.distributions.Uniform(
            torch.tensor(-10, device="cpu", dtype=torch.float32),
            torch.tensor(-5, device="cpu", dtype=torch.float32)
        ),
        "first_moment_decay": torch.distributions.Uniform(
            torch.tensor(-3, device="cpu", dtype=torch.float32),
            torch.tensor(0, device="cpu", dtype=torch.float32)
        ),
        "learning_rate": torch.distributions.Uniform(
            torch.tensor(-5, device="cpu", dtype=torch.float32),
            torch.tensor(-1, device="cpu", dtype=torch.float32)
        ),
        "second_moment_decay": torch.distributions.Uniform(
            torch.tensor(-5, device="cpu", dtype=torch.float32),
            torch.tensor(-1, device="cpu", dtype=torch.float32)
        ),
        "weight_decay": torch.distributions.Uniform(
            torch.tensor(-5, device="cpu", dtype=torch.float32),
            torch.tensor(-1, device="cpu", dtype=torch.float32)
        )
    },
    "hyperparameter_raw_perturb": {
        "dropout_p": torch.distributions.Normal(
            torch.tensor(0, device="cpu", dtype=torch.float32),
            torch.tensor(.1, device="cpu", dtype=torch.float32)
        ),
        "epsilon": torch.distributions.Normal(
            torch.tensor(0, device="cpu", dtype=torch.float32),
            torch.tensor(1, device="cpu", dtype=torch.float32)
        ),
        "first_moment_decay": torch.distributions.Normal(
            torch.tensor(0, device="cpu", dtype=torch.float32),
            torch.tensor(1, device="cpu", dtype=torch.float32)
        ),
        "learning_rate": torch.distributions.Normal(
            torch.tensor(0, device="cpu", dtype=torch.float32),
            torch.tensor(1, device="cpu", dtype=torch.float32)
        ),
        "second_moment_decay": torch.distributions.Normal(
            torch.tensor(0, device="cpu", dtype=torch.float32),
            torch.tensor(1, device="cpu", dtype=torch.float32)
        ),
        "weight_decay": torch.distributions.Normal(
            torch.tensor(0, device="cpu", dtype=torch.float32),
            torch.tensor(1, device="cpu", dtype=torch.float32)
        ),
    },
    "hyperparameter_transforms": {
        "dropout_p": lambda p: p.clip(0,1),
        "epsilon": lambda log10: 10 ** log10,
        "first_moment_decay": lambda x: (1 - 10 ** x).clamp(0, 1),
        "learning_rate": lambda log10: 10 ** log10,
        "second_moment_decay": lambda x: (1 - 10 ** x).clamp(0, 1),
        "weight_decay": lambda log10: 10 ** log10,
    },
    "improvement_threshold": 1e-4,
    "minibatch_size": 32,
    "minibatch_size_eval": 32,
    "pbt": True,
    "seed": 0,
    "steps_num": 10_001,
    "steps_without_improvement": 10_000,
    "valid_interval": 1000,
    "welch_confidence_level": .95,
    "welch_sample_size": 10,
}

In [211]:
torch.manual_seed(config["seed"])

<torch._C.Generator at 0x7d9de761beb0>

In [212]:
if not os.path.exists(config["dataset_preprocessed_path"]):
    dataset = datasets.load_dataset(
        config["dataset_path"]
    ).with_format(
        "torch",
        device=config["device"]
    )
    train, test = (
        dataset[key]
        for key in ["train", "test"]
    )
    train_valid = train.train_test_split(
        seed=config["seed"],
        test_size=len(test),
    )
    train, valid = (
        train_valid[key]
        for key in ["train", "test"]
    )

    (
        train_features,
        valid_features,
        test_features
    ) = (
        dataset["img"].to(config["float_dtype"])
        for dataset in (train, valid, test)
    )

    print(train_features.std())

    normalize_features(
        train_features,
        (valid_features, test_features)
    )

    print(train_features.std())

    print(train["label"].dtype)

    torch.save(
        {
            "train_features": train_features,
            "train_labels": train["label"],
            "valid_features": valid_features,
            "valid_labels": valid["label"],
            "test_features": test_features,
            "test_labels": test["label"],
        },
        config["dataset_preprocessed_path"]
    )

loaded = torch.load(
    config["dataset_preprocessed_path"],
    weights_only=True,
    map_location=config['device']
)
(
    train_features,
    train_labels,
    valid_features,
    valid_labels,
    test_features,
    test_labels
) = (
    loaded[key]
    for key in (
        "train_features",
        "train_labels",
        "valid_features",
        "valid_labels",
        "test_features",
        "test_labels"
    )
)

In [213]:
class ResLayer(torch.nn.Module): # note: worked with Hao on this but wrote up own solutions
    def __init__(
        self, config: dict,
        conv_in: int,
        conv_out: int,
        out_dims: int,
        kernel_shape: tuple[int],

    ):
        super(ResLayer, self).__init__()


        self.config = config
        self.n = conv_in
        self.m = conv_out
        self.F_out = out_dims
        self.kernel_shape = kernel_shape

        self.F = torch.nn.Sequential(
              LayerNorm(
                  config,
                  self.n,
                  normalized_offset=2,
              ),
              Dropout(config),
              Conv(
                  config,
                  self.n,
                  (3,3),
                  self.m,
                  init_multiplier=2 ** .5
              ),
              DictReLU(),
              LayerNorm(
                  config,
                  self.m,
                  normalized_offset=2,
              ),
              Dropout(config),
              Conv(
                  config,
                  self.m,
                  (3,3),
                  self.F_out,
                  # init_multiplier=2 ** .5
              )
            )

    def Q(self, x):
        slicing_indices = [slice(None)] * (x.dim() - len(self.kernel_shape))

        indices = []

        for kern_dim in self.kernel_shape:
            starting_index = (2 * (kern_dim - 1)) // 2
            ending_index = -2 * (((kern_dim - 1) // 2) + ((kern_dim - 1) % 2))

            s = slice(starting_index, None) if ending_index == 0 else slice(starting_index, ending_index)

            indices.append(s)

        slicing_indices += indices

        return x[tuple(slicing_indices)]

    def P(self, x, fx):

        Qx = self.Q(x)

        indices = torch.arange(self.F_out, device=self.config["device"])
        indices = indices % self.n

        feature_dim = -(len(self.kernel_shape) + 1)


        Px = torch.index_select(
            Qx,
            dim=feature_dim,
            index=indices
            )
        return Px



    def forward(self, batch: dict) -> dict:
        x = batch["features"]


        Fx = self.F(batch)["features"]

        Px = self.P(x, Fx)

        features = Px + Fx

        return batch | {"features": features}

In [None]:
model = torch.nn.Sequential(
    ResLayer(
        config,
        3,
        16,
        32,
        (3,3)
    ),
    Pool(
        config,
        kernel_shape=(2,2),
        stride=2
    ),
    DictReLU(),
    ResLayer(
        config,
        32,
        64,
        128,
        (3,3)
    ),
    Pool(
        config,
        sequence_dim_num=2
    ),
    LayerNorm(
        config,
        128
    ),
    Dropout(config),
    Linear(
        config,
        128,
        128,
        init_multiplier=2 ** .5
    ),
    DictReLU(),
    LayerNorm(
        config,
        128
    ),
    Dropout(config),
    Linear(
        config,
        128,
        10
    )
)

optimizer = AdamW(model.parameters())

dataset_train = {
    "features": train_features,
    "label": train_labels
}
dataset_valid = {
    "features": valid_features,
    "label": valid_labels
}

log = train_supervised(
    config,
    dataset_train,
    dataset_valid,
    get_cross_entropy,
    get_accuracy,
    model,
    optimizer,
    target_key="label"
)

  0%|          | 0/10001 [00:00<?, ?it/s]
  0%|          | 0/313 [00:00<?, ?it/s][A
  0%|          | 1/313 [00:01<07:52,  1.51s/it][A
  1%|          | 2/313 [00:03<08:03,  1.55s/it][A
  1%|          | 3/313 [00:04<07:55,  1.53s/it][A
  1%|▏         | 4/313 [00:05<06:55,  1.34s/it][A
  2%|▏         | 5/313 [00:06<06:20,  1.24s/it][A
  2%|▏         | 6/313 [00:07<06:00,  1.17s/it][A
  2%|▏         | 7/313 [00:08<05:48,  1.14s/it][A
  3%|▎         | 8/313 [00:09<05:40,  1.11s/it][A
  3%|▎         | 9/313 [00:10<05:33,  1.10s/it][A
  3%|▎         | 10/313 [00:12<05:28,  1.09s/it][A
  4%|▎         | 11/313 [00:13<05:25,  1.08s/it][A
  4%|▍         | 12/313 [00:14<05:23,  1.08s/it][A
  4%|▍         | 13/313 [00:15<05:51,  1.17s/it][A
  4%|▍         | 14/313 [00:17<06:20,  1.27s/it][A
  5%|▍         | 15/313 [00:18<06:42,  1.35s/it][A
  5%|▌         | 16/313 [00:19<06:22,  1.29s/it][A
  5%|▌         | 17/313 [00:20<06:00,  1.22s/it][A
  6%|▌         | 18/313 [00:21<05:45,  1

validation metric 0.1465
Best last metric 0.15
New best metric



  0%|          | 22/10001 [07:56<13:19:39,  4.81s/it]