In [None]:
import sys, os
from models import FCN, WarpingLayer, input_layer
from preprocessing import ConstantLengthDataGenerator
import numpy as np
import tensorflow.keras as keras
import matplotlib.pyplot as plt
import sklearn
from sklearn.model_selection import train_test_split

In [None]:
data_path = "./data"
X, y = np.load(f"{data_path}/X.npy", allow_pickle=True), np.load(f"{data_path}/y.npy")
y.shape, X.shape

In [None]:
mask = np.char.startswith(y, "GunPoint").reshape(-1)
y = y[mask, :]
X = X[mask]

In [None]:
y_encoder = sklearn.preprocessing.OneHotEncoder(categories="auto")
y = y_encoder.fit_transform(y.reshape(-1, 1)).toarray()
y.shape, X.shape

In [None]:
number_of_classes = y.shape[1]
initial_learning_rate = 1e-4
output_directory = f"{data_path}/models/fcn_warping/outputs"
batch_size = min(64, X.shape[0])
os.makedirs(output_directory, exist_ok=True)

In [None]:
input_layer = keras.layers.Input(shape=(None, 1))
warping_layer = WarpingLayer(256)(input_layer)
fcn_model = FCN_model(number_of_classes)(warping_layer)  # number of classes

lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=initial_learning_rate, decay_steps=3, decay_rate=1
)
model.compile(
    loss="categorical_crossentropy",
    optimizer=keras.optimizers.Adam(lr_schedule),
    metrics=["accuracy"],
    #     run_eagerly=True,
)

In [None]:
model.summary()

In [None]:
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.25)

In [None]:
kwargs = {"min_length": 256, "max_length": 256}
data_generator_train = ConstantLengthDataGenerator(
    X_train, y_train, batch_size=batch_size, **kwargs
)
data_generator_val = ConstantLengthDataGenerator(
    X_val, y_val, batch_size=len(y_val), **kwargs
)

In [None]:
history = model.fit(
    data_generator_train, epochs=30, validation_data=next(data_generator_val)
)

In [None]:
figure = plt.figure()
plt.plot(history.history["accuracy"])
plt.plot(history.history["val_accuracy"])
plt.title("model accuracy")
plt.ylabel("accuracy")
plt.xlabel("epoch")
plt.legend(["train", "validation"], loc="upper left")
figure

In [None]:
figure = plt.figure()
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.title("model loss")
plt.ylabel("loss")
plt.xlabel("epoch")
plt.legend(["train", "validation"], loc="upper left")
figure