In [1]:
%load_ext autoreload
%autoreload 2
%cd ..

/mnt/xfs/home/bencw/workspace/pretraining-distribution-shift-robustness


In [2]:
import torch as ch
from pathlib import Path
from torchvision.datasets import CIFAR10
from torchvision.transforms.functional import adjust_brightness
from ffcv.fields import IntField, RGBImageField
from ffcv.writer import DatasetWriter
from src.experiment_manager.base import ExperimentManager
import src.dataset_utils as dataset_utils
import src.modeling as modeling
from src.experiment_manager import model_manager

In [3]:
CIFAR10_ROOT = None # TODO

In [4]:
# Define a config to train a model
MODEL_CONFIG = {
    "training": {
        "optimizer": "sgd",
        "lr": 0.5,
        "lr_schedule": "triangle",
        "weight_decay": 5e-4,
        "momentum": 0.9,
        "warmup_epochs": 5,
        "epochs": 24,
        "batch_size": 512,
        "label_smoothing": 0.1,
        "use_scaler": True,
        "clip_grad": False,
        "grad_clip_norm": 1.0,
        "image_dtype": "float16",
        "decoder": "simple",
        "augmentation": "flip_translate_cutout",
        "num_workers": 10,
    },
    "evaluation": {
        "lr_tta": False,
    },
    "model": {
        "model_name": "timm_resnet18",
        "pretrained": "None",
        "resize": 224,
    },
}

In [5]:
# Helper function to write and get FFCV datasets
def get_ffcv_datasets(path, datasets, overwrite=False):
    path = Path(path)
    ffcv_datasets = {}
    for name, dataset in datasets.items():
        ffcv_path = (path / f"{name}.beton")
        if not ffcv_path.exists() or overwrite:
            ffcv_path.parent.mkdir(exist_ok=True, parents=True)
            writer = DatasetWriter(
                ffcv_path,
                {
                    "image": RGBImageField(),
                    "label": IntField(),
                },
            )
            writer.from_indexed_dataset(dataset)
        ffcv_datasets[name] = dataset_utils.FFCVDataset(
            ffcv_path,
            10,
        )
    return ffcv_datasets

In [6]:
# Class for shifted version of CIFAR-10 with decreased brightness
class LowBrightnessCIFAR10(ch.utils.data.Dataset):
    def __init__(self, train):
        self.dataset = CIFAR10(CIFAR10_ROOT, train=train)

    def __getitem__(self, index):
        image, label = self.dataset[index]
        image = adjust_brightness(image, brightness_factor=0.5)
        return image, label

    def __len__(self):
        return len(self.dataset)

In [7]:
class ExampleExperimentManager(ExperimentManager):
    def __init__(self, path, overwrite=False):
        # In this example, our splits are
        # - source_train (reference training dataset)
        # - source_val (reference validation dataset)
        # - target_val (shifted validation dataset)
        self.datasets = {
            "source_train": CIFAR10(CIFAR10_ROOT, train=True),
            "source_val": CIFAR10(CIFAR10_ROOT, train=False),
            "target_val": LowBrightnessCIFAR10(False),
        }
        self.ffcv_datasets = get_ffcv_datasets(Path(path) / "datasets", self.datasets, overwrite=overwrite)
        super().__init__(path)
        
    def get_ffcv_dataset(self, split):
        return self.ffcv_datasets[split]

    def get_indices(self, split):
        # indices are used to specify a subset of a dataset, but in this case splits correspond to the entire dataset
        return None

    def get_loader(self, split):
        return modeling.make_loader(
            self.get_ffcv_dataset(split),
            indices=self.get_indices(split),
            decoder="simple",
            train="train" in split,
            batch_size=512,
            normalization_params=(0.0, 1.0),
        )

    def _make_model_managers(self):
        model_managers = {}
        
        train_ffcv_dataset = self.get_ffcv_dataset("source_train")
        train_indices = self.get_indices("source_train")

        model_managers["baseline"] = model_manager.SimpleModelManager(
            train_ffcv_dataset,
            train_indices,
            MODEL_CONFIG,
            group="Baseline",
        )

        return model_managers

In [8]:
path = Path("example_experiment")
manager = ExampleExperimentManager(path)

In [9]:
# Get the baseline model (will train and save the model if not already trained)
model = manager.get_model("baseline")

In [10]:
# Get the predictions of the baseline model on the target (shifted) validation dataset
manager.get_preds("baseline", split_name="target_val")

tensor([[-0.5193, -0.4019, -0.2739,  ..., -0.6208, -0.2327, -0.5738],
        [ 0.0335,  0.0234, -0.5594,  ..., -0.4075,  3.6742, -0.8820],
        [-0.6811,  0.7197, -0.6793,  ..., -0.5406,  3.6793, -0.6452],
        ...,
        [-0.3106, -0.2750, -0.4471,  ..., -0.6251, -0.3018, -0.3060],
        [-0.7448,  3.8384, -0.3713,  ..., -0.3818, -0.4826, -0.7281],
        [-0.4956, -0.4333, -0.3835,  ...,  4.2603, -0.3436, -0.4699]],
       device='cuda:0')

In [11]:
# Get the metrics (by default, just accuracy) of models on the source (reference) validation set
manager.get_metrics("source_val", ignore_unpredicted=False)

{'baseline': 0.9388999938964844}

In [12]:
# Get the metrics (by default, just accuracy) of models on the target (shifted) validation set
manager.get_metrics("target_val", ignore_unpredicted=False)

{'baseline': 0.9059999585151672}