In [19]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F

In [20]:
class InputProjection(nn.Module):
    def __init__(self, input_dim=48, embed_dim=256):
        super().__init__()
        self.linear = nn.Linear(input_dim, embed_dim)
        self.activation = nn.ReLU()
        self.layer_norm = nn.LayerNorm(embed_dim)
        
    def forward(self, x):
        # x: (batch_size, 64, 48)
        x = self.linear(x)  # → (batch_size, 64, 256)
        x = self.activation(x)
        return self.layer_norm(x)  # → (batch_size, 64, 256)

In [21]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim=256, max_len=100):
        super().__init__()
        self.pe = nn.Parameter(torch.zeros(1, max_len, embed_dim))
        nn.init.xavier_uniform_(self.pe)
        
    def forward(self, x):
        # x: (batch_size, 64, 256)
        return x + self.pe[:, :x.size(1), :]  # Добавляем позиционную информацию

In [24]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim=256, num_heads=8):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        # Проекционные матрицы для каждой головы
        self.w_q = nn.Linear(embed_dim, embed_dim)
        self.w_k = nn.Linear(embed_dim, embed_dim)
        self.w_v = nn.Linear(embed_dim, embed_dim)
        self.w_o = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        # Проецируем в query, key, value
        Q = self.w_q(x)  # → (batch_size, 64, 256)
        K = self.w_k(x)
        V = self.w_v(x)
        
        # Разделяем на головы
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        # Теперь: (batch_size, num_heads, 64, head_dim)
        
        # Вычисляем attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attention_weights = F.softmax(scores, dim=-1)
        
        # Применяем attention к values
        attention_output = torch.matmul(attention_weights, V)
        
        # Объединяем головы
        attention_output = attention_output.transpose(1, 2).contiguous()
        attention_output = attention_output.view(batch_size, seq_len, self.embed_dim)
        
        # Финальная проекция
        return self.w_o(attention_output), attention_weights

In [25]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=256, num_heads=8, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(embed_dim, num_heads)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, embed_dim),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        # Self-attention с residual connection
        attn_output, weights = self.attention(x)
        x = x + self.dropout(attn_output)
        x = self.norm1(x)
        
        # Feed-forward с residual connection
        ffn_output = self.ffn(x)
        x = x + self.dropout(ffn_output)
        x = self.norm2(x)
        
        return x, weights

In [42]:
class TransformerTemporalEncoder(nn.Module):
    def __init__(self, input_dim=48, embed_dim=64, num_heads=8, 
                 num_layers=3, ff_dim=512, dropout=0.1):
        super().__init__()
        
        self.input_proj = InputProjection(input_dim, embed_dim)
        self.pos_encoding = PositionalEncoding(embed_dim, max_len=100)
        
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, ff_dim, dropout)
            for _ in range(num_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # x: (batch_size, 64, 48)
        attention_weights = []
        
        # Input projection
        x = self.input_proj(x)  # → (batch_size, 64, 256)
        x = self.pos_encoding(x)
        x = self.dropout(x)
        
        # Transformer blocks
        for block in self.transformer_blocks:
            x, weights = block(x)
            attention_weights.append(weights)
        
        return x, torch.stack(attention_weights)

In [43]:
enc = TransformerTemporalEncoder()

In [48]:
data = torch.rand(1, 64, 48)

In [49]:
temporal_encoded, attention_weights = enc(data)

In [50]:
temporal_encoded[:, -1, :]

tensor([[ 0.8123, -0.6182, -0.6583, -0.0629,  2.3469,  1.1953,  0.3435,  0.6787,
          1.0185, -0.5536,  0.7050, -1.1617, -0.8173, -0.8221,  0.2815,  0.3436,
         -0.0685, -0.4311,  0.2339,  0.7412, -0.1787, -0.1059, -0.0061, -0.4310,
         -0.6668, -1.1997, -0.7102, -1.5147,  0.0836, -0.5731, -0.4812, -1.0000,
         -0.9866, -0.3735, -0.2992, -1.0261, -1.4122, -0.6642,  1.4058, -0.4298,
          1.9099,  1.0879,  0.3377, -0.4276,  2.8044, -1.0083,  2.0285,  0.1321,
          0.2960,  0.0583,  1.5483, -0.1285, -1.6143,  0.1691, -0.3700, -1.0194,
         -0.9143,  0.7447, -1.3078, -1.4715,  1.5211,  0.0534,  1.8833,  0.7500]],
       grad_fn=<SliceBackward0>)

In [39]:
attention_weights.shape

torch.Size([3, 1, 8, 64, 64])