# Pre-trained models and datasets for audio classification

## Keyword spotting (KWS)

### Minds-14 (Dataset)

In [13]:
from datasets import load_dataset

minds = load_dataset("PolyAI/minds14", name="en-US", split="train")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [14]:
from transformers import pipeline

classifier = pipeline(
    "audio-classification",
    model="anton-l/xtreme_s_xlsr_300m_minds14"
)

In [15]:
classifier(minds[0]["audio"])

[{'score': 0.9984146356582642, 'label': 'joint_account'},
 {'score': 0.00043745661969296634, 'label': 'business_loan'},
 {'score': 0.0004359595768619329, 'label': 'cash_deposit'},
 {'score': 0.0001263972808374092, 'label': 'atm_limit'},
 {'score': 0.00010039484914159402, 'label': 'balance'}]

In [16]:
id2label_fn = minds.features["intent_class"].int2str

In [17]:
id2label_fn(minds[0]["intent_class"])

'joint_account'

### Speech Commands (Dataset)

In [18]:
from datasets import load_dataset

speech_commands = load_dataset(
    "speech_commands", "v0.02", split="validation", streaming=True
)
sample = next(iter(speech_commands))

In [19]:
classifier = pipeline(
    "audio-classification", model="MIT/ast-finetuned-speech-commands-v2"
)
classifier(sample["audio"].copy())

config.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/342M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/295 [00:00<?, ?B/s]

[{'score': 0.9999892711639404, 'label': 'backward'},
 {'score': 1.7504938796264469e-06, 'label': 'happy'},
 {'score': 6.703033363919531e-07, 'label': 'follow'},
 {'score': 5.805901537314639e-07, 'label': 'stop'},
 {'score': 5.614546694232558e-07, 'label': 'up'}]

In [24]:
speech_commands_id2label_fn = speech_commands.features["label"].int2str

In [26]:
speech_commands_id2label_fn(sample["label"])

'backward'

## Language identification (LID)

### FLEURS (Few-shot Learning Evaluation of Universal Representations of Speech) (Dataset)

In [27]:
from datasets import load_dataset


fleurs = load_dataset("google/fleurs", "all", split="validation", streaming=True)
sample = next(iter(fleurs))

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/12.6k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/13.3k [00:00<?, ?B/s]

In [28]:
from transformers import pipeline
classifier = pipeline(
    "audio-classification", model="sanchit-gandhi/whisper-medium-fleurs-lang-id"
)

config.json:   0%|          | 0.00/6.64k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/615M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/339 [00:00<?, ?B/s]

In [29]:
classifier(sample["audio"])

[{'score': 0.9999330043792725, 'label': 'Afrikaans'},
 {'score': 7.093016847647959e-06, 'label': 'Northern-Sotho'},
 {'score': 4.269149485480739e-06, 'label': 'Icelandic'},
 {'score': 3.266117346356623e-06, 'label': 'Danish'},
 {'score': 3.2580724109720904e-06, 'label': 'Cantonese Chinese'}]

In [35]:
sample["language"]

'Afrikaans'

## Zero-Shot Audio Classification

### Environmental Speech Challenge (ESC) (Dataset)

In [36]:
from datasets import load_dataset

esc50 = load_dataset("ashraq/esc50", split="train", streaming=True)
audio_sample = next(iter(esc50))["audio"]["array"]

Downloading readme:   0%|          | 0.00/345 [00:00<?, ?B/s]



Downloading metadata:   0%|          | 0.00/1.61k [00:00<?, ?B/s]

In [37]:
candidate_labels = ["Sound of a dog", "Sound of vacuum cleaner"]

In [38]:
from transformers import pipeline

classifier = pipeline(
    task="zero-shot-audio-classification", model="laion/clap-htsat-unfused"
)


config.json:   0%|          | 0.00/5.39k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/615M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/384 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/280 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/541 [00:00<?, ?B/s]

In [39]:
classifier(audio_sample, candidate_labels=candidate_labels)

[{'score': 0.9997242093086243, 'label': 'Sound of a dog'},
 {'score': 0.0002758292539510876, 'label': 'Sound of vacuum cleaner'}]

In [42]:
from IPython.display import Audio

Audio(audio_sample, rate=16000)

# Fine-tuning a model for music classification

## The Dataset

In [1]:
from datasets import load_dataset

gtzan = load_dataset("marsyas/gtzan", "all")
gtzan

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


DatasetDict({
    train: Dataset({
        features: ['file', 'audio', 'genre'],
        num_rows: 999
    })
})

In [2]:
gtzan = gtzan["train"].train_test_split(seed=42, shuffle=True, test_size=0.1)
gtzan

DatasetDict({
    train: Dataset({
        features: ['file', 'audio', 'genre'],
        num_rows: 899
    })
    test: Dataset({
        features: ['file', 'audio', 'genre'],
        num_rows: 100
    })
})

In [3]:
gtzan["train"][0]

{'file': '/data/huggingface/datasets/downloads/extracted/1897c15c7fe11286b37ffd532433625629e79c7d549e8bf7bcc1514aa46dc237/genres/pop/pop.00098.wav',
 'audio': {'path': '/data/huggingface/datasets/downloads/extracted/1897c15c7fe11286b37ffd532433625629e79c7d549e8bf7bcc1514aa46dc237/genres/pop/pop.00098.wav',
  'array': array([ 0.10720825,  0.16122437,  0.28585815, ..., -0.22924805,
         -0.20629883, -0.11334229]),
  'sampling_rate': 22050},
 'genre': 7}

In [4]:
id2label_fn = gtzan["train"].features["genre"].int2str
id2label_fn(gtzan["train"][0]["genre"])

'pop'

In [5]:
gtzan["train"][0]["genre"]

7

In [None]:
import gradio as gr

def generate_audio():
    example = gtzan["train"].shuffle()[0]
    audio = example["audio"]
    return (
        audio["sampling_rate"],
        audio["array"]
    ), id2label_fn(example["genre"])

with gr.Blocks() as demo:
    with gr.Column():
        for _ in range(4):
            audio, label = generate_audio()
            output = gr.Audio(audio, label=label)

demo.launch(debug=True)

## Picking a pretrained model for audio classification

I'll use DistilHuBERT.

## Preprocessing the data

In [7]:
from transformers import AutoFeatureExtractor

model_id = "ntu-spml/distilhubert"
feature_extractor = AutoFeatureExtractor.from_pretrained(
    model_id, do_normalize=True, return_attention_mask=True
)



In [8]:
sampling_rate = feature_extractor.sampling_rate
sampling_rate

16000

In [9]:
from datasets import Audio

gtzan = gtzan.cast_column("audio", Audio(sampling_rate=sampling_rate))

In [10]:
gtzan["train"][0]

{'file': '/data/huggingface/datasets/downloads/extracted/1897c15c7fe11286b37ffd532433625629e79c7d549e8bf7bcc1514aa46dc237/genres/pop/pop.00098.wav',
 'audio': {'path': '/data/huggingface/datasets/downloads/extracted/1897c15c7fe11286b37ffd532433625629e79c7d549e8bf7bcc1514aa46dc237/genres/pop/pop.00098.wav',
  'array': array([ 0.0873509 ,  0.20183384,  0.4790867 , ..., -0.18743178,
         -0.23294401, -0.13517427]),
  'sampling_rate': 16000},
 'genre': 7}

In [11]:
import numpy as np

sample = gtzan["train"][0]["audio"]

print(f"Mean: {np.mean(sample['array']):.3}, Variance: {np.var(sample['array']):.3}")

Mean: 0.000185, Variance: 0.0493


In [12]:
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])

print(f"inputs keys: {list(inputs.keys())}")

print(
    f"Mean: {np.mean(inputs['input_values']):.3}, Variance: {np.var(inputs['input_values']):.3}"
)

inputs keys: ['input_values', 'attention_mask']
Mean: -7.45e-09, Variance: 1.0


In [13]:
max_duration = 30.0

def preprocess_function(examples):
    audio_arrays = [x["array"] for x in examples["audio"]]
    inputs = feature_extractor(
        audio_arrays,
        sampling_rate=feature_extractor.sampling_rate,
        max_length=int(feature_extractor.sampling_rate * max_duration),
        truncation=True,
        return_attention_mask=True,
    )
    return inputs

In [14]:
gtzan_encoded = gtzan.map(
    preprocess_function,
    remove_columns=["audio", "file"],
    batched=True,
    batch_size=100,
    num_proc=8,
)
gtzan_encoded

DatasetDict({
    train: Dataset({
        features: ['genre', 'input_values', 'attention_mask'],
        num_rows: 899
    })
    test: Dataset({
        features: ['genre', 'input_values', 'attention_mask'],
        num_rows: 100
    })
})

In [15]:
gtzan_encoded= gtzan_encoded.rename_column("genre", "label")

In [16]:
id2label = {
    str(i): id2label_fn(i)
    for i in range(len(gtzan_encoded["train"].features["label"].names))
}
label2id = {v: k for k, v in id2label.items()}

id2label["7"]

'pop'

## Fine-tuning the model

In [35]:
from transformers import AutoModelForAudioClassification

num_labels = len(id2label)

model = AutoModelForAudioClassification.from_pretrained(
    model_id,
    num_labels=num_labels,
    label2id=label2id,
    id2label=id2label,
)

Some weights of HubertForSequenceClassification were not initialized from the model checkpoint at ntu-spml/distilhubert and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [45]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [36]:
from transformers import TrainingArguments

model_name = model_id.split("/")[-1]
batch_size = 4
gradient_accumulation_steps = 1
num_train_epochs = 10

training_args = TrainingArguments(
    f"{model_name}-finetuned-gtzan",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_train_epochs,
    warmup_ratio=0.1,
    logging_steps=5,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    fp16=True,
    push_to_hub=True,
    report_to="none"
)


In [37]:
import evaluate
import numpy as np

metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

In [38]:
from transformers import Trainer

trainer = Trainer(
    model,
    training_args,
    train_dataset=gtzan_encoded["train"],
    eval_dataset=gtzan_encoded["test"],
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
)

trainer.train()

  0%|          | 0/2250 [00:00<?, ?it/s]

{'loss': 2.2945, 'grad_norm': 2.555192232131958, 'learning_rate': 1.1111111111111112e-06, 'epoch': 0.02}
{'loss': 2.2987, 'grad_norm': 1.6317881345748901, 'learning_rate': 2.2222222222222225e-06, 'epoch': 0.04}
{'loss': 2.3089, 'grad_norm': 1.9010518789291382, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.07}
{'loss': 2.2986, 'grad_norm': 2.6657118797302246, 'learning_rate': 4.444444444444445e-06, 'epoch': 0.09}
{'loss': 2.2924, 'grad_norm': 2.6440532207489014, 'learning_rate': 5.555555555555556e-06, 'epoch': 0.11}
{'loss': 2.2818, 'grad_norm': 1.5505942106246948, 'learning_rate': 6.666666666666667e-06, 'epoch': 0.13}
{'loss': 2.3066, 'grad_norm': 1.6080923080444336, 'learning_rate': 7.777777777777777e-06, 'epoch': 0.16}
{'loss': 2.3129, 'grad_norm': 2.1232919692993164, 'learning_rate': 8.88888888888889e-06, 'epoch': 0.18}
{'loss': 2.2744, 'grad_norm': 1.5004475116729736, 'learning_rate': 1e-05, 'epoch': 0.2}
{'loss': 2.2717, 'grad_norm': 1.6821411848068237, 'learning_rate': 1.11

  0%|          | 0/25 [00:00<?, ?it/s]

{'eval_loss': 1.6439989805221558, 'eval_accuracy': 0.45, 'eval_runtime': 42.4564, 'eval_samples_per_second': 2.355, 'eval_steps_per_second': 0.589, 'epoch': 1.0}
{'loss': 1.5598, 'grad_norm': 5.525517463684082, 'learning_rate': 4.995061728395062e-05, 'epoch': 1.02}
{'loss': 1.6146, 'grad_norm': 5.517798900604248, 'learning_rate': 4.982716049382716e-05, 'epoch': 1.04}
{'loss': 1.6052, 'grad_norm': 5.101614952087402, 'learning_rate': 4.970370370370371e-05, 'epoch': 1.07}
{'loss': 1.698, 'grad_norm': 5.980613708496094, 'learning_rate': 4.958024691358025e-05, 'epoch': 1.09}
{'loss': 1.5045, 'grad_norm': 5.465908050537109, 'learning_rate': 4.945679012345679e-05, 'epoch': 1.11}
{'loss': 1.5075, 'grad_norm': 6.653104305267334, 'learning_rate': 4.933333333333334e-05, 'epoch': 1.13}
{'loss': 1.4782, 'grad_norm': 10.446157455444336, 'learning_rate': 4.920987654320988e-05, 'epoch': 1.16}
{'loss': 1.4852, 'grad_norm': 11.197892189025879, 'learning_rate': 4.908641975308642e-05, 'epoch': 1.18}
{'los

  0%|          | 0/25 [00:00<?, ?it/s]

{'eval_loss': 1.170931339263916, 'eval_accuracy': 0.6, 'eval_runtime': 40.5399, 'eval_samples_per_second': 2.467, 'eval_steps_per_second': 0.617, 'epoch': 2.0}
{'loss': 1.253, 'grad_norm': 6.6608052253723145, 'learning_rate': 4.4395061728395064e-05, 'epoch': 2.02}
{'loss': 1.0851, 'grad_norm': 12.8822021484375, 'learning_rate': 4.4271604938271605e-05, 'epoch': 2.04}
{'loss': 0.8183, 'grad_norm': 16.306474685668945, 'learning_rate': 4.414814814814815e-05, 'epoch': 2.07}
{'loss': 1.0194, 'grad_norm': 15.784733772277832, 'learning_rate': 4.4024691358024693e-05, 'epoch': 2.09}
{'loss': 1.1411, 'grad_norm': 21.09313201904297, 'learning_rate': 4.390123456790124e-05, 'epoch': 2.11}
{'loss': 1.2831, 'grad_norm': 8.765751838684082, 'learning_rate': 4.377777777777778e-05, 'epoch': 2.13}
{'loss': 1.0603, 'grad_norm': 19.15607452392578, 'learning_rate': 4.365432098765432e-05, 'epoch': 2.16}
{'loss': 0.9585, 'grad_norm': 7.3682780265808105, 'learning_rate': 4.353086419753087e-05, 'epoch': 2.18}
{'l

  0%|          | 0/25 [00:00<?, ?it/s]

{'eval_loss': 0.7768688797950745, 'eval_accuracy': 0.77, 'eval_runtime': 40.6094, 'eval_samples_per_second': 2.462, 'eval_steps_per_second': 0.616, 'epoch': 3.0}
{'loss': 0.6098, 'grad_norm': 7.360509872436523, 'learning_rate': 3.8864197530864196e-05, 'epoch': 3.02}
{'loss': 0.6901, 'grad_norm': 1.9910117387771606, 'learning_rate': 3.8740740740740744e-05, 'epoch': 3.04}
{'loss': 0.5222, 'grad_norm': 10.589713096618652, 'learning_rate': 3.8617283950617285e-05, 'epoch': 3.07}
{'loss': 0.442, 'grad_norm': 1.3917627334594727, 'learning_rate': 3.8493827160493825e-05, 'epoch': 3.09}
{'loss': 0.9647, 'grad_norm': 18.35091781616211, 'learning_rate': 3.837037037037037e-05, 'epoch': 3.11}
{'loss': 0.4523, 'grad_norm': 6.137357234954834, 'learning_rate': 3.824691358024692e-05, 'epoch': 3.13}
{'loss': 0.7109, 'grad_norm': 13.303772926330566, 'learning_rate': 3.8123456790123455e-05, 'epoch': 3.16}
{'loss': 0.7764, 'grad_norm': 10.645620346069336, 'learning_rate': 3.8e-05, 'epoch': 3.18}
{'loss': 0.

  0%|          | 0/25 [00:00<?, ?it/s]

{'eval_loss': 0.5280104279518127, 'eval_accuracy': 0.84, 'eval_runtime': 40.6152, 'eval_samples_per_second': 2.462, 'eval_steps_per_second': 0.616, 'epoch': 4.0}
{'loss': 0.6316, 'grad_norm': 20.527040481567383, 'learning_rate': 3.3333333333333335e-05, 'epoch': 4.02}
{'loss': 0.441, 'grad_norm': 1.8936585187911987, 'learning_rate': 3.3209876543209876e-05, 'epoch': 4.04}
{'loss': 0.2337, 'grad_norm': 1.9657421112060547, 'learning_rate': 3.308641975308642e-05, 'epoch': 4.07}
{'loss': 0.3726, 'grad_norm': 6.5667219161987305, 'learning_rate': 3.2962962962962964e-05, 'epoch': 4.09}
{'loss': 0.2074, 'grad_norm': 0.9367420673370361, 'learning_rate': 3.2839506172839505e-05, 'epoch': 4.11}
{'loss': 0.5724, 'grad_norm': 7.120716571807861, 'learning_rate': 3.271604938271605e-05, 'epoch': 4.13}
{'loss': 0.359, 'grad_norm': 7.269446849822998, 'learning_rate': 3.25925925925926e-05, 'epoch': 4.16}
{'loss': 0.4966, 'grad_norm': 83.38056945800781, 'learning_rate': 3.2469135802469134e-05, 'epoch': 4.18}

  0%|          | 0/25 [00:00<?, ?it/s]

{'eval_loss': 0.6280108094215393, 'eval_accuracy': 0.84, 'eval_runtime': 40.6323, 'eval_samples_per_second': 2.461, 'eval_steps_per_second': 0.615, 'epoch': 5.0}
{'loss': 0.2152, 'grad_norm': 11.81860637664795, 'learning_rate': 2.777777777777778e-05, 'epoch': 5.02}
{'loss': 0.2551, 'grad_norm': 11.776322364807129, 'learning_rate': 2.765432098765432e-05, 'epoch': 5.04}
{'loss': 0.2342, 'grad_norm': 8.60040283203125, 'learning_rate': 2.7530864197530864e-05, 'epoch': 5.07}
{'loss': 0.3058, 'grad_norm': 7.859400749206543, 'learning_rate': 2.7407407407407408e-05, 'epoch': 5.09}
{'loss': 0.602, 'grad_norm': 29.731151580810547, 'learning_rate': 2.7283950617283956e-05, 'epoch': 5.11}
{'loss': 0.294, 'grad_norm': 4.471146106719971, 'learning_rate': 2.7160493827160493e-05, 'epoch': 5.13}
{'loss': 0.5057, 'grad_norm': 22.64497184753418, 'learning_rate': 2.7037037037037037e-05, 'epoch': 5.16}
{'loss': 0.171, 'grad_norm': 2.8287971019744873, 'learning_rate': 2.6913580246913585e-05, 'epoch': 5.18}
{

  0%|          | 0/25 [00:00<?, ?it/s]

{'eval_loss': 0.6823034882545471, 'eval_accuracy': 0.82, 'eval_runtime': 40.6789, 'eval_samples_per_second': 2.458, 'eval_steps_per_second': 0.615, 'epoch': 6.0}
{'loss': 0.508, 'grad_norm': 0.7080361247062683, 'learning_rate': 2.2222222222222223e-05, 'epoch': 6.02}
{'loss': 0.0906, 'grad_norm': 7.651828765869141, 'learning_rate': 2.2098765432098767e-05, 'epoch': 6.04}
{'loss': 0.1186, 'grad_norm': 1.851569414138794, 'learning_rate': 2.1975308641975308e-05, 'epoch': 6.07}
{'loss': 0.1509, 'grad_norm': 33.96539306640625, 'learning_rate': 2.1851851851851852e-05, 'epoch': 6.09}
{'loss': 0.184, 'grad_norm': 0.3653853237628937, 'learning_rate': 2.1728395061728397e-05, 'epoch': 6.11}
{'loss': 0.1266, 'grad_norm': 28.073711395263672, 'learning_rate': 2.1604938271604937e-05, 'epoch': 6.13}
{'loss': 0.15, 'grad_norm': 18.80849838256836, 'learning_rate': 2.148148148148148e-05, 'epoch': 6.16}
{'loss': 0.0551, 'grad_norm': 0.35124313831329346, 'learning_rate': 2.1358024691358026e-05, 'epoch': 6.18

  0%|          | 0/25 [00:00<?, ?it/s]

{'eval_loss': 0.6526587009429932, 'eval_accuracy': 0.85, 'eval_runtime': 40.8574, 'eval_samples_per_second': 2.448, 'eval_steps_per_second': 0.612, 'epoch': 7.0}
{'loss': 0.1131, 'grad_norm': 29.281578063964844, 'learning_rate': 1.669135802469136e-05, 'epoch': 7.02}
{'loss': 0.019, 'grad_norm': 0.5338394045829773, 'learning_rate': 1.65679012345679e-05, 'epoch': 7.04}
{'loss': 0.0348, 'grad_norm': 0.613318145275116, 'learning_rate': 1.6444444444444447e-05, 'epoch': 7.07}
{'loss': 0.124, 'grad_norm': 0.5588547587394714, 'learning_rate': 1.6320987654320988e-05, 'epoch': 7.09}
{'loss': 0.1829, 'grad_norm': 33.121238708496094, 'learning_rate': 1.6197530864197532e-05, 'epoch': 7.11}
{'loss': 0.131, 'grad_norm': 0.5189381241798401, 'learning_rate': 1.6074074074074076e-05, 'epoch': 7.13}
{'loss': 0.0654, 'grad_norm': 0.3069867491722107, 'learning_rate': 1.5950617283950617e-05, 'epoch': 7.16}
{'loss': 0.0249, 'grad_norm': 0.9599993228912354, 'learning_rate': 1.582716049382716e-05, 'epoch': 7.18

  0%|          | 0/25 [00:00<?, ?it/s]

{'eval_loss': 0.5111324191093445, 'eval_accuracy': 0.86, 'eval_runtime': 40.7519, 'eval_samples_per_second': 2.454, 'eval_steps_per_second': 0.613, 'epoch': 8.0}
{'loss': 0.0335, 'grad_norm': 0.4945843517780304, 'learning_rate': 1.1135802469135803e-05, 'epoch': 8.02}
{'loss': 0.0139, 'grad_norm': 0.1748279333114624, 'learning_rate': 1.1012345679012347e-05, 'epoch': 8.04}
{'loss': 0.0319, 'grad_norm': 17.4390869140625, 'learning_rate': 1.088888888888889e-05, 'epoch': 8.07}
{'loss': 0.0161, 'grad_norm': 1.2346147298812866, 'learning_rate': 1.0765432098765432e-05, 'epoch': 8.09}
{'loss': 0.0233, 'grad_norm': 0.7376428246498108, 'learning_rate': 1.0641975308641976e-05, 'epoch': 8.11}
{'loss': 0.0112, 'grad_norm': 0.2320098578929901, 'learning_rate': 1.0518518518518519e-05, 'epoch': 8.13}
{'loss': 0.0337, 'grad_norm': 0.13235072791576385, 'learning_rate': 1.0395061728395063e-05, 'epoch': 8.16}
{'loss': 0.0157, 'grad_norm': 3.1309571266174316, 'learning_rate': 1.0271604938271605e-05, 'epoch'

  0%|          | 0/25 [00:00<?, ?it/s]

{'eval_loss': 0.5714808702468872, 'eval_accuracy': 0.86, 'eval_runtime': 40.8118, 'eval_samples_per_second': 2.45, 'eval_steps_per_second': 0.613, 'epoch': 9.0}
{'loss': 0.1933, 'grad_norm': 0.36465415358543396, 'learning_rate': 5.580246913580247e-06, 'epoch': 9.02}
{'loss': 0.2074, 'grad_norm': 23.540019989013672, 'learning_rate': 5.45679012345679e-06, 'epoch': 9.04}
{'loss': 0.0091, 'grad_norm': 0.11922164261341095, 'learning_rate': 5.333333333333334e-06, 'epoch': 9.07}
{'loss': 0.0419, 'grad_norm': 0.4064576327800751, 'learning_rate': 5.209876543209877e-06, 'epoch': 9.09}
{'loss': 0.0091, 'grad_norm': 0.8454622626304626, 'learning_rate': 5.08641975308642e-06, 'epoch': 9.11}
{'loss': 0.1894, 'grad_norm': 0.27551811933517456, 'learning_rate': 4.962962962962963e-06, 'epoch': 9.13}
{'loss': 0.0144, 'grad_norm': 0.12078557163476944, 'learning_rate': 4.839506172839506e-06, 'epoch': 9.16}
{'loss': 0.0105, 'grad_norm': 0.09724433720111847, 'learning_rate': 4.7160493827160495e-06, 'epoch': 9

  0%|          | 0/25 [00:00<?, ?it/s]

{'eval_loss': 0.6375682353973389, 'eval_accuracy': 0.85, 'eval_runtime': 40.6308, 'eval_samples_per_second': 2.461, 'eval_steps_per_second': 0.615, 'epoch': 10.0}
{'train_runtime': 5115.6974, 'train_samples_per_second': 1.757, 'train_steps_per_second': 0.44, 'train_loss': 0.6060285901973644, 'epoch': 10.0}


TrainOutput(global_step=2250, training_loss=0.6060285901973644, metrics={'train_runtime': 5115.6974, 'train_samples_per_second': 1.757, 'train_steps_per_second': 0.44, 'train_loss': 0.6060285901973644, 'epoch': 10.0})

In [57]:
kwargs = {
    # "dataset_tags": "marsyas/gtzan", # I don't know why it makes an error...
    "dataset": "GTZAN",
    "model_name": f"{model_name}-finetuned-gtzan",
    "finetuned_from": model_id,
    "tasks": "audio-classification",
}

In [None]:
trainer.push_to_hub(**kwargs)

## Share Model

In [54]:
from transformers import pipeline

# Load model
pipe = pipeline(
    "audio-classification", model="BanUrsus/distilhubert-finetuned-gtzan"
)

config.json:   0%|          | 0.00/1.85k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/94.8M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/212 [00:00<?, ?B/s]

In [56]:
pipe(sample["array"])

[{'score': 0.9954374432563782, 'label': 'pop'},
 {'score': 0.001072408864274621, 'label': 'classical'},
 {'score': 0.0010694654192775488, 'label': 'country'},
 {'score': 0.0008234226843342185, 'label': 'disco'},
 {'score': 0.0006110203685238957, 'label': 'jazz'}]