In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from transformers import HubertModel, Wav2Vec2Processor
from sklearn.model_selection import train_test_split
from transformers import Wav2Vec2FeatureExtractor

class HubertTransformerClassifier(nn.Module):
    def __init__(self, num_classes,
                 hubert_model="facebook/hubert-base-ls960",
                 n_heads=8, n_layers=3, dim_feedforward=512, dropout=0.1):
        super().__init__()
        self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hubert_model)
        self.hubert = HubertModel.from_pretrained(hubert_model)
        embed_dim = self.hubert.config.hidden_size

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=n_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim // 2, num_classes)
        )

    def forward(self, input_values, attention_mask=None):
        outputs = self.hubert(input_values=input_values, attention_mask=attention_mask)
        hidden = outputs.last_hidden_state
        encoded = self.transformer_encoder(hidden)
        pooled = encoded.mean(dim=1)
        return self.classifier(pooled)


from transformers import Wav2Vec2FeatureExtractor


class Wav2VecAudioDataset(Dataset):
    def __init__(self, df, processor):
        self.df = df.reset_index(drop=True)
        self.processor = processor

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = row["file_path"]
        label = int(row["label"])
        waveform, _ = torchaudio.load(path)
        waveform = waveform.squeeze().numpy()
        return waveform, label

def collate_fn(batch):
    waveforms, labels = zip(*batch)
    processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
    input_values = self.feature_extractor(audio, sampling_rate=16000, return_tensors="pt").input_values
    attention_mask = inputs.attention_mask
    labels = torch.tensor(labels, dtype=torch.long)
    return input_values, attention_mask, labels

# Load and split dataset
df = pd.read_csv("labeled_data.csv")
train_val_df, test_df = train_test_split(df, test_size=0.1, stratify=df["label"], random_state=42)
train_df, val_df = train_test_split(train_val_df, test_size=0.1, stratify=train_val_df["label"], random_state=42)

# Processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")

# Datasets and Loaders
train_dataset = Wav2VecAudioDataset(train_df, processor)
val_dataset = Wav2VecAudioDataset(val_df, processor)
test_dataset = Wav2VecAudioDataset(test_df, processor)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=collate_fn, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=8, collate_fn=collate_fn, num_workers=2)

# Model setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HubertTransformerClassifier(num_classes=7).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Train and evaluation functions
def train_epoch(model, loader):
    model.train()
    total, correct, loss_sum = 0, 0, 0.0
    for wave, mask, labels in loader:
        wave, mask, labels = wave.to(device), mask.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(wave, mask)
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        optimizer.step()
        loss_sum += loss.item()
        preds = logits.argmax(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return loss_sum / len(loader), correct / total * 100

def eval_epoch(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for wave, mask, labels in loader:
            wave, mask, labels = wave.to(device), mask.to(device), labels.to(device)
            preds = model(wave, mask).argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total * 100




In [None]:
best_val_acc = 0.0
best_model_state = None

for epoch in range(10):
    tr_loss, tr_acc = train_epoch(model, train_loader)
    val_acc = eval_epoch(model, val_loader)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_state = model.state_dict()

    print(f"Epoch {epoch+1} ▶ Loss {tr_loss:.4f} | Train Acc {tr_acc:.2f}% | Val Acc {val_acc:.2f}%")

torch.save(best_model_state, "best_hubert_transformer_model.pt")




In [None]:
model.load_state_dict(torch.load("best_hubert_transformer_model.pt"))
test_acc = eval_epoch(model, test_loader)
print(f"✅ Test Accuracy: {test_acc:.2f}%")