# Residual network

This notebook trains the ResNet-20 based on


After training, model is serialized and uploaded to W&B project.

In [None]:
import wandb
import tensorflow as tf
import pathlib
import shutil
import utils
from train import evaluate_model, evaluate_diacritics_performance

In [None]:
defaults = dict(
    batch_size=32*4,
    epochs=100,    
    optimizer="sgd",
    learning_rate=0.01,
    momentum=0.9,
)

RESNET_DEPTHS = [3, 4, 6, 3]
MODEL_NAME = f"resnet-{sum(RESNET_DEPTHS) + 2}"
run = wandb.init(project="master-thesis", job_type="training", name=MODEL_NAME, config=defaults,)
split_paths = utils.load_data(run=run)

# hyperparameters

opt_name = wandb.config.optimizer
lr = wandb.config.learning_rate
momentum = wandb.config.momentum
bs = wandb.config.batch_size
epochs = wandb.config.epochs

In [None]:
ds_train, ds_test, ds_val = [
    utils.create_tf_dataset(split_path, batch_size=bs) for split_path in split_paths
    ]

num_classes = len(ds_train.class_names)

print(f"There are {num_classes} classes")
print(f"Training set has {len(ds_train)} batches")
print(f"Test set has {len(ds_test)} batches")
print(f"Validation set has {len(ds_val)} batches")

ds_train = utils.preprocess_dataset(ds_train)
ds_val = utils.preprocess_dataset(ds_val)
ds_test = utils.preprocess_dataset(ds_test, cache=False)

In [None]:
from resnet import get_resnet_model

model = get_resnet_model(input_shape=[32, 32, 1], block_design=RESNET_DEPTHS, num_classes=num_classes)

opt = tf.keras.optimizers.get({
    'class_name': wandb.config.optimizer,
    'config': {
        'learning_rate': lr,
        'momentum': momentum
    }
})

model.compile(
    optimizer=opt,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

wandb_callback = wandb.keras.WandbCallback(
    save_model=False,
    compute_flops=True,
)

# save the best model
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=f"./artifacts/{MODEL_NAME}.h5",
    save_weights_only=False,
    monitor="val_accuracy",
    mode="max",
    save_best_only=True,
)

history = model.fit(
    ds_train,
    epochs=epochs,
    validation_data=ds_val,
    callbacks=[wandb_callback, checkpoint_callback],
)

In [None]:
plot_history(history)

In [None]:
# evaluate model then log to wandb

evaluate_model(model, ds_test, MODEL_NAME)
evaluate_diacritics_performance(model, ds_test)

In [None]:
# save artifact to wandb
artifact = wandb.Artifact(
    name=MODEL_NAME,
    type="model"
)

# save best model to artifact
artifact.add_file(f"./artifacts/{MODEL_NAME}.h5")
run.log_artifact(artifact)
run.finish()