In [1]:
from os.path import join

def setup_file_system(in_colab):
    if in_colab:
        from google.colab import drive

        # Set the base and mount path
        MOUNT_PATH_DRIVE = '/content/drive'
        BASE_PATH = join(
            MOUNT_PATH_DRIVE, 
            "MyDrive/project_asr"
        )

        # Mount the google drive
        drive.mount(MOUNT_PATH_DRIVE)

        return BASE_PATH

    else:
        return "/workspaces/project_automated_sound_recognition"

In [2]:
import sys
from os import chdir
from os.path import join

# Method to check if the notebook is running in colab or local
IN_COLAB = 'google.colab' in sys.modules

# Set the base path of the project
BASE_PATH = setup_file_system(IN_COLAB)

# Set the base path of the project
chdir(join(BASE_PATH, "src/"))

In [3]:
%load_ext autoreload
%autoreload 2

# Imports
# Utils
import matplotlib as plt
import numpy as np
import wandb
import sys
import importlib
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import datetime
import json
from sklearn.metrics import accuracy_score, confusion_matrix


# DL libraries
import torch
import torch.optim as optim
from torch import nn
import torch.utils.data 
from torch.utils.data import DataLoader

# User libraries
from dataset.audio_sample_dataset import AudioSampleDataset
from model.baseline_model import BaselineModel
from trainer.trainer import train_classification_model
from validator.validator import validate_classification_model
from util import config, util_functions, model_management

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
augmentations = {
    'pitch_shift': {
        'enabled': False,
    },
    'noise': {
        'enabled': True,
        'p': 0.25,
        'min_amplitude': 0.001,
        'max_amplitude': 0.015,
    },
    'mixup': {
        'enabled': True,
        'p': 0.25,
        'alpha': 0.2,
    },
    'freq_mask': {
        'enabled': True,
        'p': 0.25,
        'freq_mask_param': 5,
    },
    'time_mask': {
        'enabled': True,
        'p': 0.25,
        'time_mask_param': 10,
    }
}

test_augmentations = {
    'pitch_shift': {
        'enabled': False,
    },
    'noise': {
        'enabled': False,
    },
    'mixup': {
        'enabled': False,
    },
    'freq_mask': {
        'enabled': False,
    },
    'time_mask': {
        'enabled': False,
    }
}

In [5]:
# Get the train and test data
train_dataset = AudioSampleDataset(
        join(BASE_PATH, config.TRAIN_DATA_PATH),
        augmentations
    )
test_dataset = AudioSampleDataset(
        join(BASE_PATH, config.TEST_DATA_PATH),
        test_augmentations
    )

# Place in dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=1)

## Setup

In [6]:
# Clear gpu cache
torch.cuda.empty_cache()

# Get the model
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
model.fc = nn.Sequential(
    nn.Linear(in_features=512, out_features= 256),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(in_features=256, out_features=len(config.LABELS)),
    nn.Softmax(dim= 1)
)
model.to(config.DEVICE)

# Set the optimizer
optimizer = optim.Adam(model.parameters(), lr=config.LR)

# Set the loss fn
criteria = nn.CrossEntropyLoss()

# Set the gradient scaler
grad_scaler = torch.cuda.amp.grad_scaler.GradScaler()

# Setup weights and biasses
wandb.login()

# Get the current time for the checkpoint name
now = datetime.datetime.now()

# Set the wandb experiment name
experiment_name = util_functions.generate_run_name_from_config(augmentations)

# Start wandb
wandb.init(
    settings=wandb.Settings(start_method="fork"),
    project="project_asr", 
    name=experiment_name, 
    config={
        "learning_rate": config.LR,
        "batch_size": config.BATCH_SIZE,
        "epochs": config.EPOCHS,
        "augmentations": json.dumps(augmentations),
    }
)


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrobberdg[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Training

In [7]:
# Set the variables to keep track of the best model
best_validation_loss = 10000
best_model_state = model.state_dict()

for epoch in range(0,50):
  # Set the model in training mode
  model.train()
  
  # Train the model
  total_train_loss_this_epoch = train_classification_model(
      model,
      optimizer,
      criteria,
      grad_scaler,
      train_dataloader
  )
  
  # Set the model in evaluation mode
  model.eval()

  # Validate the model
  total_val_loss_this_epoch, pred_classes, true_classes = validate_classification_model(
      model,
      criteria,
      test_dataloader,
  )

  # Calculate the loss values
  train_loss_this_epoch = total_train_loss_this_epoch/len(train_dataloader.dataset)
  val_loss_this_epoch = total_val_loss_this_epoch/len(test_dataloader.dataset)

  # Calculate the accuracy
  acc_avg = accuracy_score(true_classes, pred_classes)

  # Calculate acc per class
  matrix = confusion_matrix(true_classes, pred_classes)
  acc_per_class = matrix.diagonal()/matrix.sum(axis=1)


  # Log the train loss this epoch
  wandb.log({
      'train_loss': train_loss_this_epoch,
      'val_loss': val_loss_this_epoch,
      'acc': acc_avg,
      'acc_airport': acc_per_class[0],
      'acc_shopping_mall': acc_per_class[1],
      'acc_metro_station': acc_per_class[3],
      'acc_street_pedestrian': acc_per_class[3],
      'acc_public_square': acc_per_class[4],
      'acc_street_traffic': acc_per_class[5],
      'acc_tram': acc_per_class[6],
      'acc_bus': acc_per_class[7],
      'acc_metro': acc_per_class[8],
      'acc_park': acc_per_class[9],
  })

  print(f'epoch: {epoch}, train_loss: {train_loss_this_epoch}, val_loss: {val_loss_this_epoch}, acc: {acc_avg}')

  # If this is the best performing model yet, save it
  if val_loss_this_epoch < best_validation_loss:
    # Update the score
    best_validation_loss = val_loss_this_epoch

    now = datetime.datetime.now()

    # Save the model
    checkpoint_path = join(
      BASE_PATH, 
      config.MODEL_CHECKPOINT_PATH, 
      f'{experiment_name}.pth'
    )
    best_model_state = model_management.save_model(model, checkpoint_path, False, '')

100%|██████████| 79/79 [02:53<00:00,  2.20s/it]
100%|██████████| 2000/2000 [00:27<00:00, 73.78it/s]


epoch: 0, train_loss: 0.018155316305160523, val_loss: 2.292484007358551, acc: 0.164


100%|██████████| 79/79 [02:33<00:00,  1.94s/it]
100%|██████████| 2000/2000 [00:27<00:00, 73.99it/s]


epoch: 1, train_loss: 0.01798256778717041, val_loss: 2.2670140978693962, acc: 0.199


100%|██████████| 79/79 [02:13<00:00,  1.70s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.40it/s]


epoch: 2, train_loss: 0.017682585239410402, val_loss: 2.2345531691312788, acc: 0.227


100%|██████████| 79/79 [02:11<00:00,  1.67s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.10it/s]


epoch: 3, train_loss: 0.01736512367725372, val_loss: 2.2085016718506814, acc: 0.281


100%|██████████| 79/79 [02:11<00:00,  1.66s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.14it/s]


epoch: 4, train_loss: 0.017091221761703492, val_loss: 2.1824832959771157, acc: 0.305


100%|██████████| 79/79 [02:14<00:00,  1.70s/it]
100%|██████████| 2000/2000 [00:28<00:00, 71.03it/s]


epoch: 5, train_loss: 0.01688775269985199, val_loss: 2.1678123722076417, acc: 0.307


100%|██████████| 79/79 [02:14<00:00,  1.70s/it]
100%|██████████| 2000/2000 [00:27<00:00, 73.88it/s]


epoch: 6, train_loss: 0.01665236759185791, val_loss: 2.1563625737428667, acc: 0.3065


100%|██████████| 79/79 [02:14<00:00,  1.70s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.17it/s]


epoch: 7, train_loss: 0.016483910799026488, val_loss: 2.1423856572508813, acc: 0.3205


100%|██████████| 79/79 [02:13<00:00,  1.69s/it]
100%|██████████| 2000/2000 [00:27<00:00, 73.62it/s]


epoch: 8, train_loss: 0.016330776154994963, val_loss: 2.1295120162963865, acc: 0.3285


100%|██████████| 79/79 [02:15<00:00,  1.72s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.60it/s]


epoch: 9, train_loss: 0.016198806941509245, val_loss: 2.1184252191782, acc: 0.342


100%|██████████| 79/79 [02:12<00:00,  1.67s/it]
100%|██████████| 2000/2000 [00:27<00:00, 73.03it/s]


epoch: 10, train_loss: 0.016064546251296996, val_loss: 2.115439921498299, acc: 0.341


100%|██████████| 79/79 [02:13<00:00,  1.70s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.61it/s]


epoch: 11, train_loss: 0.015982945072650908, val_loss: 2.1109033881425856, acc: 0.3405


100%|██████████| 79/79 [02:14<00:00,  1.70s/it]
100%|██████████| 2000/2000 [00:27<00:00, 73.97it/s]


epoch: 12, train_loss: 0.015891064858436586, val_loss: 2.103274468123913, acc: 0.3475


100%|██████████| 79/79 [02:14<00:00,  1.71s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.10it/s]


epoch: 13, train_loss: 0.015761960029602052, val_loss: 2.104612149834633, acc: 0.3455


100%|██████████| 79/79 [02:11<00:00,  1.67s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.42it/s]


epoch: 14, train_loss: 0.01570152521133423, val_loss: 2.1000625976920126, acc: 0.3515


100%|██████████| 79/79 [02:14<00:00,  1.70s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.19it/s]


epoch: 15, train_loss: 0.015642944264411925, val_loss: 2.0940523594617844, acc: 0.36


100%|██████████| 79/79 [02:16<00:00,  1.73s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.55it/s]


epoch: 16, train_loss: 0.015561930894851684, val_loss: 2.0925350080728533, acc: 0.361


100%|██████████| 79/79 [02:17<00:00,  1.74s/it]
100%|██████████| 2000/2000 [00:27<00:00, 73.98it/s]


epoch: 17, train_loss: 0.015513268077373505, val_loss: 2.0908341484069823, acc: 0.364


100%|██████████| 79/79 [02:14<00:00,  1.70s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.16it/s]


epoch: 18, train_loss: 0.015421316301822661, val_loss: 2.0870236140489578, acc: 0.365


100%|██████████| 79/79 [02:11<00:00,  1.67s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.55it/s]


epoch: 19, train_loss: 0.015370866346359252, val_loss: 2.086170552551746, acc: 0.368


100%|██████████| 79/79 [02:12<00:00,  1.68s/it]
100%|██████████| 2000/2000 [00:27<00:00, 71.75it/s]


epoch: 20, train_loss: 0.015313419604301453, val_loss: 2.081849751472473, acc: 0.3695


100%|██████████| 79/79 [02:12<00:00,  1.68s/it]
100%|██████████| 2000/2000 [00:27<00:00, 74.07it/s]


epoch: 21, train_loss: 0.015279885601997376, val_loss: 2.0799266087412835, acc: 0.3765


100%|██████████| 79/79 [02:14<00:00,  1.70s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.31it/s]


epoch: 22, train_loss: 0.015215869700908661, val_loss: 2.0784304540753364, acc: 0.371


100%|██████████| 79/79 [02:12<00:00,  1.68s/it]
100%|██████████| 2000/2000 [00:27<00:00, 73.47it/s]


epoch: 23, train_loss: 0.015147317230701446, val_loss: 2.073857323408127, acc: 0.379


100%|██████████| 79/79 [02:14<00:00,  1.70s/it]
100%|██████████| 2000/2000 [00:27<00:00, 73.71it/s]


epoch: 24, train_loss: 0.01504314056634903, val_loss: 2.0750020802617075, acc: 0.379


100%|██████████| 79/79 [02:12<00:00,  1.68s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.12it/s]


epoch: 25, train_loss: 0.01496931174993515, val_loss: 2.0713273344635965, acc: 0.3795


100%|██████████| 79/79 [02:12<00:00,  1.68s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.20it/s]


epoch: 26, train_loss: 0.014999572956562042, val_loss: 2.074025146007538, acc: 0.375


100%|██████████| 79/79 [02:11<00:00,  1.66s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.42it/s]


epoch: 27, train_loss: 0.01489180783033371, val_loss: 2.07477451056242, acc: 0.381


100%|██████████| 79/79 [02:12<00:00,  1.68s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.22it/s]


epoch: 28, train_loss: 0.014838017821311951, val_loss: 2.0740725243687628, acc: 0.3825


100%|██████████| 79/79 [02:12<00:00,  1.67s/it]
100%|██████████| 2000/2000 [00:27<00:00, 73.70it/s]


epoch: 29, train_loss: 0.014762418854236603, val_loss: 2.0759464539289474, acc: 0.3745


100%|██████████| 79/79 [02:10<00:00,  1.66s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.26it/s]


epoch: 30, train_loss: 0.01472879090309143, val_loss: 2.073307465434074, acc: 0.3815


100%|██████████| 79/79 [02:11<00:00,  1.67s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.34it/s]


epoch: 31, train_loss: 0.01469206292629242, val_loss: 2.077247202217579, acc: 0.3715


100%|██████████| 79/79 [02:11<00:00,  1.67s/it]
100%|██████████| 2000/2000 [00:27<00:00, 73.96it/s]


epoch: 32, train_loss: 0.014716549909114838, val_loss: 2.071376914560795, acc: 0.3875


100%|██████████| 79/79 [02:12<00:00,  1.68s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.18it/s]


epoch: 33, train_loss: 0.014650077784061431, val_loss: 2.0700378680825233, acc: 0.3815


100%|██████████| 79/79 [02:11<00:00,  1.67s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.11it/s]


epoch: 34, train_loss: 0.01461037802696228, val_loss: 2.0665433453917506, acc: 0.388


100%|██████████| 79/79 [02:14<00:00,  1.71s/it]
100%|██████████| 2000/2000 [00:27<00:00, 73.88it/s]


epoch: 35, train_loss: 0.014541225516796112, val_loss: 2.064856257021427, acc: 0.3865


100%|██████████| 79/79 [02:13<00:00,  1.68s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.54it/s]


epoch: 36, train_loss: 0.014484809827804566, val_loss: 2.0673522891402243, acc: 0.387


100%|██████████| 79/79 [02:12<00:00,  1.67s/it]
100%|██████████| 2000/2000 [00:27<00:00, 73.91it/s]


epoch: 37, train_loss: 0.014475904870033264, val_loss: 2.0676102433204653, acc: 0.391


100%|██████████| 79/79 [02:10<00:00,  1.66s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.56it/s]


epoch: 38, train_loss: 0.014465565633773804, val_loss: 2.0661358842849733, acc: 0.39


100%|██████████| 79/79 [02:11<00:00,  1.67s/it]
100%|██████████| 2000/2000 [00:27<00:00, 71.48it/s]


epoch: 39, train_loss: 0.014431856679916382, val_loss: 2.0695834946632385, acc: 0.3825


100%|██████████| 79/79 [02:10<00:00,  1.65s/it]
100%|██████████| 2000/2000 [00:27<00:00, 72.80it/s]


epoch: 40, train_loss: 0.014378969359397888, val_loss: 2.068142130851746, acc: 0.388


100%|██████████| 79/79 [02:10<00:00,  1.65s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.30it/s]


epoch: 41, train_loss: 0.01435872061252594, val_loss: 2.0682525297403336, acc: 0.386


100%|██████████| 79/79 [02:13<00:00,  1.69s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.11it/s]


epoch: 42, train_loss: 0.014345350635051728, val_loss: 2.0649502193331717, acc: 0.3875


100%|██████████| 79/79 [02:12<00:00,  1.68s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.42it/s]


epoch: 43, train_loss: 0.014312987625598907, val_loss: 2.0640909324884413, acc: 0.385


100%|██████████| 79/79 [02:13<00:00,  1.69s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.16it/s]


epoch: 44, train_loss: 0.014250141632556916, val_loss: 2.064260559260845, acc: 0.3915


100%|██████████| 79/79 [02:11<00:00,  1.66s/it]
100%|██████████| 2000/2000 [00:27<00:00, 74.00it/s]


epoch: 45, train_loss: 0.014261416387557983, val_loss: 2.05777489477396, acc: 0.397


100%|██████████| 79/79 [02:13<00:00,  1.69s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.31it/s]


epoch: 46, train_loss: 0.014255046546459198, val_loss: 2.0580518522262574, acc: 0.3965


100%|██████████| 79/79 [02:11<00:00,  1.67s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.39it/s]


epoch: 47, train_loss: 0.014194830524921417, val_loss: 2.069194474697113, acc: 0.3905


100%|██████████| 79/79 [02:11<00:00,  1.67s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.60it/s]


epoch: 48, train_loss: 0.014163315749168397, val_loss: 2.065612910568714, acc: 0.392


100%|██████████| 79/79 [02:11<00:00,  1.66s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.21it/s]

epoch: 49, train_loss: 0.014133450412750244, val_loss: 2.0686042371988296, acc: 0.3855





In [8]:
now = datetime.datetime.now()

# Save the final model
checkpoint_path = join(
    BASE_PATH, 
    config.MODEL_CHECKPOINT_PATH, 
    f'{experiment_name}.pth'
)
best_model_state = model_management.save_model(model, checkpoint_path, True, f'model_{experiment_name}')

In [9]:
# Mark the run as finished
wandb.finish()

0,1
acc,▁▂▃▅▅▅▆▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇███▇██████████████
acc_airport,▁▁▃▄▅▆▇▆▆▅▅▅▅▅▅▆▆▅▅▆▅▆▆▇▆▅▆▅▇▆█▆▅▆▆▇█▆▆▇
acc_bus,▁▁▄▇▅▅▄▃▃▂▃▃▃▃▃▄▄▅▄▅▆▆▆▄▅▆▆▆▇▆▇▆█▇▇▇█▇▇▆
acc_metro,▂▂▁▁▁▂▄▅████▇▇▇▇▇▇▇▇▆▇▆▇▆▆▇▆▆▆▇▇▆▇▇▇▆▆▇▇
acc_metro_station,▁▁▁▂▁▁▁▂▃▅▅▆▇▇▆▅▆▇▆▅▆█▇▇██▆▇▆█▇▇▇▇▆▆▅▇▇▇
acc_park,▄▇▆▆▄▁▃▄▃▄▅▃▅▄▆▅▅▆▅▇▇▅▆▅▄▅█▇█▇▆█▇▆█▇██▆▅
acc_public_square,▄▂▁▁▁▁▁▁▁▁▁▁▁▂▂▃▄▄▄▅▄▄▅▅▅▅▅▆▅▆▆▆▆▇▇▆▇▇██
acc_shopping_mall,▁▁▁▄██▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▅▅▆▆▆▆▆▅▅▅▆▆▆▆▆▅▆▅▅
acc_street_pedestrian,▁▁▁▂▁▁▁▂▃▅▅▆▇▇▆▅▆▇▆▅▆█▇▇██▆▇▆█▇▇▇▇▆▆▅▇▇▇
acc_street_traffic,▁▅██▇▆▆▆▇▇▇▆▇▆▇▇▆▆▆▆▇▆▆▆▆▆▆▆▆▆▆▆▆▆▅▆▆▆▆▆

0,1
acc,0.3855
acc_airport,0.3578
acc_bus,0.26699
acc_metro,0.40678
acc_metro_station,0.27111
acc_park,0.65385
acc_public_square,0.29319
acc_shopping_mall,0.41089
acc_street_pedestrian,0.27111
acc_street_traffic,0.56311
