In [1]:
import torch
from datasets import load_dataset, Audio
from transformers import (
    ASTFeatureExtractor,
    ASTForAudioClassification,
    Trainer,
    TrainingArguments,
)
import evaluate
import numpy as np

torch.cuda.empty_cache()

In [2]:
class AudioPreprocessor:
    def __init__(self):
        """Init AST feature extractor."""
        self.feature_extractor = ASTFeatureExtractor.from_pretrained(
            "MIT/ast-finetuned-audioset-10-10-0.4593"
        )

    def extract_features(self, batch):
        """Extract audio feats."""
        features = self.feature_extractor(
            raw_speech=[x["array"] for x in batch["audio"]],
            sampling_rate=16000,
            return_tensors="pt",
            padding=True,
            truncation=True,
        )
        batch["input_values"] = features["input_values"]
        return batch

    def preprocess(self, dataset):
        """Prep dataset for AST."""
        dataset = dataset.rename_column("label", "labels")
        dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
        return dataset.map(
            self.extract_features,
            batched=True,
            batch_size=32,
            remove_columns=["audio"],
        )

    def __call__(self, dataset):
        "Prep when called as a function."
        return self.preprocess(dataset)

In [3]:
class Data:
    def __init__(self, audio_preprocessor):
        """Init preprocessing and loaders."""
        self.audio_preprocessor = audio_preprocessor
        self.train, self.val = self.preprocess_splits(*self.get_data())

    def get_data(self):
        """Load train & val splits."""
        train = load_dataset("confit/esc50-parquet", "fold1", split="train")
        val = load_dataset("confit/esc50-parquet", "fold1", split="test")
        return train, val

    def preprocess_splits(self, train, val):
        """Prep train & val."""
        train = self.audio_preprocessor(train)
        val = self.audio_preprocessor(val)
        return train, val

In [4]:
class Metrics:
    def __init__(self):
        """Init evaluation metrics."""
        self.accuracy = evaluate.load("accuracy")
        self.f1 = evaluate.load("f1")
        self.precision = evaluate.load("precision")
        self.recall = evaluate.load("recall")
        self.average = 'micro'

    def eval_accuracy(self, predictions, labels):
        """Compute accuracy."""
        return self.accuracy.compute(predictions=predictions, references=labels)

    def eval_f1(self, predictions, labels):
        """Compute F1 score."""
        return self.f1.compute(
            predictions=predictions, references=labels, average=self.average
        )

    def eval_precision(self, predictions, labels):
        """Compute precision."""
        return self.precision.compute(
            predictions=predictions, references=labels, average=self.average
        )

    def eval_recall(self, predictions, labels):
        """Compute recall."""
        return self.recall.compute(
            predictions=predictions, references=labels, average=self.average
        )
    
    def compute_metrics(self, eval_pred):
        """Compute all metrics."""
        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)
        metrics = {}
        metrics.update(self.eval_accuracy(predictions, labels))
        metrics.update(self.eval_f1(predictions, labels))
        metrics.update(self.eval_precision(predictions, labels))
        metrics.update(self.eval_recall(predictions, labels))
        return metrics
    
    def __call__(self, eval_pred):
        "Compute all metrics when called as a function."
        return self.compute_metrics(eval_pred)


In [5]:
class Train:
    def __init__(self, data):
        """Init training setup."""
        self.data = data
        self.model = ASTForAudioClassification.from_pretrained(
            "MIT/ast-finetuned-audioset-10-10-0.4593"
        )
        self.metrics = Metrics()
        self.args = TrainingArguments(
            output_dir="./checkpoints",
            eval_strategy="epoch",
            per_device_train_batch_size=32,
            per_device_eval_batch_size=32,
            learning_rate=2e-5,
            num_train_epochs=5,
            logging_strategy="no",
            save_strategy="best",
            load_best_model_at_end=True,
        )
        self.trainer = Trainer(
            model=self.model,
            args=self.args,
            train_dataset=data.train,
            eval_dataset=data.val,
            compute_metrics=self.metrics,
        )

    def train(self):
        """Run training."""
        self.trainer.train(resume_from_checkpoint=True)

    def val(self):
        """Run validation."""
        self.trainer.validate()

In [6]:
audio_preprocessor = AudioPreprocessor()
data = Data(audio_preprocessor)



In [7]:
train = Train(data)
train.train()

Epoch,Training Loss,Validation Loss


In [8]:
train.trainer.evaluate()

{'eval_loss': 0.31918656826019287,
 'eval_accuracy': 0.93,
 'eval_f1': 0.93,
 'eval_precision': 0.93,
 'eval_recall': 0.93,
 'eval_runtime': 29.301,
 'eval_samples_per_second': 13.651,
 'eval_steps_per_second': 0.444,
 'epoch': 5.0}