In [None]:
from datasets import load_dataset, Audio
from transformers import ASTForAudioClassification, AutoFeatureExtractor, Trainer, TrainingArguments # noqa: F401
import matplotlib.pyplot as plt # noqa: F401

In [None]:
ds = load_dataset("gilkeyio/AudioMNIST")
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
sampling_rate = ds['train'].features["audio"].sampling_rate

In [None]:
df = ds['train'].to_pandas()
df.head()

In [None]:
feature_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

def preprocessing(input):
    output = feature_extractor(input["audio"]["array"], sampling_rate = sampling_rate)
    return {
        "input_values": output["input_values"][0],
        "labels": input["digit"] ,
    }

# test_result = preprocessing(ds['train'][0])
# print("Keys:", test_result.keys())
# print("Input type:", type(test_result['input_values']))
# print("Input shape:", test_result['input_values'].shape)
# print("Label:", test_result['labels'])

In [None]:
ds = ds.map(preprocessing)

In [None]:
training_args = TrainingArguments(
    do_train=True,
    output_dir="./model",
    learning_rate=2e-5,
    eval_strategy="epoch",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
    lr_scheduler_type="linear"
)

In [None]:
trainer = Trainer(
    model = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593"),
    args = training_args,
    train_dataset = ds["train"],
    eval_dataset = ds["test"],
    processing_class="tokenizer"
)

In [None]:
trainer.train()