In [4]:
from torch import nn
import torch 
import glob

In [5]:

class Attention(nn.Module):
    def __init__(self, hidden_dim, vec_a_size, vec_b_size):
        super(Attention, self).__init__()
        self.attn_a = nn.Linear(vec_a_size, hidden_dim)  # Project vec_a to hidden_dim
        self.attn_b = nn.Linear(vec_b_size, hidden_dim)  # Project vec_b to hidden_dim
        self.attn_score = nn.Linear(hidden_dim, 1)  # Compute attention scores
    
    def forward(self, matrix, vec_a, vec_b):
        # matrix: (batch_size, time, hidden_dim)
        # vec_a: (batch_size, vec_a_size)
        # vec_b: (batch_size, vec_b_size)
        
        batch_size, time, hidden_dim = matrix.shape
        
        # Project vectors into hidden space
        a_proj = self.attn_a(vec_a).unsqueeze(1).expand(-1, time, -1)  # (batch_size, time, hidden_dim)
        b_proj = self.attn_b(vec_b).unsqueeze(1).expand(-1, time, -1)  # (batch_size, time, hidden_dim)
        
        # Compute attention scores
        attn_input = torch.tanh(matrix + a_proj + b_proj)  # Combine information
        attn_scores = self.attn_score(attn_input).squeeze(-1)  # (batch_size, time)
        attn_weights = torch.softmax(attn_scores, dim=-1).unsqueeze(-1)  # (batch_size, time, 1)
        
        # Apply attention to the matrix
        updated_matrix = matrix * attn_weights  # Element-wise weighting
        
        return updated_matrix  # (batch_size, time, hidden_dim)

# Example Usage
batch_size, time, hidden_dim, vec_a_size, vec_b_size = 32, 10, 64, 16, 16
matrix = torch.randn(batch_size, time, hidden_dim)
vec_a = torch.randn(batch_size, vec_a_size)
vec_b = torch.randn(batch_size, vec_b_size)

attn = Attention(hidden_dim, vec_a_size, vec_b_size)
output_matrix = attn(matrix, vec_a, vec_b)
print(output_matrix.shape)  # Should be (batch_size, time, hidden_dim)


torch.Size([32, 10, 64])


In [None]:
# data_dir = './Emotion Speech Dataset/'
data_dir = '/home/dcor/niskhizov/Prosody2Vec/IEMOCAP_full_release/'
# scan recursively for all .wav files in the data_dir
wav_files = glob.glob(data_dir + '/**/*.wav', recursive=True)

