# Murcielago System V1 - Training & testing

In [15]:
import torch
import torchaudio
from torch.utils.data import DataLoader
from custom_dataset import GunShotsNoisesDataset, split_dataset
import numpy as np
from cnn import ShotDetectionNetwork
from torch import nn
from train_test_functions import train_step, test_step, eval_model, accuracy_fn

## 1) Data preparing: Dataset and Dataloaders

In [16]:
metadata_file = "./dataset/metadata.xlsx"
audios_dir = "./dataset"
fs = 48000
scales = np.arange(1, 129)
transformation_dict = {"wavelet": "cmor","scales":scales}

GSN_visualize = GunShotsNoisesDataset(metadata_file, audios_dir, transformation_dict, fs, 0.01) #this returns audio waveform to analyse samples

GSN_dataset = GunShotsNoisesDataset(metadata_file, audios_dir, transformation_dict, fs)

In [17]:
batch_size = 4
train_data, test_data = split_dataset(GSN_dataset, 0.04)

train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

## 2) Model building

In [18]:
torch.manual_seed(42)
mur_cnn_v1 = ShotDetectionNetwork(64)
print(mur_cnn_v1)

ShotDetectionNetwork(
  (conv1): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): Conv2d(64, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 1, kernel_size=(2, 2), stride=(2, 2), padding=(1, 1))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): LazyLinear(in_features=0, out_features=1, bias=True)
  )
)


## 3) Train and Test

In [19]:
learn_rate = 0.02
loss_fn = nn.BCEWithLogitsLoss()
optim = torch.optim.Adam(mur_cnn_v1.parameters(), lr=learn_rate)

In [20]:
epochs = 100

for epoch in range(epochs):
    print("-----------------------------------------------------------")
    print(f"Epoch: {epoch}")
    train_step(mur_cnn_v1, train_dataloader, loss_fn, optim, accuracy_fn=accuracy_fn)
    test_step(mur_cnn_v1, test_dataloader, loss_fn, accuracy_fn=accuracy_fn)
    

-----------------------------------------------------------
Epoch: 0


  wavelet = DiscreteContinuousWavelet(wavelet)


Train loss: 43.98751 | Train accuracy: 93.54167%
Test loss: 391.80859 | Test accuracy: 50.00000%
-----------------------------------------------------------
Epoch: 1
Train loss: 40.80391 | Train accuracy: 92.91667%
Test loss: 36.17511 | Test accuracy: 50.00000%
-----------------------------------------------------------
Epoch: 2
Train loss: 11.00864 | Train accuracy: 95.83333%
Test loss: 4927.61035 | Test accuracy: 50.00000%
-----------------------------------------------------------
Epoch: 3
Train loss: 374.76306 | Train accuracy: 88.95833%
Test loss: 6000.67383 | Test accuracy: 50.00000%
-----------------------------------------------------------
Epoch: 4
Train loss: 650.47595 | Train accuracy: 91.04167%
Test loss: 8708.41699 | Test accuracy: 50.00000%
-----------------------------------------------------------
Epoch: 5
Train loss: 477.19217 | Train accuracy: 85.83333%
Test loss: 3549.85229 | Test accuracy: 50.00000%
-----------------------------------------------------------
Epoch: 

In [21]:
for p in mur_cnn_v1.parameters():
    print(p.grad)

tensor([[[[-9.9889e-05, -1.5561e-04, -1.5939e-04, -1.6885e-04],
          [-2.4373e-04, -2.7829e-04, -2.7474e-04, -2.6236e-04],
          [-2.3959e-04, -2.5580e-04, -2.4904e-04, -2.5041e-04],
          [-3.9101e-04, -3.9032e-04, -3.7250e-04, -3.4199e-04]]],


        [[[ 1.4977e-04,  1.8734e-04,  1.7478e-04,  1.5097e-04],
          [ 1.4052e-04,  1.8685e-04,  1.5654e-04,  1.2611e-04],
          [ 2.4559e-04,  2.7943e-04,  2.6440e-04,  2.1128e-04],
          [ 2.2559e-04,  2.7059e-04,  2.3894e-04,  1.8271e-04]]],


        [[[-7.9087e-05, -1.4223e-04, -1.4970e-04, -1.6148e-04],
          [-2.4970e-04, -2.8834e-04, -2.8630e-04, -2.7522e-04],
          [-2.3419e-04, -2.5148e-04, -2.4533e-04, -2.5061e-04],
          [-4.1436e-04, -4.1162e-04, -3.9314e-04, -3.6290e-04]]],


        ...,


        [[[-1.3957e-04, -1.7910e-04, -1.7112e-04, -1.5208e-04],
          [-1.9183e-04, -2.3021e-04, -2.0425e-04, -1.7540e-04],
          [-2.7002e-04, -2.9064e-04, -2.7305e-04, -2.2849e-04],
          [-3

## 4) Results