In [1]:
from os.path import join

def setup_file_system(in_colab):
    if in_colab:
        from google.colab import drive

        # Set the base and mount path
        MOUNT_PATH_DRIVE = '/content/drive'
        BASE_PATH = join(
            MOUNT_PATH_DRIVE, 
            "MyDrive/project_asr"
        )

        # Mount the google drive
        drive.mount(MOUNT_PATH_DRIVE)

        return BASE_PATH

    else:
        return "/workspaces/project_automated_sound_recognition"

In [2]:
import sys
from os import chdir
from os.path import join

# Method to check if the notebook is running in colab or local
IN_COLAB = 'google.colab' in sys.modules

# Set the base path of the project
BASE_PATH = setup_file_system(IN_COLAB)

# Set the base path of the project
chdir(join(BASE_PATH, "src/"))

In [3]:
%load_ext autoreload
%autoreload 2

# Imports
# Utils
import matplotlib as plt
import numpy as np
import wandb
import sys
import importlib
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import datetime
import json
from sklearn.metrics import accuracy_score, confusion_matrix


# DL libraries
import torch
import torch.optim as optim
from torch import nn
import torch.utils.data 
from torch.utils.data import DataLoader

# User libraries
from dataset.audio_sample_dataset import AudioSampleDataset
from model.baseline_model import BaselineModel
from trainer.trainer import train_classification_model
from validator.validator import validate_classification_model
from util import config, util_functions, model_management

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
augmentations = {
    'pitch_shift': {
        'enabled': False,
    },
    'noise': {
        'enabled': False,
    },
    'mixup': {
        'enabled': False,
    },
    'freq_mask': {
        'enabled': False,
    },
    'time_mask': {
        'enabled': False,
    }
}

test_augmentations = {
    'pitch_shift': {
        'enabled': False,
    },
    'noise': {
        'enabled': False,
    },
    'mixup': {
        'enabled': False,
    },
    'freq_mask': {
        'enabled': False,
    },
    'time_mask': {
        'enabled': False,
    }
}

In [5]:
# Get the train and test data
train_dataset = AudioSampleDataset(
        join(BASE_PATH, config.TRAIN_DATA_PATH),
        augmentations
    )
test_dataset = AudioSampleDataset(
        join(BASE_PATH, config.TEST_DATA_PATH),
        test_augmentations
    )

# Place in dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=1)

## Setup

In [6]:
# Clear gpu cache
torch.cuda.empty_cache()

# Get the model
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
model.fc = nn.Sequential(
    nn.Linear(in_features=512, out_features= 256),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(in_features=256, out_features=len(config.LABELS)),
    nn.Softmax(dim= 1)
)
model.to(config.DEVICE)

# Set the optimizer
optimizer = optim.Adam(model.parameters(), lr=config.LR)

# Set the loss fn
criteria = nn.CrossEntropyLoss()

# Set the gradient scaler
grad_scaler = torch.cuda.amp.grad_scaler.GradScaler()

# Setup weights and biasses
wandb.login()

# Get the current time for the checkpoint name
now = datetime.datetime.now()

# Set the wandb experiment name
experiment_name = util_functions.generate_run_name_from_config(augmentations)

# Start wandb
wandb.init(
    settings=wandb.Settings(start_method="fork"),
    project="project_asr", 
    name=experiment_name, 
    config={
        "learning_rate": config.LR,
        "batch_size": config.BATCH_SIZE,
        "epochs": config.EPOCHS,
        "augmentations": json.dumps(augmentations),
    }
)


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrobberdg[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Training

In [7]:
# Set the variables to keep track of the best model
best_validation_loss = 10000
best_model_state = model.state_dict()

for epoch in range(config.EPOCHS):
  # Set the model in training mode
  model.train()
  
  # Train the model
  total_train_loss_this_epoch = train_classification_model(
      model,
      optimizer,
      criteria,
      grad_scaler,
      train_dataloader
  )
  
  # Set the model in evaluation mode
  model.eval()

  # Validate the model
  total_val_loss_this_epoch, pred_classes, true_classes = validate_classification_model(
      model,
      criteria,
      test_dataloader,
  )

  # Calculate the loss values
  train_loss_this_epoch = total_train_loss_this_epoch/len(train_dataloader.dataset)
  val_loss_this_epoch = total_val_loss_this_epoch/len(test_dataloader.dataset)

  # Calculate the accuracy
  acc_avg = accuracy_score(true_classes, pred_classes)

  # Calculate acc per class
  matrix = confusion_matrix(true_classes, pred_classes)
  acc_per_class = matrix.diagonal()/matrix.sum(axis=1)


  # Log the train loss this epoch
  wandb.log({
      'train_loss': train_loss_this_epoch,
      'val_loss': val_loss_this_epoch,
      'acc': acc_avg,
      'acc_airport': acc_per_class[0],
      'acc_shopping_mall': acc_per_class[1],
      'acc_metro_station': acc_per_class[3],
      'acc_street_pedestrian': acc_per_class[3],
      'acc_public_square': acc_per_class[4],
      'acc_street_traffic': acc_per_class[5],
      'acc_tram': acc_per_class[6],
      'acc_bus': acc_per_class[7],
      'acc_metro': acc_per_class[8],
      'acc_park': acc_per_class[9],
  })

  print(f'epoch: {epoch}, train_loss: {train_loss_this_epoch}, val_loss: {val_loss_this_epoch}, acc: {acc_avg}')

  # If this is the best performing model yet, save it
  if val_loss_this_epoch < best_validation_loss:
    # Update the score
    best_validation_loss = val_loss_this_epoch

    now = datetime.datetime.now()

    # Save the model
    checkpoint_path = join(
      BASE_PATH, 
      config.MODEL_CHECKPOINT_PATH, 
      f'{experiment_name}.pth'
    )
    best_model_state = model_management.save_model(model, checkpoint_path, False, '')

100%|██████████| 79/79 [02:54<00:00,  2.21s/it]
100%|██████████| 2000/2000 [00:27<00:00, 73.30it/s]


epoch: 0, train_loss: 0.018138077425956727, val_loss: 2.2919150342941284, acc: 0.13


100%|██████████| 79/79 [02:05<00:00,  1.59s/it]
100%|██████████| 2000/2000 [00:26<00:00, 76.02it/s]


epoch: 1, train_loss: 0.01778640995025635, val_loss: 2.2548409333229067, acc: 0.2165


100%|██████████| 79/79 [01:40<00:00,  1.27s/it]
100%|██████████| 2000/2000 [00:26<00:00, 76.31it/s]


epoch: 2, train_loss: 0.017303900814056396, val_loss: 2.2191661040186883, acc: 0.257


100%|██████████| 79/79 [01:39<00:00,  1.26s/it]
100%|██████████| 2000/2000 [00:26<00:00, 76.54it/s]


epoch: 3, train_loss: 0.016973605370521547, val_loss: 2.1987413313388826, acc: 0.2815


100%|██████████| 79/79 [01:42<00:00,  1.29s/it]
100%|██████████| 2000/2000 [00:26<00:00, 76.49it/s]


epoch: 4, train_loss: 0.016703722763061524, val_loss: 2.1759437415003777, acc: 0.294


100%|██████████| 79/79 [01:39<00:00,  1.26s/it]
100%|██████████| 2000/2000 [00:26<00:00, 76.57it/s]


epoch: 5, train_loss: 0.016454197239875793, val_loss: 2.1643058316707613, acc: 0.317


100%|██████████| 79/79 [01:47<00:00,  1.36s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.47it/s]


epoch: 6, train_loss: 0.016196745574474335, val_loss: 2.15030503898859, acc: 0.322


100%|██████████| 79/79 [01:40<00:00,  1.27s/it]
100%|██████████| 2000/2000 [00:26<00:00, 75.16it/s]


epoch: 7, train_loss: 0.015966684663295747, val_loss: 2.139086374461651, acc: 0.3325


100%|██████████| 79/79 [01:41<00:00,  1.29s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.83it/s]


epoch: 8, train_loss: 0.01574509252309799, val_loss: 2.1325585387945174, acc: 0.3335


100%|██████████| 79/79 [01:48<00:00,  1.37s/it]
100%|██████████| 2000/2000 [00:36<00:00, 55.00it/s]


epoch: 9, train_loss: 0.015465640866756438, val_loss: 2.123395310044289, acc: 0.345


100%|██████████| 79/79 [02:08<00:00,  1.63s/it]
100%|██████████| 2000/2000 [00:26<00:00, 75.14it/s]


epoch: 10, train_loss: 0.015228671884536743, val_loss: 2.1165751574635507, acc: 0.3495


100%|██████████| 79/79 [01:42<00:00,  1.29s/it]
100%|██████████| 2000/2000 [00:26<00:00, 75.21it/s]


epoch: 11, train_loss: 0.01498547579050064, val_loss: 2.117575108408928, acc: 0.352


100%|██████████| 79/79 [01:40<00:00,  1.28s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.86it/s]


epoch: 12, train_loss: 0.014791315865516663, val_loss: 2.1073262475728987, acc: 0.36


100%|██████████| 79/79 [01:42<00:00,  1.30s/it]
100%|██████████| 2000/2000 [00:26<00:00, 75.43it/s]


epoch: 13, train_loss: 0.01454827845096588, val_loss: 2.1080139955282213, acc: 0.3565


100%|██████████| 79/79 [01:41<00:00,  1.28s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.77it/s]


epoch: 14, train_loss: 0.014316615235805512, val_loss: 2.105123335957527, acc: 0.358


100%|██████████| 79/79 [01:40<00:00,  1.27s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.63it/s]


epoch: 15, train_loss: 0.014122598075866698, val_loss: 2.105411574959755, acc: 0.3525


100%|██████████| 79/79 [01:40<00:00,  1.27s/it]
100%|██████████| 2000/2000 [00:26<00:00, 75.75it/s]


epoch: 16, train_loss: 0.013920034885406494, val_loss: 2.1031957151293756, acc: 0.351


100%|██████████| 79/79 [01:40<00:00,  1.27s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.48it/s]


epoch: 17, train_loss: 0.013733689975738525, val_loss: 2.1067328009605406, acc: 0.3435


100%|██████████| 79/79 [01:39<00:00,  1.27s/it]
100%|██████████| 2000/2000 [00:26<00:00, 75.94it/s]


epoch: 18, train_loss: 0.0135571941614151, val_loss: 2.1030418565273283, acc: 0.355


100%|██████████| 79/79 [01:39<00:00,  1.26s/it]
100%|██████████| 2000/2000 [00:26<00:00, 75.88it/s]


epoch: 19, train_loss: 0.01337366783618927, val_loss: 2.101785964310169, acc: 0.3525


100%|██████████| 79/79 [01:40<00:00,  1.27s/it]
100%|██████████| 2000/2000 [00:26<00:00, 75.75it/s]


epoch: 20, train_loss: 0.013211547541618348, val_loss: 2.1060860375165937, acc: 0.3445


100%|██████████| 79/79 [01:39<00:00,  1.26s/it]
100%|██████████| 2000/2000 [00:26<00:00, 75.79it/s]


epoch: 21, train_loss: 0.013107618772983552, val_loss: 2.1061928858160974, acc: 0.345


100%|██████████| 79/79 [01:40<00:00,  1.27s/it]
100%|██████████| 2000/2000 [00:26<00:00, 75.97it/s]


epoch: 22, train_loss: 0.012993450713157655, val_loss: 2.1055525609850885, acc: 0.347


100%|██████████| 79/79 [01:40<00:00,  1.27s/it]
100%|██████████| 2000/2000 [00:26<00:00, 75.73it/s]


epoch: 23, train_loss: 0.012861967492103577, val_loss: 2.105654460787773, acc: 0.3475


100%|██████████| 79/79 [01:40<00:00,  1.27s/it]
100%|██████████| 2000/2000 [00:26<00:00, 76.10it/s]


epoch: 24, train_loss: 0.012822169256210327, val_loss: 2.1089450293183325, acc: 0.342


100%|██████████| 79/79 [01:40<00:00,  1.27s/it]
100%|██████████| 2000/2000 [00:26<00:00, 75.37it/s]


epoch: 25, train_loss: 0.012713881468772888, val_loss: 2.106264257669449, acc: 0.3515


100%|██████████| 79/79 [01:39<00:00,  1.26s/it]
100%|██████████| 2000/2000 [00:26<00:00, 76.03it/s]


epoch: 26, train_loss: 0.012619020962715149, val_loss: 2.107154433012009, acc: 0.3445


100%|██████████| 79/79 [01:41<00:00,  1.29s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.55it/s]


epoch: 27, train_loss: 0.012547046637535095, val_loss: 2.111755841612816, acc: 0.3425


100%|██████████| 79/79 [01:39<00:00,  1.26s/it]
100%|██████████| 2000/2000 [00:26<00:00, 75.86it/s]


epoch: 28, train_loss: 0.012489748680591583, val_loss: 2.107474035024643, acc: 0.3465


100%|██████████| 79/79 [01:39<00:00,  1.26s/it]
100%|██████████| 2000/2000 [00:26<00:00, 75.93it/s]


epoch: 29, train_loss: 0.012438191294670106, val_loss: 2.111159262239933, acc: 0.342


100%|██████████| 79/79 [01:40<00:00,  1.27s/it]
100%|██████████| 2000/2000 [00:26<00:00, 75.85it/s]


epoch: 30, train_loss: 0.012401812648773193, val_loss: 2.108568811237812, acc: 0.3465


100%|██████████| 79/79 [01:39<00:00,  1.26s/it]
100%|██████████| 2000/2000 [00:26<00:00, 76.39it/s]


epoch: 31, train_loss: 0.01235711017847061, val_loss: 2.1123085062503817, acc: 0.347


100%|██████████| 79/79 [01:41<00:00,  1.29s/it]
100%|██████████| 2000/2000 [00:26<00:00, 76.40it/s]


epoch: 32, train_loss: 0.012314575386047364, val_loss: 2.1115340009331702, acc: 0.341


100%|██████████| 79/79 [01:39<00:00,  1.26s/it]
100%|██████████| 2000/2000 [00:26<00:00, 75.25it/s]


epoch: 33, train_loss: 0.012277133643627166, val_loss: 2.1140459545850754, acc: 0.344


100%|██████████| 79/79 [01:39<00:00,  1.27s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.71it/s]


epoch: 34, train_loss: 0.012243996131420135, val_loss: 2.114928304553032, acc: 0.3355


100%|██████████| 79/79 [01:39<00:00,  1.26s/it]
100%|██████████| 2000/2000 [00:26<00:00, 75.83it/s]


epoch: 35, train_loss: 0.012220220422744751, val_loss: 2.112961592853069, acc: 0.3445


100%|██████████| 79/79 [01:39<00:00,  1.26s/it]
100%|██████████| 2000/2000 [00:27<00:00, 72.78it/s]


epoch: 36, train_loss: 0.012215866780281068, val_loss: 2.109927894592285, acc: 0.3425


100%|██████████| 79/79 [01:40<00:00,  1.27s/it]
100%|██████████| 2000/2000 [00:26<00:00, 75.84it/s]


epoch: 37, train_loss: 0.012172927677631378, val_loss: 2.1132511412501334, acc: 0.341


100%|██████████| 79/79 [01:39<00:00,  1.26s/it]
100%|██████████| 2000/2000 [00:26<00:00, 76.22it/s]


epoch: 38, train_loss: 0.012143469536304474, val_loss: 2.1111253851652148, acc: 0.3385


100%|██████████| 79/79 [01:40<00:00,  1.28s/it]
100%|██████████| 2000/2000 [00:27<00:00, 73.51it/s]

epoch: 39, train_loss: 0.01211595047712326, val_loss: 2.1157547003626824, acc: 0.341





In [8]:
now = datetime.datetime.now()

# Save the final model
checkpoint_path = join(
    BASE_PATH, 
    config.MODEL_CHECKPOINT_PATH, 
    f'{experiment_name}.pth'
)
best_model_state = model_management.save_model(model, checkpoint_path, True, f'model_{experiment_name}')

In [9]:
# Mark the run as finished
wandb.finish()

0,1
acc,▁▄▅▆▆▇▇▇▇████████▇██████▇██▇█▇██▇█▇█▇▇▇▇
acc_airport,▁▁▁▁▁▄▅▆▇▇▇██▇▇▇▇▇▇▇▆▆▇▆▆▆▆▆▇▆▇▇▇▇▆▇▇▇▇▇
acc_bus,▁▃▅▇███▇▅▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▄▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
acc_metro,▁▁▁▁▁▁▁▁▁▁▁▂▄▆▇▇▇▇▆▇▇▆▇▇▇█▇█▇█▆▇▇▇▇▇▇▇▇▇
acc_metro_station,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃▃▄▅▆▆▆▆▆▇▇▇█▇▇▇█▇▇▇▇▆▇▇▇
acc_park,▁▄▆▆▇▇▇█▇▇█▇█▇██▇▇█████████▇▇██▇█▇██▇▇▇▇
acc_public_square,▄▃▁▁▁▃▄▄▆▅▆█▇▇▇▇▇▇▇▆▆▆▇▆▆▆▇▇▇▇▇▇▇▇▆▇▇▇▆▆
acc_shopping_mall,████▇▆▆▅▄▄▅▄▃▃▃▃▃▂▂▂▂▂▁▂▁▂▂▁▁▁▁▁▂▁▂▂▂▂▁▂
acc_street_pedestrian,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃▃▄▅▆▆▆▆▆▇▇▇█▇▇▇█▇▇▇▇▆▇▇▇
acc_street_traffic,▁▆▇███▇█▇▇▇▆▇▇▇▆▆▅▆▆▅▆▆▆▅▅▅▅▆▅▅▅▅▆▅▅▅▅▅▅

0,1
acc,0.341
acc_airport,0.34404
acc_bus,0.2767
acc_metro,0.34463
acc_metro_station,0.26222
acc_park,0.54808
acc_public_square,0.18325
acc_shopping_mall,0.36634
acc_street_pedestrian,0.26222
acc_street_traffic,0.49515
