In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset
from transformers import AutoModel, AutoFeatureExtractor
import numpy as np
import librosa
from torch.utils.data import DataLoader
import torch.optim as optim
import pandas as pd
import os
from tqdm import tqdm
from sklearn.metrics import f1_score

from utils import load_labels_from_dataset, get_audio_paths, get_split_audio_paths

In [None]:
class AudioDepressionDataset(Dataset):
    def __init__(self, audio_paths, labels, model_name, sample_rate=16_000):
        self.audio_paths = audio_paths  
        self.labels = labels            
        self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name, do_normalize=False)
        self.sample_rate = sample_rate
        self.segment_length_seconds = 10
        self.segment_length_samples = self.segment_length_seconds * sample_rate

    def __len__(self):
        return len(self.audio_paths)
    
    def _load_audio(self, audio_path):
        audio, _ = librosa.load(audio_path, sr=self.sample_rate)
        if len(audio.shape) > 1:
            audio = audio.mean(axis=0)
        audio = audio / np.max(np.abs(audio))
        # Pad or truncate to exactly 10s
        if len(audio) < self.segment_length_samples:
            padded_audio = np.zeros(self.segment_length_samples)
            padded_audio[:len(audio)] = audio
            audio = padded_audio
        else:
            audio = audio[:self.segment_length_samples]
        return audio

    def __getitem__(self, idx):
        audio_path = self.audio_paths[idx]
        label = self.labels[idx]
        audio = self._load_audio(audio_path)
        features = self.feature_extractor(
            audio, 
            sampling_rate=self.sample_rate,
            max_length=self.segment_length_samples,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {
            'input_values': features.input_values[0], 
            'label': torch.tensor(label, dtype=torch.long)
        }

In [None]:
class AttentionPoolingLayer(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.linear = nn.Linear(embed_dim, 1)
        
    def forward(self, x, mask=None):
        """
        Forward pass.
        Args:
            x: The input tensor of shape (batch_size, seq_len, embed_dim).
            mask: The padding mask of shape (batch_size, seq_len).
        Returns:
            The output tensor of shape (batch_size, embed_dim).
        """
        weights = self.linear(x)  # (bs, seq_len, embed_dim) -> (bs, seq_len, 1)

        # Apply the mask before softmax to ignore padding
        if mask is not None:
            # .unsqueeze(-1): (bs, seq_len) -> (bs, seq_len, 1)
            # Assign a very negative value where the mask is True (padding)
            weights.masked_fill_(mask.unsqueeze(-1), -1e9)

        weights = torch.softmax(weights, dim=1)  # Now masked elements will have ~0 weight

        # Weighted sum (bs, seq_len, 1) * (bs, seq_len, embed_dim) -> (bs, embed_dim)
        x = torch.sum(weights * x, dim=1) 
        return x

In [None]:
class DepressionClassifier(nn.Module):
    def __init__(self, model_name, num_classes, dropout=0.1):
        super(DepressionClassifier, self).__init__()
    
        # SSL model loading & config
        self.ssl_model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
        self.ssl_hidden_size = self.ssl_model.config.hidden_size # e.g. 768

        # Weighted sum of SSL model's hidden layers
        num_ssl_layers = self.ssl_model.config.num_hidden_layers
        layers_to_aggregate = num_ssl_layers + 1 # +1 for the initial embeddings

        self.layer_weights = nn.Parameter(torch.ones(layers_to_aggregate))
        self.layer_norms = nn.ModuleList(
            [nn.LayerNorm(self.ssl_hidden_size) for _ in range(layers_to_aggregate)]
        )
        self.softmax = nn.Softmax(dim=-1)

        # Attention pooling for frame-level features
        self.frame_pooling = AttentionPoolingLayer(embed_dim=self.ssl_hidden_size)

        self.classifier = nn.Sequential(
            nn.Linear(self.ssl_hidden_size, self.ssl_hidden_size),
            nn.Dropout(dropout),
            nn.ReLU(),
            nn.Linear(self.ssl_hidden_size, num_classes),
        )

        self.init_weights()
    
    def init_weights(self):
        # initialize weights of classifier
        for name, param in self.classifier.named_parameters():
            if 'weight' in name and len(param.shape) > 1:
                nn.init.xavier_normal_(param)
            elif 'bias' in name:
                nn.init.constant_(param, 0)

    def forward(self, batch):
        input_values = batch['input_values']  # (bs, seq_len)
        
        ssl_hidden_states = self.ssl_model(
            input_values=input_values,
            return_dict=True,
        ).hidden_states  # tuple of (bs, seq_len, hidden_size)

        # Weighted sum of all hidden layers
        ssl_hidden_state = torch.zeros_like(ssl_hidden_states[-1])
        weights = self.softmax(self.layer_weights)
        for i in range(len(ssl_hidden_states)):
            ssl_hidden_state += weights[i] * self.layer_norms[i](ssl_hidden_states[i])

        # Attention pool over sequence length (frames)
        pooled = self.frame_pooling(ssl_hidden_state, mask=None)  # (bs, hidden_size)
        output = self.classifier(pooled)  # (bs, num_classes)
        return output

In [None]:
model_name = "facebook/wav2vec2-base"
num_classes = 2 
dataset_name = "datasets/DAIC-WOZ-Cleaned-Split"

train_df = pd.read_csv(os.path.join(dataset_name, 'train_split_Depression_AVEC2017.csv'))
dev_df = pd.read_csv(os.path.join(dataset_name, 'dev_split_Depression_AVEC2017.csv'))

train_paths, y_train = get_split_audio_paths(train_df, dataset_name)
dev_paths, y_dev = get_split_audio_paths(dev_df, dataset_name)

# Print class distribution for 10s segments
print("Train 10s segment distribution:", np.bincount(y_train))
print("Dev 10s segment distribution:", np.bincount(y_dev))
train_counts = np.bincount(y_train)
dev_counts = np.bincount(y_dev)
print("Train 10s segment distribution:", train_counts)
print("Train 10s segment percentages:", np.round(100 * train_counts / train_counts.sum(), 2), "%")
print("Dev 10s segment distribution:", dev_counts)
print("Dev 10s segment percentages:", np.round(100 * dev_counts / dev_counts.sum(), 2), "%")

# Datasets
train_dataset = AudioDepressionDataset(
    audio_paths=train_paths,
    labels=y_train,
    model_name=model_name
)

dev_dataset = AudioDepressionDataset(
    audio_paths=dev_paths,
    labels=y_dev,
    model_name=model_name
)

# DataLoaders
batch_size = 16
num_workers = 0

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=num_workers
)

dev_dataloader = DataLoader(
    dev_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=num_workers
)

# Modello
model = DepressionClassifier(
    model_name=model_name,
    num_classes=num_classes,
    dropout=0.1
)

# Ottimizzatore e loss
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

print("\n=== Model Summary ===")
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

In [None]:
# Training loop
num_epochs = 5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

best_val_f1 = 0.0
model_save_path = "depression_classifier_best.pth"

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    train_pbar = tqdm(enumerate(train_dataloader), 
                      total=len(train_dataloader),
                      desc=f"Epoch {epoch+1}/{num_epochs} - Training")

    print(f"\nEpoch {epoch}")
    
    for batch_idx, batch in train_pbar:
        batch['input_values'] = batch['input_values'].to(device)
        batch['label'] = batch['label'].to(device)

        optimizer.zero_grad()
        output = model(batch)
        loss = criterion(output, batch['label'])

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch} completed. Average Loss: {avg_loss:.4f}")

    # Validation
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in dev_dataloader:
            batch['input_values'] = batch['input_values'].to(device)
            batch['label'] = batch['label'].to(device)
            
            output = model(batch)
            val_loss += criterion(output, batch['label']).item()
            
            predictions = torch.argmax(output, dim=1)
            all_preds.extend(predictions.cpu().numpy())
            all_labels.extend(batch['label'].cpu().numpy())
            total += batch['label'].size(0)
            correct += (predictions == batch['label']).sum().item()
    
    val_accuracy = correct / total
    avg_val_loss = val_loss / len(dev_dataloader)
    val_f1 = f1_score(all_labels, all_preds, average='macro')
    print(f'Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.4f}, F1: {val_f1:.4f}')
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), model_save_path)
        print(f"New best model saved to {model_save_path} with F1: {val_f1:.4f}")

In [None]:
'''
# Load and prepare test set only after training/validation
test_df = pd.read_csv(os.path.join(split_dataset_name, 'full_test_split.csv'))
test_paths, y_test = get_split_audio_paths(test_df, split_dataset_name)

test_dataset = AudioDepressionDataset(
    audio_paths=test_paths,
    labels=y_test,
    model_name=model_name
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers
)
'''