In [15]:
import wandb
import tensorflow as tf
import numpy as np
import pathlib
import shutil
from typing import List

def load_data(run: wandb.sdk.wandb_run.Run) -> List[tf.data.Dataset]:
    """
    Downloads datasets from a wandb artifact and loads them into a list of tf.data.Datasets.
    """

    artifact_name = f"letters_splits_tfds"
    artifact = run.use_artifact(f"master-thesis/{artifact_name}:latest")
    artifact_dir = pathlib.Path(
        f"./artifacts/{artifact.name.replace(':', '-')}"
    ).resolve()
    if not artifact_dir.exists():
        artifact_dir = artifact.download()
        artifact_dir = pathlib.Path(artifact_dir).resolve()
    
    output_list = []
    for split in ["train", "test", "val"]:
        ds = tf.data.Dataset.load(str(artifact_dir / split), compression="GZIP")
        output_list.append(ds)
    
    return output_list

def get_number_of_classes(ds: tf.data.Dataset) -> int:
    """
    Returns the number of classes in a dataset.
    """
    labels_iterator= ds.map(lambda x, y: y).as_numpy_iterator()
    labels = np.concatenate(list(labels_iterator))
    return len(np.unique(labels))

def preprocess_dataset(ds: tf.data.Dataset) -> tf.data.Dataset :
    ds = ds.map(lambda x, y: (tf.cast(x, tf.float32) / 255.0, y)) # normalize
    return ds.cache().prefetch(buffer_size=tf.data.AUTOTUNE)

In [16]:
run = wandb.init(project="master-thesis", job_type="training")

[34m[1mwandb[0m:   3 of 3 files downloaded.  


In [17]:
BATCH_SIZE = 32
ds_train, ds_test, ds_val = load_data(run)

num_classes = get_number_of_classes(ds_val)

print(f"There are {num_classes} classes")
print(f"Training set has {len(ds_train)} batches")
print(f"Test set has {len(ds_test)} batches")
print(f"Validation set has {len(ds_val)} batches")

Found 176116 files belonging to 35 classes.
Found 22001 files belonging to 35 classes.
Found 22001 files belonging to 35 classes.
There are 35 classes
Training set has 176116 batches
Test set has 22001 batches
Validation set has 22001 batches


In [18]:
ds_train = preprocess_dataset(ds_train, batch_size=bs)
ds_val = preprocess_dataset(ds_val, batch_size=bs)
ds_test = preprocess_dataset(ds_test, batch_size=bs)

model = tf.keras.Sequential(
    [
        tf.keras.layers.InputLayer(input_shape=(32, 32, 1)),
        tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
        tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(num_classes, activation="softmax"),
    ]
)

model.compile(
    optimizer="adam",
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

history = model.fit(
    ds_train,
    epochs=50,
    validation_data=ds_val,
    #callbacks=[wandb.keras.WandbCallback()],
)

SyntaxError: 'break' outside loop (3134801564.py, line 5)

In [None]:
# plot history
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 1, figsize=(15, 10))
epochs = range(1, len(history.history["loss"]) + 1)
ax.plot(epochs, history.history["accuracy"], label="accuracy")
ax.plot(epochs, history.history["val_accuracy"], label="val_accuracy")
ax.set_xlabel("Epoch")
ax.set_ylabel("Accuracy")
ax.legend(loc="lower right")

plt.show()