In [1]:
# some basic libraries
import sys  
import os
import seaborn as sn
import numpy as np
from numpy.lib.function_base import average
import math
import bisect
import pickle
import random
from platform import python_version

# ipython display
from IPython.core.display import display

# pytorch lightning
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger

# pytorch
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
import torch.optim as optim
import torch.distributed as dist

from torch.nn.parameter import Parameter
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split, sampler
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Dataset

from torchmetrics import Accuracy
from torchvision import transforms
from torchlibrosa.stft import STFT, ISTFT, magphase

# librosa audio processing
import librosa

# sound file
import soundfile as sf

# sklearn machine learning library
from sklearn import metrics
from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
from sklearn import preprocessing

# tensorboard for monitoring training progress
import tensorboard

from htsat_utils import do_mixup, get_mix_lambda, do_mixup_label
from htsat_utils import get_loss_func, d_prime, float32_to_int16

import htsat_config
from htsat_model import HTSAT_Swin_Transformer 

import warnings
warnings.filterwarnings("ignore")

# print system information
print('Python Version     : ', python_version())

  from IPython.core.display import display


Python Version     :  3.9.13


In [2]:
# Create an instance of the HTSAT model
model = HTSAT_Swin_Transformer(
    spec_size=htsat_config.htsat_spec_size,
    patch_size=htsat_config.htsat_patch_size,
    in_chans=1,
    num_classes=htsat_config.classes_num,
    window_size=htsat_config.htsat_window_size,
    config = htsat_config,
    depths = htsat_config.htsat_depth,
    embed_dim = htsat_config.htsat_dim,
    patch_stride=htsat_config.htsat_stride,
    num_heads=htsat_config.htsat_num_head)

In [3]:
# show model detailed structure
model

HTSAT_Swin_Transformer(
  (spectrogram_extractor): Spectrogram(
    (stft): STFT(
      (conv_real): Conv1d(1, 513, kernel_size=(1024,), stride=(256,), bias=False)
      (conv_imag): Conv1d(1, 513, kernel_size=(1024,), stride=(256,), bias=False)
    )
  )
  (logmel_extractor): LogmelFilterBank()
  (spec_augmenter): SpecAugmentation(
    (time_dropper): DropStripes()
    (freq_dropper): DropStripes()
  )
  (bn0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (patch_embed): PatchEmbed(
    (proj): Conv2d(1, 8, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0): BasicLayer(
      dim=8, input_resolution=(64, 64), depth=2
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          dim=8, input_resolution=(64, 64), num_heads=4, window_size=8, shift_size=0, mlp_ratio=4.0
          (norm1): LayerNorm((8,), eps=1e-05, 

In [4]:
class EchoDataset(Dataset):

    def __init__(self, dataset_root):
        
        # track where this dataset is from
        self.dataset_root = dataset_root
        
        # track total number of training samples
        self.total_size = 0
        
        # load the dataset filenames into RAM
        self.audio_dataset = self.from_dir_structure(dataset_root)
        
        # shuffle the dataset entries
        self.shuffle_dataset()
              
    def from_dir_structure(self, dataset_path):
        
        # each sub directory represents a data class
        subfolders = [f for f in os.scandir(dataset_path) if f.is_dir()]

        # list of samples
        audio_dataset = []
        
        # perform a sparse encoding of the target labels
        le = preprocessing.LabelEncoder()
        targets = le.fit([folder.name for folder in subfolders])

        # load all files from each subfolder
        for subfolder in subfolders:
            
            # now get all the files in the folder
            audiofiles = [f for f in os.scandir(subfolder.path) if f.is_file()]

            for audiofile in audiofiles:
                
                # get metadata
                duration = librosa.get_duration(filename=audiofile.path)
                
                # transform category to sparse class target number
                target = le.transform([subfolder.name])
                
                # convert to tensor type the entropy loss function expects
                target = torch.tensor(target[0], dtype=torch.int64)

                # record the path and target
                audio_dataset.append((audiofile.path, subfolder.name, duration, target))

        # store the total length
        self.total_size = len(audio_dataset)
        
        return audio_dataset            

    # this shuffles the whole list of training samples
    def shuffle_dataset(self):
        random.shuffle(self.audio_dataset)

    # get sample at location 'index'
    def __getitem__(self, index):
        """Load waveform and target of an audio clip.
        Args:
            index: the index number
        Return: {
            "filename": str,
            "waveform": (clip_samples,),
            "target": (classes_num,)
        }
        """
        
        # retrieve the sample from the dataset
        sample = self.audio_dataset[index]
        
        # retrieve the duration
        duration = sample[2]
        
        # random offset within the audio file
        if duration >= 5.0:
            offset = random.uniform(0, duration-htsat_config.CLIP_LENGTH)
        else:
            offset = 0
        
        # load the waveform
        waveform, sr = librosa.load(sample[0], 
                                    sr = htsat_config.sample_rate, 
                                    duration = min(duration, htsat_config.CLIP_LENGTH),
                                    offset=offset, 
                                    mono=True)
        
        # pad the waveform if it is too short
        if duration < htsat_config.CLIP_LENGTH:
            #print("1. wavefore shape ", waveform.shape)
            waveform = librosa.util.pad_center(waveform, size=htsat_config.CLIP_LENGTH*htsat_config.sample_rate, mode='constant')
            #print("2. wavefore shape ", waveform.shape)
        
        # convert to 16 bit integer representation
        # waveform = float32_to_int16(y)
        
        # return a dictionary with the sample data
        return {
            "filename": sample[0],
            "waveform": waveform,
            "target": sample[3],
        }

    def __len__(self):
        return self.total_size

In [5]:
class SEDWrapper(pl.LightningModule):
    def __init__(self, sed_model, config, dataset):
        super().__init__()
        self.sed_model = sed_model
        self.config = config
        self.dataset = dataset
        self.loss_func = get_loss_func(config.loss_type)

    def evaluate_metric(self, pred, ans):
        acc = accuracy_score(ans, np.argmax(pred, 1))
        return {"acc": acc}  
    
    def forward(self, x, mix_lambda = None):
        output_dict = self.sed_model(x, mix_lambda)
        return output_dict["clipwise_output"], output_dict["framewise_output"]

    def inference(self, x):
        self.device_type = next(self.parameters()).device
        self.eval()
        x = torch.from_numpy(x).float().to(self.device_type)
        output_dict = self.sed_model(x, None, True)
        for key in output_dict.keys():
            output_dict[key] = output_dict[key].detach().cpu().numpy()
        return output_dict

    def training_step(self, batch, batch_idx):
        self.device_type = next(self.parameters()).device
        mix_lambda = None
        pred, _ = self(batch["waveform"], mix_lambda)
        loss = self.loss_func(pred, batch["target"])
        self.log("loss", loss, on_epoch= True, prog_bar=True)
        return loss
        
    def training_epoch_end(self, outputs):
        print("end epoch")
        # re-shuffle the audio dataset
        # self.dataset.generate_queue()

    def validation_step(self, batch, batch_idx):
        pred, _ = self(batch["waveform"])
        return [pred.detach(), batch["target"].detach()]
    
    def validation_epoch_end(self, validation_step_outputs):
        self.device_type = next(self.parameters()).device
        pred = torch.cat([d[0] for d in validation_step_outputs], dim = 0)
        target = torch.cat([d[1] for d in validation_step_outputs], dim = 0)

        if torch.cuda.device_count() > 1:
            gather_pred = [torch.zeros_like(pred) for _ in range(dist.get_world_size())]
            gather_target = [torch.zeros_like(target) for _ in range(dist.get_world_size())]
            dist.barrier()

        metric_dict = {
            "acc":0.
        }
        
        if torch.cuda.device_count() > 1:
            dist.all_gather(gather_pred, pred)
            dist.all_gather(gather_target, target)
            if dist.get_rank() == 0:
                gather_pred = torch.cat(gather_pred, dim = 0).cpu().numpy()
                gather_target = torch.cat(gather_target, dim = 0).cpu().numpy()
                if self.config.dataset_type == "scv2":
                    gather_target = np.argmax(gather_target, 1)
                metric_dict = self.evaluate_metric(gather_pred, gather_target)
                print(self.device_type, dist.get_world_size(), metric_dict, flush = True)
        
            self.log("acc", metric_dict["acc"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
            dist.barrier()
        else:
            gather_pred = pred.cpu().numpy()
            gather_target = target.cpu().numpy()
            metric_dict = self.evaluate_metric(gather_pred, gather_target)
            # print(self.device_type, metric_dict, flush = True)  
            self.log("acc", metric_dict["acc"], on_epoch = True, prog_bar=True, sync_dist=False)
            
        
    def time_shifting(self, x, shift_len):
        shift_len = int(shift_len)
        new_sample = torch.cat([x[:, shift_len:], x[:, :shift_len]], axis = 1)
        return new_sample 

    def test_step(self, batch, batch_idx):
        self.device_type = next(self.parameters()).device
        preds = []
        # time shifting optimization
        if self.config.fl_local or self.config.dataset_type != "audioset": 
            shift_num = 1 # framewise localization cannot allow the time shifting
        else:
            shift_num = 10 
        for i in range(shift_num):
            pred, pred_map = self(batch["waveform"])
            preds.append(pred.unsqueeze(0))
            batch["waveform"] = self.time_shifting(batch["waveform"], shift_len = 100 * (i + 1))
        preds = torch.cat(preds, dim=0)
        pred = preds.mean(dim = 0)
        if self.config.fl_local:
            return [
                pred.detach().cpu().numpy(), 
                pred_map.detach().cpu().numpy(),
                batch["audio_name"],
                batch["real_len"].cpu().numpy()
            ]
        else:
            return [pred.detach(), batch["target"].detach()]

    def test_epoch_end(self, test_step_outputs):
        self.device_type = next(self.parameters()).device
        if self.config.fl_local:
            pred = np.concatenate([d[0] for d in test_step_outputs], axis = 0)
            pred_map = np.concatenate([d[1] for d in test_step_outputs], axis = 0)
            audio_name = np.concatenate([d[2] for d in test_step_outputs], axis = 0)
            real_len = np.concatenate([d[3] for d in test_step_outputs], axis = 0)
            heatmap_file = os.path.join(self.config.heatmap_dir, self.config.test_file + "_" + str(self.device_type) + ".npy")
            save_npy = [
                {
                    "audio_name": audio_name[i],
                    "heatmap": pred_map[i],
                    "pred": pred[i],
                    "real_len":real_len[i]
                }
                for i in range(len(pred))
            ]
            np.save(heatmap_file, save_npy)
        else:
            self.device_type = next(self.parameters()).device
            pred = torch.cat([d[0] for d in test_step_outputs], dim = 0)
            target = torch.cat([d[1] for d in test_step_outputs], dim = 0)
            gather_pred = [torch.zeros_like(pred) for _ in range(dist.get_world_size())]
            gather_target = [torch.zeros_like(target) for _ in range(dist.get_world_size())]
            dist.barrier()
            metric_dict = {
                "acc":0.
            }
            dist.all_gather(gather_pred, pred)
            dist.all_gather(gather_target, target)
            if dist.get_rank() == 0:
                gather_pred = torch.cat(gather_pred, dim = 0).cpu().numpy()
                gather_target = torch.cat(gather_target, dim = 0).cpu().numpy()
                if self.config.dataset_type == "scv2":
                    gather_target = np.argmax(gather_target, 1)
                metric_dict = self.evaluate_metric(gather_pred, gather_target)
                print(self.device_type, dist.get_world_size(), metric_dict, flush = True)
            self.log("acc", metric_dict["acc"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
            dist.barrier() 

    def configure_optimizers(self):
        
        # network optimiser
        optimizer = optim.AdamW(
            filter(lambda p: p.requires_grad, self.parameters()),
            lr = self.config.learning_rate, 
            betas = (0.9, 0.999), eps = 1e-08, weight_decay = 0.05, 
        )

        # learning rate scheduler
        def lr_foo(epoch):       
            if epoch < 3:
                # warm up lr
                lr_scale = self.config.lr_rate[epoch]
            else:
                # warmup schedule
                lr_pos = int(-1 - bisect.bisect_left(self.config.lr_scheduler_epoch, epoch))
                if lr_pos < -3:
                    lr_scale = max(self.config.lr_rate[0] * (0.98 ** epoch), 0.03 )
                else:
                    lr_scale = self.config.lr_rate[lr_pos]
            return lr_scale
        
        # construct the learning rate scheduler
        scheduler = optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lr_foo
        )
        
        return [optimizer], [scheduler]

In [6]:
class EchoDataModule(pl.LightningDataModule):
    
    def __init__(self, train_dataset, eval_dataset, device_num):
        super().__init__()
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.device_num = device_num

    def train_dataloader(self):
        train_sampler = DistributedSampler(self.train_dataset, shuffle = False) if self.device_num > 1 else None
        train_loader = DataLoader(
            dataset = self.train_dataset,
            num_workers = htsat_config.num_workers,
            batch_size = htsat_config.batch_size // self.device_num,
            shuffle = False,
            sampler = train_sampler
        )
        return train_loader
    
    def val_dataloader(self):
        eval_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None
        eval_loader = DataLoader(
            dataset = self.eval_dataset,
            num_workers = htsat_config.num_workers,
            batch_size = htsat_config.batch_size // self.device_num,
            shuffle = False,
            sampler = eval_sampler
        )
        return eval_loader
    
    def test_dataloader(self):
        test_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None
        test_loader = DataLoader(
            dataset = self.eval_dataset,
            num_workers = htsat_config.num_workers,
            batch_size = htsat_config.batch_size // self.device_num,
            shuffle = False,
            sampler = test_sampler
        )
        
        return test_loader

In [7]:
# the complete dataset
complete_dataset = EchoDataset(htsat_config.dataset_path)

for item in complete_dataset:
    print(item['filename'], item['waveform'].shape, item['target'])

C:/Users/Andrew/OneDrive - Deakin University/DataSets/birdclef2022/jabwar\XC385329.ogg (256000,) tensor(1)
C:/Users/Andrew/OneDrive - Deakin University/DataSets/birdclef2022/spodov\XC304725.ogg (256000,) tensor(3)
C:/Users/Andrew/OneDrive - Deakin University/DataSets/birdclef2022/brant\XC594893.ogg (256000,) tensor(0)
C:/Users/Andrew/OneDrive - Deakin University/DataSets/birdclef2022/brant\XC552893.ogg (256000,) tensor(0)
C:/Users/Andrew/OneDrive - Deakin University/DataSets/birdclef2022/sheowl\XC474682.ogg (256000,) tensor(2)
C:/Users/Andrew/OneDrive - Deakin University/DataSets/birdclef2022/brant\XC610562.ogg (256000,) tensor(0)
C:/Users/Andrew/OneDrive - Deakin University/DataSets/birdclef2022/wiltur\XC153736.ogg (256000,) tensor(4)
C:/Users/Andrew/OneDrive - Deakin University/DataSets/birdclef2022/spodov\XC665481.ogg (256000,) tensor(3)
C:/Users/Andrew/OneDrive - Deakin University/DataSets/birdclef2022/sheowl\XC519559.ogg (256000,) tensor(2)
C:/Users/Andrew/OneDrive - Deakin Univer

In [8]:
# get the number of available GPUs
device_num = torch.cuda.device_count()

# the complete dataset
complete_dataset = EchoDataset(htsat_config.dataset_path)

# split the dataset train/validation/test 70%/15%/15%
split_datasets = torch.utils.data.random_split(complete_dataset, [0.80,0.20]) # ,0.15])

# assign the split datasets
train_dataset = split_datasets[0]
eval_dataset  = split_datasets[1]
# test_dataset  = split_datasets[2]

# create the audio data set pipeline
audio_pipeline = EchoDataModule(train_dataset, eval_dataset, device_num)

# construct the model trainer
trainer = pl.Trainer(
        deterministic=False,
        default_root_dir = htsat_config.workspace,
        gpus = device_num, 
        val_check_interval = 0.1,
        max_epochs = htsat_config.max_epoch,
        auto_lr_find = True,    
        sync_batchnorm = True,
        #callbacks = [checkpoint_callback],
        accelerator = "ddp" if device_num > 1 else None,
        num_sanity_val_steps = 0,
        resume_from_checkpoint = None, 
        replace_sampler_ddp = False,
        gradient_clip_val=1.0
    )

# construct the model
sed_model = HTSAT_Swin_Transformer(
        spec_size=htsat_config.htsat_spec_size,
        patch_size=htsat_config.htsat_patch_size,
        in_chans=1,
        num_classes=htsat_config.classes_num,
        window_size=htsat_config.htsat_window_size,
        config = htsat_config,
        depths = htsat_config.htsat_depth,
        embed_dim = htsat_config.htsat_dim,
        patch_stride=htsat_config.htsat_stride,
        num_heads=htsat_config.htsat_num_head
    )

# wrapper to track metrics during training 
model = SEDWrapper(
        sed_model = sed_model, 
        config = htsat_config,
        dataset = eval_dataset
    )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [9]:
def pre_train(model):
    ckpt = torch.load('swin_tiny_patch4_window7_224.pth', map_location="cpu")
    # load pretrain model
    ckpt = ckpt["model"]
    found_parameters = []
    unfound_parameters = []
    model_params = dict(model.state_dict())

    for key in model_params:
        m_key = key.replace("sed_model.", "")
        if m_key in ckpt:
            if m_key == "patch_embed.proj.weight":
                ckpt[m_key] = torch.mean(ckpt[m_key], dim = 1, keepdim = True)
            if m_key == "head.weight" or m_key == "head.bias":
                ckpt.pop(m_key)
                unfound_parameters.append(key)
                continue
            assert model_params[key].shape==ckpt[m_key].shape, "%s is not match, %s vs. %s" %(key, str(model_params[key].shape), str(ckpt[m_key].shape))
            found_parameters.append(key)
            ckpt[key] = ckpt.pop(m_key)
        else:
            unfound_parameters.append(key)
    print("pretrain param num: %d \t wrapper param num: %d"%(len(found_parameters), len(ckpt.keys())))
    print("unfound parameters: ", unfound_parameters)
    model.load_state_dict(ckpt, strict = False)
    model_params = dict(model.named_parameters())
    
# pre_train(sed_model)

In [10]:
# train the model
trainer.fit(model, audio_pipeline)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                   | Params
-----------------------------------------------------
0 | sed_model | HTSAT_Swin_Transformer | 1.4 M 
-----------------------------------------------------
237 K     Trainable params
1.1 M     Non-trainable params
1.4 M     Total params
5.417     Total estimated model params size (MB)


Epoch 0: 100%|██████████| 8/8 [00:19<00:00,  2.45s/it, loss=1.6, v_num=27, loss_step=1.600, acc=0.260] end epoch
Epoch 1: 100%|██████████| 8/8 [00:15<00:00,  1.89s/it, loss=1.6, v_num=27, loss_step=1.590, acc=0.202, loss_epoch=1.600]end epoch
Epoch 2: 100%|██████████| 8/8 [00:15<00:00,  1.89s/it, loss=1.59, v_num=27, loss_step=1.570, acc=0.202, loss_epoch=1.600]end epoch
Epoch 3: 100%|██████████| 8/8 [00:15<00:00,  1.98s/it, loss=1.59, v_num=27, loss_step=1.560, acc=0.250, loss_epoch=1.580]end epoch
Epoch 4: 100%|██████████| 8/8 [00:14<00:00,  1.87s/it, loss=1.58, v_num=27, loss_step=1.560, acc=0.269, loss_epoch=1.570]end epoch
Epoch 5: 100%|██████████| 8/8 [00:15<00:00,  1.94s/it, loss=1.57, v_num=27, loss_step=1.550, acc=0.269, loss_epoch=1.560]end epoch
Epoch 6: 100%|██████████| 8/8 [00:15<00:00,  1.88s/it, loss=1.57, v_num=27, loss_step=1.550, acc=0.269, loss_epoch=1.560]end epoch
Epoch 7: 100%|██████████| 8/8 [00:15<00:00,  1.93s/it, loss=1.56, v_num=27, loss_step=1.540, acc=0.250