# SKAL! Model Training
Brief Example Notebook on how to train a SKAL! model

## Input variables

In [None]:
training_dir = "datasets/nanotwice/"
dataset_format = "nanotwice"
model = "bigan"
config_path = "config/bigan/nanotwice.yaml"
experiment_dir = "experiments"


In [None]:
import os

os.environ['TF_GPU_ALLOCATOR'] = "cuda_malloc_async"

subdirs = os.listdir(os.getcwd())
# little workaround. with a properlyb built package this should be useless
if "skal" not in subdirs:
    os.chdir("..")

In [None]:
from skal.utils import utils
from skal.experiment.config import Config
from skal.experiment.workspace import Workspace
from skal.data.folders import FolderFactory
from skal.data.augmenters import AugmenterBuilder
from skal.data.preprocessors import PreprocessorBuilder
from skal.data.dataset_builder import AnomalyDatasetBuilder
from skal.models.model_choices import LoaderFactory

## Setting up the environment

In [None]:
utils.set_gpu()
exp_params = utils.load_yaml_file(config_path)
config = Config(**exp_params)
exp_ws = Workspace(root_dir=experiment_dir)
folder = FolderFactory.get_folder(dataset_format, training_dir)

training_paths = folder.get_training_paths(shuffle=True, seed=config.seed)
print(f"Found {len(training_paths)} training paths")

preprocessor = PreprocessorBuilder.get_preprocessor(config.preprocessor)
augmenter = AugmenterBuilder.augmenter_from_config(config.augmenter)
dataset_builder = AnomalyDatasetBuilder(
    folder, preprocessor, augmenter=augmenter, seed=config.seed
)
train_ds, val_ds = dataset_builder.train_val_ds_from_folder(
    shuffle=True, batch_size=config.batch_size, val_split=config.val_split)


## Loading the Model

In [None]:
loader = LoaderFactory.get_loader(config.model['name'])
model = loader.load_model_from_config(config.model, seed=config.seed)
trainer = loader.load_trainer()

## Model training

In [None]:
exp_ws.make_experiment_dirs()

print("Everything is ready. Starting training...")
trainer.train_model(model, train_ds, val_ds, config, exp_ws)
model.save_weights(exp_ws.save_dir)
print("Job done!")