In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import mixed_precision
import tensorflow.keras.backend as K
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping,  ReduceLROnPlateau,  TensorBoard
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
from configs import Configs as C
from data_processing import create_dataset, calculate_spectrogram_dimensions
from model import build_CNN
import datetime
# get configs
c = C()

In [None]:
# empty out VRAM if being used for some reason
K.clear_session()
# allow for mixed prcision compute for more effienct compute
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)
# enable GPU dynamic VRAM allocation 
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

In [3]:
# functions for getting .wav file paths for different datasets
def txt_file_path_loader(file_path, base_dir):
    with open(file_path, 'r') as f:
        lines = f.readlines()
    file_paths = [base_dir + line.strip() for line in lines]
    return file_paths

def get_all_wav_file_paths(data_dir):
    wav_file_paths = []
    for root, _, files in os.walk(data_dir):
        for file in files:
            if file.endswith(".wav"):
                wav_file_path = root + '/' + file
                wav_file_paths.append(wav_file_path)
    return wav_file_paths

In [4]:
# create a testing and validation list containing relevant file paths
test_file_path = './data/testing_list.txt'
val_file_path = './data/validation_list.txt'
base_dir = './data/'

test_paths = txt_file_path_loader(test_file_path, base_dir)
val_paths = txt_file_path_loader(val_file_path, base_dir)
all_paths = get_all_wav_file_paths(base_dir)
# get all files and subtract testing and val files to get training files
train_paths = set(all_paths) - set(val_paths) - set(test_paths)
# convert set back into list
train_paths = list(train_paths)

In [5]:
# batch sizes
TRAIN_BATCH_SIZE = c.train_batch_size
TEST_BATCH_SIZE = c.test_batch_size
VAL_BATCH_SIZE = c.val_batch_size

# for readiability
FRAME_LENGTH = c.spectrogram_configs['frame_length']
FRAME_STEP = c.spectrogram_configs['frame_step']
TARGET_RATE = c.target_rate
# creating datasets
train_dataset = create_dataset(train_paths, TRAIN_BATCH_SIZE, TARGET_RATE, FRAME_LENGTH, FRAME_STEP)
val_dataset = create_dataset(test_paths, VAL_BATCH_SIZE, TARGET_RATE, FRAME_LENGTH, FRAME_STEP)
# testing dataset may be added later

In [6]:
# counter = 1
# for t, l in train_dataset.take(1):
#     for z in t:
#         print(z.shape)
#         plt.figure()
#         plt.imshow(z.numpy())
#     print('end of batch', counter)
#     counter+= 1

In [None]:
# model details
ACTIVATION = 'relu' #maybe try leaky relu
NUM_CLASSES = c.num_classes
height, width = calculate_spectrogram_dimensions(1, TARGET_RATE, FRAME_LENGTH, FRAME_STEP)
INPUT_SHAPE = (height, width, 1)
print(INPUT_SHAPE)
LR = c.learning_rate

# build model
model = build_CNN(INPUT_SHAPE, NUM_CLASSES, ACTIVATION)
model.summary()
# Compile the model
model.compile(
    optimizer=Adam(learning_rate=LR, clipnorm = 1.0),
    loss='categorical_crossentropy',
    metrics = ['accuracy']
    )

# Callbacks for selecting the best model and early stopping if more training does nothing 
checkpoint = ModelCheckpoint('OCR model', monitor='val_loss', save_best_only=True, verbose=1)
early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True, verbose=1)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-6)
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)

callbacks = [
    checkpoint,
    early_stopping,
    reduce_lr,
    tensorboard_callback
]

In [None]:
EPOCHS = c.epochs
history = model.fit(
    train_dataset,
    epochs = EPOCHS,
    validation_data = val_dataset,
    callbacks = callbacks,
    verbose = 1
)

In [None]:
model.save('command classification model')