<a href="https://colab.research.google.com/github/Vaarun-C/DnD-map-visualizer/blob/main/AST/Inference_with_the_Audio_Spectogram_Transformer_to_classify_audio.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Set-up environment

First we install 🤗 Transformers.

In [1]:
!pip install -q git+https://github.com/huggingface/transformers.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for transformers (pyproject.toml) ... [?25l[?25hdone


## Load audio

Let's load some audio on which we'd like to test the model.

In [47]:
from huggingface_hub import hf_hub_download
import IPython
import librosa
import soundfile as sf
import torchaudio
from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer, ASTFeatureExtractor
import torch
import os
import numpy as np

In [37]:
filepath = '/content/recordingFROMCSV.wav'

IPython.display.Audio(filepath)

In [38]:
# Input and output file paths
output_path = "output_resampled.wav"

# Load the WAV file
data, sample_rate = librosa.load(filepath, sr=None)  # sr=None keeps original sample rate

# Define the new sample rate
new_sample_rate = 16000  # Target sample rate in Hz

# Resample the audio
resampled_data = librosa.resample(data, orig_sr=sample_rate, target_sr=new_sample_rate)

# Save the resampled audio
sf.write(output_path, resampled_data, new_sample_rate)

print(f"Resampled .wav file saved at: {output_path}")

Resampled .wav file saved at: output_resampled.wav


In [39]:
filepath = '/content/output_resampled.wav'

IPython.display.Audio(filepath)

## Prepare audio for the model (using feature extractor)

We can prepare the audio using ASTFeatureExtractor, which turns it into a tensor of shape (batch_size, time_dimension, frequency_dimension). This is also known as a spectrogram.

In [40]:
feature_extractor = ASTFeatureExtractor()

In [41]:
waveform, sampling_rate = torchaudio.load(filepath)
waveform = waveform.squeeze().numpy()

waveform.shape

(160000,)

In [42]:
inputs = feature_extractor(waveform, sampling_rate=sampling_rate, padding="max_length", return_tensors="pt")
input_values = inputs.input_values
print(input_values.shape)

torch.Size([1, 1024, 128])


## Load model

Next we load one of the models that the AST authors released from the [hub](https://huggingface.co/models?other=audio-spectrogram-transformer).

This one was fine-tuned on AudioSet, an important benchmark for audio classification.

In [43]:
model = AutoModelForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

## Forward pass

Next let's forward the audio through the model! We perform an argmax on the model's logits to get the predicted class index. We use model.config.id2label to turn that back into text.

In [44]:
with torch.no_grad():
  outputs = model(input_values)

In [45]:
predicted_class_idx = outputs.logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

Predicted class: Jet engine


# Finetuning

In [46]:
def prepare_dataset(data_dir):
    dataset_dict = {
        "audio": [],
        "label": []
    }

    # Label mapping
    label2id = {"no_drone": 0, "drone": 1}

    # Process each class directory
    for label in ["no_drone", "drone"]:
        class_dir = os.path.join(data_dir, label)
        for audio_file in os.listdir(class_dir):
            if audio_file.endswith(('.wav')):
                dataset_dict["audio"].append(os.path.join(class_dir, audio_file))
                dataset_dict["label"].append(label2id[label])

    return dataset_dict

# Load your dataset
dataset_dict = prepare_dataset("/content/dataset")
dataset = Dataset.from_dict(dataset_dict)

# Split dataset
dataset = dataset.train_test_split(test_size=0.2, shuffle=True, seed=42)

ModuleNotFoundError: No module named 'datasets'

In [None]:
feature_extractor = ASTFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

def preprocess_function(examples):
    audio_arrays = [x["array"] for x in examples["audio"]]
    inputs = feature_extractor(
        audio_arrays,
        sampling_rate=feature_extractor.sampling_rate,
        max_length=16000 * 10,  # 10 seconds max length
        truncation=True,
        padding="max_length",
        return_tensors="pt",
    )
    return inputs


encoded_dataset = dataset.map(preprocess_function, remove_columns=dataset.column_names, batched=True)

In [None]:
model = AutoModelForAudioClassification.from_pretrained(
    "MIT/ast-finetuned-audioset-10-10-0.4593",
    num_labels=2,  # Binary classification
    label2id={"no_signal": 0, "signal": 1},
    id2label={0: "no_signal", 1: "signal"}
)

In [None]:
training_args = TrainingArguments(
    output_dir="./ast-finetuned-signal-detection",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    weight_decay=0.01,
    push_to_hub=False,
    logging_dir='./logs',
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["test"],
    tokenizer=feature_extractor,
)

In [None]:
trainer.train()

In [None]:
model.save_pretrained("./ast-finetuned-signal-detection")
feature_extractor.save_pretrained("./ast-finetuned-signal-detection")

# Inference

In [None]:
def predict_audio(audio_path, model, feature_extractor):
    waveform, sampling_rate = torchaudio.load(audio_path)
    waveform = waveform.squeeze().numpy()

    inputs = feature_extractor(
        waveform,
        sampling_rate=sampling_rate,
        padding="max_length",
        return_tensors="pt"
    )

    with torch.no_grad():
        outputs = model(inputs.input_values)

    predicted_class_idx = outputs.logits.argmax(-1).item()
    predicted_label = model.config.id2label[predicted_class_idx]

    return predicted_label

# Example usage
model = AutoModelForAudioClassification.from_pretrained("./ast-finetuned-signal-detection")
feature_extractor = ASTFeatureExtractor.from_pretrained("./ast-finetuned-signal-detection")

result = predict_audio("path/to/audio/file.wav", model, feature_extractor)
print(f"Predicted class: {result}")