In [None]:
#!pip install transformers[torch]


In [None]:
#!pip uninstall datasets -y
#!pip install datasets
# test huffingface, transformers is installed
#from transformers import pipeline
#classifier = pipeline("sentiment-analysis")
#print(classifier("This is amazing!"))


In [None]:
epoch_num=40
batch_size=4
train_dir="./train-val-independent/train/"
validation_dir="./train-val-independent/validation/"
everyNth=1 # we train on every Nth data point, because all of it doesn't fit into memory.
postfix="-AST_multiclass_mosquito_indep_"
b_train_just_last=False
trainNlayers=3 # train last N layers. if 0, then train full network.



In [None]:
import os
from datasets import Dataset, DatasetDict
import torch
import librosa
import numpy as np




In [None]:
def load_audio_dataset_from_folders(train_dir, validation_dir):
    """
    Load data from folders and convert to Dataset format.

    Args:
        train_dir (str): Path to the 'train' folder.
        validation_dir (str): Path to the 'validation' folder.

    Returns:
        DatasetDict: A DatasetDict with 'train' and 'validation' data.
    """
    def get_audio_files_with_labels(directory):
        data = []
        for class_name in os.listdir(directory):  # Classes ('mosquito', 'not')
            class_path = os.path.join(directory, class_name)
            if os.path.isdir(class_path):
                for file_name in os.listdir(class_path):
                    if file_name.endswith(".wav"):  # Only WAV files
                        file_path = os.path.join(class_path, file_name)
                        data.append({"file_path": file_path, "label": class_name})
        return data

    # Load train and validation data
    train_data = get_audio_files_with_labels(train_dir)
    validation_data = get_audio_files_with_labels(validation_dir)

    # Create Dataset
    
    train_dataset = Dataset.from_dict({
        "file_path": [d["file_path"] for idx, d in enumerate(train_data) if idx % everyNth == 0],
        "label": [d["label"] for idx, d in enumerate(train_data) if idx % everyNth == 0]
    })

    validation_dataset = Dataset.from_dict({
        "file_path": [d["file_path"] for idx, d in enumerate(validation_data) if idx % everyNth == 0],
        "label": [d["label"] for idx, d in enumerate(validation_data) if idx % everyNth == 0]
    })

    return DatasetDict({"train": train_dataset, "validation": validation_dataset})


In [None]:
# Load data
dataset = load_audio_dataset_from_folders(train_dir, validation_dir)

print(dataset)


In [None]:
label_list=set()

# Iterate through the subdirectories
for species in os.listdir(train_dir):
    label_list.add(species)
label2id = {key: idx for idx, key in enumerate(label_list)}
id2label = {v: k for k, v in label2id.items()}
cls_num=len(id2label)
print(id2label)


In [None]:
train_dataset = dataset["train"]
validation_dataset = dataset["validation"]
print(validation_dataset)


In [None]:
print(f"Class labels: {id2label}")

In [None]:
# Load model

from transformers import AutoProcessor, AutoModelForAudioClassification
processor = AutoProcessor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

model = AutoModelForAudioClassification.from_pretrained(
    "MIT/ast-finetuned-audioset-10-10-0.4593",
    num_labels=cls_num,  # Number of classes
    id2label=id2label,
    label2id=label2id,
ignore_mismatched_sizes=True)

if b_train_just_last:
    # Freeze all parameters
    for param in model.parameters():
        param.requires_grad = False

    # Only leave the classification layer trainable
    for param in model.classifier.parameters():
        param.requires_grad = True
elif trainNlayers>0:
    # Freeze all layers
    for param in model.parameters():
        param.requires_grad = False

    # Make the last N blocks trainable (e.g., the last 2)
    for layer in model.audio_spectrogram_transformer.encoder.layer[-trainNlayers:]:
        for param in layer.parameters():
            param.requires_grad = True
else:
    for param in model.parameters():
        param.requires_grad = True
    
model

In [None]:
# Check if a GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Move the model and inputs to the GPU
model = model.to(device)
print(device)


In [None]:
def preprocess_audio_function(examples, target_length=16000):
    """
    Audio preprocessing function adapted for batch processing.
    """
    input_values = []
    labels = []

    for audio_fn, label in zip(examples['file_path'], examples['label']):
        # Use librosa to load the file and resample
        audio, sr = librosa.load(audio_fn, sr=16000)  # Resample to 16 kHz

        # Pad short audio files
        if len(audio) < target_length:
            padding = target_length - len(audio)
            audio = np.pad(audio, (0, padding), mode="constant")
        elif len(audio) > target_length:
            audio = audio[:target_length]

        # Apply processor
        inputs = processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
        
        # Store
        input_values.append(inputs["input_values"].squeeze().numpy())
        labels.append(int(label2id[label]))

    return {"input_values": input_values, "label": labels}


In [None]:
encoded_train_dataset = train_dataset.map(preprocess_audio_function,batched=True,batch_size=4,)
encoded_validation_dataset = validation_dataset.map(preprocess_audio_function,batched=True,batch_size=4,)


In [None]:
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, balanced_accuracy_score

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = accuracy_score(labels, predictions)
    precision = precision_score(labels, predictions, average="weighted")
    recall = recall_score(labels, predictions, average="weighted")
    f1 = f1_score(labels, predictions, average="weighted")
    balanced_accuracy = balanced_accuracy_score(labels, predictions)
    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "balanced_accuracy": balanced_accuracy
    }



In [None]:
from transformers import TrainingArguments, Trainer

# Training parameters
training_args = TrainingArguments(
    output_dir="./results" + postfix,
    evaluation_strategy="epoch",  # Evaluate at the end of each epoch
    learning_rate=5e-5,  # Increased initial learning rate (can be reduced to 2e-5 if unstable)
    lr_scheduler_type="cosine",  # Learning rate scheduler (gradual decrease during training)
    warmup_steps=3,  # Number of warmup steps for stability
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,  # Evaluation batch size
    num_train_epochs=epoch_num,
    save_strategy="epoch",  # Save at the end of each epoch
    save_total_limit=2,  # Only keep 2 model saves
    logging_dir='./logs' + postfix,
    logging_steps=10,  # Log less frequently per step
    report_to="all",  # Log to console and file
    #load_best_model_at_end=True,  # Automatically load the best model at the end of training
    metric_for_best_model="balanced_accuracy",  # Metric used to select the best model
    greater_is_better=True,  # Higher metric values are better
    gradient_accumulation_steps=8,  # Accumulated gradient count (larger effective batch size)
    fp16=True,  # Mixed precision training for faster training (if hardware supports it)
    save_steps=10,  # Save steps (if evaluation_strategy is not "epoch")
    dataloader_num_workers=4,  # Faster loading using multiple threads
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_train_dataset,
    eval_dataset=encoded_validation_dataset,
    tokenizer=processor,
    compute_metrics=compute_metrics,  # Compute metrics
)

# Train the model
trainer.train()

# Evaluate the model
metrics = trainer.evaluate()
print(metrics)

# Save the model
model.save_pretrained("./classifier"+postfix)
processor.save_pretrained("./classifier-"+postfix)


metrics = trainer.evaluate()
print(metrics)