In [None]:
import os
import sys
from pathlib import Path
from dotenv import load_dotenv

root = Path(os.getcwd()).parents[1].resolve()
sys.path.insert(0, str(root))

assert load_dotenv(".env")

In [None]:
from tensorflow import dtypes, one_hot, io, image, data
from core.utils.config import BaselineConfig, emotion
from core.models.baseline import DenseNet
from core.utils.config import BaselineConfig

In [None]:
# model params
densenet_config = BaselineConfig()

frame_shape = (
    densenet_config.image_height,
    densenet_config.image_width,
    densenet_config.image_channels,
)

In [None]:
model = DenseNet(densenet_config)

model.load_weights(root.joinpath("models/baseline/densenet121_rgb.h5"))

model.compile()

In [None]:
def process_image(path):
    """generate tensorflow.data.Dataset from images

    Args:
        path (Tensor): Tensor object from tensorflow.data.Dataset

    Returns:
        img: tensorflow image object
    """
    # load image
    img = io.read_file(path)
    img = io.decode_jpeg(img, channels=3)
    img = image.convert_image_dtype(img, dtypes.float32)
    img = image.resize(img, frame_shape[:2])
    img = img / 255.0

    return img


def generate_tf_dataset(path: os.PathLike):

    image_list = list(path.rglob("*.jpg"))

    label_list = [emotion[path.parent.name] for path in image_list]

    # image
    image_dataset = data.Dataset.from_tensor_slices(
        [str(path) for path in image_list]
    ).map(process_image)

    # label
    label_dataset = data.Dataset.from_tensor_slices(label_list).map(
        lambda x: one_hot(indices=x, depth=7, dtype=dtypes.uint8)
    )

    return data.Dataset.zip((image_dataset, label_dataset))

In [None]:
data_dir = root.joinpath("data/affectnet")

assert data_dir.exists()

validation_dataset = (
    generate_tf_dataset(data_dir.joinpath("test/faces"))
    .batch(batch_size=32)
    .prefetch(buffer_size=2)
)

In [None]:
for batch, label in validation_dataset.take(1):
    print(batch.shape)
    print(label.shape)

In [None]:
model.evaluate(validation_dataset)