In [None]:
import wandb
import tensorflow as tf
import numpy as np
import pathlib
import shutil

def load_data( run = wandb.init(project="master-thesis", job_type="preprocessing")) -> pathlib.Path:
    """
    Unpacks data from an artifact into a folder and returns the path to the folder.
    """

    artifact_name = f"letters_splits"
    artifact = run.use_artifact(f"master-thesis/{artifact_name}:latest")
    artifact_dir = artifact.download()
    artifact_dir = pathlib.Path(artifact_dir).resolve()
    
    for split_file in artifact_dir.iterdir():
        if split_file.name.endswith(".tar.gz"):
            split = split_file.name.replace(".tar.gz", "")
            shutil.unpack_archive(split_file, artifact_dir / split, format="gztar")
    return [ artifact_dir / split for split in ["train", "test", "val"]]

def get_number_of_classes(ds: tf.data.Dataset) -> int:
    """
    Returns the number of classes in a dataset.
    """
    return len(ds.class_names)

def create_tf_dataset(split_path: pathlib.Path, batch_size: int = 32):
    """
    Creates a tf dataset from path containing a folder for each class.
    """
    ds = tf.keras.utils.image_dataset_from_directory(
        split_path, 
        image_size=(32,32), 
        batch_size=batch_size,
        color_mode='grayscale',
    )
    return ds

def preprocess_dataset(ds: tf.data.Dataset) -> tf.data.Dataset :
    ds = ds.map(lambda x, y: (tf.cast(x, tf.float32) / 255.0, y)) # normalize
    return ds.cache().prefetch(buffer_size=tf.data.AUTOTUNE)

In [None]:
run = wandb.init(project="master-thesis", job_type="training")
BATCH_SIZE = 32

ds_train, ds_test, ds_val = [
    create_tf_dataset(split_path, batch_size=BATCH_SIZE) for split_path in load_data(run=run)
    ]

num_classes = len(ds_train.class_names)

print(f"In training set there are {len(ds_train)} examples")
print(f"In validation set there are {len(ds_val)} examples")
print(f"In validation set there are {len(ds_test)} examples")

print(f"There are {num_classes} classes")

In [None]:
ds_train = preprocess_dataset(ds_train)
ds_val = preprocess_dataset(ds_val)
ds_test = preprocess_dataset(ds_test)

model = tf.keras.Sequential(
    [
        tf.keras.layers.InputLayer(input_shape=(32, 32, 1)),
        tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
        tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(num_classes, activation="softmax"),
    ]
)

model.compile(
    optimizer="adam",
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

history = model.fit(
    ds_train,
    epochs=10,
    validation_data=ds_val,
    #callbacks=[wandb.keras.WandbCallback()],
)

In [None]:
# plot history
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 1, figsize=(15, 10))
epochs = range(1, len(history.history["loss"]) + 1)
ax.plot(epochs, history.history["accuracy"], label="accuracy")
ax.plot(epochs, history.history["val_accuracy"], label="val_accuracy")
ax.set_xlabel("Epoch")
ax.set_ylabel("Accuracy")
ax.legend(loc="lower right")

plt.show()