In [1]:
from os.path import join

def setup_file_system(in_colab):
    if in_colab:
        from google.colab import drive

        # Set the base and mount path
        MOUNT_PATH_DRIVE = '/content/drive'
        BASE_PATH = join(
            MOUNT_PATH_DRIVE, 
            "MyDrive/project_asr"
        )

        # Mount the google drive
        drive.mount(MOUNT_PATH_DRIVE)

        return BASE_PATH

    else:
        return "/workspaces/project_automated_sound_recognition"

In [2]:
import sys
from os import chdir
from os.path import join

# Method to check if the notebook is running in colab or local
IN_COLAB = 'google.colab' in sys.modules

# Set the base path of the project
BASE_PATH = setup_file_system(IN_COLAB)

# Set the base path of the project
chdir(join(BASE_PATH, "src/"))

In [3]:
%load_ext autoreload
%autoreload 2

# Imports
# Utils
import matplotlib as plt
import numpy as np
import wandb
import sys
import importlib
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import datetime
import json
from sklearn.metrics import accuracy_score, confusion_matrix


# DL libraries
import torch
import torch.optim as optim
from torch import nn
import torch.utils.data 
from torch.utils.data import DataLoader

# User libraries
from dataset.audio_sample_dataset import AudioSampleDataset
from model.baseline_model import BaselineModel
from trainer.trainer import train_classification_model
from validator.validator import validate_classification_model
from util import config, util_functions, model_management

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
augmentations = {
    'pitch_shift': {
        'enabled': True,
        'p': 0.25,
        'min_semitones': -4, 
        "max_semitones": 4,
    },
    'noise': {
        'enabled': True,
        'p': 0.25,
        'min_amplitude': 0.001,
        'max_amplitude': 0.015,
    },
    'mixup': {
        'enabled': True,
        'p': 0.25,
        'alpha': 0.2,
    },
    'freq_mask': {
        'enabled': True,
        'p': 0.25,
        'freq_mask_param': 5,
    },
    'time_mask': {
        'enabled': True,
        'p': 0.25,
        'time_mask_param': 10,
    }
}

test_augmentations = {
    'pitch_shift': {
        'enabled': False,
    },
    'noise': {
        'enabled': False,
    },
    'mixup': {
        'enabled': False,
    },
    'freq_mask': {
        'enabled': False,
    },
    'time_mask': {
        'enabled': False,
    }
}

In [5]:
# Get the train and test data
train_dataset = AudioSampleDataset(
        join(BASE_PATH, config.TRAIN_DATA_PATH),
        augmentations
    )
test_dataset = AudioSampleDataset(
        join(BASE_PATH, config.TEST_DATA_PATH),
        test_augmentations
    )

# Place in dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE)
test_dataloader = DataLoader(test_dataset, batch_size=1)

## Setup

In [6]:
# Clear gpu cache
torch.cuda.empty_cache()

# Get the model
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
model.fc = nn.Sequential(
    nn.Linear(in_features=512, out_features= 256),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(in_features=256, out_features=len(config.LABELS)),
    nn.Softmax(dim= 1)
)
model.to(config.DEVICE)

# Set the optimizer
optimizer = optim.Adam(model.parameters(), lr=config.LR)

# Set the loss fn
criteria = nn.CrossEntropyLoss()

# Set the gradient scaler
grad_scaler = torch.cuda.amp.grad_scaler.GradScaler()

# Setup weights and biasses
wandb.login()

# Get the current time for the checkpoint name
now = datetime.datetime.now()

# Set the wandb experiment name
experiment_name = util_functions.generate_run_name_from_config(augmentations)

# Start wandb
wandb.init(
    settings=wandb.Settings(start_method="fork"),
    project="project_asr", 
    name=experiment_name, 
    config={
        "learning_rate": config.LR,
        "batch_size": config.BATCH_SIZE,
        "epochs": config.EPOCHS,
        "augmentations": json.dumps(augmentations),
    }
)


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrobberdg[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Training

In [7]:
# Set the variables to keep track of the best model
best_validation_loss = 10000
best_model_state = model.state_dict()

for epoch in range(config.EPOCHS):
  # Set the model in training mode
  model.train()

  # Train the model
  total_train_loss_this_epoch = train_classification_model(
      model,
      optimizer,
      criteria,
      grad_scaler,
      train_dataloader
  )
  
  # Set the model in evaluation mode
  model.eval()

  # Validate the model
  total_val_loss_this_epoch, pred_classes, true_classes = validate_classification_model(
      model,
      criteria,
      test_dataloader,
  )

  # Calculate the loss values
  train_loss_this_epoch = total_train_loss_this_epoch/len(train_dataloader.dataset)
  val_loss_this_epoch = total_val_loss_this_epoch/len(test_dataloader.dataset)

  # Calculate the accuracy
  acc_avg = accuracy_score(true_classes, pred_classes)

  # Calculate acc per class
  matrix = confusion_matrix(true_classes, pred_classes)
  acc_per_class = matrix.diagonal()/matrix.sum(axis=1)


  # Log the train loss this epoch
  wandb.log({
      'train_loss': train_loss_this_epoch,
      'val_loss': val_loss_this_epoch,
      'acc': acc_avg,
      'acc_per_class': acc_per_class

  })

  print(f'epoch: {epoch}, train_loss: {train_loss_this_epoch}, val_loss: {val_loss_this_epoch}, acc: {acc_avg}')
  print(acc_per_class)

  # If this is the best performing model yet, save it
  if val_loss_this_epoch < best_validation_loss:
    # Update the score
    best_validation_loss = val_loss_this_epoch

    now = datetime.datetime.now()

    # Save the model
    checkpoint_path = join(
      BASE_PATH, 
      config.MODEL_CHECKPOINT_PATH, 
      f'{experiment_name}.pth'
    )
    best_model_state = model_management.save_model(model, checkpoint_path, False, '')

100%|██████████| 79/79 [05:47<00:00,  4.40s/it]
100%|██████████| 2000/2000 [00:27<00:00, 72.48it/s]


epoch: 0, train_loss: 0.018155230593681335, val_loss: 2.29294194740057, acc: 0.1705
[0.00458716 0.02475248 0.06842105 0.         0.05235602 0.33495146
 0.68926554 0.16504854 0.06779661 0.36057692]


100%|██████████| 79/79 [05:43<00:00,  4.35s/it]
100%|██████████| 2000/2000 [00:27<00:00, 72.76it/s]


epoch: 1, train_loss: 0.01797297444343567, val_loss: 2.260622967362404, acc: 0.2165
[0.         0.02970297 0.01578947 0.00444444 0.02617801 0.7038835
 0.55932203 0.14563107 0.08474576 0.62019231]


100%|██████████| 79/79 [05:28<00:00,  4.16s/it]
100%|██████████| 2000/2000 [00:26<00:00, 76.50it/s]


epoch: 2, train_loss: 0.01765333662033081, val_loss: 2.226066596210003, acc: 0.2335
[0.00458716 0.2029703  0.00526316 0.         0.0104712  0.74757282
 0.49152542 0.16990291 0.09039548 0.625     ]


100%|██████████| 79/79 [05:30<00:00,  4.18s/it]
100%|██████████| 2000/2000 [00:26<00:00, 75.78it/s]


epoch: 3, train_loss: 0.01736545548439026, val_loss: 2.207856089115143, acc: 0.262
[0.         0.52475248 0.         0.         0.         0.73300971
 0.46892655 0.19902913 0.10734463 0.59615385]


100%|██████████| 79/79 [05:25<00:00,  4.11s/it]
100%|██████████| 2000/2000 [00:26<00:00, 75.77it/s]


epoch: 4, train_loss: 0.017156323444843292, val_loss: 2.188377488195896, acc: 0.2835
[0.         0.73762376 0.         0.         0.         0.67961165
 0.49717514 0.20873786 0.14124294 0.58653846]


100%|██████████| 79/79 [05:28<00:00,  4.16s/it]
100%|██████████| 2000/2000 [00:26<00:00, 74.94it/s]


epoch: 5, train_loss: 0.01689266753196716, val_loss: 2.1716101933717726, acc: 0.287
[0.         0.83168317 0.         0.         0.         0.6407767
 0.48587571 0.18932039 0.23163842 0.51923077]


100%|██████████| 79/79 [05:29<00:00,  4.18s/it]
100%|██████████| 2000/2000 [00:26<00:00, 76.32it/s]


epoch: 6, train_loss: 0.01673360159397125, val_loss: 2.158861037969589, acc: 0.2935
[0.         0.8019802  0.         0.         0.         0.69417476
 0.46892655 0.12621359 0.3559322  0.52884615]


100%|██████████| 79/79 [05:27<00:00,  4.15s/it]
100%|██████████| 2000/2000 [00:26<00:00, 76.32it/s]


epoch: 7, train_loss: 0.016575013053417204, val_loss: 2.1525671423077584, acc: 0.297
[0.         0.84653465 0.         0.         0.         0.65533981
 0.42372881 0.09223301 0.49717514 0.50961538]


100%|██████████| 79/79 [05:29<00:00,  4.17s/it]
100%|██████████| 2000/2000 [00:26<00:00, 75.03it/s]


epoch: 8, train_loss: 0.01645995932817459, val_loss: 2.1468028410077094, acc: 0.297
[0.         0.84653465 0.         0.         0.         0.6407767
 0.35028249 0.13106796 0.51977401 0.52884615]


100%|██████████| 79/79 [05:51<00:00,  4.45s/it]
100%|██████████| 2000/2000 [00:28<00:00, 69.93it/s]


epoch: 9, train_loss: 0.016411147558689117, val_loss: 2.1393866870999334, acc: 0.305
[0.         0.81683168 0.         0.00888889 0.         0.68446602
 0.38983051 0.13592233 0.51412429 0.54807692]


 42%|████▏     | 33/79 [02:36<03:36,  4.72s/it]

In [None]:
now = datetime.datetime.now()

# Save the final model
checkpoint_path = join(
    BASE_PATH, 
    config.MODEL_CHECKPOINT_PATH, 
    f'{experiment_name}.pth'
)
best_model_state = model_management.save_model(model, checkpoint_path, True, f'model_{experiment_name}')

In [None]:
# Mark the run as finished
wandb.finish()