In [1]:
import warnings
warnings.filterwarnings("ignore")

# some basic libraries
import sys  
import os
import seaborn as sn
import numpy as np
import tensorflow as tf
from numpy.lib.function_base import average
import math
import bisect
import pickle
import random
from platform import python_version

# ipython display
from IPython.core.display import display

# pytorch lightning
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint

# pytorch
import torch
# import torchaudio
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
import torch.optim as optim
import torch.distributed as dist

from torch.nn.parameter import Parameter
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split, sampler
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Dataset

from torchmetrics import Accuracy
from torchvision import transforms
from torchlibrosa.stft import STFT, ISTFT, magphase

# librosa audio processing
# import librosa

# sound file
import soundfile as sf

# sklearn machine learning library
from sklearn import metrics
from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
from sklearn import preprocessing

# tensorboard for monitoring training progress
# import tensorboard

from htsat_utils import do_mixup, get_mix_lambda, do_mixup_label
from htsat_utils import get_loss_func, d_prime, float32_to_int16

# import the HTSAT model
from htsat_model import HTSAT_Swin_Transformer 

# import project echo modules
import baseline_config
import echo_dataset_melspec
import echo_data_module
import echo_module

# print system information
print('Python Version     : ', python_version())
print('TensorFlow Version : ', tf.__version__)
print('Pytorch Version    : ', torch.__version__)

Python Version     :  3.9.15
TensorFlow Version :  2.10.1
Pytorch Version    :  1.13.0+cu117


In [2]:
# the complete dataset
complete_dataset = echo_dataset_melspec.EchoDatasetMelspec(baseline_config.dataset_path)

# check the dataset looks right
for item in complete_dataset:
    print(item['filename'], item['waveform'].shape, item['melspec'].shape, item['target'])

c:/birdclef2022/sheowl\XC658728.ogg (160000,) (1, 1778, 128) tensor(2)
c:/birdclef2022/brant\XC149188.ogg (160000,) (1, 1778, 128) tensor(0)
c:/birdclef2022/brant\XC437760.ogg (160000,) (1, 1778, 128) tensor(0)
c:/birdclef2022/spodov\XC203404.ogg (160000,) (1, 1778, 128) tensor(3)
c:/birdclef2022/sheowl\XC607790.ogg (160000,) (1, 1778, 128) tensor(2)
c:/birdclef2022/sheowl\XC618443.ogg (160000,) (1, 1778, 128) tensor(2)
c:/birdclef2022/brant\XC163349.ogg (160000,) (1, 1778, 128) tensor(0)
c:/birdclef2022/spodov\XC124409.ogg (160000,) (1, 1778, 128) tensor(3)
c:/birdclef2022/spodov\XC359363.ogg (160000,) (1, 1778, 128) tensor(3)
c:/birdclef2022/jabwar\XC475525.ogg (160000,) (1, 1778, 128) tensor(1)
c:/birdclef2022/sheowl\XC607792.ogg (160000,) (1, 1778, 128) tensor(2)
c:/birdclef2022/brant\XC408265.ogg (160000,) (1, 1778, 128) tensor(0)
c:/birdclef2022/wiltur\XC598621.ogg (160000,) (1, 1778, 128) tensor(4)
c:/birdclef2022/jabwar\XC486788.ogg (160000,) (1, 1778, 128) tensor(1)
c:/birdcle

In [3]:
# Create an instance of the HTSAT model
model = HTSAT_Swin_Transformer(
    spec_size=baseline_config.htsat_spec_size,
    patch_size=baseline_config.htsat_patch_size,
    in_chans=1,
    num_classes=baseline_config.classes_num,
    window_size=baseline_config.htsat_window_size,
    config = baseline_config,
    depths = baseline_config.htsat_depth,
    embed_dim = baseline_config.htsat_dim,
    patch_stride=baseline_config.htsat_stride,
    num_heads=baseline_config.htsat_num_head)

In [4]:
# show model detailed structure
model

HTSAT_Swin_Transformer(
  (spectrogram_extractor): Spectrogram(
    (stft): STFT(
      (conv_real): Conv1d(1, 1025, kernel_size=(2048,), stride=(90,), bias=False)
      (conv_imag): Conv1d(1, 1025, kernel_size=(2048,), stride=(90,), bias=False)
    )
  )
  (logmel_extractor): LogmelFilterBank()
  (spec_augmenter): SpecAugmentation(
    (time_dropper): DropStripes()
    (freq_dropper): DropStripes()
  )
  (bn0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (patch_embed): PatchEmbed(
    (proj): Conv2d(1, 48, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0): BasicLayer(
      dim=48, input_resolution=(128, 128), depth=2
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          dim=48, input_resolution=(128, 128), num_heads=4, window_size=8, shift_size=0, mlp_ratio=4.0
          (norm1): LayerNorm((48,), ep

In [5]:
def create_echo_datasets():
    
    if not os.path.exists('datasets/'):
        os.makedirs('datasets/')
            
    # this only needs to be run once to avoid data leakage when re-training from checkpoint
    if not os.path.exists('datasets/train_dataset.pkl'):
        # the complete dataset
        complete_dataset = echo_dataset_melspec.EchoDatasetMelspec(htsat_config.dataset_path)

        # split the dataset train/validation/test 70%/15%/15%
        split_datasets = torch.utils.data.random_split(complete_dataset, [0.80,0.20]) # ,0.15])

        # assign the split datasets
        train_dataset = split_datasets[0]
        eval_dataset  = split_datasets[1]   
        
        torch.save(train_dataset,'datasets/train_dataset.pkl') 
        torch.save(eval_dataset,'datasets/eval_dataset.pkl') 

def load_echo_datasets():
    train_dataset = torch.load('datasets/train_dataset.pkl')
    eval_dataset = torch.load('datasets/eval_dataset.pkl')
    return train_dataset, eval_dataset

# create the datasets if they don't exist
create_echo_datasets()

# load the same dataset every time
train_dataset, eval_dataset = load_echo_datasets()

In [6]:
# get the number of available GPUs
device_num = torch.cuda.device_count()

# create the audio data set pipeline
audio_pipeline = echo_data_module.EchoDataModule(train_dataset, eval_dataset, device_num)

# checkpoint to record snapshots during training
checkpoint_callback = ModelCheckpoint(
    dirpath = 'checkpoints/',
    monitor = "acc",
    filename='l-{epoch:d}-{acc:.3f}',
    save_top_k = 5,
    mode = "max",
)

# resume from here in the training
checkpoint_resume = None # 'checkpoints/l-epoch=212-acc=0.596.ckpt'

# construct the model trainer
trainer = pl.Trainer(
        deterministic=False,
        default_root_dir = baseline_config.workspace,
        gpus = device_num, 
        val_check_interval = 1.0,
        max_epochs = baseline_config.max_epoch,
        auto_lr_find = True,    
        sync_batchnorm = True,
        callbacks = [checkpoint_callback],
        accelerator = "ddp" if device_num > 1 else None,
        num_sanity_val_steps = 0,
        resume_from_checkpoint = checkpoint_resume, 
        replace_sampler_ddp = False,
        gradient_clip_val=1.0
    )

# construct the model
sed_model = HTSAT_Swin_Transformer(
        spec_size=baseline_config.htsat_spec_size,
        patch_size=baseline_config.htsat_patch_size,
        in_chans=1,
        num_classes=baseline_config.classes_num,
        window_size=baseline_config.htsat_window_size,
        config = baseline_config,
        depths = baseline_config.htsat_depth,
        embed_dim = baseline_config.htsat_dim,
        patch_stride=baseline_config.htsat_stride,
        num_heads=baseline_config.htsat_num_head
    )

# wrapper to track metrics during training 
model = echo_module.EchoModule(
        sed_model = sed_model, 
        config = baseline_config,
        dataset = eval_dataset
    )

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


In [7]:
def pre_train(model):
    ckpt = torch.load('swin_tiny_patch4_window7_224.pth', map_location="cpu")
    # load pretrain model
    ckpt = ckpt["model"]
    found_parameters = []
    unfound_parameters = []
    model_params = dict(model.state_dict())

    for key in model_params:
        m_key = key.replace("sed_model.", "")
        if m_key in ckpt:
            if m_key == "patch_embed.proj.weight":
                ckpt[m_key] = torch.mean(ckpt[m_key], dim = 1, keepdim = True)
            if m_key == "head.weight" or m_key == "head.bias":
                ckpt.pop(m_key)
                unfound_parameters.append(key)
                continue
            assert model_params[key].shape==ckpt[m_key].shape, "%s is not match, %s vs. %s" %(key, str(model_params[key].shape), str(ckpt[m_key].shape))
            found_parameters.append(key)
            ckpt[key] = ckpt.pop(m_key)
        else:
            unfound_parameters.append(key)
    print("pretrain param num: %d \t wrapper param num: %d"%(len(found_parameters), len(ckpt.keys())))
    print("unfound parameters: ", unfound_parameters)
    model.load_state_dict(ckpt, strict = False)
    model_params = dict(model.named_parameters())

# can't seem to this to work...
# pre_train(sed_model)

In [8]:
# train the model
trainer.fit(model, audio_pipeline)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                   | Params
-----------------------------------------------------
0 | sed_model | HTSAT_Swin_Transformer | 11.3 M
-----------------------------------------------------
6.9 M     Trainable params
4.3 M     Non-trainable params
11.3 M    Total params
45.052    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]