# Dataset overview

In this notebook we review class counts in train and validation splits

In [None]:
! pip install matplotlib wandb tensorflow

In [None]:
import wandb
import pathlib
import shutil
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt


def load_data(run: wandb.sdk.wandb_run.Run) -> pathlib.Path:
    """
    Unpacks data from an artifact into a folder and returns the path to the folder.
    """

    artifact_name = f"letters_splits"
    artifact = run.use_artifact(f"master-thesis/{artifact_name}:latest")
    artifact_dir = pathlib.Path(
        f"./artifacts/{artifact.name.replace(':', '-')}"
    ).resolve()
    if not artifact_dir.exists():
        artifact_dir = artifact.download()
        artifact_dir = pathlib.Path(artifact_dir).resolve()
        for split_file in artifact_dir.iterdir():
            if split_file.name.endswith(".tar.gz"):
                split = split_file.name.replace(".tar.gz", "")
                shutil.unpack_archive(split_file, artifact_dir / split, format="gztar")

    return [artifact_dir / split for split in ["train", "test", "val"]]


In [None]:
run = wandb.init(project="master-thesis", job_type="preprocessing")
split_paths = load_data(run=run)

ds_train = tf.keras.utils.image_dataset_from_directory(
        split_paths[0],
        image_size=(32, 32),
        color_mode="grayscale",
    )

ds_val = tf.keras.utils.image_dataset_from_directory(
        split_paths[2],
        image_size=(32, 32),
        color_mode="grayscale",
    )

number_of_classes = len(ds_train.class_names)

In [None]:
# calculate class count for each split
train_class_count = np.zeros(number_of_classes)
for _, label in ds_train:
    train_class_count += tf.math.bincount(label, minlength=number_of_classes)

val_class_count = np.zeros(number_of_classes)
for _, label in ds_val:
    val_class_count += tf.math.bincount(label, minlength=number_of_classes)

# plot class count for each split
plt.bar(ds_train.class_names, train_class_count)
plt.title("Train")
plt.show()

plt.bar(ds_val.class_names, val_class_count)

In [None]:
# log class count for each split to wandb

wandb.log({"train_class_count": wandb.Histogram(train_class_count)})
wandb.log({"val_class_count": wandb.Histogram(val_class_count)})
