In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset
from transformers import AutoModel, AutoFeatureExtractor
import numpy as np
import librosa

from utils import load_labels_from_dataset

In [None]:
class AudioDepressionDataset(Dataset):
    def __init__(self, audio_paths, labels, model_name, sample_rate=16_000, max_length_in_seconds=20):
        self.audio_paths = audio_paths  
        self.labels = labels            
        self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name, do_normalize=False)
        self.sample_rate = sample_rate
        self.max_length_in_seconds = max_length_in_seconds

    def __len__(self):
        return len(self.audio_paths)
    
    def _load_audio(self, audio_path):
        audio, _ = librosa.load(audio_path, self.sample_rate)
        if len(audio.shape) > 1:
            audio = audio.mean(axis=0)
        audio = audio / np.max(np.abs(audio))
        audio = audio.squeeze()
        audio = torch.tensor(audio, dtype=torch.float32)
        return audio

    def __getitem__(self, idx):
        audio_path = self.audio_paths[idx]
        label = self.labels[idx]
        audio = self._load_audio(audio_path)

        features = self.feature_extractor(
            audio, 
            sampling_rate=self.sample_rate,
            max_length=self.sample_rate * self. max_length_in_seconds,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
            return_attention_mask=True,
        )
    
        return {
            'input_values': features.input_values[0],
            'label': torch.tensor(label, dtype=torch.long)
        }

In [None]:
class AttentiveStatisticsPooling(nn.Module):
    """
    Implementation of Attentive Statistics Pooling based on
    "Attentive Statistics Pooling for Deep Speaker Embedding" (https://www.isca-archive.org/interspeech_2018/okabe18_interspeech.pdf)
    """
    def __init__(self, input_dim, attention_dim=64):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, attention_dim)
        # BatchNorm1d is applied on the channel dimension
        self.bn = nn.BatchNorm1d(attention_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(attention_dim, 1)

    def forward(self, x):
        """
        Forward pass.
        Args:
            x: The input tensor of shape (batch_size, seq_len, input_dim).
        Returns:
            The output tensor of shape (batch_size, input_dim * 2).
        """
        # (batch_size, seq_len, input_dim) -> (batch_size, seq_len, attention_dim)
        x_attn = self.linear1(x)

        # BatchNorm requires shape (batch_size, channels, seq_len), so we transpose
        # (batch_size, seq_len, attention_dim) -> (batch_size, attention_dim, seq_len)
        x_attn = x_attn.transpose(1, 2)
        x_attn = self.bn(x_attn)
        # Transpose back to the original dimension order
        # (batch_size, attention_dim, seq_len) -> (batch_size, seq_len, attention_dim)
        x_attn = x_attn.transpose(1, 2)

        # Apply activation and final linear layer
        x_attn = self.relu(x_attn)
        # (batch_size, seq_len, attention_dim) -> (batch_size, seq_len, 1)
        attention_scores = self.linear2(x_attn)

        attention_weights = torch.softmax(attention_scores, dim=1)

        # Equation (5): Weighted mean
        # (batch_size, seq_len, 1) * (batch_size, seq_len, input_dim) -> (batch_size, input_dim)
        mean = torch.sum(attention_weights * x, dim=1)

        # Equation (6): Weighted standard deviation
        # E[X^2] - (E[X])^2
        # (batch_size, seq_len, 1) * (batch_size, seq_len, input_dim) -> (batch_size, input_dim)
        variance = torch.sum(attention_weights * x.pow(2), dim=1) - mean.pow(2)
        std_dev = torch.sqrt(variance.clamp(min=1e-6))

        # (batch_size, input_dim), (batch_size, input_dim) -> (batch_size, input_dim * 2)
        pooled_output = torch.cat((mean, std_dev), dim=1)

        return pooled_output

In [None]:
class DepressionClassifier(nn.Module):
    def __init__(self, model_name, num_classes, dropout):
        super(DepressionClassifier, self).__init__()
        self.model_name = model_name
        self.num_classes = num_classes
        self.dropout = dropout
        self.ssl_model = AutoModel.from_pretrained(self.model_name, output_hidden_states=True)
        self.ssl_hidden_size = self.ssl_model.config.hidden_size
        self.head_hidden_size = self.ssl_hidden_size
        
        # TODO aggiungere codice per freezare layer se serve
        
        print(f'Number of trainable parameters: {sum(p.numel() for p in self.ssl_model.parameters() if p.requires_grad) / 1e6:.2f}M')
        
        # +1 perchè prendiamo anche il layer che fa feature extraction
        layers_to_aggregate = self.ssl_model.num_hidden_layers + 1
        self.layer_weights = nn.Parameter(torch.ones(layers_to_aggregate))
        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(self.ssl_hidden_size) for _ in range(layers_to_aggregate)
        ])
        self.softmax = nn.Softmax(dim=-1)

        # dimensione tensore in input allo strato di pooling
        self.pooling_embedding_dim = self.ssl_hidden_size
        # dimensione tensore in output dallo strato di pooling
        self.global_embedding_dim = self.ssl_hidden_size

        self.pooling_layer = AttentiveStatisticsPooling(input_dim=self.pooling_embedding_dim)

        self.classifier = nn.Sequential(
            nn.Linear(self.global_embedding_dim, self.head_hidden_size),
            nn.Dropout(self.dropout),
            nn.ReLU(),
            nn.Linear(self.head_hidden_size, self.num_classes),
        )

        self.init_weights()
    
    def init_weights(self):
        # initialize weights of classifier
        for name, param in self.classifier.named_parameters():
            if 'weight' in name and len(param.shape) > 1:
                nn.init.xavier_normal_(param)
            elif 'bias' in name:
                nn.init.constant_(param, 0)

    def forward(self, batch):
        ssl_input = batch['input_values']
        
        ssl_hidden_states = self.ssl_model(
            input_values=ssl_input,
            return_dict=True,
        ).hidden_states

        ssl_hidden_state = torch.zeros_like(ssl_hidden_states[-1])
        weights = self.softmax(self.layer_weights)
        for i in range(self.ssl_model.config.num_hidden_layers + 1):
            ssl_hidden_state += weights[i] * self.layer_norms[i](ssl_hidden_states[i])
        
        # attention pooling
        features = self.pooling_layer(ssl_hidden_state)

        output = self.classifier(features)

        return output