In [None]:
import torch
from torch.utils.data import DataLoader
from utils.train_audio import train_model
from utils.dataset_audio import AudioDataset
from utils.dataset_audio import collate_fn
from models.model_audio import RNN
from utils.test_audio import test_model
from utils.predict_audio import log_results

if __name__ == "__main__":
    root_data = "data"
    batch_size = 32
    epochs = 20
    learning_rate = 1e-3
    sample_rate = 16000
    save_path = "best_model.pth"
    log_file = "log.txt"
    bidirectional = True

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Usando il dispositivo: {device}")

    print("Caricamento dei dataset...")
    train_dataset = AudioDataset(root_data, split="train", sample_rate=sample_rate)
    val_dataset = AudioDataset(root_data, split="val", sample_rate=sample_rate)
    test_dataset = AudioDataset(root_data, split="test", sample_rate=sample_rate)

    label_mapping = train_dataset.get_label_mapping()
    num_classes = len(label_mapping)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    model = RNN(
        input_size=1,
        hidden_size=128,
        num_layers=3,
        num_classes=num_classes,
        bidirectional=bidirectional
    )

    print("Inizio training...")
    train_model(model, train_loader, val_loader, num_epochs=epochs, lr=learning_rate, device=device, save_path=save_path)

    print("Test finale sul dataset di test...")
    test_acc = test_model(model, test_loader, device=device, model_path=save_path)

    print("Salvataggio log...")
    log_results(log_file, final_test_accuracy=test_acc, epochs=epochs, learning_rate=learning_rate)

    # predict_single_file(model, "esempio.wav", label_mapping, device=device, sample_rate=sample_rate, model_path=save_path)



Usando il dispositivo: cpu
Caricamento dei dataset...
Inizio training...

Epoch 1/20


Training:   0%|          | 0/14 [00:00<?, ?it/s]

: 