In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset
import torchaudio
from transformers import AutoModel

from utils import load_labels_from_dataset

In [None]:
class AudioDepressionDataset(Dataset):
    def __init__(self, audio_paths, labels, num_samples=None):
        self.audio_paths = audio_paths  # Lista dei percorsi ai file .wav
        self.labels = labels            # Lista di etichette (0 o 1)
        self.num_samples = num_samples  # Lunghezza desiderata dei segnali (in campioni)

    def __len__(self):
        return len(self.audio_paths)

    def __getitem__(self, idx):
        audio_path = self.audio_paths[idx]
        label = self.labels[idx]

        # Carica l'audio
        waveform, _ = torchaudio.load(audio_path)

        # Pad o taglia per avere lunghezza fissa
        if self.num_samples:
            if waveform.shape[1] > self.num_samples:
                waveform = waveform[:, :self.num_samples]
            elif waveform.shape[1] < self.num_samples:
                padding = self.num_samples - waveform.shape[1]
                waveform = torch.nn.functional.pad(waveform, (0, padding))

        return {
            'input': waveform,
            'label': torch.tensor(label, dtype=torch.long)
        }

In [None]:
import torch
import torch.nn as nn
from transformers import Wav2Vec2Model

class LayerWeightedAttentionPooling(nn.Module):
    def __init__(self, hidden_size, num_layers):
        super().__init__()
        # Pesi per somma tra layer
        self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
        # Attention pooling: wᵗ·tanh(W·hᵢ)
        self.attn_proj = nn.Linear(hidden_size, 128)
        self.attn_score = nn.Linear(128, 1)

    def forward(self, hidden_states):
        # hidden_states: list of tensors (num_layers) of shape (B, T, H)
        stacked = torch.stack(hidden_states, dim=0)  # (L, B, T, H)
        norm_weights = torch.softmax(self.layer_weights, dim=0)
        weighted = (norm_weights[:, None, None, None] * stacked).sum(dim=0)  # (B, T, H)

        # Attention pooling
        attn = torch.tanh(self.attn_proj(weighted))  # (B, T, 128)
        attn_weights = torch.softmax(self.attn_score(attn), dim=1)  # (B, T, 1)
        pooled = (weighted * attn_weights).sum(dim=1)  # (B, H)

        return pooled

class DepressionClassifier(nn.Module):
    def __init__(self, model_name, num_classes, dropout):
        super(DepressionClassifier, self).__init__()
        self.model_name = model_name
        self.num_classes = num_classes
        self.dropout = dropout
        self.ssl_model = AutoModel.from_pretrained(self.model_name, output_hidden_states=True)
        
        # TODO aggiungere codice per freezare layer se serve
        
        print(f'Number of trainable parameters: {sum(p.numel() for p in self.ssl_model.parameters() if p.requires_grad) / 1e6:.2f}M')
        
        # +1 perchè prendiamo anche il layer che fa feature extraction
        layers_to_aggregate = self.ssl_model.num_hidden_layers + 1
        self.layer_weights = nn.Parameter(torch.ones(layers_to_aggregate))
        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(self.ssl_model.hidden_size) for _ in range(layers_to_aggregate)
        ])
        self.softmax = nn.Softmax(dim=-1)

    def forward(self):
        pass