# Baseline network + pruning + quantization

This notebook trains the baseline network with exact same architecture as the one in paper [Recognition of handwritten Latin characters with diacritics using CNN](https://journals.pan.pl/dlibra/publication/136210/edition/119099/content/bulletin-of-the-polish-academy-of-sciences-technical-sciences-recognition-of-handwritten-latin-characters-with-diacritics-using-cnn-lukasik-edyta-charytanowicz-malgorzata-milosz-marek-tokovarov-michail-kaczorowska-monika-czerwinski-dariusz-zientarski-tomasz-2021-69-no-1?language=en)

Model architecture description:

"The architecture of the concrete CNN is shown in Fig. 2.
The input is a 32x32 binarized matrix. The input is then prop-
agated through 12 adaptable layers. First come two convolu-
tional layers having 32 filters with the size of 3x3 and stride 1.
Secondly, the output of the convolutional layer is fed to the
ReLU function. The output is down-sampled using a max-pool-
ing operation with a 2x2 stride. Next, the dropout technique is
used with the coefficient 0.25. The four operations (two con-
volutions, nonlinearity, max-pooling, and dropout) are repeated,
using 64 filters for the convolutional layers. The output of the
last layer is then flattened and fed through a fully connected
layer with 256 neurons and ReLU nonlinearities, dropped out
with the 0.25 coefficient, and a final output layer is fully con-
nected with a Softmax activation function. The Adam optimizer
and the cross-entropy loss function were used in the network. 
The output is a probability distribution over 89 classes."


Additionally, this network was pruned and quantized after training.



After training, model is serialized and uploaded to W&B project.

In [None]:
import pathlib
import shutil
import os
import time
import datetime
import numpy as np
import wandb
from wandb.keras import WandbCallback
import tensorflow as tf
import matplotlib.pyplot as plt
import zipfile

from typing import List

def load_data(run, artifact_name = "phcd_paper_splits_tfds") -> List[tf.data.Dataset]:
    """
    Downloads datasets from a wandb artifact and loads them into a list of tf.data.Datasets.
    """

    artifact = run.use_artifact(f"master-thesis/{artifact_name}:latest")
    artifact_dir = pathlib.Path(
        f"./artifacts/{artifact.name.replace(':', '-')}"
    ).resolve()
    if not artifact_dir.exists():
        artifact_dir = artifact.download()
        artifact_dir = pathlib.Path(artifact_dir).resolve()

    # if tf.__version__ minor is less than 10, use
    # tf.data.experimental.load instead of tf.data.Dataset.load

    if int(tf.__version__.split(".")[1]) < 10:
        load_function = tf.data.experimental.load
    else:
        load_function = tf.data.Dataset.load
    
    output_list = []
    for split in ["train", "test", "val"]:
        ds = load_function(str(artifact_dir / split), compression="GZIP")
        output_list.append(ds)
    
    return output_list

def get_readable_class_labels(subset = 'phcd_paper'):
    if subset == 'phcd_paper':
        return ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c',
        'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p',
        'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'A', 'B', 'C',
        'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
        'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'ą', 'ć', 'ę',
        'ł', 'ń', 'ó', 'ś', 'ź', 'ż', 'Ą', 'Ć', 'Ę', 'Ł', 'Ń', 'Ó', 'Ś',
        'Ź', 'Ż', '+', '-', ':', ';', '$', '!', '?', '@', '.']
    elif subset == 'uppercase':
        return ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 
        'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'Ą', 'Ć', 
        'Ę', 'Ł', 'Ń', 'Ó', 'Ś', 'Ź', 'Ż']
    elif subset == 'lowercase':
        return ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
        'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'ą', 'ć',
        'ę', 'ł', 'ń', 'ó', 'ś', 'ź', 'ż']
    elif subset == 'numbers':
        return ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    elif subset == 'uppercase_no_diacritics':
        return ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
        'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
    elif subset == 'lowercase_no_diacritics':
        return ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
        'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

def calculate_accuracy_per_class(model, test_dataset, test_dataset_name):
    '''
    Calculates the accuracy per class for a given model and test dataset.

    Returns dict with class labels as keys and accuracy as values.
    '''
        
    y_pred = model.predict(test_dataset)
    y_pred = np.argmax(y_pred, axis=1)
    # get labels
    y_true = test_dataset.map(lambda x, y: y).as_numpy_iterator()
    y_true = np.concatenate(list(y_true))
    # calculate accuracy per class
    labels = get_readable_class_labels(test_dataset_name)
    class_accuracy = np.zeros(len(labels))
    for i, label in enumerate(labels):
        class_accuracy[i] = np.sum(y_pred[y_true == i] == i) / np.sum(y_true == i)
    return { label: acc for label, acc in zip(labels, class_accuracy) }
    

def plot_accuracy_per_class(class_accuracy_dict):
    plt.figure(figsize=(10, 5))
    labels = list(class_accuracy_dict.keys())
    class_accuracy = list(class_accuracy_dict.values())
    plt.bar(labels, class_accuracy)
    plt.xticks(labels)
    plt.xlabel("Class")
    plt.ylabel("Accuracy")
    plt.title("Accuracy per class")
    plt.show()


def accuracy_table(class_accuracy_dict):
    labels = list(class_accuracy_dict.keys())
    class_accuracy = list(class_accuracy_dict.values())
    return wandb.Table(columns=["Class", "Accuracy"], data=list(zip(labels, class_accuracy)))

def get_number_of_classes(ds: tf.data.Dataset) -> int:
    """
    Returns the number of classes in a dataset.
    """
    labels_iterator= ds.map(lambda x, y: y).as_numpy_iterator()
    labels = np.concatenate(list(labels_iterator))
    return len(np.unique(labels))

def get_number_of_examples(ds: tf.data.Dataset) -> int:
    """
    Returns the number of examples in a dataset.
    """
    return sum(1 for _ in ds)

def preprocess_dataset(ds: tf.data.Dataset, batch_size: int, cache: bool = True) -> tf.data.Dataset:
    ds = ds.map(lambda x, y: (tf.cast(x, tf.float32) / 255.0, y))  # normalize
    ds = ds.unbatch().batch(batch_size)
    if cache:
        ds = ds.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
    return ds

def calculate_model_compressed_size_on_disk(path: str) -> int:
    compressed_path = path + ".zip"
    with zipfile.ZipFile(compressed_path, 'w', compression=zipfile.ZIP_DEFLATED) as f:
        f.write(path)
    return pathlib.Path(compressed_path).stat().st_size    

def calculate_model_num_parameters(model: tf.keras.Model) -> int:
    return model.count_params()

def calculate_model_flops(summary) -> float:
    # from run.summary get GFLOPs or GFLOPS whichever is available
    if "GFLOPs" in summary.keys():
        return summary.get("GFLOPs")
    elif "GFLOPS" in summary.keys():
        return summary.get("GFLOPS")
    else:
        return 0

def plot_history(history, title):
    plt.figure(figsize=(15,7))
    plt.suptitle(title)
    
    plt.subplot(121)
    plt.plot(history.history['accuracy'], label='train')
    plt.plot(history.history['val_accuracy'], label='val')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend()
    
    plt.subplot(122)
    plt.plot(history.history['loss'], label='train')
    plt.plot(history.history['val_loss'], label='val')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend()

In [None]:
defaults = dict(
    batch_size=32*2,
    epochs=50,    
    optimizer="adam"
)

def train_model(model_name, artifact_name, defaults):
    with wandb.init(project="master-thesis", job_type="training", name=model_name, config=defaults, tags=[artifact_name]) as run:
        
        # hyperparameters
        epochs = wandb.config.epochs
        bs = wandb.config.batch_size

        ds_train, ds_test, ds_val = load_data(run, artifact_name=artifact_name)

        num_classes = get_number_of_classes(ds_val)

        ds_train = preprocess_dataset(ds_train, batch_size=bs)
        ds_val = preprocess_dataset(ds_val, batch_size=bs)
        ds_test = preprocess_dataset(ds_test, batch_size=bs, cache=False)

        model = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=(32, 32, 1)),

                tf.keras.layers.Conv2D(16,kernel_size=(3, 3), activation="relu"),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.Conv2D(16,kernel_size=(3, 3), activation="relu"),
                tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
                tf.keras.layers.Dropout(0.2),
                
                tf.keras.layers.Conv2D(32,kernel_size=(3, 3), activation="relu"),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.Conv2D(32,kernel_size=(3, 3), activation="relu"),
                tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
                tf.keras.layers.Dropout(0.25),

                tf.keras.layers.Flatten(),
                tf.keras.layers.Dense(512, activation="gelu"),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.Dropout(0.1),
                tf.keras.layers.Dense(num_classes),
            ]
        )

        model.compile(
            optimizer="adam",
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            metrics=["accuracy"],
        )
        history = model.fit(
            ds_train,
            epochs=wandb.config.epochs,
            validation_data=ds_val,
            callbacks=[
                WandbCallback(
                    compute_flops=True, 
                    save_model=False, 
                    log_weights=False, 
                    log_gradients=False
                )
            ],
        )

        plot_history(history, "Baseline")
        tf.keras.models.save_model(model, 'model_baseline.h5', include_optimizer=False)

        # evaluate model on ds_test and log to wandb
        test_loss, test_acc = model.evaluate(ds_test)

        # calculate model size on disk, flops and number of parameters
        flops = calculate_model_flops(wandb.run.summary)
        num_parameters = calculate_model_num_parameters(model)
        compressed_disk_size = calculate_model_compressed_size_on_disk('model_baseline.h5')
        class_acc = calculate_accuracy_per_class(model, ds_test, artifact_base_name)

        data_to_log = {
            "test_loss": test_loss, 
            "test_accuracy": test_acc, 
            "num_parameters": num_parameters,
            "compressed_disk_size": compressed_disk_size, 
            "model_flops": flops,
            "accuracy_per_class": accuracy_table(class_acc)
            }
        plot_accuracy_per_class(class_acc)
        print(data_to_log)
        wandb.log(data_to_log)

        # upload model to wandb
        wandb.save('model_baseline.h5')

artifact_base_names = [ "lowercase", "lowercase_no_diacritics", "uppercase", "phcd_paper",]
for artifact_base_name in artifact_base_names:
    model_name = f"architecture-5"
    artifact_name = f"{artifact_base_name}_splits_tfds" # "phcd_paper_splits_tfds

    for i in range(20):
        train_model(model_name, artifact_name, defaults)