# SANE Exploration

In this notebook, we explore SANE datasets, models, training and inference as an easy starting point.
This notebook uses a sample dataset. To prepare that, navigate to SANE/data, and run `bash download_cifar10_cnn_sample.sh` to download the sample dataset, and `python3 preprocess_dataset_cnn_cifar10_sample.py`.

In [1]:
# imports
import logging

logging.basicConfig(level=logging.INFO)

import os

import torch

from SANE.evaluation.ray_fine_tuning_callback import CheckpointSamplingCallback
from SANE.evaluation.ray_fine_tuning_callback_subsampled import (
    CheckpointSamplingCallbackSubsampled,
)
from SANE.evaluation.ray_fine_tuning_callback_bootstrapped import (
    CheckpointSamplingCallbackBootstrapped,
)

import json

from pathlib import Path


from SANE.models.ae_trainer import AE_trainer
from SANE.datasets.dataset_sampling_preprocessed import PreprocessedSamplingDataset

INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmp1cr6k324
INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmp1cr6k324/_remote_module_non_scriptable.py


In [2]:
PATH_ROOT = Path('.')

In [13]:
# Configure SANE pretraining

### configure experiment #########
experiment_name = "sane_cifar10_cnn_standalone"
# set module parameters
config = {}
config["seed"] = 32
config["device"] = "cuda" if torch.cuda.is_available() else "cpu"
config["device_no"] = 1

config["ae:transformer_type"] = "gpt2"
config["ae:i_dim"] = 288
config["ae:i_dim"] = 289
config["ae:lat_dim"] = 128
config["ae:max_positions"] = [100, 10, 40]
config["ae:d_model"] = 1024
config["ae:nhead"] = 8
config["ae:num_layers"] = 8

# permutation specs
config["training::permutation_number"] = 5
config["training::view_2_canon"] = False
config["training::view_2_canon"] = True
config["testing::permutation_number"] = 5
config["testing::view_1_canon"] = True
config["testing::view_2_canon"] = False
### Augmentations
config["trainset::add_noise_view_1"] = 0.1
config["trainset::add_noise_view_2"] = 0.1
config["trainset::noise_multiplicative"] = True
config["trainset::erase_augment_view_1"] = None
config["trainset::erase_augment_view_2"] = None

config["training::windowsize"] = 64
config["trainset::batchsize"] = 32

# configure optimizer
config["optim::optimizer"] = "adamw"
config["optim::lr"] = 1e-4
config["optim::wd"] = 3e-9
config["optim::scheduler"] = "OneCycleLR"

# Task config
config["training::temperature"] = 0.1
config["training::gamma"] = 0.05
config["training::reduction"] = "mean"
config["training::contrast"] = "simclr"

# training duration
config["training::epochs_train"] = 5
config["training::output_epoch"] = 5
config["training::test_epochs"] = 1

# training optimization
config["model::compile"] = True
config["training::precision"] = "amp"
config["training::reduction"] = "mean"
config["monitor_memory"] = True
config["trainloader::workers"] = 6

# configure output path
experiment_dir = PATH_ROOT.joinpath("sane_pretraining")
try:
    experiment_dir.mkdir(parents=True, exist_ok=False)
except FileExistsError:
    pass
config['experiment_dir'] = experiment_dir
###### Datasets ###########################################################################
# pre-compute dataset and drop in torch.save
data_path = Path("../data/dataset_cnn_cifar10_sample_ep21-25_std/")
# path to dataset for training
config["dataset::dump"] = data_path.joinpath("dataset.pt")
config["downstreamtask::dataset"] = None
# call dataset prepper function
logging.info("prepare data")


config["callbacks"] = []
config

INFO:root:prepare data


{'seed': 32,
 'device': 'cpu',
 'device_no': 1,
 'ae:transformer_type': 'gpt2',
 'ae:i_dim': 289,
 'ae:lat_dim': 128,
 'ae:max_positions': [100, 10, 40],
 'ae:d_model': 1024,
 'ae:nhead': 8,
 'ae:num_layers': 8,
 'training::permutation_number': 5,
 'training::view_2_canon': True,
 'testing::permutation_number': 5,
 'testing::view_1_canon': True,
 'testing::view_2_canon': False,
 'trainset::add_noise_view_1': 0.1,
 'trainset::add_noise_view_2': 0.1,
 'trainset::noise_multiplicative': True,
 'trainset::erase_augment_view_1': None,
 'trainset::erase_augment_view_2': None,
 'training::windowsize': 64,
 'trainset::batchsize': 32,
 'optim::optimizer': 'adamw',
 'optim::lr': 0.0001,
 'optim::wd': 3e-09,
 'optim::scheduler': 'OneCycleLR',
 'training::temperature': 0.1,
 'training::gamma': 0.05,
 'training::reduction': 'mean',
 'training::contrast': 'simclr',
 'training::epochs_train': 50,
 'training::output_epoch': 25,
 'training::test_epochs': 1,
 'model::compile': True,
 'training::precision

In [14]:
# init ae_trainer - this sets up the model, loads the dataset and configures the training loop.
ae_trainer = AE_trainer(config)

INFO:root:Set up AE Trainable
INFO:root:get datasets
INFO:root:Load Data
INFO:root:set up dataloaders
INFO:root:corrected batchsize to 32
INFO:root:set downstream tasks
INFO:root:No properties found in dataset - skip downstream tasks.
INFO:root:instanciate model
INFO:root:Initialize Model
INFO: Global seed set to 32
INFO:lightning.fabric.utilities.seed:Global seed set to 32
INFO:root:device: cpu
INFO:root:compiling the model... (takes a ~minute)
INFO:root:compiled successfully
INFO:root:set transformations
INFO:root:set callbacks
INFO:root:module setup done


model: use simclr NT_Xent loss
Running single-gpu. send model to device: cpu
num decayed parameter tensors: 73, with 202,504,068 parameters
num non-decayed parameter tensors: 40, with 35,353 parameters
using fused AdamW: False
++++++ USE AUTOMATIC MIXED PRECISION +++++++


In [None]:
# run training loop
ae_trainer.train()

In [29]:
# get data for further exploration
batch = next(iter(ae_trainer.trainloader))
tokens, mask, positions, properties = batch
print(f'tokens: {tokens.shape} - mask: {mask.shape} - positions: {positions.shape} - properties: {properties.shape}')
# tokens are of shape [batch_size, no_permutations, sequence_length, token_size]
# masks are of shape [batch_size, sequence_length, token_size]
# positions are of shape [batch_size, sequence_length, 3]
# properties are of shape [batch_size, 3]

tokens = tokens[:,0,:,:].squeeze() # choose one permutation

tokens: torch.Size([32, 201, 64, 289]) - mask: torch.Size([32, 64, 289]) - positions: torch.Size([32, 64, 3]) - properties: torch.Size([32, 3])


In [31]:
# compute embeddings
with torch.no_grad():
    z = ae_trainer.module.forward_encoder(tokens,positions.to(torch.int))
print(f'z: {z.shape}')

z: torch.Size([32, 64, 128])


In [33]:
# decode to weights
with torch.no_grad():
    tokens_recon = ae_trainer.module.forward_decoder(z,positions.to(torch.int))
print(f'tokens_recon: {tokens_recon.shape}')

tokens_recon: torch.Size([32, 64, 289])


In [34]:
# load CNN model

from SANE.models.def_net import NNmodule


config_cnn_path = Path('../data/cifar10_cnn_sample_ep21-25/NN_tune_trainable_da045_00000_0_seed=1_2021-09-25_11-43-53/params.json')
config_cnn = json.load(config_cnn_path.open('r'))

cnn = NNmodule(config_cnn)

In [37]:
# get cnn checkpoint

check = cnn.model.state_dict()
check.keys()

odict_keys(['module_list.0.weight', 'module_list.0.bias', 'module_list.4.weight', 'module_list.4.bias', 'module_list.8.weight', 'module_list.8.bias', 'module_list.13.weight', 'module_list.13.bias', 'module_list.16.weight', 'module_list.16.bias'])

In [38]:
# tokenize checkpoint 
from SANE.datasets.dataset_auxiliaries import tokenize_checkpoint
toks, masks, pos = tokenize_checkpoint(check,tokensize=0,return_mask=True)
toks.shape

In [40]:
# tokenize checkpoint 
from SANE.datasets.dataset_auxiliaries import tokens_to_checkpoint
check_recon = tokens_to_checkpoint(tokens=toks, pos=pos, reference_checkpoint=check, ignore_bn=True)

In [41]:
# assert equivalence
torch.allclose(check['module_list.0.weight'],check_recon['module_list.0.weight'])

True