In [None]:
import numpy as np
import random
import matplotlib.pyplot as plt

from Constants.Paths import *
from Constants.Labels import *
from SpectrogramLoading import *
from Models.TrainingHistory import TrainingHistory
from Models.CnnModel import CnnModel, WandbDetails
from Models.InputPadding import pad_to_same_size
from Models.HistoryPlots import plot_loss_history, plot_accuracy_history

In [None]:
train_paths, val_paths = get_divided_paths_with_labels()

random.seed(42)
random.shuffle(train_paths)
random.shuffle(val_paths)

In [None]:
model = CnnModel(
    classes=labels,
    learning_rate=1e-3,
    lr_decay=1e-4,
    beta_1=0.9,
    beta_2=0.999,
    eps=1e-8,
    extractor_dropout_1=0.0,
    extractor_dropout_2=0.0,
    extractor_dropout_3=0.0,
    classifier_dropout_1=0.0,
    classifier_dropout_2=0.0,
    classifier_dropout_3=0.0,
    wandb_details=WandbDetails(
        project="test-project",
        experiment_name="integration-test-5",
        config_name="first-cnn",
        artifact_name="test-model"
    )
)

In [None]:
model.train(train_paths, val_paths, epochs=25, batch_size=32)

In [None]:
_, ax = plt.subplots(2, 1, figsize=(10, 11))
plot_loss_history(model.get_history(), ax[0])
plot_accuracy_history(model.get_history(), ax[1])
plt.show()