In [None]:
from transformer_hyperparam import get_compiled_transformer

In [None]:
import os
import pandas as pd
from copy import deepcopy
import numpy as np
import random
from tensorflow import keras

class ASLDataGenerator(keras.utils.Sequence):

    def __init__(self, dataset_path: str, max_phrase_length: int, max_sign_length: int, batch_size: int, train: bool = True):
        super().__init__()
        self.ds_path = dataset_path
        self.files = os.listdir(dataset_path)
        self.max_phrase_length = max_phrase_length
        self.max_sign_length = max_sign_length
        self.unwanted_columns = ['sequence_id', 'frame', 'participant_id', 'phrase']
        self.train = train
        self.batch_size = batch_size

    def __len__(self):
        return len(self.files)//self.batch_size

    def shuffle(self):
        random.shuffle(self.files)

    def on_epoch_end(self):
        if self.train:
            self.shuffle()

    def __getitem__(self, idx: int):
        # TODO: integrate batch_size --> take all examples from idx to idx + batch_size
        batch_signs = []
        batch_contexts = []
        batch_phrases = []
        for i in range(idx, idx+self.batch_size):
            df_phrase = pd.read_pickle(self.ds_path + '/' + self.files[idx])
            phrase = df_phrase.iloc[0].phrase
            if len(phrase) > self.max_phrase_length:
                print('Houston we have a problem!')

            while len(phrase) < self.max_phrase_length:
                phrase.append(59)
            context = deepcopy(phrase)
            context.insert(0, 60)
            context = context[:-1]

            df_phrase = df_phrase.drop(self.unwanted_columns, axis=1)
            signs = df_phrase.to_numpy(copy=True)

            if signs.shape[0] > self.max_sign_length:
                print('Hooooouuuuusteeeeeen!')
            signs = np.pad(signs, [(0, self.max_sign_length-signs.shape[0]), (0,0)], 'edge')
            batch_signs.append(signs)
            batch_contexts.append(context)
            batch_phrases.append(phrase)

        batch_signs = np.array(batch_signs)
        batch_contexts = np.array(batch_contexts)
        batch_phrases = np.array(batch_phrases)

        return [batch_signs, batch_contexts], batch_phrases

In [None]:
MAX_PHRASE_LENGTH = 100
MAX_SIGN_LENGTH = 900
BATCH_SIZE=32

data_gen_train = ASLDataGenerator('./dataset complete/preprocessed_files_data_generator/train_ones', MAX_PHRASE_LENGTH, MAX_SIGN_LENGTH, batch_size=BATCH_SIZE)
data_gen_train.shuffle()


data_gen_test = ASLDataGenerator('./dataset complete/preprocessed_files_data_generator/test_ones', MAX_PHRASE_LENGTH, MAX_SIGN_LENGTH, train=False, batch_size=BATCH_SIZE)

data_gen_val = ASLDataGenerator('./dataset complete/preprocessed_files_data_generator/val_ones', MAX_PHRASE_LENGTH, MAX_SIGN_LENGTH, train=False, batch_size=BATCH_SIZE)

In [None]:
import optuna
import json

epochs = 5

def objective(trial):
    d_model = trial.suggest_int('d_model', 20, 256)
    num_layers = trial.suggest_int('num_layers', 1, 10)
    num_heads=trial.suggest_int('num_heads', 2, 10)
    ff_dim=trial.suggest_int('ff_dim', 32, 1024)
    dropout_rate=trial.suggest_float('droupout_rate', 0., 0.6)

    with open ("./dataset complete/character_to_prediction_index.json", "r") as f:
        characters = json.load(f)

    output_vocab_size = len(characters) + 2

    get_compiled_transformer(d_model, num_layers, num_heads, ff_dim, dropout_rate, output_vocab_size)

    callbacks = [
        optuna.integration.TFKerasPruningsCallback(trial, 'val_loss'),

    ]
    history = transformer.fit(data_gen_train, epochs=epochs, batch_size=BATCH_SIZE, validation_data=data_gen_val, callbacks=callbacks)

    eval_results = transformer.evaluate(data_gen_test, batch_size = BATCH_SIZE)
    print(f'Result:{float(eval_results[0]) }')

    return float(eval_results[0])

In [None]:
study = optuna.create_study(direction='minimize', pruner=optuna.pruners.SuccessiveHalvingPruner(), sampler=optuna.samplers.TPESampler)

study.optimize(objective, n_trials=100)

pruned_trials = study.get_trials(states=[optuna.trial.TrialState.PRUNED])
complete_trials = study.get_trials(states=[optuna.trial.TrialState.COMPLETE])

print('Study statistics:')
print('   Number of finished Trials: ', len(study.trials))
print('   Number of pruned Trials: ', len(pruned_trials))
print('   Number of complete Trials: ', len(complete_trials))

print('Best Trial: ')
trial = study.best_trial

print('  MSE: ', trial.value)

print('   Params: ')
for key, value in trial.params.items():
    print('   {}: {}'.format(key, value))