In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from datetime import datetime
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [2]:
from helper.label import classes
from helper.audio_extraction.get_file_list import get_file_list
from helper.audio_extraction.padded_and_windowed import extract_windowed_features
from FeedForward import ChordAI

In [3]:
TRAINING_DATA_PATH = './audio'
train_list = get_file_list(TRAINING_DATA_PATH)
data, train_data = extract_windowed_features(train_list, classes, test_rate=0.2)



windowed_features:  936
test_windowed_features:  240


In [4]:
INPUT_SIZE = data[0][0].shape[0]
OUTPUT_SIZE = len(classes)
BATCH_SIZE = 32

In [5]:

model = ChordAI(INPUT_SIZE, OUTPUT_SIZE).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

In [6]:
def create_data_loader(data, batch_size):
    dataloader = DataLoader(data, batch_size=batch_size)
    return dataloader

In [8]:
train_dataloader = create_data_loader(data, BATCH_SIZE)
test_dataloader = create_data_loader(train_data, BATCH_SIZE)

In [9]:
def train_one_epoch(model, data_loader, loss_function, optimizer, device, mode='train'):
    acc = 0
    for inputs, targets in data_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        # change to tensor
        inputs = torch.tensor(inputs, dtype=torch.float32)
        targets = torch.tensor(targets, dtype=torch.long)
        
        # calculate loss
        predictions = model(inputs)
        loss = loss_function(predictions, targets)
        if mode == 'train':
            # backpropagate error and update weights
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        # calculate accuracy
        acc += (predictions.argmax(1) == targets).sum().item()
    return acc / len(data_loader.dataset)

In [10]:
def train(model, train_data, test_data, loss_function, optimizer, device):
    tac, tsc = [], []
    i = 0
    patience = 0
    scr = 0
    start_time = datetime.now()
    while True:
        i+=1
        print(f"\nEpoch : {i:4} | ", end=" ")
        train_acc = train_one_epoch(model, train_data, loss_function, optimizer, device)
        tac.append(train_acc)

        with torch.no_grad():
            test_acc = train_one_epoch(model, test_data, loss_function, optimizer, device, mode='test')
            tsc.append(test_acc)

        print(f"train acc : {train_acc:.4f} | test acc : {test_acc:.4f} | patience : {patience} | best acc : {scr:.4f}", end=" ") 
        if test_acc > scr:
            scr = test_acc
            patience = 0
            torch.save(model.state_dict(), "models/chord_model.pth")
            log = {
                "train_acc": tac,
                "test_acc": tsc
            }
            torch.save(log, "models/logs.pth")
        else:
            patience +=1

        if patience >= 5:
            break
    end_time = datetime.now()
    print(f"\nTraining completed in {(end_time-start_time).seconds} seconds")
    torch.save({
        'INPUT_SIZE': INPUT_SIZE,
        'OUTPUT_SIZE': OUTPUT_SIZE,
    }, "models/config.pth")

In [12]:
train(model, train_dataloader, test_dataloader, loss_fn, optimizer, device)


Epoch :    1 |  

  inputs = torch.tensor(inputs, dtype=torch.float32)
  targets = torch.tensor(targets, dtype=torch.long)


train acc : 0.0085 | test acc : 0.2042 | patience : 0 | best acc : 0.0000 
Epoch :    2 |  train acc : 0.1677 | test acc : 0.3708 | patience : 0 | best acc : 0.2042 
Epoch :    3 |  train acc : 0.3237 | test acc : 0.4417 | patience : 0 | best acc : 0.3708 
Epoch :    4 |  train acc : 0.4605 | test acc : 0.5792 | patience : 0 | best acc : 0.4417 
Epoch :    5 |  train acc : 0.6218 | test acc : 0.6458 | patience : 0 | best acc : 0.5792 
Epoch :    6 |  train acc : 0.6699 | test acc : 0.6375 | patience : 0 | best acc : 0.6458 
Epoch :    7 |  train acc : 0.7297 | test acc : 0.7375 | patience : 1 | best acc : 0.6458 
Epoch :    8 |  train acc : 0.7767 | test acc : 0.7125 | patience : 0 | best acc : 0.7375 
Epoch :    9 |  train acc : 0.8697 | test acc : 0.8167 | patience : 1 | best acc : 0.7375 
Epoch :   10 |  train acc : 0.8974 | test acc : 0.8042 | patience : 0 | best acc : 0.8167 
Epoch :   11 |  train acc : 0.9028 | test acc : 0.8167 | patience : 1 | best acc : 0.8167 
Epoch :   12 | 