In [31]:
import torch
from torch import nn
from torch.utils.data import Dataset
from transformers import AutoModel, AutoFeatureExtractor
import numpy as np
import librosa
from torch.utils.data import DataLoader
import torch.optim as optim
import pandas as pd
import os
from utils import load_labels_from_dataset

In [32]:
class AudioDepressionDataset(Dataset):
    def __init__(self, audio_paths, labels, model_name, sample_rate=16_000, segment_length_seconds=20, max_segments=None):
        self.audio_paths = audio_paths  
        self.labels = labels            
        self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name, do_normalize=False)
        self.sample_rate = sample_rate
        self.segment_length_seconds = segment_length_seconds
        self.segment_length_samples = segment_length_seconds * sample_rate
        self.max_segments = max_segments

    def __len__(self):
        return len(self.audio_paths)
    
    def _load_audio(self, audio_path):
        audio, _ = librosa.load(audio_path, sr=self.sample_rate)
        if len(audio.shape) > 1:
            audio = audio.mean(axis=0)
        audio = audio / np.max(np.abs(audio))
        return audio
    
    def _segment_audio(self, audio):
        """Segmenta l'audio in chunks di lunghezza fissa"""
        segments = []
        
        # Se l'audio è più corto del segmento desiderato, pad con zeri
        if len(audio) < self.segment_length_samples:
            padded_audio = np.zeros(self.segment_length_samples)
            padded_audio[:len(audio)] = audio
            segments.append(padded_audio)
        else:
            # Dividi in segmenti
            for i in range(0, len(audio), self.segment_length_samples):
                segment = audio[i:i + self.segment_length_samples]
                
                # Se l'ultimo segmento è troppo corto, pad con zeri
                if len(segment) < self.segment_length_samples:
                    padded_segment = np.zeros(self.segment_length_samples)
                    padded_segment[:len(segment)] = segment
                    segment = padded_segment
                
                segments.append(segment)
                
                # Limita il numero di segmenti se specificato
                if self.max_segments and len(segments) >= self.max_segments:
                    break
        
        return np.array(segments)

    def __getitem__(self, idx):
        audio_path = self.audio_paths[idx]
        label = self.labels[idx]
        
        audio = self._load_audio(audio_path)
        segments = self._segment_audio(audio)
        
        segment_features = []
        for segment in segments:
            features = self.feature_extractor(
                segment, 
                sampling_rate=self.sample_rate,
                max_length=self.segment_length_samples,
                padding='max_length',
                truncation=True,
                return_tensors='pt',
                return_attention_mask=False,
            )
            segment_features.append(features.input_values[0])
        
        segment_features = torch.stack(segment_features)  # (num_segments, seq_len)
        
        return {
            'input_values': segment_features, 
            'label': torch.tensor(label, dtype=torch.long),
            'num_segments': len(segments)
        }

In [33]:
class AttentiveStatisticsPooling(nn.Module):
    """
    Implementation of Attentive Statistics Pooling based on
    "Attentive Statistics Pooling for Deep Speaker Embedding" (https://www.isca-archive.org/interspeech_2018/okabe18_interspeech.pdf)
    """
    def __init__(self, input_dim, attention_dim=64):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, attention_dim)
        # BatchNorm1d is applied on the channel dimension
        self.bn = nn.BatchNorm1d(attention_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(attention_dim, 1)

    def forward(self, x):
        """
        Forward pass.
        Args:
            x: The input tensor of shape (batch_size, seq_len, input_dim).
        Returns:
            The output tensor of shape (batch_size, input_dim * 2).
        """
        # (batch_size, seq_len, input_dim) -> (batch_size, seq_len, attention_dim)
        x_attn = self.linear1(x)

        # BatchNorm requires shape (batch_size, channels, seq_len), so we transpose
        # (batch_size, seq_len, attention_dim) -> (batch_size, attention_dim, seq_len)
        x_attn = x_attn.transpose(1, 2)
        x_attn = self.bn(x_attn)
        # Transpose back to the original dimension order
        # (batch_size, attention_dim, seq_len) -> (batch_size, seq_len, attention_dim)
        x_attn = x_attn.transpose(1, 2)

        # Apply activation and final linear layer
        x_attn = self.relu(x_attn)
        # (batch_size, seq_len, attention_dim) -> (batch_size, seq_len, 1)
        attention_scores = self.linear2(x_attn)

        attention_weights = torch.softmax(attention_scores, dim=1)

        # Equation (5): Weighted mean
        # (batch_size, seq_len, 1) * (batch_size, seq_len, input_dim) -> (batch_size, input_dim)
        mean = torch.sum(attention_weights * x, dim=1)

        # Equation (6): Weighted standard deviation
        # E[X^2] - (E[X])^2
        # (batch_size, seq_len, 1) * (batch_size, seq_len, input_dim) -> (batch_size, input_dim)
        variance = torch.sum(attention_weights * x.pow(2), dim=1) - mean.pow(2)
        std_dev = torch.sqrt(variance.clamp(min=1e-6))

        # (batch_size, input_dim), (batch_size, input_dim) -> (batch_size, input_dim * 2)
        pooled_output = torch.cat((mean, std_dev), dim=1)

        return pooled_output

In [34]:
class DepressionClassifier(nn.Module):
    def __init__(self, model_name, num_classes, dropout=0.1, 
                 sequence_model_type='bilstm', sequence_hidden_size=256):
        super(DepressionClassifier, self).__init__()
    
        self.model_name = model_name
        self.num_classes = num_classes
        self.dropout = dropout
        self.sequence_model_type = sequence_model_type

        self.ssl_model = AutoModel.from_pretrained(self.model_name, output_hidden_states=True)
        self.ssl_hidden_size = self.ssl_model.config.hidden_size
        self.head_hidden_size = self.ssl_hidden_size
        
        # +1 perchè prendiamo anche il layer che fa feature extraction
        layers_to_aggregate = self.ssl_model.config.num_hidden_layers + 1
        self.layer_weights = nn.Parameter(torch.ones(layers_to_aggregate))
        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(self.ssl_hidden_size) for _ in range(layers_to_aggregate)
        ])
        self.softmax = nn.Softmax(dim=-1)

        # Segment-level pooling
        self.segment_pooling = AttentiveStatisticsPooling(input_dim=self.ssl_hidden_size)
        segment_embedding_dim = self.ssl_hidden_size * 2

        # Sequence model per aggregare i segmenti
        if sequence_model_type == 'bilstm':
            self.sequence_model = nn.LSTM(
                input_size=segment_embedding_dim,
                hidden_size=sequence_hidden_size,
                num_layers=2,
                batch_first=True,
                dropout=dropout,
                bidirectional=True
            )
            sequence_output_dim = sequence_hidden_size * 2  # bidirectional
        elif sequence_model_type == 'transformer':
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=segment_embedding_dim,
                nhead=8,
                dim_feedforward=sequence_hidden_size,
                dropout=dropout,
                batch_first=True
            )
            self.sequence_model = nn.TransformerEncoder(encoder_layer, num_layers=2)
            sequence_output_dim = segment_embedding_dim

        # Global pooling
        self.global_pooling = AttentiveStatisticsPooling(input_dim=sequence_output_dim)
        global_embedding_dim = sequence_output_dim * 2
        
        self.classifier = nn.Sequential(
            nn.Linear(global_embedding_dim, self.head_hidden_size),
            nn.Dropout(self.dropout),
            nn.ReLU(),
            nn.Linear(self.head_hidden_size, self.num_classes),
        )

        self.init_weights()
    
    def init_weights(self):
        # initialize weights of classifier
        for name, param in self.classifier.named_parameters():
            if 'weight' in name and len(param.shape) > 1:
                nn.init.xavier_normal_(param)
            elif 'bias' in name:
                nn.init.constant_(param, 0)

    def forward(self, batch):
        input_values = batch['input_values']  # (batch_size, num_segments, seq_len)
        batch_size, num_segments, seq_len = input_values.shape
        
        # IMPORTANTE: Processa ogni audio separatamente per mantenere la struttura gerarchica
        all_segment_embeddings = []
        
        for batch_idx in range(batch_size):
            # Prendi tutti i segmenti di un singolo audio
            single_audio_segments = input_values[batch_idx]  # (num_segments, seq_len)
            
            # Processa tutti i segmenti di questo audio insieme
            ssl_hidden_states = self.ssl_model(
                input_values=single_audio_segments,
                return_dict=True,
            ).hidden_states
            
            # Weighted aggregation of layers
            ssl_hidden_state = torch.zeros_like(ssl_hidden_states[-1])
            weights = self.softmax(self.layer_weights)
            for i in range(len(ssl_hidden_states)):
                ssl_hidden_state += weights[i] * self.layer_norms[i](ssl_hidden_states[i])
            
            # Attention pooling per ogni segmento di questo audio
            segment_embeddings = []
            for seg_idx in range(num_segments):
                segment_emb = self.segment_pooling(ssl_hidden_state[seg_idx])  # (hidden_size * 2)
                segment_embeddings.append(segment_emb)
            
            segment_embeddings = torch.stack(segment_embeddings)  # (num_segments, hidden_size * 2)
            all_segment_embeddings.append(segment_embeddings)
        
        # Stack embeddings di tutti gli audio nel batch
        all_segment_embeddings = torch.stack(all_segment_embeddings)  # (batch_size, num_segments, hidden_size * 2)
        
        # Sequence modeling per ogni audio nel batch
        if self.sequence_model_type == 'bilstm':
            sequence_output, _ = self.sequence_model(all_segment_embeddings)
        elif self.sequence_model_type == 'transformer':
            sequence_output = self.sequence_model(all_segment_embeddings)
        
        # Global pooling per ogni audio
        global_embeddings = []
        for batch_idx in range(batch_size):
            global_emb = self.global_pooling(sequence_output[batch_idx])
            global_embeddings.append(global_emb)
        
        global_embeddings = torch.stack(global_embeddings)
        
        # Final classification
        output = self.classifier(global_embeddings)
        
        return output

In [35]:
# Funzione di collate per gestire batch con numero variabile di segmenti
def collate_fn(batch):
    """
    Questa funzione serve perché diversi audio possono avere numero diverso di segmenti.
    Ad esempio:
    - Audio 1: 30 secondi → 3 segmenti da 10s
    - Audio 2: 50 secondi → 5 segmenti da 10s
    
    Per creare un batch uniforme, dobbiamo fare padding al numero massimo di segmenti.
    Viene chiamata automaticamente dal DataLoader quando batch_size > 1.
    """
    # Trova il numero massimo di segmenti nel batch
    max_segments = max([item['num_segments'] for item in batch])
    
    batch_input_values = []
    batch_labels = []
    
    for item in batch:
        input_values = item['input_values']
        num_segments = item['num_segments']
        
        # Pad se necessario (aggiunge segmenti di zeri)
        if num_segments < max_segments:
            padding_shape = (max_segments - num_segments, input_values.shape[1])
            padding = torch.zeros(padding_shape, dtype=input_values.dtype)
            input_values = torch.cat([input_values, padding], dim=0)
        
        batch_input_values.append(input_values)
        batch_labels.append(item['label'])
    
    return {
        'input_values': torch.stack(batch_input_values),
        'label': torch.stack(batch_labels)
    }

In [36]:
# Parametri
model_name = "facebook/wav2vec2-base"
num_classes = 2  # Binary classification (depressed vs non-depressed)
dataset_name = "datasets/DAIC-WOZ-Cleaned"

train_df = pd.read_csv(os.path.join('datasets', 'DAIC-WOZ', 'train_split_Depression_AVEC2017.csv'))
dev_df = pd.read_csv(os.path.join('datasets', 'DAIC-WOZ', 'dev_split_Depression_AVEC2017.csv'))
test_df = pd.read_csv(os.path.join('datasets', 'DAIC-WOZ', 'full_test_split.csv'))

y_train = load_labels_from_dataset(train_df)
y_dev = load_labels_from_dataset(dev_df) 
y_test = load_labels_from_dataset(test_df)

def get_audio_paths(df, dataset_name):
    audio_paths = []
    for participant_id in df['Participant_ID']:
        dir_name = f"{participant_id}_P"
        wav_path = os.path.join(dataset_name, dir_name, f"{participant_id}_AUDIO.wav")
        if os.path.isfile(wav_path):
            audio_paths.append(wav_path)
        else:
            print(f"Warning: File non trovato per {participant_id} in {wav_path}")
    return audio_paths

# Carica i path audio
train_paths = get_audio_paths(train_df, dataset_name)
dev_paths = get_audio_paths(dev_df, dataset_name)
test_paths = get_audio_paths(test_df, dataset_name)

# Crea i dataset
train_dataset = AudioDepressionDataset(
    audio_paths=train_paths,
    labels=y_train,
    model_name=model_name,
    segment_length_seconds=10,  # Segmenti da 10 secondi
    max_segments=140  # Massimo 140 segmenti (= 23.3 minuti max)
)

dev_dataset = AudioDepressionDataset(
    audio_paths=dev_paths,
    labels=y_dev,
    model_name=model_name,
    segment_length_seconds=10,
    max_segments=140
)

test_dataset = AudioDepressionDataset(
    audio_paths=test_paths,
    labels=y_test,
    model_name=model_name,
    segment_length_seconds=10,
    max_segments=140
)

# DataLoaders
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=8,  # Riduci se hai problemi di memoria
    shuffle=True, 
    collate_fn=collate_fn,
    num_workers=2
)

dev_dataloader = DataLoader(
    dev_dataset, 
    batch_size=8, 
    shuffle=False, 
    collate_fn=collate_fn,
    num_workers=2
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=2   
)

# Modello
model = DepressionClassifier(
    model_name=model_name,
    num_classes=num_classes,
    dropout=0.1,
    sequence_model_type='bilstm',  # Prova anche 'transformer'
    sequence_hidden_size=256
)

# Ottimizzatore e loss
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

print("\n=== Model Summary ===")
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Training loop
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    
    for batch_idx, batch in enumerate(train_dataloader):
        # Sposta i dati su GPU
        batch['input_values'] = batch['input_values'].to(device)
        batch['label'] = batch['label'].to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        output = model(batch)
        loss = criterion(output, batch['label'])
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 10 == 0:
            print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
    
    avg_loss = total_loss / len(train_dataloader)
    print(f'Epoch {epoch} completed. Average Loss: {avg_loss:.4f}')
    
    # Validation
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in dev_dataloader:
            batch['input_values'] = batch['input_values'].to(device)
            batch['label'] = batch['label'].to(device)
            
            output = model(batch)
            val_loss += criterion(output, batch['label']).item()
            
            predictions = torch.argmax(output, dim=1)
            total += batch['label'].size(0)
            correct += (predictions == batch['label']).sum().item()
    
    val_accuracy = correct / total
    avg_val_loss = val_loss / len(dev_dataloader)
    print(f'Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.4f}')




=== Model Summary ===
Total parameters: 100,513,937
Trainable parameters: 100,513,937


Traceback (most recent call last):
  File "<string>", line 1, in <module>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/davidebonura/miniconda3/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main
  File "/Users/davidebonura/miniconda3/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
    exitcode = _main(fd, parent_sentinel)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/Users/davidebonura/miniconda3/lib/python3.12/multiprocessing/spawn.py", line 132, in _main
  File "/Users/davidebonura/miniconda3/lib/python3.12/multiprocessing/spawn.py", line 132, in _main
    self = reduction.pickle.load(from_parent)
    self = reduction.pickle.load(from_parent)
                ^ ^ ^ ^ ^ ^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^AttributeError^^: 
Can't get attribute 'AudioDepressionDataset' on <module '__main__' (<class

RuntimeError: DataLoader worker (pid(s) 97582, 97583) exited unexpectedly