## Generate the required graphs from the wave files:
- melspectrogram_db: uses `librosa package`
- spectogram:
  By specifying the package, we can generate spectrogram using `pyplot package` or `librosa package`:
  'pyplot', 'librosa-log' or 'librosa-linear'
- cochleagram: uses `pycochleagram package` for generating the cochleagrams
  
In all cases, we need to specify the path for the generated graphs, the signal and the sampling frequency (rate) of the wave files

### generating mel-spectograms from the `Watkins Marine Mammal Sound Database`

In [None]:

from helper import signal_utils
import soundfile as sf
import os
import shutil

import matplotlib
matplotlib.use('Agg')

input_wavefiles_path='dataset/wavefiles'
output_graphs_path = 'dataset/graphs'

from timeit import default_timer as timer

s_time = timer()

N_FFT = 1024

if os.path.exists(f'{output_graphs_path}'):
    shutil.rmtree(f'{output_graphs_path}')
os.mkdir(f'{output_graphs_path}')

for file in os.listdir(input_wavefiles_path):
    signal_array, samplig_rate = sf.read(f'{input_wavefiles_path}/{file}')
    fig, _ = signal_utils.melspectrogram_db(signal_array=signal_array, samplig_rate=samplig_rate, fmax= samplig_rate//2,
                                             out_spec_path=f'{output_graphs_path}/{str(file).split(".")[0]}.jpg', n_mels=90,
                                               hop_length=512, n_fft=N_FFT,  cmap = 'jet', axis_off = True)
print(timer() - s_time )


# Preparing dataset
### Copy (move) the data to the coresponding folders (acording to the labels from the csv file). If the data is already devided, then this step is not required

In [None]:
from torchvision import transforms
import torch
from torch import optim
from torch.utils.data import DataLoader
import torch.nn as nn

import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

import numpy as np

import matplotlib.pyplot as plt

from helper import train_utils


label_csv_file = 'dataset/labels.csv'
path_all_graphs='dataset/graphs'#'dataset/challenge_wav/afew_output/X_train_plots'

batch_size = 256
epochs = 1

transform =   transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
])


src_data_dir = f'{path_all_graphs}'
dest_data_dir =f'dataset/graph_classes'
n_classes =  train_utils.prepare_data_classes_from_train(src_data_dir, dest_data_dir, label_csv_file, move=False)

# Split the src trainset into train and test

X_train, X_test = train_utils.split_train(dest_data_dir, 0.8, transform)
train_dataloader = DataLoader(X_train, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(X_test, batch_size=batch_size, shuffle=True)

### Defining a model and train it

In [None]:
model = train_utils.get_pretrained_model('vgg16', n_classes)

criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters())

experience_name = f'Melspectogram Experiment_N_FFT_{N_FFT}_{batch_size}'
print(f"running on {experience_name}")

model, history = train_utils.train(
    model = model,
    criterion = criterion,
    optimizer = optimizer,
    train_loader = train_dataloader,
    test_loader = test_dataloader,
    n_epochs=epochs)



### test and save the results

In [None]:
print('Testing on 20% of trainSet...')
acc_test_train  = train_utils.test(model,test_dataloader)
torch.save(history,f'history_{experience_name}') 

print('##############################################\n\n\n')
with open(f'history_{experience_name}.txt', 'w') as fi:
    fi.write(str(history))
    fi.write('##############################################\n\n\n')
    fi.write(f'acc_test_train = {str(acc_test_train)}')
    fi.close()

### Saving the trained model

In [None]:
# torch.save(model, f'model_{experience_name}') 

### Loading the saved results and plotting graphs

In [22]:
# %matplotlib
plt.ion()
import torch
import numpy as np
history = torch.load(f'history_{experience_name}')
acc = []
for h in history:
    acc.append(h[0]*100)
import matplotlib
import matplotlib.pyplot as plt
plt.figure(figsize=(20,10))

plt.plot(np.arange(0,epochs), acc, linewidth=7.0)
plt.text(x=10, y=70, s=f'max: {max(acc):.2f}% @ epoch {np.argmax(acc)}')
matplotlib.rcParams.update({'font.size': 50})