In [1]:
import torch
import torch.nn as nn
import torchaudio
from torch.cuda.amp import autocast
from torchaudio.transforms import Resample
from torch.utils.data import Dataset, DataLoader

import timm
from timm.models.layers import to_2tuple, trunc_normal_

import IPython

import numpy as np
import pandas as pd

import os
import wget
import time
import datetime
import pickle

In [2]:
ANNOTATIONS_FILE = "UrbanSound8K/metadata/UrbanSound8K.csv"
AUDIO_DIR = "UrbanSound8K/audio/"

MEL_BINS = 128
TARGET_LENGTH = 1024
TARGET_SAMPLE_RATE = 16000

BATCH_SIZE = 32
EPOCHS = 20
LEARNING_RATE = 0.001
NUM_CLASSES = 10

In [3]:
if os.path.exists('audioset_10_10_0.4593.pth') == False:
    audioset_mdl_url = 'https://www.dropbox.com/s/cv4knew8mvbrnvq/audioset_0.4593.pth?dl=1'
    wget.download(audioset_mdl_url, out='audioset_10_10_0.4593.pth')

In [4]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()

        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

In [5]:
class ASTModel(nn.Module):
    """
    The AST model.
    :param label_dim: the label dimension, i.e., the number of total classes, it is 527 for AudioSet, 50 for ESC-50, and 35 for speechcommands v2-35
    :param fstride: the stride of patch spliting on the frequency dimension, for 16*16 patchs, fstride=16 means no overlap, fstride=10 means overlap of 6
    :param tstride: the stride of patch spliting on the time dimension, for 16*16 patchs, tstride=16 means no overlap, tstride=10 means overlap of 6
    :param input_fdim: the number of frequency bins of the input spectrogram
    :param input_tdim: the number of time frames of the input spectrogram
    """
    def __init__(self, label_dim=527, fstride=10, tstride=10, input_fdim=128, input_tdim=1024):

        super(ASTModel, self).__init__()
        
        # override timm input shape restriction
        timm.models.vision_transformer.PatchEmbed = PatchEmbed

        # if AudioSet pretraining is not used (but ImageNet pretraining may still apply)
        self.v = timm.create_model('vit_deit_tiny_distilled_patch16_224', pretrained=True)

        self.original_num_patches = self.v.patch_embed.num_patches
        self.oringal_hw = int(self.original_num_patches ** 0.5)
        self.original_embedding_dim = self.v.pos_embed.shape[2]
        self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim))

        # automatcially get the intermediate shape
        f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
        num_patches = f_dim * t_dim
        self.v.patch_embed.num_patches = num_patches

        # the linear projection layer
        new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
        new_proj.weight = torch.nn.Parameter(torch.sum(self.v.patch_embed.proj.weight, dim=1).unsqueeze(1))
        new_proj.bias = self.v.patch_embed.proj.bias
        self.v.patch_embed.proj = new_proj

        # the positional embedding
        # get the positional embedding from deit model, skip the first two tokens (cls token and distillation token), reshape it to original 2D shape (24*24).
        new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, self.original_num_patches, self.original_embedding_dim).transpose(1, 2).reshape(1, self.original_embedding_dim, self.oringal_hw, self.oringal_hw)
        # cut (from middle) or interpolate the second dimension of the positional embedding
        if t_dim <= self.oringal_hw:
            new_pos_embed = new_pos_embed[:, :, :, int(self.oringal_hw / 2) - int(t_dim / 2): int(self.oringal_hw / 2) - int(t_dim / 2) + t_dim]
        else:
            new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(self.oringal_hw, t_dim), mode='bilinear')
        # cut (from middle) or interpolate the first dimension of the positional embedding
        if f_dim <= self.oringal_hw:
            new_pos_embed = new_pos_embed[:, :, int(self.oringal_hw / 2) - int(f_dim / 2): int(self.oringal_hw / 2) - int(f_dim / 2) + f_dim, :]
        else:
            new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear')
        # flatten the positional embedding
        new_pos_embed = new_pos_embed.reshape(1, self.original_embedding_dim, num_patches).transpose(1,2)
        # concatenate the above positional embedding with the cls token and distillation token of the deit model.
        self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))

    def get_shape(self, fstride, tstride, input_fdim=128, input_tdim=1024):
        test_input = torch.randn(1, 1, input_fdim, input_tdim)
        test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
        test_out = test_proj(test_input)
        f_dim = test_out.shape[2]
        t_dim = test_out.shape[3]
        return f_dim, t_dim

    @autocast()
    def forward(self, x):
        """
        :param x: the input spectrogram, expected shape: (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
        :return: prediction
        """
        # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
        x = x.unsqueeze(1)
        x = x.transpose(2, 3)

        B = x.shape[0]
        x = self.v.patch_embed(x)
        cls_tokens = self.v.cls_token.expand(B, -1, -1)
        dist_token = self.v.dist_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, dist_token, x), dim=1)
        x = x + self.v.pos_embed
        x = self.v.pos_drop(x)
        for blk in self.v.blocks:
            x = blk(x)
        x = self.v.norm(x)
        x = (x[:, 0] + x[:, 1]) / 2

        x = self.mlp_head(x)
        return x

In [6]:
class UrbanSoundDataset(Dataset):
    def __init__(self, annotation_file, audio_dir, device, mel_bins, target_length, target_sample_rate):
        self.annotations = pd.read_csv(annotation_file)
        self.audio_dir = audio_dir
        self.device = device
        self.mel_bins = mel_bins
        self.target_length = target_length
        self.target_sample_rate = target_sample_rate
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, item):
        audio_sample_path = self._get_audio_sample_path(item)
        label = self._get_audio_sample_label(item)

        waveform, sr = torchaudio.load(audio_sample_path)
        if sr != target_sample_rate:
            resampler = Resample(sr, target_sample_rate,  dtype=waveform.dtype)
            waveform = resampler(waveform)

        fbank = torchaudio.compliance.kaldi.fbank(
            waveform,
            htk_compat=True,
            sample_frequency=target_sample_rate,
            use_energy=False,
            window_type="hanning",
            num_mel_bins=mel_bins,
            dither=0.0,
            frame_shift=10,
        )

        n_frames = fbank.shape[0]
        p = target_length - n_frames
        if p > 0:
            m = torch.nn.ZeroPad2d((0, 0, 0, p))
            fbank = m(fbank)
        elif p < 0:
            fbank = fbank[:target_length, :]

        fbank = (fbank - (-4.2677393)) / (4.5689974 * 2)
        fbank.shape

        feats_data = fbank.expand(target_length, mel_bins)
        feats_data

        return feats_data, label

    def _get_audio_sample_path(self, item):
        fold = f"fold{self.annotations.iloc[item, 5]}"
        path = os.path.join(self.audio_dir, fold, self.annotations.iloc[item, 0])
        return path
    
    def _get_audio_sample_label(self, item):
        return self.annotations.iloc[item, 6]

def create_data_loader(train_data, batch_size):
    train_dataloader = DataLoader(train_data, batch_size=batch_size)
    return train_dataloader

In [7]:
def train(ast_model, train_loader, n_epochs):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.set_grad_enabled(True)

    progress = []
    best_epoch, best_cum_epoch, best_acc = 0, 0, -np.inf
    global_step, epoch = 0, 0

    start_time = time.time()

    def _save_progress():
        progress.append([epoch, global_step, best_epoch, best_cum_epoch, best_acc, time.time() - start_time])
        with open("progress.pkl", "wb") as f:
            pickle.dump(progress, f)

    if not isinstance(ast_model, nn.DataParallel):
        ast_model = nn.DataParallel(ast_model)
    
    ast_model = ast_model.to(device)
    
    # Set up the optimizer
    main_metrics = "acc"
    trainables = [p for p in ast_model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(trainables, 0.0001, weight_decay=5e-7, betas=(0.95, 0.999))
    loss_fn = nn.CrossEntropyLoss()
    warmup = False
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, list(range(2, 1000, 1)), gamma=0.5)

    epoch += 1

    print("current #steps=%s, #epochs=%s" % (global_step, epoch))
    print("start training...")

    result = np.zeros([n_epochs, 10])
    ast_model.train()
    while epoch < n_epochs + 1:
        begin_time = time.time()
        end_time = time.time()
        ast_model.train()

        print('---------------')
        print(datetime.datetime.now())
        print("current #epochs=%s, #steps=%s" % (epoch, global_step))

        for i, (audio_input, labels) in enumerate(train_loader):
            B = audio_input.shape[0]
            audio_input = audio_input.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            with autocast():
                audio_output = ast_model(audio_input)
                if isinstance(loss_fn, torch.nn.CrossEntropyLoss):
                    loss = loss_fn(audio_output, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print_step = global_step % 5 == 0
            early_print_step = epoch == 0 and global_step % (5/10) == 0
            print_step = print_step or early_print_step

            if print_step and global_step != 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                  'Train Loss {loss:.4f}\t'.format(
                   epoch, i, len(train_loader), loss=loss), flush=True)

            end_time = time.time()
            global_step += 1
    torch.save(ast_model.state_dict(), 'ast_model.pth')       

In [None]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

usd = UrbanSoundDataset(ANNOTATIONS_FILE, AUDIO_DIR, device, MEL_BINS, TARGET_LENGTH, TARGET_SAMPLE_RATE)
train_loader = create_data_loader(usd, batch_size=BATCH_SIZE)

ast_model = ASTModel(label_dim=NUM_CLASSES)
train(ast_model, train_loader, n_epochs=EPOCHS)