In [None]:

# Basic parameters
data_dir = "./dataset_forSSL_indep/"  # Folder containing wav files: the entire ABUZZ audio plus the independent training part
learning_rate = 1e-4  # Smaller learning rate for more stable SSL
batch_size = 8  # Increased batch size (original was 4)
epoch_num = 150
str_id = "_SSL_v3_indep"
resume_from_checkpoint = None  # None to start from scratch; or provide path to resume from a checkpoint




In [None]:

import os
import random
from sklearn.model_selection import train_test_split
from transformers import TrainingArguments, Trainer
from transformers import EarlyStoppingCallback
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from transformers import AutoProcessor, ASTModel
import librosa
import torch.nn.functional as F
import numpy as np

from transformers.trainer_utils import set_seed
set_seed(42)


In [None]:
# Load the base AST model
processor = AutoProcessor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
base_model = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

# Adaptation for SSL
class AST_SSL(nn.Module):
    def __init__(self, base_model, output_dim):
        super(AST_SSL, self).__init__()
        self.encoder = base_model
        self.encoder_output_dim = base_model.config.hidden_size

        # Store output dimensions
        self.output_dim = output_dim

        # Projector: linear transformation of AST hidden representation
        self.projector = nn.Linear(self.encoder_output_dim, output_dim[-1])

        # Convolutional decoder for temporal reconstruction
        self.decoder = nn.Sequential(
            nn.Conv1d(output_dim[-1], output_dim[-1], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(output_dim[-1], output_dim[-1], kernel_size=3, padding=1),
        )

    def forward(self, input_values, labels=None):
        outputs = self.encoder(input_values=input_values).last_hidden_state  # (B, T_enc, H)
        projected = self.projector(outputs)                                   # (B, T_enc, C)
        x = projected.permute(0, 2, 1)                                        # (B, C, T_enc)
        reconstructed = self.decoder(x).permute(0, 2, 1)                      # (B, T_dec, C)

        if labels is not None:
            # Align time axis: interpolate reconstructed sequence to match labels time dimension
            rec = F.interpolate(
                reconstructed.permute(0, 2, 1),
                size=labels.shape[1],
                mode="linear",
                align_corners=True
            ).permute(0, 2, 1)
            loss = F.mse_loss(rec, labels)
            return {"loss": loss, "logits": rec}
        return {"logits": reconstructed}
    

In [None]:
import torchaudio.transforms as T

class AudioDataset(Dataset):
    def __init__(self, file_paths, processor, is_train=True):
        self.file_paths = file_paths
        self.processor = processor
        self.is_train = is_train
        self.time_masking = T.TimeMasking(time_mask_param=80)
        self.freq_masking = T.FrequencyMasking(freq_mask_param=30)

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        audio, _ = librosa.load(file_path, sr=16000)

        # Adjust audio length
        target_length = 16000  # 1 second
        if len(audio) < target_length:
            padding = target_length - len(audio)
            audio = np.pad(audio, (0, padding), mode="constant")
        else:
            audio = audio[:target_length]

        # Processor preprocessing
        inputs = self.processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)

        clean_input = inputs["input_values"].squeeze()  # Save the original
        inputs["output_values"] = clean_input.clone()   # The target is to restore the clean spectrum

        if self.is_train:
            inputs["input_values"] = self.augment_spectrogram(clean_input)
        else:
            inputs["input_values"] = clean_input
    
        return inputs

    def augment_spectrogram(self, input_values: torch.Tensor):
        # input_values: (T, M) ~ (time, mel)
        spec = input_values.transpose(0, 1).contiguous()   # (M, T) -> (freq, time)
        spec = self.time_masking(spec)
        spec = self.freq_masking(spec)
        spec = spec.transpose(0, 1).contiguous()           # back to (T, M)
        return spec


In [None]:
# 90/10 data split
all_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".wav")]
# all_files = all_files[:300]

train_files, val_files = train_test_split(all_files, test_size=0.1, random_state=42)

# Datasets
train_dataset = AudioDataset(train_files, processor, is_train=True)
val_dataset = AudioDataset(val_files, processor, is_train=False)

# Check and store the dimensions of the first element
sample_data = train_dataset[0]  # First element from the dataset
input_shape = sample_data["input_values"].shape
output_shape = sample_data["output_values"].shape

print(f"Input values shape: {input_shape}")
print(f"Output values shape: {output_shape}")


In [None]:
input_dim = input_shape  # Example input dimensions
output_dim = output_shape  # Example target dimensions

ssl_model = AST_SSL(base_model, output_dim=output_dim)

if resume_from_checkpoint is not None:
    ssl_model.load_state_dict(torch.load(os.path.join(resume_from_checkpoint, "model.pth")))
    print("Model loaded from checkpoint.")


def compute_metrics(eval_pred):
    preds, labels = eval_pred
    preds = np.asarray(preds)
    labels = np.asarray(labels)
    mse = ((preds - labels) ** 2).mean(dtype=np.float64)
    return {"eval_mse": float(mse)}

def custom_collate_fn(batch):
    input_values = torch.stack([b["input_values"] for b in batch])
    labels = torch.stack([b["output_values"] for b in batch])
    return {"input_values": input_values, "labels": labels}

# Training parameters
training_args = TrainingArguments(
    output_dir="./AST-SSL-results" + str_id,
    evaluation_strategy="epoch",
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=epoch_num,
    save_strategy="epoch",
    save_total_limit=2,
    logging_dir='./AST-SSL-logs' + str_id,
    logging_steps=10,
    report_to="tensorboard",
    load_best_model_at_end=True,
    metric_for_best_model="eval_mse",  # Must match the return key of compute_metrics
    greater_is_better=False,
    fp16=True,  # Mixed precision if GPU is available
    gradient_accumulation_steps=2,  # If memory is low
    dataloader_num_workers=4,  # Speed up data loading
    warmup_ratio=0.05,
    max_grad_norm=1.0,
)

callbacks = [
    EarlyStoppingCallback(
        early_stopping_patience=10,     # Stop if no improvement after 10 consecutive epochs
        early_stopping_threshold=1e-4   # Minimum absolute improvement threshold (MSE)
    )
]

# Initialize Trainer
trainer = Trainer(
    model=ssl_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    data_collator=custom_collate_fn,  # Use the custom collate_fn
    callbacks=callbacks,
)

# Train the model
trainer.train()


In [None]:
# Save model and processor
model_save_path = "./AST-SSL-indep" + str_id

# Save processor in the Hugging Face supported way
processor.save_pretrained(model_save_path)

# Save model weights (since AST_SSL does not support the save_pretrained method)
torch.save(ssl_model.state_dict(), os.path.join(model_save_path, "model.pth"))

metrics = trainer.evaluate()
print(metrics)
