In [1]:
%cd ..

/home/sazerlife/projects/courses/itmo/semester-2/event_detection/lab4-kaggle-audioset


In [2]:
import json
import random
from pathlib import Path
from typing import Dict, List, Set, Tuple

import librosa
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import soundfile as sf
import torch
import torch.nn as nn
import torch.optim as opt
from scipy import stats as st
from torch.nn.modules.loss import _Loss
from torch.utils.data import TensorDataset, DataLoader
import torchaudio.transforms as T
from torchlibrosa.augmentation import SpecAugmentation
from torchlibrosa.stft import LogmelFilterBank, Spectrogram
from torchmetrics.classification import MultilabelAccuracy, MultilabelF1Score, Accuracy, F1Score
from torchvision.transforms import Compose
from tqdm import tqdm
from transformers import ASTConfig, ASTFeatureExtractor, ASTModel

from sklearn.utils.class_weight import compute_class_weight
from src.utils.train_val_split import train_val_split
from torchvision.models import resnet34


tqdm.pandas()


SEED=12345

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

DEVICE = "cuda:0"
SAMPLE_RATE = 16000
DATA_PATH = Path("data/raw/")

train_csv_path = DATA_PATH / "train.csv"
train_audio_path = DATA_PATH / "audio_train"

test_csv_path = DATA_PATH / "test.csv"
test_audio_path = DATA_PATH /  "audio_test"


EXPERIMENTS_PATH = Path("experiments/resnet34/")
submission_csv_path = EXPERIMENTS_PATH / "submission.csv"

In [11]:
class InferWrapper:
    def __get_resnet34(self) -> nn.Module:
        resnet_model = resnet34(pretrained=False)
        resnet_model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        resnet_model.fc = nn.Linear(512, 41)

        state_dict = torch.load(EXPERIMENTS_PATH / "checkpoint-55-epoch.pt")
        resnet_model.load_state_dict(state_dict)
        resnet_model = resnet_model.to(DEVICE)

        return resnet_model

    def __init__(self, sample_rate=16000, n_fft = 1024, win_length = None, hop_length = 512, n_mels = 128) -> None:
        self.melspec_transform = T.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            win_length=win_length,
            hop_length=hop_length,
            center=True,
            pad_mode="reflect",
            power=2.0,
            norm="slaney",
            onesided=True,
            n_mels=n_mels,
            mel_scale="htk",
            normalized=True,
        )
        self.model = self.__get_resnet34()

        labels = [
            'Acoustic_guitar', 'Applause', 'Bark', 'Bass_drum', 'Burping_or_eructation', 'Bus', 'Cello', 
            'Chime', 'Clarinet', 'Computer_keyboard', 'Cough', 'Cowbell', 'Double_bass', 'Drawer_open_or_close', 
            'Electric_piano', 'Fart', 'Finger_snapping', 'Fireworks', 'Flute', 'Glockenspiel', 'Gong',
            'Gunshot_or_gunfire', 'Harmonica', 'Hi-hat', 'Keys_jangling', 'Knock', 'Laughter', 'Meow', 'Microwave_oven', 
            'Oboe', 'Saxophone', 'Scissors', 'Shatter', 'Snare_drum', 'Squeak', 'Tambourine', 'Tearing', 'Telephone',
            'Trumpet', 'Violin_or_fiddle', 'Writing'
        ]
        self.labels = {k: v for k, v in enumerate(labels)}

    @torch.no_grad()
    def infer_frame(self, audio: torch.FloatTensor):
        melspectrogram = self.melspec_transform(audio).unsqueeze(0)
        prediction = self.model(melspectrogram.to(DEVICE))
        return torch.softmax(prediction[0], dim=-1).cpu()

    @torch.no_grad()
    def __call__(self, audio_path: str) -> str:
        audio, sr = sf.read(audio_path, dtype="float32", always_2d=True)
        assert sr == SAMPLE_RATE
        audio = torch.from_numpy(audio.T)
        
        if audio.shape[-1] / sr <= 2.0:
            audio_padded = torch.zeros(1, sr * 5, dtype=audio.dtype)
            audio_padded[:, :audio.shape[-1]] = audio[:, :]
            prediction = self.infer_frame(audio_padded)
        
        elif audio.shape[-1] / sr < 10.0:
            prediction = self.infer_frame(audio)
        
        else:
            win_len = 5 * sr
            hop_len = sr

            padding_size = win_len - (audio.shape[-1] % hop_len)
            audio_padded = torch.zeros(1, audio.shape[-1] + padding_size, dtype=audio.dtype)
            audio_padded[:, :audio.shape[-1]] = audio[:, :]
            audio = audio_padded
            
            predictions = list()
            for idx in range(0, audio.shape[-1], hop_len):
                prediction = self.infer_frame(audio[:, idx : idx + win_len])
                # predictions.append(torch.round(prediction, decimals=2))
                if prediction.max() > 0.15:
                    predictions.append(prediction)

            # for pred in predictions:
            #     print(pred.max(), end=", ")
            # print()
            # print(predictions)
            prediction = torch.vstack(predictions).mean(0)
        
        label_idx  = prediction.argmax(-1).item()
        return self.labels[label_idx]

In [12]:
infer_wrapper = InferWrapper()



In [13]:
train_csv = pd.read_csv(train_csv_path)
labels = list()

for fname, _ in tqdm(train_csv.values):
    label = infer_wrapper(train_audio_path / fname)
    labels.append(label)

100%|██████████| 5683/5683 [02:07<00:00, 44.70it/s]


In [14]:
train_csv['predicted'] = labels

In [15]:
from sklearn.metrics import accuracy_score, f1_score

In [16]:
accuracy_score(train_csv["label"].values, train_csv["predicted"].values), f1_score(train_csv["label"].values, train_csv["predicted"].values, average="weighted")

(0.05067745908850959, 0.01746065986170221)

In [24]:
accuracy_score(train_csv["label"].values, train_csv["predicted"].values), f1_score(train_csv["label"].values, train_csv["predicted"].values, average="weighted")

(0.051909202885799755, 0.01796734257852576)

In [27]:
train_csv['predicted'].value_counts()

Squeak              2629
Gong                2177
Cello                322
Trumpet              239
Violin_or_fiddle     239
Fart                  68
Knock                  7
Flute                  2
Name: predicted, dtype: int64

In [None]:
test_csv = pd.read_csv(test_csv_path)
labels = list()

for fname, in tqdm(test_csv.values):
    label = infer_wrapper(test_audio_path / fname)
    labels.append(label)