In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import unittest
import numpy as np

class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_model: int):
        super(ScaledDotProductAttention, self).__init__()
        
        self.d_model = d_model
        
        # Linear transformations cho Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
    
    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, mask=None) -> torch.Tensor:
        """
        Pure Scaled Dot-Product Attention
        
        Args:
            Q: Query tensor (batch_size, seq_len_q, d_model)
            K: Key tensor (batch_size, seq_len_k, d_model)  
            V: Value tensor (batch_size, seq_len_v, d_model)
            mask: Optional mask tensor (batch_size, seq_len_q, seq_len_k) hoặc broadcastable
        
        Returns:
            output: Attention output (batch_size, seq_len_q, d_model)
        """
        # Linear transformations
        Q = self.W_q(Q)  # (batch_size, seq_len_q, d_model)
        K = self.W_k(K)  # (batch_size, seq_len_k, d_model)
        V = self.W_v(V)  # (batch_size, seq_len_v, d_model)
        
        # Scaled dot-product attention
        d_k = K.size(-1)  # d_model
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        # scores shape: (batch_size, seq_len_q, seq_len_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attention_weights = F.softmax(scores, dim=-1)
        # attention_weights shape: (batch_size, seq_len_q, seq_len_k)
        
        output = torch.matmul(attention_weights, V)
        # output shape: (batch_size, seq_len_q, d_model)
        
        return output
    
    def forward_with_weights(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, mask=None):
        """
        Forward pass với return cả attention weights (để debugging/visualization)
        """
        Q = self.W_q(Q)
        K = self.W_k(K)
        V = self.W_v(V)
        
        d_k = K.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        
        return output, attention_weights


class TestScaledDotProductAttention(unittest.TestCase):
    
    def setUp(self):
        """Thiết lập các tham số test"""
        self.batch_size = 2
        self.seq_len = 10
        self.d_model = 512
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Tạo model
        self.model = ScaledDotProductAttention(self.d_model).to(self.device)
        
        # Tạo input tensors
        self.Q = torch.randn(self.batch_size, self.seq_len, self.d_model).to(self.device)
        self.K = torch.randn(self.batch_size, self.seq_len, self.d_model).to(self.device)
        self.V = torch.randn(self.batch_size, self.seq_len, self.d_model).to(self.device)
    
    def test_initialization(self):
        """Test khởi tạo model"""
        model = ScaledDotProductAttention(512)
        self.assertEqual(model.d_model, 512)
        self.assertIsInstance(model.W_q, nn.Linear)
        self.assertIsInstance(model.W_k, nn.Linear)
        self.assertIsInstance(model.W_v, nn.Linear)
        
        # Kiểm tra dimensions của linear layers
        self.assertEqual(model.W_q.in_features, 512)
        self.assertEqual(model.W_q.out_features, 512)
    
    def test_forward_shape_self_attention(self):
        """Test shape với self-attention (Q=K=V)"""
        X = torch.randn(self.batch_size, self.seq_len, self.d_model).to(self.device)
        output = self.model(X, X, X)
        expected_shape = (self.batch_size, self.seq_len, self.d_model)
        self.assertEqual(output.shape, expected_shape)
    
    def test_forward_shape_cross_attention(self):
        """Test shape với cross-attention (K,V có seq_len khác Q)"""
        seq_len_kv = 15
        K_cross = torch.randn(self.batch_size, seq_len_kv, self.d_model).to(self.device)
        V_cross = torch.randn(self.batch_size, seq_len_kv, self.d_model).to(self.device)
        
        output = self.model(self.Q, K_cross, V_cross)
        expected_shape = (self.batch_size, self.seq_len, self.d_model)  # shape theo Q
        self.assertEqual(output.shape, expected_shape)
    
    def test_attention_weights_properties(self):
        """Test các tính chất của attention weights"""
        output, attention_weights = self.model.forward_with_weights(self.Q, self.K, self.V)
        
        # 1. Shape của attention weights
        expected_weights_shape = (self.batch_size, self.seq_len, self.seq_len)
        self.assertEqual(attention_weights.shape, expected_weights_shape)
        
        # 2. Attention weights phải sum to 1 theo dim cuối
        weights_sum = attention_weights.sum(dim=-1)
        expected_sum = torch.ones(self.batch_size, self.seq_len).to(self.device)
        torch.testing.assert_close(weights_sum, expected_sum, rtol=1e-5, atol=1e-6)
        
        # 3. Attention weights phải >= 0
        self.assertTrue((attention_weights >= 0).all())
        
        # 4. Attention weights phải <= 1
        self.assertTrue((attention_weights <= 1).all())
    
    def test_causal_mask(self):
        """Test với causal mask (cho decoder self-attention)"""
        # Tạo causal mask (lower triangular)
        mask = torch.tril(torch.ones(self.seq_len, self.seq_len))
        mask = mask.unsqueeze(0).expand(self.batch_size, -1, -1).to(self.device)
        
        output, attention_weights = self.model.forward_with_weights(self.Q, self.K, self.V, mask=mask)
        
        # Kiểm tra attention weights = 0 ở vị trí mask = 0
        masked_positions = (mask == 0)
        masked_weights = attention_weights[masked_positions]
        
        # Attention weights tại vị trí masked phải rất nhỏ (≈ 0)
        self.assertTrue((masked_weights < 1e-8).all())
    
    def test_padding_mask(self):
        """Test với padding mask"""
        # Giả sử token cuối cùng là padding
        mask = torch.ones(self.batch_size, self.seq_len, self.seq_len).to(self.device)
        mask[:, :, -1] = 0  # Mask token cuối cùng
        
        output, attention_weights = self.model.forward_with_weights(self.Q, self.K, self.V, mask=mask)
        
        # Attention weights tại cột cuối phải ≈ 0
        last_column_weights = attention_weights[:, :, -1]
        self.assertTrue((last_column_weights < 1e-8).all())
    
    def test_scaling_effect(self):
        """Test hiệu ứng của scaling factor sqrt(d_k)"""
        # So sánh với attention không scale
        Q_proj = self.model.W_q(self.Q)
        K_proj = self.model.W_k(self.K)
        V_proj = self.model.W_v(self.V)
        
        # Attention có scale
        scores_scaled = torch.matmul(Q_proj, K_proj.transpose(-2, -1)) / math.sqrt(self.d_model)
        weights_scaled = F.softmax(scores_scaled, dim=-1)
        
        # Attention không scale
        scores_unscaled = torch.matmul(Q_proj, K_proj.transpose(-2, -1))
        weights_unscaled = F.softmax(scores_unscaled, dim=-1)
        
        # Scaled attention weights nên ít concentrated hơn (entropy cao hơn)
        def entropy(weights):
            return -(weights * torch.log(weights + 1e-9)).sum(dim=-1)
        
        entropy_scaled = entropy(weights_scaled).mean()
        entropy_unscaled = entropy(weights_unscaled).mean()
        
        # Với d_model lớn, scaled attention thường có entropy cao hơn
        if self.d_model > 64:
            self.assertGreater(entropy_scaled, entropy_unscaled)
    
    def test_gradient_flow(self):
        """Test gradient flow qua attention mechanism"""
        self.Q.requires_grad_(True)
        self.K.requires_grad_(True)
        self.V.requires_grad_(True)
        
        output = self.model(self.Q, self.K, self.V)
        loss = output.sum()
        loss.backward()
        
        # Kiểm tra gradients không phải None và không phải zero
        self.assertIsNotNone(self.Q.grad)
        self.assertIsNotNone(self.K.grad)
        self.assertIsNotNone(self.V.grad)
        
        # Kiểm tra gradients có magnitude > 0
        self.assertGreater(self.Q.grad.abs().max().item(), 0)
        self.assertGreater(self.K.grad.abs().max().item(), 0)
        self.assertGreater(self.V.grad.abs().max().item(), 0)
        
        # Kiểm tra model parameters có gradients
        for param in self.model.parameters():
            self.assertIsNotNone(param.grad)
            self.assertGreater(param.grad.abs().max().item(), 0)
    
    def test_attention_pattern_symmetry(self):
        """Test tính chất đối xứng khi Q=K"""
        X = torch.randn(1, 5, self.d_model).to(self.device)
        V = torch.randn(1, 5, self.d_model).to(self.device)
        
        _, attention_weights = self.model.forward_with_weights(X, X, V)
        
        # Khi Q=K, attention matrix nên có một số tính chất đối xứng
        # (không hoàn toàn đối xứng vì có linear transformations khác nhau)
        attention_matrix = attention_weights.squeeze(0)  # (5, 5)
        
        # Ít nhất diagonal elements nên có giá trị cao (self-attention)
        diagonal_vals = torch.diag(attention_matrix)
        off_diagonal_vals = attention_matrix[~torch.eye(5, dtype=bool)]
        
        # Trung bình diagonal nên >= trung bình off-diagonal
        self.assertGreaterEqual(diagonal_vals.mean().item(), off_diagonal_vals.mean().item())
    
    def test_batch_consistency(self):
        """Test tính nhất quán giữa các batch"""
        # Tạo input giống nhau cho cả hai batch items
        X = torch.randn(1, self.seq_len, self.d_model).to(self.device)
        X_batched = X.repeat(2, 1, 1)  # (2, seq_len, d_model)
        
        output = self.model(X_batched, X_batched, X_batched)
        
        # Output của batch item 0 và 1 nên giống nhau
        torch.testing.assert_close(output[0], output[1], rtol=1e-5, atol=1e-6)
    
    def test_numerical_stability(self):
        """Test tính ổn định số học với giá trị extreme"""
        # Test với giá trị lớn
        Q_large = torch.randn(2, 5, self.d_model).to(self.device) * 10
        K_large = torch.randn(2, 5, self.d_model).to(self.device) * 10
        V_large = torch.randn(2, 5, self.d_model).to(self.device) * 10
        
        output = self.model(Q_large, K_large, V_large)
        
        # Kiểm tra không có NaN hoặc Inf
        self.assertFalse(torch.isnan(output).any())
        self.assertFalse(torch.isinf(output).any())
        
        # Test với giá trị nhỏ
        Q_small = torch.randn(2, 5, self.d_model).to(self.device) * 0.01
        K_small = torch.randn(2, 5, self.d_model).to(self.device) * 0.01
        V_small = torch.randn(2, 5, self.d_model).to(self.device) * 0.01
        
        output = self.model(Q_small, K_small, V_small)
        
        self.assertFalse(torch.isnan(output).any())
        self.assertFalse(torch.isinf(output).any())
    
    def test_deterministic_output(self):
        """Test tính deterministic"""
        torch.manual_seed(42)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(42)
        
        output1 = self.model(self.Q, self.K, self.V)
        
        torch.manual_seed(42)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(42)
        
        output2 = self.model(self.Q, self.K, self.V)
        
        torch.testing.assert_close(output1, output2)


def run_comprehensive_test():
    """Chạy tất cả tests với báo cáo chi tiết"""
    print("=" * 70)
    print("COMPREHENSIVE TEST FOR PURE SCALED DOT-PRODUCT ATTENTION")
    print("=" * 70)
    
    # Tạo test suite
    suite = unittest.TestLoader().loadTestsFromTestCase(TestScaledDotProductAttention)
    
    # Chạy tests với verbose output
    runner = unittest.TextTestRunner(verbosity=2)
    result = runner.run(suite)
    
    print(f"\n{'='*70}")
    print("TEST SUMMARY")
    print(f"{'='*70}")
    print(f"Tests run: {result.testsRun}")
    print(f"Failures: {len(result.failures)}")
    print(f"Errors: {len(result.errors)}")
    print(f"Success rate: {((result.testsRun - len(result.failures) - len(result.errors))/result.testsRun)*100:.1f}%")
    
    return result.wasSuccessful()


def demo_attention_visualization():
    """Demo và visualization của attention mechanism"""
    print("\n" + "="*70)
    print("ATTENTION MECHANISM DEMONSTRATION")
    print("="*70)
    
    # Tạo model nhỏ để dễ quan sát
    model = ScaledDotProductAttention(d_model=64)
    
    # Tạo một sequence đơn giản
    batch_size, seq_len = 1, 5
    torch.manual_seed(42)
    
    X = torch.randn(batch_size, seq_len, 64)
    print(f"Input shape: {X.shape}")
    
    # Self-attention
    print(f"\n--- SELF-ATTENTION ---")
    output, attention_weights = model.forward_with_weights(X, X, X)
    print(f"Output shape: {output.shape}")
    print(f"Attention weights shape: {attention_weights.shape}")
    print(f"Attention matrix:")
    att_matrix = attention_weights.squeeze(0).detach().numpy()
    for i in range(seq_len):
        row_str = " ".join([f"{att_matrix[i, j]:.3f}" for j in range(seq_len)])
        print(f"  Position {i}: [{row_str}]")
    
    # Với causal mask
    print(f"\n--- CAUSAL MASKED ATTENTION ---")
    mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0)
    output_masked, attention_weights_masked = model.forward_with_weights(X, X, X, mask=mask)
    print(f"Masked attention matrix:")
    att_matrix_masked = attention_weights_masked.squeeze(0).detach().numpy()
    for i in range(seq_len):
        row_str = " ".join([f"{att_matrix_masked[i, j]:.3f}" for j in range(seq_len)])
        print(f"  Position {i}: [{row_str}]")
    
    # Cross-attention example
    print(f"\n--- CROSS-ATTENTION ---")
    K_cross = torch.randn(batch_size, 3, 64)  # Shorter key/value sequence
    V_cross = torch.randn(batch_size, 3, 64)
    
    output_cross, attention_weights_cross = model.forward_with_weights(X, K_cross, V_cross)
    print(f"Cross-attention output shape: {output_cross.shape}")
    print(f"Cross-attention weights shape: {attention_weights_cross.shape}")
    
    print("✓ Demonstration completed successfully!")


if __name__ == "__main__":
    # Chạy comprehensive test
    success = run_comprehensive_test()
    
    # Chạy demonstration
    demo_attention_visualization()
    
    if success:
        print(f"\n🎉 All tests passed! Your Pure Scaled Dot-Product Attention is working correctly!")
    else:
        print(f"\n❌ Some tests failed. Please check the output above.")

test_attention_pattern_symmetry (__main__.TestScaledDotProductAttention)
Test tính chất đối xứng khi Q=K ... 

COMPREHENSIVE TEST FOR PURE SCALED DOT-PRODUCT ATTENTION


FAIL
test_attention_weights_properties (__main__.TestScaledDotProductAttention)
Test các tính chất của attention weights ... ok
test_batch_consistency (__main__.TestScaledDotProductAttention)
Test tính nhất quán giữa các batch ... ok
test_causal_mask (__main__.TestScaledDotProductAttention)
Test với causal mask (cho decoder self-attention) ... ok
test_deterministic_output (__main__.TestScaledDotProductAttention)
Test tính deterministic ... ok
test_forward_shape_cross_attention (__main__.TestScaledDotProductAttention)
Test shape với cross-attention (K,V có seq_len khác Q) ... ok
test_forward_shape_self_attention (__main__.TestScaledDotProductAttention)
Test shape với self-attention (Q=K=V) ... ok
test_gradient_flow (__main__.TestScaledDotProductAttention)
Test gradient flow qua attention mechanism ... ok
test_initialization (__main__.TestScaledDotProductAttention)
Test khởi tạo model ... ok
test_numerical_stability (__main__.TestScaledDotProductAttention)
Test tính ổn định số học với gi


TEST SUMMARY
Tests run: 12
Failures: 1
Errors: 0
Success rate: 91.7%

ATTENTION MECHANISM DEMONSTRATION
Input shape: torch.Size([1, 5, 64])

--- SELF-ATTENTION ---
Output shape: torch.Size([1, 5, 64])
Attention weights shape: torch.Size([1, 5, 5])
Attention matrix:
  Position 0: [0.256 0.132 0.186 0.240 0.186]
  Position 1: [0.171 0.176 0.243 0.149 0.261]
  Position 2: [0.233 0.178 0.196 0.182 0.211]
  Position 3: [0.173 0.213 0.175 0.282 0.156]
  Position 4: [0.237 0.163 0.094 0.170 0.336]

--- CAUSAL MASKED ATTENTION ---
Masked attention matrix:
  Position 0: [1.000 0.000 0.000 0.000 0.000]
  Position 1: [0.494 0.506 0.000 0.000 0.000]
  Position 2: [0.384 0.293 0.323 0.000 0.000]
  Position 3: [0.205 0.253 0.207 0.335 0.000]
  Position 4: [0.237 0.163 0.094 0.170 0.336]

--- CROSS-ATTENTION ---
Cross-attention output shape: torch.Size([1, 5, 64])
Cross-attention weights shape: torch.Size([1, 5, 3])
✓ Demonstration completed successfully!

❌ Some tests failed. Please check the outpu