In [1]:
# basic imports
import tensorflow as tf 
import torch
import io
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers

from keras.utils import dataset_utils
import matplotlib.pyplot as plt

import baseline_config

In [None]:
def paths_and_labels_to_dataset(image_paths,labels,num_classes):
    path_ds = tf.data.Dataset.from_tensor_slices(image_paths)
    img_ds = path_ds.map(
        lambda path: tf.io.read_file(path) , num_parallel_calls=tf.data.AUTOTUNE
    )
    label_ds = dataset_utils.labels_to_dataset(labels, True, num_classes)
    img_ds = tf.data.Dataset.zip((img_ds, label_ds))
    return img_ds

def create_dataset(subset):
    image_paths, labels, class_names = dataset_utils.index_directory(
            baseline_config.dataset_path + subset,
            labels="inferred",
            formats=('.pt'),
            class_names=None,
            shuffle=False,
            seed=42,
            follow_links=True)

    dataset = paths_and_labels_to_dataset(
        image_paths=image_paths,
        labels=labels,
        num_classes=len(class_names))
    
    return dataset, class_names

train_dataset, class_names = create_dataset('TRAIN/')
test_dataset,_             = create_dataset('TEST/')
validation_dataset,_       = create_dataset('VALIDATION/')

In [None]:
def dataset_tranforms(image,label):
    image = torch.load(io.BytesIO(image.numpy()))
    image = tf.expand_dims(image, -1)
    label = tf.cast(label, tf.int64)
    return image,label

py_func_wrapper = lambda x,y: tf.py_function(func=dataset_tranforms, inp=[x,y], Tout=[tf.float32,tf.int64])

train_dataset = ( 
                  train_dataset
                  .map(py_func_wrapper)
                  .batch(baseline_config.batch_size)
                )

validation_dataset = ( 
                  validation_dataset
                  .map(py_func_wrapper)
                  .batch(baseline_config.batch_size)
                )

test_dataset = ( 
                  test_dataset
                  .map(py_func_wrapper)
                  .batch(baseline_config.batch_size)
                )

for melspec, label in train_dataset.take(2):
    print(melspec.shape, melspec, label)

In [None]:
model = keras.Sequential(
    [
        keras.Input(shape=(313,128,1)),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(len(class_names), activation="softmax"),
    ]
)

model.summary()

In [None]:
epochs = 15

model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(), 
              optimizer="adam", 
              metrics=["accuracy"], 
              run_eagerly=True)

model.fit(train_dataset, 
          validation_data=validation_dataset, 
          batch_size=baseline_config.batch_size, 
          epochs=baseline_config.max_epoch)