# Parameter extraction

In [49]:
import torch
from torch import nn
import torchaudio
import os
import numpy as np
import matplotlib.pyplot as plt

### Load dataset

In [50]:
from src.gtfxdataset import GtFxDataset

AUDIO_DIR = "_assets/DATASET/GT-FX-ALL/"
ANNOTATIONS_FILE = os.path.join(AUDIO_DIR, "annotation.csv")

SAMPLE_RATE = 22050
NUM_SAMPLES = 22050*3

EFFECT_MAP = ["distortion", "chorus", "tremolo", "delay", "reverb"]
EFFECT = 4

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(f"Using device {device}")

mel_spectrogram = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_fft=512,
    hop_length=1050,
    n_mels=64,
    # power=2
)

spectrogram = torchaudio.transforms.Spectrogram(
    power=2,
    n_fft=127,
    win_length= 127,
    hop_length= 1040,
    normalized=True
)

mfcc = torchaudio.transforms.MFCC(
    sample_rate = SAMPLE_RATE, 
    n_mfcc = 64,
    melkwargs = {
        "n_fft": 1024,
        "hop_length": 1030,
        "n_mels": 64,
        "center": False})

fxData = GtFxDataset(ANNOTATIONS_FILE,
                        AUDIO_DIR,
                        mfcc,
                        SAMPLE_RATE,
                        NUM_SAMPLES,
                        device,
                        EFFECT_MAP[EFFECT])

signal, _, _, _, _ = fxData[0]
print(f"There are {len(fxData)} samples in the dataset.")
print(f"Shape of signal: {signal.shape}")
    

Using device cpu
There are 16896 samples in the dataset.
Shape of signal: torch.Size([1, 64, 64])


#### Split dataset into train and test sets

In [51]:
from src.extrector import train

BATCH_SIZE = 64

split_ratio = [0.8, 0.1, 0.1]
train_set, test_set, val_set = torch.utils.data.random_split(fxData, lengths=split_ratio)

train_dataloader = train.create_data_loader(train_set, BATCH_SIZE)
test_dataloader = train.create_data_loader(test_set, BATCH_SIZE)
val_dataloader = train.create_data_loader(val_set, BATCH_SIZE)


#### Model training

In [52]:
from src.extrector import model

LEARNING_RATE = 0.0003
EPOCHS = 5

WEIGHTS_DIR = "_weights/"
WEIGHTS_FILE = os.path.join(WEIGHTS_DIR, "c55_parameter_" + str(EFFECT) + ".pth")

if not os.path.exists('%s' % WEIGHTS_DIR):
    os.makedirs('%s' % WEIGHTS_DIR)

# construct model and assign it to device
cnn = model.Extractor().to(device)

# initialise loss funtion + optimiser
loss_fn = nn.MSELoss(reduction='mean')
optimiser = torch.optim.Adam(cnn.parameters(), lr=LEARNING_RATE)

# train model
train.train(cnn, train_dataloader, test_dataloader, loss_fn, optimiser, device, EPOCHS, effect=EFFECT)

# save model
torch.save(cnn.state_dict(), WEIGHTS_FILE)
print("Trained feed forward net saved at %s" %(WEIGHTS_FILE))

Epoch 1
loss: 0.082270  [  0/13568]
loss: 0.052145  [1280/13568]
loss: 0.049215  [2560/13568]
loss: 0.039952  [3840/13568]
loss: 0.043157  [5120/13568]
loss: 0.044093  [6400/13568]
loss: 0.045471  [7680/13568]
loss: 0.027066  [8960/13568]
loss: 0.031229  [10240/13568]
loss: 0.030402  [11520/13568]
loss: 0.042724  [12800/13568]
avg MSE: 0.035681
---------------------------
Epoch 2
loss: 0.026128  [  0/13568]
loss: 0.029098  [1280/13568]
loss: 0.031910  [2560/13568]
loss: 0.029922  [3840/13568]
loss: 0.029470  [5120/13568]
loss: 0.021193  [6400/13568]
loss: 0.040507  [7680/13568]
loss: 0.025790  [8960/13568]
loss: 0.022867  [10240/13568]
loss: 0.021869  [11520/13568]
loss: 0.022812  [12800/13568]
avg MSE: 0.025411
---------------------------
Epoch 3
loss: 0.036826  [  0/13568]
loss: 0.025664  [1280/13568]
loss: 0.026879  [2560/13568]
loss: 0.015964  [3840/13568]
loss: 0.017132  [5120/13568]
loss: 0.027433  [6400/13568]
loss: 0.018404  [7680/13568]
loss: 0.019370  [8960/13568]
loss: 0.022

#### Evaluation

In [53]:
import csv

cnn = model.Extractor().to(device)

state_dict = torch.load(WEIGHTS_FILE)
cnn.load_state_dict(state_dict)

log = train.test(cnn, val_dataloader, device, effect=EFFECT)

for i in range(10):
    print(log[i])

# file = open('report.csv', 'w+', newline ='')

# # writing the data into the file
# with file:   
#     write = csv.writer(file)
#     write.writerows(log)


avg MSE: 0.019669
['c55_7_od_0_cs_1_tr_1_dl_0_rv_0', 0.37, 0.18]
['c52_20_od_1_rv_1', 0.59, 0.51]
['c54_21_cs_2_tr_0_dl_0_rv_2', 0.77, 0.81]
['c54_8_od_0_cs_1_tr_2_rv_1', 0.48, 0.53]
['c54_16_od_2_cs_2_tr_0_rv_1', 0.56, 0.56]
['c53_15_od_1_tr_1_rv_1', 0.48, 0.56]
['c54_8_od_2_tr_1_dl_1_rv_0', 0.42, 0.15]
['c54_4_od_1_tr_2_dl_1_rv_0', 0.13, 0.14]
['c54_12_cs_1_tr_1_dl_0_rv_1', 0.58, 0.55]
['c52_2_od_2_rv_2', 0.89, 0.79]
