In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AdvancedEmotionModel(nn.Module):
    def __init__(self, input_features, hidden_dim=64):
        """
        input_features: Liczba cech na wej≈õciu (np. 20 dla 4ch*5pasm lub ~2500 dla surowego FFT z Kaggle)
        """
        super(AdvancedEmotionModel, self).__init__()

        self.conv1 = nn.Conv1d(in_channels=input_features, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(64)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)

        self.gru = nn.GRU(input_size=64, hidden_size=hidden_dim, 
                          num_layers=2, batch_first=True, bidirectional=True)

        self.attention_fc = nn.Linear(hidden_dim * 2, 1)
        
        self.fc_valence = nn.Sequential(
            nn.Linear(hidden_dim * 2, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )
        
        self.fc_energy = nn.Sequential(
            nn.Linear(hidden_dim * 2, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def attention(self, gru_output):

        weights = torch.tanh(self.attention_fc(gru_output))
        weights = F.softmax(weights, dim=1)

        context_vector = torch.sum(weights * gru_output, dim=1)
        
        return context_vector

    def forward(self, x):

        x = x.permute(0, 2, 1) 
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        x = x.permute(0, 2, 1)
        
        gru_out, _ = self.gru(x)
        
        context = self.attention(gru_out)
        
        valence = self.fc_valence(context)
        energy = self.fc_energy(context)
        
        return valence, energy