In [None]:
import os, gc, random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict

import librosa

from tqdm.notebook import tqdm
from glob import glob

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn

from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader
import torch.nn.functional as F

import transformers
from transformers import ASTConfig, ASTFeatureExtractor, ASTModel

from sklearn.model_selection import StratifiedKFold, KFold
from sklearn.metrics import roc_auc_score

from time import time

import wandb

In [None]:
def get_logger(log_file='log.txt'):
    import logging
    import sys
    
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(message)s')
    # Logging to file
    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    # Logging to console
    ch = logging.StreamHandler(sys.stdout)
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    
    return logger

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
def wandb_init(project_name, run_name, config):
    config_dict = {
        k: v for k, v in config.__dict__.items() if not k.startswith('_') and not callable(v) and k != 'copy'
    }
    run = wandb.init(project=project_name, name=run_name, config=config_dict)
    return run

In [None]:
DRIVE_FOLDER = "." #"/content/drive/MyDrive/Colab Notebooks"
KEEP_COLS = ['category_number', 'common_name', 'audio_length', 'type', 'remarks', 'quality', 'scientific_name', 'mp3_link', 'region']

class Config:
    # path
    dataset_dir = f"{DRIVE_FOLDER}/Audio_XenoCanto"
    labels_list = f"{DRIVE_FOLDER}/xeno_labels.csv"
    model_name = "ast_baseline"
    backbone_name = "MIT/ast-finetuned-audioset-10-10-0.4593"
    # number of classes in the dataset
    n_classes = 397 
    # Audio parameters
    audio_sr = 16000 #Hz
    segment_length = 10  #s
    fft_window = 0.025 #s
    hop_window_length = 0.01 #s
    n_mels = 128
    low_cut = 1000 #Hz
    high_cut = 8000 #Hz
    top_db = 100
    # Training parameters
    batch_size = 4 
    num_workers = 0
    n_splits = 5
    log_dir = f"{DRIVE_FOLDER}/training_logs"
    max_lr = 1e-5
    epochs = 10
    weight_decay = 0.01
    lr_final_div = 1000
    amp = True
    grad_accum_steps = 1
    max_grad_norm = 1e7
    print_epoch_freq = 1
    print_freq = 200
    # model parameters
    n_decoder_layers = 6
    n_decoder_heads = 6
    ff_dim_decoder = 2048
    # seed
    random_seed = 2046
    
    @classmethod
    def copy(cls):
        new_class = type('CustomConfig', (cls,), {k: v for k, v in cls.__dict__.items() if not k.startswith('__') and not callable(v)})
        return new_class
    
config = Config.copy()

if not os.path.exists(config.log_dir):
    os.makedirs(config.log_dir)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

seed_everything(config.random_seed)

In [None]:
df_audio_meta = pd.read_csv(f"{config.dataset_dir}/metadata.csv", nrows=None)
df_audio_meta = df_audio_meta.dropna().reset_index(drop=True)

# Filter out files that do not exist
df_audio_meta['file_exists'] = df_audio_meta['file_name'].apply(lambda x: os.path.exists(f"{config.dataset_dir}/{x}"))
df_audio_meta = df_audio_meta[df_audio_meta['file_exists']].reset_index(drop=True)

# parse scientific names
df_audio_meta['scientific_name'] = df_audio_meta['scientific_name'].apply(lambda x: "_".join(x.split(" ")))

# drop species with less than 2 samples
class_counts = df_audio_meta['scientific_name'].value_counts()
print(f"Number of classes with less than 2 samples: {len(class_counts[class_counts < 2])}")

df_audio_meta = df_audio_meta[df_audio_meta['scientific_name'].isin(class_counts[class_counts > 1].index)].copy().reset_index(drop=True)

# encode scientific names to label ids
label_ids_list = df_audio_meta['scientific_name'].unique().tolist()
label_ids_list.sort()
label_to_id = {label: i for i, label in enumerate(label_ids_list)}
df_audio_meta['species_id'] = df_audio_meta['scientific_name'].map(label_to_id)

# drop samples with no labels
df_audio_meta.dropna(subset=['species_id'], inplace=True)
df_audio_meta.reset_index(drop=True, inplace=True)
df_audio_meta['species_id'] = df_audio_meta['species_id'].astype(int)

print(f"Number of classes in dataset: {df_audio_meta['species_id'].nunique()}")
print(f'Number of samples:', len(df_audio_meta))

# save the number of classes in the config
config.n_classes = df_audio_meta['species_id'].nunique()

df_audio_meta.head(5)

In [None]:
class BirdDecoder(nn.Module):
    def __init__(self, enc_dim, d_dim=768, n_layers=2, n_heads=6, ff_dim=3072, dropout=0.0):
        super().__init__()
        
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_dim,
            nhead=n_heads,
            dim_feedforward=ff_dim,
            dropout=dropout,
            activation='gelu', 
            batch_first=True
        )
        
        self.decoder_embed = nn.Linear(enc_dim, d_dim, bias=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, d_dim)).to(DEVICE)
        
        # random initialization of position embedding
        self.pos_embedding = nn.Parameter(torch.randn(1, 1, d_dim)).to(DEVICE)
        
    def forward(self, tgt, memory, mask_indices):
        
        # tgt: (batch_size, seq_len, enc_dim) sequence of image patches
        # memory: (batch_size, seq_len, enc_dim) embeddings from encoder
        # position_embedding: (seq_len, enc_dim) positional embedding for each patch
        # mask_indices: (batch_size, n_mask_indices) indices of patches to mask
        
        batch_size, seq_len, d_dim = tgt.size()
        
        # Embed and add position embedding
        pos_embedding = self.pos_embedding.expand(batch_size, seq_len, -1)
        tgt = self.decoder_embed(tgt) + pos_embedding
        
        # Mask tokens
        mask_tokens = self.mask_token.repeat(batch_size, mask_indices.size(1), 1)
        tgt.scatter_(1, mask_indices.unsqueeze(-1).expand(-1, -1, d_dim), mask_tokens)
        memory.scatter_(1, mask_indices.unsqueeze(-1).expand(-1, -1, d_dim), mask_tokens)
        
        tgt_key_padding_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=tgt.device)
        tgt_key_padding_mask.scatter_(1, mask_indices, True)
        
        mem_key_padding_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=tgt.device)
        mem_key_padding_mask.scatter_(1, mask_indices, True)
        
        # Decode
        decoded_tokens = self.decoder(
            tgt, 
            memory, 
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=mem_key_padding_mask
        )
        
        return decoded_tokens
    
    
class BirdEncoder(nn.Module):
    def __init__(self, backbone_name):
        super().__init__()
        
        backbone_config = ASTConfig.from_pretrained(backbone_name)
        self.encoder_model = ASTModel.from_pretrained(backbone_name, config=backbone_config)
        self.hidden_size = self.encoder_model.config.hidden_size

    def forward(self, x):
        
        # x: (batch_size, t_len, f_len) input spectrograms
        # output: (batch_size, seq_len, enc_dim) embeddings
        
        spec_embeddings = self.encoder_model.embeddings(x)
        encoder_outputs = self.encoder_model.encoder(spec_embeddings)
        
        return encoder_outputs.last_hidden_state, spec_embeddings
    
    
class BirdASTForPretrain(nn.Module):

    def __init__(self, backbone_name, mask_ratio=0.75, d_dim=768, n_layers=2, n_heads=6, ff_dim=3072, dropout=0.0):
        super().__init__()
        
        self.encoder = BirdEncoder(backbone_name)
        
        enc_dim = self.encoder.hidden_size
        self.decoder = BirdDecoder(enc_dim, d_dim, n_layers, n_heads, ff_dim, dropout)
        
        self.mask_ratio = mask_ratio
        
    def forward(self, x):
        # x: (batch_size, t_len, f_len) input spectrograms
        # output: (batch_size, t_len, f_len) reconstructed spectrograms
        
        batch_size, t_len, f_len = x.size()
        memory, spec_embeddings = self.encoder(x)
        
        # Mask indices
        seq_len = memory.size(1)
        mask_indices = torch.randint(0, seq_len, (batch_size, int(self.mask_ratio * seq_len)), device=x.device)
        
        # Decode
        decoded_tokens = self.decoder(spec_embeddings, memory, mask_indices)
        
        return decoded_tokens 
    

def get_shape(config):
    # see Karpathy's cs231n blog on how to calculate the output dimensions
    # https://cs231n.github.io/convolutional-networks/#conv
    frequency_out_dimension = (config.num_mel_bins - config.patch_size) // config.frequency_stride + 1
    time_out_dimension = (config.max_length - config.patch_size) // config.time_stride + 1

    return frequency_out_dimension, time_out_dimension


def patchify_spectrogram(spectrograms, config):
    # spectrograms: (batch_size, max_length, num_mel_bins)
    
    batch_size = spectrograms.size(0)
    patch_size = config.patch_size
    time_stride = config.time_stride
    frequency_stride = config.frequency_stride

    patches = spectrograms.unfold(1, patch_size, time_stride).unfold(2, patch_size, frequency_stride)
    patches = patches.unsqueeze(3).expand(-1, -1, -1, 3, -1, -1)

    num_patches_y = patches.size(1)
    num_patches_x = patches.size(2)

    flattened_patches = patches.contiguous().view(batch_size, num_patches_y * num_patches_x, -1)

    # -> (batch_size, num_patches_y * num_patches_x, 3*patch_size*patch_size)
    return flattened_patches

In [None]:
class BirdSongDataset(Dataset):
    
    def __init__(self, df_audio_meta, config):
        self.df_audio_meta = df_audio_meta
        self.feature_extractor = ASTFeatureExtractor()
        self.config = config
    
    def __len__(self):
        return len(self.df_audio_meta)

    def __getitem__(self, idx):
        row = self.df_audio_meta.iloc[idx]
        audio_path = f"{self.config.dataset_dir}/{row['file_name']}"
        audio_arr, sr = self.get_audio(audio_path)
        spec = self.feature_extractor(audio_arr, sampling_rate=sr, padding="max_length", return_tensors="pt")
        return spec['input_values'].squeeze(0), row['species_id']

    def get_audio(self, audio_path):
        audio, sr = librosa.load(audio_path, sr=self.config.audio_sr)
        return audio, sr

def collate_fn(batch):
    inputs = [x[0] for x in batch]
    targets = [x[1] for x in batch]
    data_dict = {
        "input_ids": torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True, padding_value=0),
        "labels": torch.tensor(targets)
    }
    return data_dict

In [None]:
# # test the dataset, dataloader and patchify function
# backbone_name = "MIT/ast-finetuned-audioset-10-10-0.4593"
# backbone_config = ASTConfig.from_pretrained(backbone_name)

# bs_dataset = BirdSongDataset(df_audio_meta, config)
# bs_dataloader = DataLoader(bs_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

# for batch in bs_dataloader:
#     spectrograms = batch['input_ids']
#     labels = batch['labels']
#     break

# spectrogram_patches = patchify_spectrogram(spectrograms, backbone_config)

# print(spectrograms.size())
# print(spectrogram_patches.size())

# fig, ax = plt.subplots(1, 1, figsize=(6, 12))
# ax.imshow(spectrograms[0].cpu().numpy(), aspect='auto', vmax=1, vmin=-1)
# ax.set_title("Spectrogram")
# plt.show()

In [None]:
## Plot patches
# fig, axes = plt.subplots(time_dim, freq_dim, figsize=(freq_dim * 2, time_dim * 2))
# patch_size = 16
# for i in range(time_dim):
#     for j in range(freq_dim):
#         patch_index = i * freq_dim + j
#         patch = spectrogram_patches[0, patch_index].view(3, patch_size, patch_size).permute(1, 2, 0).numpy()
#         ax = axes[i, j]
#         ax.imshow(patch[:, :, 0], aspect='auto',  vmax=1, vmin=-1)
#         ax.set_title(f'Patch {patch_index + 1}')
#         ax.axis('off') 

# plt.tight_layout()
# plt.show()

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.value = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.value = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
# training loop
# train the BirdASTForPretrain with self-supervised learning

backbone_config = ASTConfig.from_pretrained(config.backbone_name)

bs_dataset = BirdSongDataset(df_audio_meta, config)
bs_dataloader = DataLoader(bs_dataset, batch_size=config.batch_size, shuffle=False, collate_fn=collate_fn)

model = BirdASTForPretrain(
    config.backbone_name, 
    mask_ratio=0.75, 
    d_dim=768, 
    n_layers=config.n_decoder_layers, 
    n_heads=config.n_decoder_heads, 
    ff_dim=config.ff_dim_decoder, 
    dropout=0.0
    )
model.to(DEVICE)

optimizer = AdamW(model.parameters(), lr=config.max_lr, weight_decay=config.weight_decay)
scheduler = OneCycleLR(
    optimizer, 
    max_lr=config.max_lr, 
    final_div_factor=config.lr_final_div, 
    steps_per_epoch=len(bs_dataloader), 
    epochs=config.epochs
    )

scaler = GradScaler(enabled=config.amp)
loss_fn = nn.MSELoss(reduction='mean')

loss_records = defaultdict(list)

wandb_init("BirdAST_Pretrain", "BirdAST_Pretrain_Large", config)
logger = get_logger(f"{config.log_dir}/BirdAST_Pretrain.log")

best_loss = np.inf

for epoch in range(config.epochs):
    
    loss_meter = AverageMeter()
    
    for i, batch in tqdm(enumerate(bs_dataloader), total=len(bs_dataloader)):
        
        spectrograms = batch['input_ids'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)
        
        optimizer.zero_grad()
        
        with autocast(enabled=config.amp):
            reconstructed_spectrograms = model(spectrograms)
            spectrogram_patches = patchify_spectrogram(spectrograms, backbone_config)
            loss = loss_fn(reconstructed_spectrograms[:, 2:, :], spectrogram_patches)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        scheduler.step()
        
        loss_meter.update(loss.item())
        
        wandb.log({"Loss": loss.item(), "Learning Rate": scheduler.get_last_lr()[0], "Epoch": epoch, "Batch": i})
        
        if i % config.print_freq == 0:
            logger.info(f"Epoch: {epoch} | Batch: {i}, Loss: {loss_meter.avg}")
            
    wandb.log({"Epoch_Loss": loss_meter.avg, "Epoch": epoch})
    loss_records[f'epoch_{i+1}'].append(loss_meter.avg)
    logger.info('-'*10 + "\n" + f"Epoch: {epoch} | Loss: {loss_meter.avg}" + "\n" + '-'*10)
    
    # save model
    model_path = f"{config.log_dir}/BirdAST_Pretrain_epoch_{epoch}.pth"
    if loss_meter.avg < best_loss:
        best_loss = loss_meter.avg
        torch.save(model.state_dict(), model_path)
        logger.info(f"Best Model loss = {loss_meter.avg} | saved at: {model_path}")
        
    # clear memory
    del spectrograms, labels, reconstructed_spectrograms, spectrogram_patches 
    gc.collect()
    torch.cuda.empty_cache()
    
    

In [None]:
for i, batch in tqdm(enumerate(bs_dataloader), total=len(bs_dataloader)):
        
    spectrograms = batch['input_ids'].to(DEVICE)
    labels = batch['labels'].to(DEVICE)
    
    reconstructed_spectrograms = model(spectrograms)
        
    break 

In [None]:
spectrogram_patches = patchify_spectrogram(spectrograms, backbone_config)

In [None]:
def unpathify(spec_patches, time_dim=101, freq_dim=12, patch_size=16, n_channels=3):
    # spec_patches: (batch_size, num_patches, patch_size*patch_size*3
    
    batch_size, num_patches, d_dim = spec_patches.size()
    
    patches = spec_patches.view(batch_size, time_dim, freq_dim, n_channels, patch_size, patch_size)
   
    return patches

In [None]:
spec_unpatched = unpathify(spectrogram_patches)

spec_unpatched.size()

In [None]:
# Plot patches

time_dim = 101
freq_dim = 12

fig, axes = plt.subplots(time_dim, freq_dim, figsize=(freq_dim * 2, time_dim * 2))
patch_size = 16
for i in range(time_dim):
    for j in range(freq_dim):
        patch = spec_unpatched[0, i, j].cpu().numpy()
        ax = axes[i, j]
        ax.imshow(patch[0, :, :], aspect='auto',  vmax=1, vmin=-1)
        ax.set_title(f'Patch {i}, {j}')
        ax.axis('off') 

plt.tight_layout()
plt.show()

In [None]:
reconstructed_specs_unpatched = unpathify(reconstructed_spectrograms[:, 2:, :])

print(reconstructed_specs_unpatched.size())

fig, axes = plt.subplots(time_dim, freq_dim, figsize=(freq_dim * 2, time_dim * 2))
patch_size = 16
for i in range(time_dim):
    for j in range(freq_dim):
        patch = reconstructed_specs_unpatched[0, i, j].detach().cpu().numpy()
        ax = axes[i, j]
        ax.imshow(patch[0, :, :], aspect='auto',  vmax=1, vmin=-1)
        ax.set_title(f'Patch {i}, {j}')
        ax.axis('off') 

plt.tight_layout()
plt.show()