In [None]:
#!pip install librosa
#!pip install safetensors

import numpy as np
print(np.__version__) # was 1.23.4
#!pip install "numpy<=1.21"


In [None]:
epoch_num=50
batch_size=1
batch_sizes=[batch_size]*epoch_num

SSL_pretrained_path="./SSL_pretrained/AST-SSL-results_SSL_v2_db2_indepval/checkpoint-33297/"

train_dir = "../dataset/train-val-random/train/"
validation_dir = "../dataset/train-val-random/validation/"

postfix="_AST_from_SSL_25_02_12-3layers"
b_train_just_last=False
trainNlayers=3 # train last N layers. if 0, then train full network.

b_just_test=False
#batch_sizes
random_seed=42

output_path="ensemble/"


In [None]:
if b_just_test:
    epoch_num=1
    batch_sizes=[batch_size]
    

In [None]:
import os
import shutil
from datasets import Dataset, DatasetDict
import torch
import librosa
import numpy as np

#from sklearn.model_selection import train_test_split
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
from transformers import AutoProcessor, ASTModel
import torch
import torch.nn as nn
#from torch.utils.data import Dataset
import torch.nn.functional as F
from safetensors.torch import load_file
import random


In [None]:
os.makedirs(output_path, exist_ok=True)

    

In [None]:
# Set random seed
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)  # If using GPU
np.random.seed(random_seed)
random.seed(random_seed)

# For deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [None]:
def load_audio_dataset_from_folders(train_dir, validation_dir, b_just_test=False):
    """
    Load data from folders and convert to Dataset format.

    Args:
        train_dir (str): Path to the 'train' folder.
        validation_dir (str): Path to the 'validation' folder.

    Returns:
        DatasetDict: A DatasetDict with 'train' and 'validation' data.
    """
    
    if b_just_test:
        everyNth=10
    else:
        everyNth=1
        
    def get_audio_files_with_labels(directory):
        data = []
        for class_name in os.listdir(directory):  # Classes ()
            class_path = os.path.join(directory, class_name)
            if os.path.isdir(class_path):
                for file_name in os.listdir(class_path):
                    if file_name.endswith(".wav"):  # Only WAV files
                        file_path = os.path.join(class_path, file_name)
                        data.append({"file_path": file_path, "label": class_name})
        return data

    # Load train and validation data
    train_data = get_audio_files_with_labels(train_dir)
    validation_data = get_audio_files_with_labels(validation_dir)

    # Create Dataset
    
    train_dataset = Dataset.from_dict({
        "file_path": [d["file_path"] for idx, d in enumerate(train_data) if idx % everyNth == 0],
        "label": [d["label"] for idx, d in enumerate(train_data) if idx % everyNth == 0]
    })

    validation_dataset = Dataset.from_dict({
        "file_path": [d["file_path"] for idx, d in enumerate(validation_data) if idx % everyNth == 0],
        "label": [d["label"] for idx, d in enumerate(validation_data) if idx % everyNth == 0]
    })

    return DatasetDict({"train": train_dataset, "validation": validation_dataset})



In [None]:
# Load data
dataset = load_audio_dataset_from_folders(train_dir, validation_dir, b_just_test=b_just_test)

print(dataset)


In [None]:
label_list = sorted(os.listdir(train_dir))

# Clear and deterministic mapping
label2id = {key: idx for idx, key in enumerate(label_list)}
id2label = {v: k for k, v in label2id.items()}

cls_num = len(id2label)
print(id2label)


In [None]:
train_dataset = dataset["train"]
validation_dataset = dataset["validation"]
print(validation_dataset)

In [None]:
print(train_dataset['file_path'][:10])  # Checks the paths of the first 10 files


In [None]:
print(f"Class labels: {id2label}")


In [None]:


# Load the base AST model
processor = AutoProcessor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
base_model = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

# Convert for SSL
class AST_SSL(nn.Module):
    def __init__(self, base_model, output_dim):
        super(AST_SSL, self).__init__()
        self.encoder = base_model
        self.encoder_output_dim = base_model.config.hidden_size

        # Input and output dimension check
        self.output_dim = output_dim

        # Projector - Linear transformation on the AST hidden representation
        self.projector = nn.Linear(self.encoder_output_dim, output_dim[-1])

        # Convolutional decoder for temporal reconstruction
        self.decoder = nn.Sequential(
            nn.Conv1d(output_dim[-1], output_dim[-1], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(output_dim[-1], output_dim[-1], kernel_size=3, padding=1),
        )

    def forward(self, input_values, labels=None):
        outputs = self.encoder(input_values=input_values).last_hidden_state
        projected = self.projector(outputs)

        # Reshape for temporal decoder
        projected = projected.permute(0, 2, 1)  # (B, T, C) -> (B, C, T)
        reconstructed = self.decoder(projected).permute(0, 2, 1)  # Back to (B, T, C)

        if labels is not None:
            # **Interpolation to the shape of labels**
            reconstructed = F.interpolate(reconstructed.permute(0, 2, 1), 
                                          size=labels.shape[1], mode="linear", align_corners=True)
            reconstructed = reconstructed.permute(0, 2, 1)
            
            loss_fn = nn.MSELoss()
            loss = loss_fn(reconstructed, labels)
            return loss, reconstructed

        return reconstructed


In [None]:

ssl_model = AST_SSL(base_model, output_dim=[1024, 128])  # base_model is the base AST model
#ssl_model.load_state_dict(torch.load(f"{SSL_pretrained_path}/model.pth"))

# Load weights
state_dict = load_file(f"{SSL_pretrained_path}/model.safetensors")
ssl_model.load_state_dict(state_dict)
print("Model loaded from SSL successfully.")

print("Model loaded from SSL successfully.")

ssl_model


In [None]:

# ❌ Remove unnecessary layers
del ssl_model.projector  # Remove projector
del ssl_model.decoder  # Remove decoder

# **✔ Add new classification layer**
import torch.nn.functional as F

class AST_Classifier(nn.Module):
    def __init__(self, ssl_model, num_classes):
        super().__init__()
        self.encoder = ssl_model.encoder  # We keep the AST encoder part
        #self.layernorm = ssl_model.encoder.layernorm  # LayerNorm remains
        self.classifier = nn.Linear(768, num_classes)  # New classification layer

    def forward(self, input_values, labels=None):
        # The encoder's output is a ModelOutput, from which we need to select last_hidden_state.
        outputs = self.encoder(input_values)
        x = outputs.last_hidden_state  # This is already a Tensor
        x = self.encoder.layernorm(x)
        x = x.mean(dim=1)  # Global pooling: temporal averaging (B, D)
        logits = self.classifier(x)  # Classification logits, shape: (B, num_classes)
        
        if labels is not None:
            # If labels are also provided, calculate cross-entropy loss
            loss = torch.nn.functional.cross_entropy(logits, labels)
            return {"loss": loss, "logits": logits}
        
        return {"logits": logits}

classifier_model = AST_Classifier(ssl_model, num_classes=cls_num)

print("New classifier model created:", classifier_model)



In [None]:
model=classifier_model

if b_train_just_last:
    # Freeze all parameters
    for param in model.parameters():
        param.requires_grad = False

    # Only leave the classification layer trainable
    for param in model.classifier.parameters():
        param.requires_grad = True
        
elif trainNlayers>0:
    # Freeze all layers
    for param in model.parameters():
        param.requires_grad = False

    # Make the last N blocks trainable (e.g., the last 2)
    for layer in model.encoder.encoder.layer[-trainNlayers:]:
        for param in layer.parameters():
            param.requires_grad = True
else:
    for param in model.parameters():
        param.requires_grad = True




In [None]:
# Check if a GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Move the model and inputs to the GPU
model = model.to(device)
print(device)


In [None]:
from transformers import BatchFeature
def preprocess_audio_function(examples, target_length=16000):
    """
    Audio preprocessing function adapted for batch processing.
    """
    input_values = []
    labels = []

    # If batched=False, examples is a single element
    if isinstance(examples["file_path"], str):
        examples = {key: [value] for key, value in examples.items()}  # Convert to list

    for audio_fn, label in zip(examples['file_path'], examples['label']):
        if not os.path.isfile(audio_fn):
            print(f"Error: '{audio_fn}' is not a file!")
            continue  # Skips those that are not real files

        # Use librosa to load and resample the file
        audio, sr = librosa.load(audio_fn, sr=16000)  # Resample to 16 kHz

        # Pad short audios
        if len(audio) < target_length:
            padding = target_length - len(audio)
            audio = np.pad(audio, (0, padding), mode="constant")
        elif len(audio) > target_length:
            audio = audio[:target_length]

        # Apply processor
        inputs = processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)

        # Convert to PyTorch Tensor (expected by Trainer)
        input_values.append(inputs["input_values"].squeeze().numpy())
        labels.append(torch.tensor(label2id[label], dtype=torch.long))  # PyTorch long tensor

    # Convert to BatchFeature format
    #return BatchFeature(data={"input_values": torch.tensor(input_values), "labels": torch.tensor(labels)})
    return {"input_values": torch.tensor(input_values), "labels": torch.tensor(labels)}



In [None]:
encoded_train_dataset = train_dataset.map(preprocess_audio_function,batched=True,remove_columns=train_dataset.column_names)
encoded_validation_dataset = validation_dataset.map(preprocess_audio_function,batched=True,remove_columns=train_dataset.column_names)



In [None]:
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, balanced_accuracy_score

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = accuracy_score(labels, predictions)
    precision = precision_score(labels, predictions, average="weighted")
    recall = recall_score(labels, predictions, average="weighted")
    f1 = f1_score(labels, predictions, average="weighted")
    balanced_accuracy = balanced_accuracy_score(labels, predictions)
    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "balanced_accuracy": balanced_accuracy
    }



In [None]:
from transformers import TrainingArguments, Trainer
from transformers import TrainerCallback

LOG_FILE = f"train_log{postfix}.txt"

# If the file already exists, delete it to be clean for a new training run
if os.path.exists(LOG_FILE):
    os.remove(LOG_FILE)

   
class DynamicBatchSizeCallback(TrainerCallback):
    def on_epoch_begin(self, args, state, control, **kwargs):
        new_batch_size = batch_sizes[int(state.epoch)]  # Convert to int!

        log_text = f"Epoch {int(state.epoch) + 1}/{epoch_num} - New batch size: {new_batch_size}\n"

        # Print to console
        print(f"\n[INFO] {log_text}")

        # Log to file
        with open(LOG_FILE, "a") as log_file:
            log_file.write(log_text)

        # Update batch size
        args.per_device_train_batch_size = new_batch_size
        args.per_device_eval_batch_size = new_batch_size
    def on_log(self, args, state, control, logs=None, **kwargs):
        """Log loss and other metrics to a file."""
        if logs is not None:
            with open(LOG_FILE, "a") as log_file:
                log_file.write(f"Epoch {state.epoch + 1} logs: {logs}\n")

                

In [None]:

# Training parameters
training_args = TrainingArguments(
    output_dir=output_path+"./results" + postfix,
    evaluation_strategy="epoch",  # Evaluate at the end of each epoch
    #eval_strategy="epoch",  # This should be used in new versions
    learning_rate=5e-5,  # Increased initial learning rate (can be reduced to 2e-5 if unstable)
    lr_scheduler_type="cosine",  # Learning rate scheduler (gradual decrease during training)
    warmup_steps=3,  # Number of warmup steps for stability
    per_device_train_batch_size=batch_sizes[0],  # Initial batch size
    per_device_eval_batch_size=batch_sizes[0],  # Initial batch size
    num_train_epochs=epoch_num,
    save_strategy="epoch",  # Save at the end of each epoch
    save_total_limit=2,  # Only keep 2 model saves
    logging_dir=output_path+'./logs' + postfix,
    logging_steps=10,  # Log less frequently per step
    #report_to="all",  # Log to console and file
    report_to="none",  # Do not send to TensorBoard
    logging_first_step=True,  # Also log the first step
    logging_strategy="epoch",  # Log after every epoch
    load_best_model_at_end=True,  # Automatically load the best model at the end of training
    metric_for_best_model="accuracy",  # Metric used to select the best model
    greater_is_better=True,  # Higher metric values are better
    gradient_accumulation_steps=4,  # Accumulated gradient count (larger effective batch size)
    fp16=True,  # Mixed precision training for faster training (if hardware supports it)
    save_steps=10,  # Save steps (if evaluation_strategy is not "epoch")
    dataloader_num_workers=4,  # Faster loading using multiple threads
    seed=random_seed,
)

# Add callback to the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_train_dataset,
    eval_dataset=encoded_validation_dataset,
    tokenizer=processor,
    #processing_class=processor, # in the new version
    #data_collator=data_collator,  # The new collate function
    compute_metrics=compute_metrics,
    callbacks=[DynamicBatchSizeCallback()]  # For handling dynamic batch size
)

#trainer.add_callback(lambda args, state, control, logs: log_training_metrics(logs))

# Train model
trainer.train()




In [None]:
# Save the best model under the name "best"
best_model_path = os.path.join(output_path, "best", postfix)

# If the best folder already exists, delete it
if os.path.exists(best_model_path):
    shutil.rmtree(best_model_path)

# Save the trained trainer model to the "best" folder
trainer.save_model(best_model_path)

print(f"✅ Best model saved to: {best_model_path}")
