In [19]:
import wandb
import tensorflow as tf
import numpy as np
import pathlib
from typing import List
from sklearn.preprocessing import OneHotEncoder

def load_label_dict(run = wandb.init(project="master-thesis", job_type="preprocessing")) -> List[str]:
    """
    Loads labels from an artifact and returns them as a list.
    """
    artifact = run.use_artifact(f"master-thesis/letters_labels:latest")
    artifact_dir = artifact.download()
    labels = np.load(pathlib.Path(artifact_dir) / "labels.npy", allow_pickle=True)
    return {idx: label for idx, label in enumerate(labels)}


def load_data(split: str = "train", run = wandb.init(project="master-thesis", job_type="preprocessing")) -> np.ndarray:
    """
    Loads data from an artifact and returns it as a numpy array.
    """
    if split not in ["train", "test", "val"]:
        raise ValueError("Split must be either train, test or val")

    artifact_name = f"letters_{split}"
    artifact = run.use_artifact(f"master-thesis/{artifact_name}:latest")
    artifact_dir = artifact.download()
    artifact_dir = pathlib.Path(artifact_dir).resolve()
    data = []
    labels = [ str(l).replace(".npz","") for l in artifact_dir.iterdir()]
    for label in labels:
        data.append(
            np.load(
                f"{label}.npz", allow_pickle=True
            )["arr_0"]
        )
    data = np.concatenate(data)
    labels = [ int(l.split("/")[-1]) for l in labels ]
    return data, labels

def create_tf_dataset(data, labels, number_of_classes, batch_size: int = 32, shuffle: bool = True, normalize: bool = True):
    """
    Creates a tf dataset from data and labels.
    """
    data = tf.data.Dataset.from_tensor_slices(data)
    labels = tf.data.Dataset.from_tensor_slices(labels)
    
    df = tf.data.Dataset.zip((data, labels))
    if shuffle:
        df = df.shuffle(buffer_size=10000)
    if normalize:
        df = df.map(lambda x, y: (tf.cast(x, tf.float32) / 255.0, y))
    return df.batch(batch_size).prefetch(buffer_size=tf.data.experimental.AUTOTUNE).cache()

def calculate_number_of_classes(labels: List[str]) -> int:
    """
    Calculates the number of classes from a list of labels.
    """
    return len(set(labels))

In [21]:
run = wandb.init(project="master-thesis", job_type="training")

labels_dict = load_label_dict(run=run)
train_data, train_labels = load_data("train", run=run)
val_data, val_labels = load_data("val", run=run)
number_of_classes = calculate_number_of_classes(train_labels)
print(f"Number of classes: {number_of_classes}")

print(f"In training set there are {len(train_data)} examples")
print(f"In validation set there are {len(val_data)} examples")

enc = OneHotEncoder()

train_labels = [labels_dict[l] for l in train_labels]
train_labels = enc.fit_transform(np.array(train_labels).reshape(-1,1)).toarray()
val_labels = enc.transform(np.array(val_labels).reshape(-1,1)).toarray()

[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   35 of 35 files downloaded.  
[34m[1mwandb[0m:   35 of 35 files downloaded.  


Number of classes: 35
In training set there are 176116 examples
In validation set there are 22001 examples


In [None]:
batch_size = 128
df_train = create_tf_dataset(train_data, train_labels, number_of_classes, batch_size=batch_size)
df_val = create_tf_dataset(val_data, val_labels, number_of_classes, batch_size=batch_size)


model = tf.keras.Sequential(
    [
        tf.keras.layers.InputLayer(input_shape=(32, 32, 1)),
        tf.keras.layers.Conv2D(32, (3, 3), activation="relu"),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(64, (3, 3), activation="relu"),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(64, (3, 3), activation="relu"),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation="relu"),
        tf.keras.layers.Dense(number_of_classes),
    ]
)

model.compile(
    optimizer="adam",
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

history = model.fit(
    df_train,
    epochs=10,
    validation_data=df_val,
    callbacks=[wandb.keras.WandbCallback()],
)

In [None]:
# plot history
import matplotlib.pyplot as plt

plt.plot(history.history["accuracy"], label="accuracy")
plt.plot(history.history["val_accuracy"], label="val_accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.ylim([0.5, 1])
plt.legend(loc="lower right")

plt.show()