In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import time

class CompressedAttention(nn.Module):
    """
    Description:
      –†–µ–∞–ª–∏–∑–∞—Ü–∏—è –º–µ—Ö–∞–Ω–∏–∑–º–∞ —Å–∂–∞—Ç–æ–≥–æ –≤–Ω–∏–º–∞–Ω–∏—è (Compressed Attention) –∏–∑ –º–µ—Ç–æ–¥–∞ NSA

    –ü–∞—Ä–∞–º–µ—Ç—Ä—ã:
        hidden_size (int): –†–∞–∑–º–µ—Ä —Å–∫—Ä—ã—Ç–æ–≥–æ —Å–æ—Å—Ç–æ—è–Ω–∏—è
        block_size (int): –†–∞–∑–º–µ—Ä –±–ª–æ–∫–∞ –¥–ª—è —Å–∂–∞—Ç–∏—è (–ø–∞—Ä–∞–º–µ—Ç—Ä l –≤ —Å—Ç–∞—Ç—å–µ)
        stride (int): –®–∞–≥ –º–µ–∂–¥—É –±–ª–æ–∫–∞–º–∏ (–ø–∞—Ä–∞–º–µ—Ç—Ä d –≤ —Å—Ç–∞—Ç—å–µ)
        num_heads (int): –ö–æ–ª–∏—á–µ—Å—Ç–≤–æ –≥–æ–ª–æ–≤ –≤–Ω–∏–º–∞–Ω–∏—è
        dropout (float): –í–µ—Ä–æ—è—Ç–Ω–æ—Å—Ç—å –¥—Ä–æ–ø–∞—É—Ç–∞
    """
    def __init__(self, hidden_size, block_size=32, stride=16, num_heads=4, dropout=0.1):
        super(CompressedAttention, self).__init__()

        self.hidden_size = hidden_size
        self.block_size = block_size
        self.stride = stride
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        # –ü—Ä–æ–µ–∫—Ü–∏–∏ –¥–ª—è –∑–∞–ø—Ä–æ—Å–æ–≤, –∫–ª—é—á–µ–π –∏ –∑–Ω–∞—á–µ–Ω–∏–π
        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)

        # –ü—Ä–æ–µ–∫—Ü–∏—è –¥–ª—è –≤—ã—Ö–æ–¥–∞
        self.out_proj = nn.Linear(hidden_size, hidden_size)

        # –§—É–Ω–∫—Ü–∏—è —Å–∂–∞—Ç–∏—è œÜ (MLP –¥–ª—è —Å–∂–∞—Ç–∏—è –±–ª–æ–∫–æ–≤)
        self.block_compressor = nn.Sequential(
            nn.Linear(block_size * self.head_dim, 2 * self.head_dim),
            nn.GELU(),
            nn.Linear(2 * self.head_dim, self.head_dim)
        )

        self.dropout = nn.Dropout(dropout)
        self.scale = self.head_dim ** -0.5

    def _get_blocks(self, x, block_size, stride):
        """
        Description:
          –†–∞–∑–±–∏–≤–∞–µ—Ç –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç—å –Ω–∞ –±–ª–æ–∫–∏ —Å –∑–∞–¥–∞–Ω–Ω—ã–º —Ä–∞–∑–º–µ—Ä–æ–º –∏ —à–∞–≥–æ–º

        –ê—Ä–≥—É–º–µ–Ω—Ç—ã:
            x: —Ç–µ–Ω–∑–æ—Ä —Ñ–æ—Ä–º—ã (batch_size, seq_len, hidden_size)
            block_size: —Ä–∞–∑–º–µ—Ä –±–ª–æ–∫–∞
            stride: —à–∞–≥ –º–µ–∂–¥—É –±–ª–æ–∫–∞–º–∏

        –í–æ–∑–≤—Ä–∞—â–∞–µ—Ç:
            blocks: —Å–ø–∏—Å–æ–∫ –±–ª–æ–∫–æ–≤
            block_indices: —Å–ø–∏—Å–æ–∫ –¥–∏–∞–ø–∞–∑–æ–Ω–æ–≤ –∏–Ω–¥–µ–∫—Å–æ–≤ –¥–ª—è –∫–∞–∂–¥–æ–≥–æ –±–ª–æ–∫–∞
        """
        batch_size, seq_len, hidden_size = x.shape
        blocks = []
        block_indices = []

        # –°–æ–∑–¥–∞–µ–º –±–ª–æ–∫–∏ —Å –ø–µ—Ä–µ–∫—Ä—ã—Ç–∏–µ–º
        for i in range(0, seq_len - block_size + 1, stride):
            block = x[:, i:i+block_size, :]  # (batch_size, block_size, hidden_size)
            blocks.append(block)
            block_indices.append((i, i+block_size))

        return blocks, block_indices

    def compress_blocks(self, blocks, head_dim):
        """
        Description:
          –°–∂–∏–º–∞–µ—Ç –±–ª–æ–∫–∏ —Ç–æ–∫–µ–Ω–æ–≤ –≤ –µ–¥–∏–Ω—ã–µ –ø—Ä–µ–¥—Å—Ç–∞–≤–ª–µ–Ω–∏—è —Å –ø–æ–º–æ—â—å—é MLP

        –ê—Ä–≥—É–º–µ–Ω—Ç—ã:
            blocks: —Å–ø–∏—Å–æ–∫ –±–ª–æ–∫–æ–≤ —Ñ–æ—Ä–º—ã (batch_size, block_size, head_dim)
            head_dim: —Ä–∞–∑–º–µ—Ä–Ω–æ—Å—Ç—å –≥–æ–ª–æ–≤—ã –≤–Ω–∏–º–∞–Ω–∏—è

        –í–æ–∑–≤—Ä–∞—â–∞–µ—Ç:
            compressed_blocks: —Ç–µ–Ω–∑–æ—Ä —Ñ–æ—Ä–º—ã (batch_size, num_blocks, head_dim)
        """
        batch_size = blocks[0].shape[0]
        num_blocks = len(blocks)

        # –û–±—ä–µ–¥–∏–Ω—è–µ–º –≤—Å–µ –±–ª–æ–∫–∏ –≤ –æ–¥–∏–Ω —Ç–µ–Ω–∑–æ—Ä
        blocks_tensor = torch.cat([block.unsqueeze(1) for block in blocks], dim=1)  # (batch_size, num_blocks, block_size, head_dim)

        # –†–µ—à–µ–π–ø –¥–ª—è –ø–µ—Ä–µ–¥–∞—á–∏ –≤ MLP
        reshaped_blocks = blocks_tensor.reshape(batch_size * num_blocks, -1)        # (batch_size * num_blocks, block_size * head_dim)

        # –ü—Ä–∏–º–µ–Ω—è–µ–º —Å–∂–∞—Ç–∏–µ (—Ñ—É–Ω–∫—Ü–∏—è œÜ –∏–∑ —Å—Ç–∞—Ç—å–∏)
        compressed = self.block_compressor(reshaped_blocks)                         # (batch_size * num_blocks, head_dim)

        # –ü—Ä–∏–≤–æ–¥–∏–º –∫ –Ω—É–∂–Ω–æ–π —Ñ–æ—Ä–º–µ
        compressed_blocks = compressed.reshape(batch_size, num_blocks, head_dim)    # (batch_size, num_blocks, head_dim)

        return compressed_blocks

    def forward(self, hidden_states, attention_mask=None, output_attentions=False):
        """
        Description:
          –í—ã–ø–æ–ª–Ω—è–µ—Ç —Å–∂–∞—Ç–æ–µ –≤–Ω–∏–º–∞–Ω–∏–µ –Ω–∞–¥ –≤—Ö–æ–¥–Ω–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç—å—é

        –ê—Ä–≥—É–º–µ–Ω—Ç—ã:
            hidden_states: —Ç–µ–Ω–∑–æ—Ä —Ñ–æ—Ä–º—ã (batch_size, seq_len, hidden_size)
            attention_mask: –º–∞—Å–∫–∞ –≤–Ω–∏–º–∞–Ω–∏—è
            output_attentions: —Ñ–ª–∞–≥ –¥–ª—è –≤—ã–≤–æ–¥–∞ –º–∞—Ç—Ä–∏—Ü—ã –≤–Ω–∏–º–∞–Ω–∏—è

        –í–æ–∑–≤—Ä–∞—â–∞–µ—Ç:
            context_layer: —Ç–µ–Ω–∑–æ—Ä –≤—ã—Ö–æ–¥–∞ —Ñ–æ—Ä–º—ã (batch_size, seq_len, hidden_size)
            attention_probs (–æ–ø—Ü–∏–æ–Ω–∞–ª—å–Ω–æ): –º–∞—Ç—Ä–∏—Ü–∞ –≤–Ω–∏–º–∞–Ω–∏—è
        """
        batch_size, seq_len, _ = hidden_states.shape

        # –®–∞–≥ 1: –ü—Ä–æ–µ–∫—Ü–∏–∏ –∑–∞–ø—Ä–æ—Å–æ–≤, –∫–ª—é—á–µ–π –∏ –∑–Ω–∞—á–µ–Ω–∏–π
        q = self.q_proj(hidden_states)  # (batch_size, seq_len, hidden_size)
        k = self.k_proj(hidden_states)  # (batch_size, seq_len, hidden_size)
        v = self.v_proj(hidden_states)  # (batch_size, seq_len, hidden_size)

        # –†–∞–∑–¥–µ–ª–µ–Ω–∏–µ –Ω–∞ –≥–æ–ª–æ–≤—ã –≤–Ω–∏–º–∞–Ω–∏—è
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # (batch_size, num_heads, seq_len, head_dim)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # (batch_size, num_heads, seq_len, head_dim)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # (batch_size, num_heads, seq_len, head_dim)

        # –®–∞–≥ 2: –ü–æ–ª—É—á–µ–Ω–∏–µ –±–ª–æ–∫–æ–≤ –∏ –∏—Ö —Å–∂–∞—Ç–∏–µ –¥–ª—è –∫–ª—é—á–µ–π –∏ –∑–Ω–∞—á–µ–Ω–∏–π
        all_compressed_k  = []
        all_compressed_v  = []
        all_block_indices = []

        # –ü—Ä–∏–º–µ–Ω—è–µ–º –¥–ª—è –∫–∞–∂–¥–æ–π –≥–æ–ª–æ–≤—ã –≤–Ω–∏–º–∞–Ω–∏—è –æ—Ç–¥–µ–ª—å–Ω–æ
        for h in range(self.num_heads):
            head_k = k[:, h]  # (batch_size, seq_len, head_dim)
            head_v = v[:, h]  # (batch_size, seq_len, head_dim)

            # –†–∞–∑–±–∏–≤–∞–µ–º –Ω–∞ –±–ª–æ–∫–∏ –∏ –ø–æ–ª—É—á–∞–µ–º –∏—Ö –∏–Ω–¥–µ–∫—Å—ã
            blocks_k, block_indices = self._get_blocks(head_k, self.block_size, self.stride)
            blocks_v, _ = self._get_blocks(head_v, self.block_size, self.stride)

            # –°–∂–∏–º–∞–µ–º –±–ª–æ–∫–∏ —Å –ø–æ–º–æ—â—å—é MLP
            compressed_k = self.compress_blocks(blocks_k, self.head_dim)  # (batch_size, num_blocks, head_dim)
            compressed_v = self.compress_blocks(blocks_v, self.head_dim)  # (batch_size, num_blocks, head_dim)

            all_compressed_k.append(compressed_k)
            all_compressed_v.append(compressed_v)
            all_block_indices.append(block_indices)

        # –û–±—ä–µ–¥–∏–Ω—è–µ–º —Ä–µ–∑—É–ª—å—Ç–∞—Ç—ã –¥–ª—è –≤—Å–µ—Ö –≥–æ–ª–æ–≤
        compressed_k = torch.stack(all_compressed_k, dim=1)  # (batch_size, num_heads, num_blocks, head_dim)
        compressed_v = torch.stack(all_compressed_v, dim=1)  # (batch_size, num_heads, num_blocks, head_dim)

        # –î–ª—è –ø—Ä–∏–º–µ—Ä–∞ –±–µ—Ä–µ–º –∏–Ω–¥–µ–∫—Å—ã –∏–∑ –ø–µ—Ä–≤–æ–π –≥–æ–ª–æ–≤—ã, –æ–Ω–∏ –æ–¥–∏–Ω–∞–∫–æ–≤—ã–µ –¥–ª—è –≤—Å–µ—Ö
        block_indices = all_block_indices[0]
        num_blocks = len(block_indices)

        # –®–∞–≥ 3: –í—ã—á–∏—Å–ª–µ–Ω–∏–µ –≤–Ω–∏–º–∞–Ω–∏—è –º–µ–∂–¥—É –∑–∞–ø—Ä–æ—Å–∞–º–∏ –∏ —Å–∂–∞—Ç—ã–º–∏ –∫–ª—é—á–∞–º–∏
        # –î–ª—è –∫–∞–∂–¥–æ–≥–æ –∑–∞–ø—Ä–æ—Å–∞ –º—ã –≤—ã—á–∏—Å–ª—è–µ–º –µ–≥–æ –≤–Ω–∏–º–∞–Ω–∏–µ –∫ —Å–∂–∞—Ç—ã–º –±–ª–æ–∫–∞–º

        # –ú–∞—Ç—Ä–∏—Ü–∞ –≤–Ω–∏–º–∞–Ω–∏—è: (batch_size, num_heads, seq_len, num_blocks)
        attention_scores = torch.matmul(q, compressed_k.transpose(-1, -2)) * self.scale

        if attention_mask is not None:
            # –ü—Ä–∏–º–µ–Ω—è–µ–º –º–∞—Å–∫—É –≤–Ω–∏–º–∞–Ω–∏—è (–µ—Å–ª–∏ –æ–Ω–∞ –µ—Å—Ç—å)
            attention_scores = attention_scores + attention_mask

        # –ù–æ—Ä–º–∞–ª–∏–∑–∞—Ü–∏—è –≤–µ—Å–æ–≤ —Å –ø–æ–º–æ—â—å—é softmax
        attention_probs = F.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)

        # –®–∞–≥ 4: –í–∑–≤–µ—à–µ–Ω–Ω–∞—è —Å—É–º–º–∞ —Å–∂–∞—Ç—ã—Ö –∑–Ω–∞—á–µ–Ω–∏–π
        context_layer = torch.matmul(attention_probs, compressed_v)                # (batch_size, num_heads, seq_len, head_dim)

        # –í–æ—Å—Å—Ç–∞–Ω–æ–≤–ª–µ–Ω–∏–µ –∏—Å—Ö–æ–¥–Ω–æ–π —Ñ–æ—Ä–º—ã
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()             # (batch_size, seq_len, num_heads, head_dim)
        context_layer = context_layer.view(batch_size, seq_len, self.hidden_size)  # (batch_size, seq_len, hidden_size)

        # –§–∏–Ω–∞–ª—å–Ω–∞—è –ø—Ä–æ–µ–∫—Ü–∏—è
        output = self.out_proj(context_layer)

        if output_attentions:
            return output, attention_probs, block_indices
        else:
            return output


def demonstrate_compressed_attention(use_long_sequence=False):
    """
    Description:
      –î–µ–º–æ–Ω—Å—Ç—Ä–∏—Ä—É–µ—Ç —Ä–∞–±–æ—Ç—É –º–µ—Ö–∞–Ω–∏–∑–º–∞ —Å–∂–∞—Ç–æ–≥–æ –≤–Ω–∏–º–∞–Ω–∏—è
      –ü–æ–∫–∞–∑—ã–≤–∞–µ—Ç —Å—Ä–∞–≤–Ω–µ–Ω–∏–µ —Å –æ–±—ã—á–Ω—ã–º –ø–æ–ª–Ω—ã–º –≤–Ω–∏–º–∞–Ω–∏–µ–º

    –ê—Ä–≥—É–º–µ–Ω—Ç—ã:
        use_long_sequence: –µ—Å–ª–∏ True, –∏—Å–ø–æ–ª—å–∑—É–µ—Ç –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç—å –¥–ª–∏–Ω–æ–π 32K —Ç–æ–∫–µ–Ω–æ–≤
    """
    # –ü–∞—Ä–∞–º–µ—Ç—Ä—ã –¥–ª—è –¥–µ–º–æ–Ω—Å—Ç—Ä–∞—Ü–∏–∏
    hidden_size = 64
    num_heads = 4
    batch_size = 1

    if use_long_sequence:
        # –ü–∞—Ä–∞–º–µ—Ç—Ä—ã –¥–ª—è –¥–ª–∏–Ω–Ω–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ (32K)
        seq_len = 32000
        block_size = 256  # –£–≤–µ–ª–∏—á–∏–≤–∞–µ–º —Ä–∞–∑–º–µ—Ä –±–ª–æ–∫–∞ –¥–ª—è –±–æ–ª–µ–µ —ç—Ñ—Ñ–µ–∫—Ç–∏–≤–Ω–æ–≥–æ —Å–∂–∞—Ç–∏—è
        stride = 128      # –£–≤–µ–ª–∏—á–∏–≤–∞–µ–º —à–∞–≥ –¥–ª—è –±–æ–ª–µ–µ —ç—Ñ—Ñ–µ–∫—Ç–∏–≤–Ω–æ–≥–æ —Å–∂–∞—Ç–∏—è
    else:
        # –ü–∞—Ä–∞–º–µ—Ç—Ä—ã –¥–ª—è –∫–æ—Ä–æ—Ç–∫–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ (128)
        seq_len = 128
        block_size = 32
        stride = 16

    # –°–æ–∑–¥–∞–µ–º –º–æ–¥–µ–ª–∏
    compressed_attention = CompressedAttention(
        hidden_size=hidden_size,
        block_size=block_size,
        stride=stride,
        num_heads=num_heads
    )

    # –°–æ–∑–¥–∞–µ–º –≤—Ö–æ–¥–Ω—ã–µ –¥–∞–Ω–Ω—ã–µ —Å –æ–ø—Ä–µ–¥–µ–ª–µ–Ω–Ω—ã–º–∏ –ø–∞—Ç—Ç–µ—Ä–Ω–∞–º–∏
    # –ú—ã —Å–æ–∑–¥–∞–¥–∏–º –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç—å, –≥–¥–µ –Ω–µ–∫–æ—Ç–æ—Ä—ã–µ —Ç–æ–∫–µ–Ω—ã –±—É–¥—É—Ç "–≤–∞–∂–Ω—ã–º–∏"
    torch.manual_seed(42)  # –î–ª—è –≤–æ—Å–ø—Ä–æ–∏–∑–≤–æ–¥–∏–º–æ—Å—Ç–∏

    print(f"üìå –°–æ–∑–¥–∞–Ω–∏–µ –≤—Ö–æ–¥–Ω—ã—Ö –¥–∞–Ω–Ω—ã—Ö –¥–ª–∏–Ω–æ–π {seq_len} —Ç–æ–∫–µ–Ω–æ–≤...")

    # –ë–∞–∑–æ–≤—ã–π –≤—Ö–æ–¥–Ω–æ–π —Ç–µ–Ω–∑–æ—Ä (—Å–æ–∑–¥–∞–µ–º —ç—Ñ—Ñ–µ–∫—Ç–∏–≤–Ω–æ, –±–µ–∑ —á—Ä–µ–∑–º–µ—Ä–Ω–æ–≥–æ –∏—Å–ø–æ–ª—å–∑–æ–≤–∞–Ω–∏—è –ø–∞–º—è—Ç–∏)
    x = torch.zeros(batch_size, seq_len, hidden_size)

    # –ó–∞–ø–æ–ª–Ω—è–µ–º —à—É–º–æ–º –±–æ–ª–µ–µ —ç—Ñ—Ñ–µ–∫—Ç–∏–≤–Ω–æ (–ø–æ —á–∞—Å—Ç—è–º)
    chunk_size = 1000 if use_long_sequence else seq_len
    for i in range(0, seq_len, chunk_size):
        end = min(i + chunk_size, seq_len)
        x[:, i:end, :] = torch.randn(batch_size, end-i, hidden_size) * 0.1

    # –î–æ–±–∞–≤–ª—è–µ–º "–≤–∞–∂–Ω—ã–µ" —Ç–æ–∫–µ–Ω—ã —á–µ—Ä–µ–∑ —Ä–∞–≤–Ω—ã–µ –ø—Ä–æ–º–µ–∂—É—Ç–∫–∏
    # –î–ª—è –¥–ª–∏–Ω–Ω–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ —É–≤–µ–ª–∏—á–∏–≤–∞–µ–º –∏–Ω—Ç–µ—Ä–≤–∞–ª
    important_interval = 1000 if use_long_sequence else 10
    important_positions = list(range(0, seq_len, important_interval))
    for pos in important_positions:
        if pos < seq_len:
            x[:, pos, :] = torch.ones(hidden_size)  # –í—ã–¥–µ–ª—è–µ–º –≤–∞–∂–Ω—ã–µ —Ç–æ–∫–µ–Ω—ã –∑–Ω–∞—á–µ–Ω–∏–µ–º 1

    # –î–æ–±–∞–≤–ª—è–µ–º –Ω–µ—Å–∫–æ–ª—å–∫–æ –∫–ª–∞—Å—Ç–µ—Ä–æ–≤ "–≤–∞–∂–Ω—ã—Ö" —Ç–æ–∫–µ–Ω–æ–≤
    if use_long_sequence:
        # –°–æ–∑–¥–∞–µ–º 3 –∫–ª–∞—Å—Ç–µ—Ä–∞ –≤ –Ω–∞—á–∞–ª–µ, —Å–µ—Ä–µ–¥–∏–Ω–µ –∏ –∫–æ–Ω—Ü–µ –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏
        cluster_positions = [
            (1000, 1200),      # –ù–∞—á–∞–ª–æ
            (seq_len//2-100, seq_len//2+100),  # –°–µ—Ä–µ–¥–∏–Ω–∞
            (seq_len-1200, seq_len-1000)       # –ö–æ–Ω–µ—Ü
        ]
    else:
        # –î–ª—è –∫–æ—Ä–æ—Ç–∫–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ - –æ–¥–∏–Ω –∫–ª–∞—Å—Ç–µ—Ä –≤ —Å–µ—Ä–µ–¥–∏–Ω–µ
        middle_start = seq_len // 3
        cluster_positions = [(middle_start, middle_start + 20)]

    # –î–æ–±–∞–≤–ª—è–µ–º –∫–ª–∞—Å—Ç–µ—Ä—ã –≤–∞–∂–Ω—ã—Ö —Ç–æ–∫–µ–Ω–æ–≤
    for start, end in cluster_positions:
        for pos in range(start, end):
            if pos < seq_len:
                x[:, pos, :] = torch.ones(hidden_size) * 0.8  # –ö–ª–∞—Å—Ç–µ—Ä —Å —á—É—Ç—å –º–µ–Ω—å—à–µ–π –≤–∞–∂–Ω–æ—Å—Ç—å—é

    # –í—ã—á–∏—Å–ª—è–µ–º –≤–∞–∂–Ω–æ—Å—Ç—å —Ç–æ–ª—å–∫–æ –¥–ª—è —á–∞—Å—Ç–∏ —Ç–æ–∫–µ–Ω–æ–≤ –≤ –¥–ª–∏–Ω–Ω–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏
    token_importance = x.sum(dim=2).squeeze().cpu().numpy()

    print("\n" + "="*80)
    print("–î–ï–ú–û–ù–°–¢–†–ê–¶–ò–Ø –ú–ï–•–ê–ù–ò–ó–ú–ê –°–ñ–ê–¢–û–ì–û –í–ù–ò–ú–ê–ù–ò–Ø (COMPRESSED ATTENTION)")
    print("="*80 + "\n")

    print(f"üìå –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è –º–æ–¥–µ–ª–∏ CompressedAttention —Å –ø–∞—Ä–∞–º–µ—Ç—Ä–∞–º–∏:")
    print(f"  - –†–∞–∑–º–µ—Ä —Å–∫—Ä—ã—Ç–æ–≥–æ —Å–æ—Å—Ç–æ—è–Ω–∏—è (hidden_size): {hidden_size}")
    print(f"  - –†–∞–∑–º–µ—Ä –±–ª–æ–∫–∞ (block_size): {block_size}")
    print(f"  - –®–∞–≥ (stride): {stride}")
    print(f"  - –ö–æ–ª–∏—á–µ—Å—Ç–≤–æ –≥–æ–ª–æ–≤ –≤–Ω–∏–º–∞–Ω–∏—è (num_heads): {num_heads}\n")

    print(f"üìå –°–æ–∑–¥–∞–Ω–∏–µ –≤—Ö–æ–¥–Ω—ã—Ö –¥–∞–Ω–Ω—ã—Ö:")
    print(f"  - –†–∞–∑–º–µ—Ä –ø–∞–∫–µ—Ç–∞ (batch_size): {batch_size}")
    print(f"  - –î–ª–∏–Ω–∞ –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ (seq_len): {seq_len}\n")

    print(f"üìå –ü–æ–¥–≥–æ—Ç–æ–≤–∫–∞ –¥–∞–Ω–Ω—ã—Ö:")
    print(f"  - –°–æ–∑–¥–∞–ª–∏ –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç—å —Å –Ω–µ—Å–∫–æ–ª—å–∫–∏–º–∏ –ø–∞—Ç—Ç–µ—Ä–Ω–∞–º–∏:")
    if use_long_sequence:
        print(f"    1. –†–∞–≤–Ω–æ–º–µ—Ä–Ω–æ —Ä–∞—Å–ø—Ä–µ–¥–µ–ª–µ–Ω–Ω—ã–µ '–≤–∞–∂–Ω—ã–µ' —Ç–æ–∫–µ–Ω—ã –∫–∞–∂–¥—ã–µ 1000 –ø–æ–∑–∏—Ü–∏–π")
        print(f"    2. –ö–ª–∞—Å—Ç–µ—Ä—ã '–≤–∞–∂–Ω—ã—Ö' —Ç–æ–∫–µ–Ω–æ–≤ –≤ –Ω–∞—á–∞–ª–µ (1000-1200), —Å–µ—Ä–µ–¥–∏–Ω–µ –∏ –∫–æ–Ω—Ü–µ –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏")
    else:
        print(f"    1. –†–∞–≤–Ω–æ–º–µ—Ä–Ω–æ —Ä–∞—Å–ø—Ä–µ–¥–µ–ª–µ–Ω–Ω—ã–µ '–≤–∞–∂–Ω—ã–µ' —Ç–æ–∫–µ–Ω—ã –∫–∞–∂–¥—ã–µ 10 –ø–æ–∑–∏—Ü–∏–π")
        print(f"    2. –ö–ª–∞—Å—Ç–µ—Ä '–≤–∞–∂–Ω—ã—Ö' —Ç–æ–∫–µ–Ω–æ–≤ –≤ —Å–µ—Ä–µ–¥–∏–Ω–µ (–ø–æ–∑–∏—Ü–∏–∏ {cluster_positions[0][0]}-{cluster_positions[0][1]})")
    print(f"    3. –°–ª—É—á–∞–π–Ω—ã–π —à—É–º –¥–ª—è –æ—Å—Ç–∞–ª—å–Ω—ã—Ö —Ç–æ–∫–µ–Ω–æ–≤\n")

    print(f"üìå –í–∞–∂–Ω–æ—Å—Ç—å —Ç–æ–∫–µ–Ω–æ–≤ (—Å—É–º–º–∞ –∑–Ω–∞—á–µ–Ω–∏–π –ø–æ —Å–∫—Ä—ã—Ç–æ–º—É –∏–∑–º–µ—Ä–µ–Ω–∏—é, –ø—Ä–∏–º–µ—Ä—ã):")

    # –î–ª—è –¥–ª–∏–Ω–Ω–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ –ø–æ–∫–∞–∑—ã–≤–∞–µ–º —Ç–æ–ª—å–∫–æ –æ–±—Ä–∞–∑—Ü—ã
    if use_long_sequence:
        # –ü–æ–∫–∞–∑—ã–≤–∞–µ–º –Ω–∞—á–∞–ª–æ, —Å–µ—Ä–µ–¥–∏–Ω—É –∏ –∫–æ–Ω–µ—Ü –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏
        sample_ranges = [
            (0, 32),                          # –ù–∞—á–∞–ª–æ
            (seq_len//2-16, seq_len//2+16),   # –°–µ—Ä–µ–¥–∏–Ω–∞
            (seq_len-32, seq_len)             # –ö–æ–Ω–µ—Ü
        ]

        for start, end in sample_ranges:
            print(f"  –ü–æ–∑–∏—Ü–∏–∏ {start:5d}-{end-1:5d} (–ø—Ä–∏–º–µ—Ä):")
            for i in range(start, end, 16):
                end_i = min(i + 16, end)
                values = [f"{token_importance[j]:4.1f}" for j in range(i, end_i)]
                print(f"    {i:5d}-{end_i-1:5d}: {' '.join(values)}")
    else:
        # –î–ª—è –∫–æ—Ä–æ—Ç–∫–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ –ø–æ–∫–∞–∑—ã–≤–∞–µ–º –≤—Å–µ —Ç–æ–∫–µ–Ω—ã
        for i in range(0, seq_len, 16):
            end = min(i + 16, seq_len)
            values = [f"{token_importance[j]:4.1f}" for j in range(i, end)]
            print(f"  –ü–æ–∑–∏—Ü–∏–∏ {i:3d}-{end-1:3d}: {' '.join(values)}")
    print()

    # –ü—Ä–æ–µ–∫—Ü–∏–∏ –∑–∞–ø—Ä–æ—Å–æ–≤, –∫–ª—é—á–µ–π –∏ –∑–Ω–∞—á–µ–Ω–∏–π
    q = compressed_attention.q_proj(x)
    k = compressed_attention.k_proj(x)
    v = compressed_attention.v_proj(x)

    print(f"üìå –®–∞–≥ 1: –ü—Ä–æ–µ–∫—Ü–∏–∏ –∑–∞–ø—Ä–æ—Å–æ–≤, –∫–ª—é—á–µ–π –∏ –∑–Ω–∞—á–µ–Ω–∏–π")
    print(f"  - –§–æ—Ä–º–∞ –∑–∞–ø—Ä–æ—Å–æ–≤ (q): {q.shape}")
    print(f"  - –§–æ—Ä–º–∞ –∫–ª—é—á–µ–π (k): {k.shape}")
    print(f"  - –§–æ—Ä–º–∞ –∑–Ω–∞—á–µ–Ω–∏–π (v): {v.shape}\n")

    print(f"üìå –®–∞–≥ 2: –°–∂–∞—Ç–∏–µ –∫–ª—é—á–µ–π –∏ –∑–Ω–∞—á–µ–Ω–∏–π (—Å–∞–º–∞—è –≤–∞–∂–Ω–∞—è —á–∞—Å—Ç—å)")
    print(f"  - –†–∞–∑–º–µ—Ä –±–ª–æ–∫–∞ (block_size): {block_size}")
    print(f"  - –®–∞–≥ (stride): {stride}")

    # –†–∞–∑–¥–µ–ª–µ–Ω–∏–µ k –Ω–∞ –≥–æ–ª–æ–≤—ã –≤–Ω–∏–º–∞–Ω–∏—è
    k_heads = k.view(batch_size, seq_len, num_heads, hidden_size // num_heads).permute(0, 2, 1, 3)
    v_heads = v.view(batch_size, seq_len, num_heads, hidden_size // num_heads).permute(0, 2, 1, 3)

    # –ü–æ–ª—É—á–∞–µ–º –±–ª–æ–∫–∏ –¥–ª—è –ø–µ—Ä–≤–æ–π –≥–æ–ª–æ–≤—ã
    head_k = k_heads[:, 0]  # (batch_size, seq_len, head_dim)
    blocks_k, block_indices = compressed_attention._get_blocks(head_k, block_size, stride)

    num_blocks = len(blocks_k)
    print(f"  - –ö–æ–ª–∏—á–µ—Å—Ç–≤–æ –±–ª–æ–∫–æ–≤ –ø–æ—Å–ª–µ —Ä–∞–∑–±–∏–µ–Ω–∏—è: {num_blocks}")
    print(f"  - –ò–Ω–¥–µ–∫—Å—ã –±–ª–æ–∫–æ–≤: {block_indices}\n")

    # –°–∂–∏–º–∞–µ–º –±–ª–æ–∫–∏
    head_dim = hidden_size // num_heads
    compressed_k = compressed_attention.compress_blocks(blocks_k, head_dim)

    print(f"  - –§–æ—Ä–º–∞ —Å–∂–∞—Ç—ã—Ö –∫–ª—é—á–µ–π: {compressed_k.shape}")
    print(f"  - –ö–æ—ç—Ñ—Ñ–∏—Ü–∏–µ–Ω—Ç —Å–∂–∞—Ç–∏—è: {seq_len} / {compressed_k.shape[1]} = {seq_len / compressed_k.shape[1]:.1f}x\n")

    # –ó–∞–ø—É—Å–∫–∞–µ–º –ø–æ–ª–Ω–æ–µ –≤—ã—á–∏—Å–ª–µ–Ω–∏–µ
    print(f"üìå –®–∞–≥ 3: –í—ã—á–∏—Å–ª–µ–Ω–∏–µ –≤–Ω–∏–º–∞–Ω–∏—è –Ω–∞ —Å–∂–∞—Ç—ã—Ö –ø—Ä–µ–¥—Å—Ç–∞–≤–ª–µ–Ω–∏—è—Ö")
    output, attention_probs, block_indices = compressed_attention(x, output_attentions=True)

    # –í–∏–∑—É–∞–ª–∏–∑–∏—Ä—É–µ–º –º–∞—Ç—Ä–∏—Ü—É –≤–Ω–∏–º–∞–Ω–∏—è –¥–ª—è –ø–µ—Ä–≤–æ–π –≥–æ–ª–æ–≤—ã
    print(f"  - –§–æ—Ä–º–∞ –≤—ã—Ö–æ–¥–∞: {output.shape}")
    print(f"  - –§–æ—Ä–º–∞ –º–∞—Ç—Ä–∏—Ü—ã –≤–Ω–∏–º–∞–Ω–∏—è: {attention_probs.shape}\n")

    # –°—Ä–∞–≤–Ω–∏–≤–∞–µ–º –≤—ã—á–∏—Å–ª–∏—Ç–µ–ª—å–Ω—É—é —Å–ª–æ–∂–Ω–æ—Å—Ç—å
    print(f"üìå –®–∞–≥ 4: –°—Ä–∞–≤–Ω–µ–Ω–∏–µ –≤—ã—á–∏—Å–ª–∏—Ç–µ–ª—å–Ω–æ–π —Å–ª–æ–∂–Ω–æ—Å—Ç–∏")

    # –°—Ç–∞–Ω–¥–∞—Ä—Ç–Ω–æ–µ –≤–Ω–∏–º–∞–Ω–∏–µ: O(seq_len^2)
    full_attention_complexity = seq_len * seq_len

    # –°–∂–∞—Ç–æ–µ –≤–Ω–∏–º–∞–Ω–∏–µ: O(seq_len * num_blocks)
    compressed_attention_complexity = seq_len * num_blocks

    print(f"  - –°—Ç–∞–Ω–¥–∞—Ä—Ç–Ω–æ–µ –≤–Ω–∏–º–∞–Ω–∏–µ (Full Attention): O(seq_len^2) = {full_attention_complexity}")
    print(f"  - –°–∂–∞—Ç–æ–µ –≤–Ω–∏–º–∞–Ω–∏–µ (Compressed Attention): O(seq_len * num_blocks) = {compressed_attention_complexity}")
    print(f"  - –°–æ–∫—Ä–∞—â–µ–Ω–∏–µ —Å–ª–æ–∂–Ω–æ—Å—Ç–∏: {full_attention_complexity / compressed_attention_complexity:.1f}x\n")

    # –ò–∑–º–µ—Ä—è–µ–º —Ä–µ–∞–ª—å–Ω–æ–µ –≤—Ä–µ–º—è –≤—ã–ø–æ–ª–Ω–µ–Ω–∏—è
    print(f"üìå –®–∞–≥ 5: –ò–∑–º–µ—Ä–µ–Ω–∏–µ –≤—Ä–µ–º–µ–Ω–∏ –≤—ã–ø–æ–ª–Ω–µ–Ω–∏—è")

    # –†–µ–∞–ª–∏–∑–∞—Ü–∏—è —Å—Ç–∞–Ω–¥–∞—Ä—Ç–Ω–æ–≥–æ –≤–Ω–∏–º–∞–Ω–∏—è –¥–ª—è —Å—Ä–∞–≤–Ω–µ–Ω–∏—è
    def standard_attention(q, k, v, scale):
        attention_scores = torch.matmul(q, k.transpose(-1, -2)) * scale
        attention_probs = F.softmax(attention_scores, dim=-1)
        context_layer = torch.matmul(attention_probs, v)
        return context_layer, attention_probs

    # –î–µ–ª–∞–µ–º –∑–∞–ø—Ä–æ—Å—ã, –∫–ª—é—á–∏ –∏ –∑–Ω–∞—á–µ–Ω–∏—è –¥–ª—è –æ–¥–Ω–æ–π –≥–æ–ª–æ–≤—ã
    q_head = q.view(batch_size, seq_len, num_heads, head_dim)[:, :, 0, :]  # (batch_size, seq_len, head_dim)
    k_head = k.view(batch_size, seq_len, num_heads, head_dim)[:, :, 0, :]
    v_head = v.view(batch_size, seq_len, num_heads, head_dim)[:, :, 0, :]

    # –ò–∑–º–µ—Ä—è–µ–º –≤—Ä–µ–º—è –¥–ª—è —Å—Ç–∞–Ω–¥–∞—Ä—Ç–Ω–æ–≥–æ –≤–Ω–∏–º–∞–Ω–∏—è
    if use_long_sequence:
        print("  - –î–ª—è –ø–æ–ª–Ω–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ 32K —Ç–æ–∫–µ–Ω–æ–≤ —Å—Ç–∞–Ω–¥–∞—Ä—Ç–Ω–æ–µ –≤–Ω–∏–º–∞–Ω–∏–µ —Å–ª–∏—à–∫–æ–º –∑–∞—Ç—Ä–∞—Ç–Ω–æ")
        print("  - –û—Ü–µ–Ω–∏–≤–∞–µ–º –≤—Ä–µ–º—è –Ω–∞ –ø–æ–¥–º–Ω–æ–∂–µ—Å—Ç–≤–µ –¥–∞–Ω–Ω—ã—Ö (1000 —Ç–æ–∫–µ–Ω–æ–≤) –∏ —ç–∫—Å—Ç—Ä–∞–ø–æ–ª–∏—Ä—É–µ–º")

        # –ò—Å–ø–æ–ª—å–∑—É–µ–º —Ç–æ–ª—å–∫–æ —á–∞—Å—Ç—å –¥–∞–Ω–Ω—ã—Ö –¥–ª—è –æ—Ü–µ–Ω–∫–∏ –≤—Ä–µ–º–µ–Ω–∏
        sample_size = 1000
        q_sample = q_head[:, :sample_size, :]
        k_sample = k_head[:, :sample_size, :]
        v_sample = v_head[:, :sample_size, :]

        # –ò–∑–º–µ—Ä—è–µ–º –≤—Ä–µ–º—è –Ω–∞ –ø–æ–¥–º–Ω–æ–∂–µ—Å—Ç–≤–µ
        start_time = time.time()
        for _ in range(10):  # –ú–µ–Ω—å—à–µ –∏—Ç–µ—Ä–∞—Ü–∏–π –¥–ª—è –¥–ª–∏–Ω–Ω–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏
            _, _ = standard_attention(q_sample, k_sample, v_sample, compressed_attention.scale)
        sample_std_time = (time.time() - start_time) / 10

        # –≠–∫—Å—Ç—Ä–∞–ø–æ–ª–∏—Ä—É–µ–º –≤—Ä–µ–º—è –¥–ª—è –ø–æ–ª–Ω–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ (–∫–≤–∞–¥—Ä–∞—Ç–∏—á–Ω–∞—è –∑–∞–≤–∏—Å–∏–º–æ—Å—Ç—å)
        scaling_factor = (seq_len / sample_size) ** 2
        std_time = sample_std_time * scaling_factor
        print(f"  - –ò–∑–º–µ—Ä–µ–Ω–Ω–æ–µ –≤—Ä–µ–º—è –¥–ª—è {sample_size} —Ç–æ–∫–µ–Ω–æ–≤: {sample_std_time:.6f} —Å")
        print(f"  - –≠–∫—Å—Ç—Ä–∞–ø–æ–ª–∏—Ä–æ–≤–∞–Ω–Ω–æ–µ –≤—Ä–µ–º—è –¥–ª—è {seq_len} —Ç–æ–∫–µ–Ω–æ–≤: {std_time:.6f} —Å")
    else:
        # –î–ª—è –∫–æ—Ä–æ—Ç–∫–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ –∏–∑–º–µ—Ä—è–µ–º –Ω–∞–ø—Ä—è–º—É—é
        start_time = time.time()
        for _ in range(100):  # –ü–æ–≤—Ç–æ—Ä—è–µ–º –Ω–µ—Å–∫–æ–ª—å–∫–æ —Ä–∞–∑ –¥–ª—è –±–æ–ª–µ–µ —Ç–æ—á–Ω–æ–≥–æ –∏–∑–º–µ—Ä–µ–Ω–∏—è
            _, _ = standard_attention(q_head, k_head, v_head, compressed_attention.scale)
        std_time = (time.time() - start_time) / 100

    # –ò–∑–º–µ—Ä—è–µ–º –≤—Ä–µ–º—è –¥–ª—è —Å–∂–∞—Ç–æ–≥–æ –≤–Ω–∏–º–∞–Ω–∏—è
    repeat_count = 10 if use_long_sequence else 100  # –ú–µ–Ω—å—à–µ –∏—Ç–µ—Ä–∞—Ü–∏–π –¥–ª—è –¥–ª–∏–Ω–Ω–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏
    start_time = time.time()
    for _ in range(repeat_count):
        _, _ = compressed_attention(x, output_attentions=True)[:2]
    compressed_time = (time.time() - start_time) / repeat_count

    print(f"  - –°—Ç–∞–Ω–¥–∞—Ä—Ç–Ω–æ–µ –≤–Ω–∏–º–∞–Ω–∏–µ: {std_time:.6f} —Å")
    print(f"  - –°–∂–∞—Ç–æ–µ –≤–Ω–∏–º–∞–Ω–∏–µ: {compressed_time:.6f} —Å")
    print(f"  - –£—Å–∫–æ—Ä–µ–Ω–∏–µ: {std_time / compressed_time:.2f}x\n")

    # –í–∏–∑—É–∞–ª–∏–∑–∞—Ü–∏—è
    print(f"üìå –®–∞–≥ 6: –í–∏–∑—É–∞–ª–∏–∑–∞—Ü–∏—è –º–∞—Ç—Ä–∏—Ü—ã –≤–Ω–∏–º–∞–Ω–∏—è")

    if use_long_sequence:
        print("  - –î–ª—è 32K —Ç–æ–∫–µ–Ω–æ–≤ –≤–∏–∑—É–∞–ª–∏–∑–∏—Ä—É–µ–º —Ç–æ–ª—å–∫–æ —Ñ—Ä–∞–≥–º–µ–Ω—Ç –º–∞—Ç—Ä–∏—Ü—ã –≤–Ω–∏–º–∞–Ω–∏—è")

        # –î–ª—è –¥–ª–∏–Ω–Ω–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ –≤–∏–∑—É–∞–ª–∏–∑–∏—Ä—É–µ–º —Ç–æ–ª—å–∫–æ —á–∞—Å—Ç—å –º–∞—Ç—Ä–∏—Ü—ã
        # –í—ã–±–∏—Ä–∞–µ–º –∏–Ω—Ç–µ—Ä–µ—Å–Ω—ã–µ —Ñ—Ä–∞–≥–º–µ–Ω—Ç—ã: –Ω–∞—á–∞–ª–æ, —Å–µ—Ä–µ–¥–∏–Ω–∞, –∫–æ–Ω–µ—Ü
        sample_ranges = [
            (0, 500),                           # –ù–∞—á–∞–ª–æ
            (seq_len//2-250, seq_len//2+250),   # –°–µ—Ä–µ–¥–∏–Ω–∞
            (seq_len-500, seq_len)              # –ö–æ–Ω–µ—Ü
        ]

        fig, axes = plt.subplots(3, 1, figsize=(12, 18))

        for i, (start, end) in enumerate(sample_ranges):
            # –î–ª—è –≤–∏–∑—É–∞–ª–∏–∑–∞—Ü–∏–∏ –∏—Å–ø–æ–ª—å–∑—É–µ–º –ø–µ—Ä–≤—É—é –≥–æ–ª–æ–≤—É –≤–Ω–∏–º–∞–Ω–∏—è –∏ –ø–µ—Ä–≤—ã–π –ø—Ä–∏–º–µ—Ä –≤ –±–∞—Ç—á–µ
            attention_fragment = attention_probs[0, 0, start:end].cpu().detach().numpy()

            # –í–∏–∑—É–∞–ª–∏–∑–∏—Ä—É–µ–º —Ñ—Ä–∞–≥–º–µ–Ω—Ç –º–∞—Ç—Ä–∏—Ü—ã –≤–Ω–∏–º–∞–Ω–∏—è –∫–∞–∫ —Ç–µ–ø–ª–æ–≤—É—é –∫–∞—Ä—Ç—É
            im = axes[i].imshow(attention_fragment, cmap='viridis', aspect='auto')
            fig.colorbar(im, ax=axes[i], label='–í–µ—Å –≤–Ω–∏–º–∞–Ω–∏—è')

            # –ù–∞—Å—Ç—Ä–∞–∏–≤–∞–µ–º –æ—Å–∏
            axes[i].set_xlabel('–ò–Ω–¥–µ–∫—Å –±–ª–æ–∫–∞')
            axes[i].set_ylabel(f'–ò–Ω–¥–µ–∫—Å –∑–∞–ø—Ä–æ—Å–∞ ({start}-{end})')
            axes[i].set_title(f'–§—Ä–∞–≥–º–µ–Ω—Ç –º–∞—Ç—Ä–∏—Ü—ã –≤–Ω–∏–º–∞–Ω–∏—è ({start}-{end})')

            # –£—Å—Ç–∞–Ω–∞–≤–ª–∏–≤–∞–µ–º –º–µ—Ç–∫–∏ —Ç–∏–∫–æ–≤
            axes[i].set_xticks(np.arange(len(block_indices)))
            axes[i].set_xticklabels([f"{s}-{e}" for s, e in block_indices], rotation=45, fontsize=8)

            # –ü–æ–∫–∞–∑—ã–≤–∞–µ–º —Ç–æ–ª—å–∫–æ –Ω–µ–∫–æ—Ç–æ—Ä—ã–µ —Ç–æ–∫–µ–Ω—ã –Ω–∞ y-–æ—Å–∏ –¥–ª—è —è—Å–Ω–æ—Å—Ç–∏
            fragment_len = end - start
            y_step = max(1, fragment_len // 10)
            y_ticks = np.arange(0, fragment_len, y_step)
            axes[i].set_yticks(y_ticks)
            axes[i].set_yticklabels([str(start + j) for j in y_ticks], fontsize=8)

        # –°–æ—Ö—Ä–∞–Ω—è–µ–º –ø–æ–ª–Ω—É—é –º–∞—Ç—Ä–∏—Ü—É –≤–Ω–∏–º–∞–Ω–∏—è –¥–ª—è –∞–Ω–∞–ª–∏–∑–∞
        attention_head = attention_probs[0, 0].cpu().detach().numpy()

    else:
        # –î–ª—è –∫–æ—Ä–æ—Ç–∫–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ –≤–∏–∑—É–∞–ª–∏–∑–∏—Ä—É–µ–º –≤—Å—é –º–∞—Ç—Ä–∏—Ü—É
        # –î–ª—è –≤–∏–∑—É–∞–ª–∏–∑–∞—Ü–∏–∏ –∏—Å–ø–æ–ª—å–∑—É–µ–º –ø–µ—Ä–≤—É—é –≥–æ–ª–æ–≤—É –≤–Ω–∏–º–∞–Ω–∏—è –∏ –ø–µ—Ä–≤—ã–π –ø—Ä–∏–º–µ—Ä –≤ –±–∞—Ç—á–µ
        attention_head = attention_probs[0, 0].cpu().detach().numpy()

        # –°–æ–∑–¥–∞–µ–º —Å–µ—Ç–∫—É –ø–æ–∑–∏—Ü–∏–π —Ç–æ–∫–µ–Ω–æ–≤
        token_positions = np.arange(seq_len)

        # –ü—Ä–µ–æ–±—Ä–∞–∑—É–µ–º –∏–Ω–¥–µ–∫—Å—ã –±–ª–æ–∫–æ–≤ –≤ —Å—Ä–µ–¥–Ω–∏–µ –ø–æ–∑–∏—Ü–∏–∏
        block_positions = [(start + end) // 2 for start, end in block_indices]

        fig, ax = plt.subplots(figsize=(12, 8))

        # –í–∏–∑—É–∞–ª–∏–∑–∏—Ä—É–µ–º –º–∞—Ç—Ä–∏—Ü—É –≤–Ω–∏–º–∞–Ω–∏—è –∫–∞–∫ —Ç–µ–ø–ª–æ–≤—É—é –∫–∞—Ä—Ç—É
        im = ax.imshow(attention_head, cmap='viridis', aspect='auto')
        fig.colorbar(im, ax=ax, label='–í–µ—Å –≤–Ω–∏–º–∞–Ω–∏—è')

        # –ù–∞—Å—Ç—Ä–∞–∏–≤–∞–µ–º –æ—Å–∏
        ax.set_xlabel('–ò–Ω–¥–µ–∫—Å –±–ª–æ–∫–∞')
        ax.set_ylabel('–ò–Ω–¥–µ–∫—Å –∑–∞–ø—Ä–æ—Å–∞ (—Ç–æ–∫–µ–Ω–∞)')
        ax.set_title('–ú–∞—Ç—Ä–∏—Ü–∞ –≤–Ω–∏–º–∞–Ω–∏—è –¥–ª—è —Å–∂–∞—Ç–æ–≥–æ –≤–Ω–∏–º–∞–Ω–∏—è (Compressed Attention)')

        # –£—Å—Ç–∞–Ω–∞–≤–ª–∏–≤–∞–µ–º –º–µ—Ç–∫–∏ —Ç–∏–∫–æ–≤
        ax.set_xticks(np.arange(len(block_indices)))
        ax.set_xticklabels([f"{start}-{end}" for start, end in block_indices], rotation=45)

        # –ü–æ–∫–∞–∑—ã–≤–∞–µ–º —Ç–æ–ª—å–∫–æ –Ω–µ–∫–æ—Ç–æ—Ä—ã–µ —Ç–æ–∫–µ–Ω—ã –Ω–∞ y-–æ—Å–∏ –¥–ª—è —è—Å–Ω–æ—Å—Ç–∏
        y_ticks = np.arange(0, seq_len, 16)
        ax.set_yticks(y_ticks)
        ax.set_yticklabels([str(i) for i in y_ticks])

    plt.tight_layout()
    print("  - –ú–∞—Ç—Ä–∏—Ü–∞ –≤–Ω–∏–º–∞–Ω–∏—è –≤–∏–∑—É–∞–ª–∏–∑–∏—Ä–æ–≤–∞–Ω–∞")

    # –í–∏–∑—É–∞–ª–∏–∑–∞—Ü–∏—è –≤–∞–∂–Ω—ã—Ö –±–ª–æ–∫–æ–≤
    print("\nüìå –®–∞–≥ 7: –ê–Ω–∞–ª–∏–∑ —Å–∂–∞—Ç—ã—Ö –±–ª–æ–∫–æ–≤")

    # –ù–∞—Ö–æ–¥–∏–º –Ω–∞–∏–±–æ–ª–µ–µ –≤–∞–∂–Ω—ã–µ –±–ª–æ–∫–∏ (–ø–æ —Å—É–º–º–µ –≤–µ—Å–æ–≤ –≤–Ω–∏–º–∞–Ω–∏—è)
    if use_long_sequence:
        print("  - –î–ª—è 32K —Ç–æ–∫–µ–Ω–æ–≤ –∞–Ω–∞–ª–∏–∑–∏—Ä—É–µ–º –∞–≥—Ä–µ–≥–∏—Ä–æ–≤–∞–Ω–Ω—ã–µ –¥–∞–Ω–Ω—ã–µ")

        # –î–ª—è –¥–ª–∏–Ω–Ω–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ –≤—ã—á–∏—Å–ª—è–µ–º –≤–∞–∂–Ω–æ—Å—Ç—å –±–ª–æ–∫–æ–≤ –∫–∞–∫ —Å—Ä–µ–¥–Ω–µ–µ
        # –ø–æ –Ω–µ—Å–∫–æ–ª—å–∫–∏–º –∫–ª—é—á–µ–≤—ã–º —Ñ—Ä–∞–≥–º–µ–Ω—Ç–∞–º –¥–ª—è —ç–∫–æ–Ω–æ–º–∏–∏ –≤—ã—á–∏—Å–ª–µ–Ω–∏–π
        fragment_samples = [
            0,                  # –ù–∞—á–∞–ª–æ
            seq_len // 4,       # –ü–µ—Ä–≤–∞—è —á–µ—Ç–≤–µ—Ä—Ç—å
            seq_len // 2,       # –°–µ—Ä–µ–¥–∏–Ω–∞
            3 * seq_len // 4,   # –¢—Ä–µ—Ç—å—è —á–µ—Ç–≤–µ—Ä—Ç—å
            seq_len - 1         # –ö–æ–Ω–µ—Ü
        ]

        # –°–æ–±–∏—Ä–∞–µ–º –¥–∞–Ω–Ω—ã–µ –ø–æ —Ñ—Ä–∞–≥–º–µ–Ω—Ç–∞–º
        fragment_importances = []
        for pos in fragment_samples:
            fragment_row = attention_probs[0, 0, pos].cpu().detach().numpy()
            fragment_importances.append(fragment_row)

        # –£—Å—Ä–µ–¥–Ω—è–µ–º –¥–∞–Ω–Ω—ã–µ –ø–æ –≤—Å–µ–º —Ñ—Ä–∞–≥–º–µ–Ω—Ç–∞–º
        block_importance = np.mean(fragment_importances, axis=0)
    else:
        # –î–ª—è –∫–æ—Ä–æ—Ç–∫–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ –∏—Å–ø–æ–ª—å–∑—É–µ–º –ø–æ–ª–Ω—É—é –∏–Ω—Ñ–æ—Ä–º–∞—Ü–∏—é
        block_importance = attention_head.sum(axis=0)

    top_blocks_idx = np.argsort(block_importance)[-3:][::-1]

    print(f"  - –ù–∞–∏–±–æ–ª–µ–µ –≤–∞–∂–Ω—ã–µ –±–ª–æ–∫–∏:")
    for i, idx in enumerate(top_blocks_idx):
        start, end = block_indices[idx]
        importance = block_importance[idx]
        print(f"    {i+1}. –ë–ª–æ–∫ {idx} (–ø–æ–∑–∏—Ü–∏–∏ {start}-{end}): –≤–∞–∂–Ω–æ—Å—Ç—å = {importance:.3f}")

    print("\nüìå –®–∞–≥ 8: –°—Ä–∞–≤–Ω–µ–Ω–∏–µ —Å –æ–±—ã—á–Ω—ã–º –≤–Ω–∏–º–∞–Ω–∏–µ–º –¥–ª—è –∑–∞–¥–∞–Ω–Ω–æ–≥–æ –∑–∞–ø—Ä–æ—Å–∞")

    # –í—ã–±–∏—Ä–∞–µ–º –æ–ø—Ä–µ–¥–µ–ª–µ–Ω–Ω—ã–π –∑–∞–ø—Ä–æ—Å –¥–ª—è –∞–Ω–∞–ª–∏–∑–∞
    if use_long_sequence:
        # –î–ª—è –¥–ª–∏–Ω–Ω–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ –≤—ã–±–∏—Ä–∞–µ–º –∑–∞–ø—Ä–æ—Å –∏–∑ —Å–µ—Ä–µ–¥–∏–Ω—ã –∫–ª–∞—Å—Ç–µ—Ä–∞
        query_idx = seq_len // 2

        print(f"  - –ê–Ω–∞–ª–∏–∑ –≤–Ω–∏–º–∞–Ω–∏—è –¥–ª—è –∑–∞–ø—Ä–æ—Å–∞ –≤ –ø–æ–∑–∏—Ü–∏–∏ {query_idx}:")
        print(f"    * –í –æ–±—ã—á–Ω–æ–º –≤–Ω–∏–º–∞–Ω–∏–∏ —ç—Ç–æ—Ç –∑–∞–ø—Ä–æ—Å —Ä–∞—Å–ø—Ä–µ–¥–µ–ª—è–ª –±—ã —Å–≤–æ—ë –≤–Ω–∏–º–∞–Ω–∏–µ –Ω–∞ –≤—Å–µ {seq_len} —Ç–æ–∫–µ–Ω–æ–≤")
        print(f"    * –í —Å–∂–∞—Ç–æ–º –≤–Ω–∏–º–∞–Ω–∏–∏ –≤–Ω–∏–º–∞–Ω–∏–µ —Ä–∞—Å–ø—Ä–µ–¥–µ–ª—è–µ—Ç—Å—è —Ç–æ–ª—å–∫–æ –Ω–∞ {len(block_indices)} –±–ª–æ–∫–æ–≤")

        # –î–ª—è –æ–±—ã—á–Ω–æ–≥–æ –≤–Ω–∏–º–∞–Ω–∏—è —ç–∫—Å—Ç—Ä–∞–ø–æ–ª—è—Ü–∏—è
        print(f"\n    (–û–±—ã—á–Ω–æ–µ –≤–Ω–∏–º–∞–Ω–∏–µ –¥–ª—è 32K –Ω–µ –≤—ã—á–∏—Å–ª—è–µ—Ç—Å—è –∏–∑-–∑–∞ –≤—ã—Å–æ–∫–æ–π –≤—ã—á–∏—Å–ª–∏—Ç–µ–ª—å–Ω–æ–π —Å–ª–æ–∂–Ω–æ—Å—Ç–∏)")

        # –î–ª—è —Å–∂–∞—Ç–æ–≥–æ –≤–Ω–∏–º–∞–Ω–∏—è
        compressed_attention_probs = attention_probs[0, 0, query_idx].cpu().detach().numpy()

        # –ü–æ–∫–∞–∑—ã–≤–∞–µ–º —Ç–æ–ø –±–ª–æ–∫–æ–≤ –¥–ª—è —Å–∂–∞—Ç–æ–≥–æ –≤–Ω–∏–º–∞–Ω–∏—è
        top_k = 5
        top_compressed_idx = np.argsort(compressed_attention_probs)[-top_k:][::-1]
        print(f"\n    –¢–æ–ø-{top_k} –±–ª–æ–∫–æ–≤ –≤ —Å–∂–∞—Ç–æ–º –≤–Ω–∏–º–∞–Ω–∏–∏ –¥–ª—è –∑–∞–ø—Ä–æ—Å–∞ {query_idx}:")
        for i, idx in enumerate(top_compressed_idx):
            start, end = block_indices[idx]
            print(f"      {i+1}. –ë–ª–æ–∫ {idx} (–ø–æ–∑–∏—Ü–∏–∏ {start}-{end}): –≤–µ—Å = {compressed_attention_probs[idx]:.4f}")
    else:
        # –î–ª—è –∫–æ—Ä–æ—Ç–∫–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ - –∫–∞–∫ –≤ –æ—Ä–∏–≥–∏–Ω–∞–ª–µ
        # –í—ã–±–∏—Ä–∞–µ–º –æ–ø—Ä–µ–¥–µ–ª–µ–Ω–Ω—ã–π –∑–∞–ø—Ä–æ—Å –¥–ª—è –∞–Ω–∞–ª–∏–∑–∞ (–Ω–∞–ø—Ä–∏–º–µ—Ä, —Ç–æ–∫–µ–Ω –≤ –ø–æ–∑–∏—Ü–∏–∏ –≤–∞–∂–Ω–æ–≥–æ –∫–ª–∞—Å—Ç–µ—Ä–∞)
        query_idx = cluster_positions[0][0] + 5  # –ò–Ω–¥–µ–∫—Å –∑–∞–ø—Ä–æ—Å–∞ –≤ —Å–µ—Ä–µ–¥–∏–Ω–µ –≤–∞–∂–Ω–æ–≥–æ –∫–ª–∞—Å—Ç–µ—Ä–∞

        # –í—ã—á–∏—Å–ª—è–µ–º –æ–±—ã—á–Ω–æ–µ –≤–Ω–∏–º–∞–Ω–∏–µ –¥–ª—è —ç—Ç–æ–≥–æ –∑–∞–ø—Ä–æ—Å–∞
        q_token = q_head[:, query_idx:query_idx+1, :]  # (batch_size, 1, head_dim)
        attention_scores = torch.matmul(q_token, k_head.transpose(-1, -2)) * compressed_attention.scale  # (batch_size, 1, seq_len)
        full_attention_probs = F.softmax(attention_scores, dim=-1).squeeze().cpu().detach().numpy()

        # –í—ã—á–∏—Å–ª—è–µ–º —Å–∂–∞—Ç–æ–µ –≤–Ω–∏–º–∞–Ω–∏–µ –¥–ª—è —ç—Ç–æ–≥–æ –∑–∞–ø—Ä–æ—Å–∞
        compressed_attention_probs = attention_head[query_idx]

        print(f"  - –ê–Ω–∞–ª–∏–∑ –≤–Ω–∏–º–∞–Ω–∏—è –¥–ª—è –∑–∞–ø—Ä–æ—Å–∞ –≤ –ø–æ–∑–∏—Ü–∏–∏ {query_idx}:")
        print(f"    * –í –æ–±—ã—á–Ω–æ–º –≤–Ω–∏–º–∞–Ω–∏–∏ —ç—Ç–æ—Ç –∑–∞–ø—Ä–æ—Å —Ä–∞—Å–ø—Ä–µ–¥–µ–ª—è–µ—Ç —Å–≤–æ—ë –≤–Ω–∏–º–∞–Ω–∏–µ –Ω–∞ –≤—Å–µ {seq_len} —Ç–æ–∫–µ–Ω–æ–≤")
        print(f"    * –í —Å–∂–∞—Ç–æ–º –≤–Ω–∏–º–∞–Ω–∏–∏ –≤–Ω–∏–º–∞–Ω–∏–µ —Ä–∞—Å–ø—Ä–µ–¥–µ–ª—è–µ—Ç—Å—è —Ç–æ–ª—å–∫–æ –Ω–∞ {len(block_indices)} –±–ª–æ–∫–æ–≤")

        # –°—Ä–∞–≤–Ω–∏–≤–∞–µ–º —Ä–∞—Å–ø—Ä–µ–¥–µ–ª–µ–Ω–∏–µ –≤–Ω–∏–º–∞–Ω–∏—è
        top_k = 5  # –ü–æ–∫–∞–∑—ã–≤–∞–µ–º —Ç–æ–ø-5 –Ω–∞–∏–±–æ–ª–µ–µ –≤–∞–∂–Ω—ã—Ö —Ç–æ–∫–µ–Ω–æ–≤/–±–ª–æ–∫–æ–≤

        # –î–ª—è –æ–±—ã—á–Ω–æ–≥–æ –≤–Ω–∏–º–∞–Ω–∏—è
        top_tokens_idx = np.argsort(full_attention_probs)[-top_k:][::-1]
        print(f"\n    –¢–æ–ø-{top_k} —Ç–æ–∫–µ–Ω–æ–≤ –≤ –æ–±—ã—á–Ω–æ–º –≤–Ω–∏–º–∞–Ω–∏–∏:")
        for i, idx in enumerate(top_tokens_idx):
            print(f"      {i+1}. –¢–æ–∫–µ–Ω {idx}: –≤–µ—Å = {full_attention_probs[idx]:.4f}")

        # –î–ª—è —Å–∂–∞—Ç–æ–≥–æ –≤–Ω–∏–º–∞–Ω–∏—è
        top_compressed_idx = np.argsort(compressed_attention_probs)[-top_k:][::-1]
        print(f"\n    –¢–æ–ø-{top_k} –±–ª–æ–∫–æ–≤ –≤ —Å–∂–∞—Ç–æ–º –≤–Ω–∏–º–∞–Ω–∏–∏:")
        for i, idx in enumerate(top_compressed_idx):
            start, end = block_indices[idx]
            print(f"      {i+1}. –ë–ª–æ–∫ {idx} (–ø–æ–∑–∏—Ü–∏–∏ {start}-{end}): –≤–µ—Å = {compressed_attention_probs[idx]:.4f}")

    print("\nüìå –ó–∞–∫–ª—é—á–µ–Ω–∏–µ")
    if use_long_sequence:
        print("  - –ú–µ—Ö–∞–Ω–∏–∑–º —Å–∂–∞—Ç–æ–≥–æ –≤–Ω–∏–º–∞–Ω–∏—è —É—Å–ø–µ—à–Ω–æ —Ä–∞–±–æ—Ç–∞–µ—Ç —Å –¥–ª–∏–Ω–Ω–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç—å—é (32K —Ç–æ–∫–µ–Ω–æ–≤)")
        print(f"  - –ö–æ–ª–∏—á–µ—Å—Ç–≤–æ –±–ª–æ–∫–æ–≤: {len(block_indices)} –≤–º–µ—Å—Ç–æ {seq_len} —Ç–æ–∫–µ–Ω–æ–≤")
        print(f"  - –ö–æ—ç—Ñ—Ñ–∏—Ü–∏–µ–Ω—Ç —Å–∂–∞—Ç–∏—è: {seq_len / len(block_indices):.1f}x")
        print(f"  - –¢–µ–æ—Ä–µ—Ç–∏—á–µ—Å–∫–æ–µ —Å–æ–∫—Ä–∞—â–µ–Ω–∏–µ –≤—ã—á–∏—Å–ª–µ–Ω–∏–π: {(seq_len**2) / (seq_len * len(block_indices)):.1f}x")
        print(f"  - –î–ª—è —Å—Ç–∞–Ω–¥–∞—Ä—Ç–Ω–æ–≥–æ –≤–Ω–∏–º–∞–Ω–∏—è –ø–æ—Ç—Ä–µ–±–æ–≤–∞–ª–æ—Å—å –±—ã –æ–∫–æ–ª–æ {std_time:.4f} —Å–µ–∫—É–Ω–¥ (—ç–∫—Å—Ç—Ä–∞–ø–æ–ª—è—Ü–∏—è)")
        print(f"  - –î–ª—è —Å–∂–∞—Ç–æ–≥–æ –≤–Ω–∏–º–∞–Ω–∏—è –ø–æ—Ç—Ä–µ–±–æ–≤–∞–ª–æ—Å—å {compressed_time:.4f} —Å–µ–∫—É–Ω–¥")
        print(f"  - –¢–µ–æ—Ä–µ—Ç–∏—á–µ—Å–∫–æ–µ —É—Å–∫–æ—Ä–µ–Ω–∏–µ: {std_time / compressed_time:.2f}x")
        print("  - –°–∂–∞—Ç–æ–µ –≤–Ω–∏–º–∞–Ω–∏–µ –æ—Å–æ–±–µ–Ω–Ω–æ —ç—Ñ—Ñ–µ–∫—Ç–∏–≤–Ω–æ –¥–ª—è –¥–ª–∏–Ω–Ω—ã—Ö –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–µ–π")
        print("  - –ü—Ä–∏ –æ–±—Ä–∞–±–æ—Ç–∫–µ –¥–ª–∏–Ω–Ω—ã—Ö –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–µ–π –ø—Ä–µ–∏–º—É—â–µ—Å—Ç–≤–æ –≤ —Å–∫–æ—Ä–æ—Å—Ç–∏ –º–Ω–æ–≥–æ–∫—Ä–∞—Ç–Ω–æ –ø–µ—Ä–µ–≤–µ—à–∏–≤–∞–µ—Ç")
        print("    –Ω–∞–∫–ª–∞–¥–Ω—ã–µ —Ä–∞—Å—Ö–æ–¥—ã –Ω–∞ —Å–∂–∞—Ç–∏–µ –±–ª–æ–∫–æ–≤")
    else:
        print("  - –ú–µ—Ö–∞–Ω–∏–∑–º —Å–∂–∞—Ç–æ–≥–æ –≤–Ω–∏–º–∞–Ω–∏—è —É—Å–ø–µ—à–Ω–æ —Å–Ω–∏–∂–∞–µ—Ç –≤—ã—á–∏—Å–ª–∏—Ç–µ–ª—å–Ω—É—é —Å–ª–æ–∂–Ω–æ—Å—Ç—å")
        print(f"  - –°–æ–∫—Ä–∞—â–µ–Ω–∏–µ –≤—ã—á–∏—Å–ª–µ–Ω–∏–π: {full_attention_complexity / compressed_attention_complexity:.1f}x")
        print(f"  - –£—Å–∫–æ—Ä–µ–Ω–∏–µ: {std_time / compressed_time:.2f}x")
        print("  - –ü—Ä–∏ —ç—Ç–æ–º —Å–æ—Ö—Ä–∞–Ω—è–µ—Ç—Å—è —Å–ø–æ—Å–æ–±–Ω–æ—Å—Ç—å –º–æ–¥–µ–ª–∏ —Ñ–æ–∫—É—Å–∏—Ä–æ–≤–∞—Ç—å—Å—è –Ω–∞ –≤–∞–∂–Ω—ã—Ö —á–∞—Å—Ç—è—Ö –∫–æ–Ω—Ç–µ–∫—Å—Ç–∞")
        print("  - –°–∂–∞—Ç–æ–µ –≤–Ω–∏–º–∞–Ω–∏–µ –æ—Å–æ–±–µ–Ω–Ω–æ —ç—Ñ—Ñ–µ–∫—Ç–∏–≤–Ω–æ –¥–ª—è –¥–ª–∏–Ω–Ω—ã—Ö –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–µ–π")

    return fig

if __name__ == "__main__":
    # –ü–æ —É–º–æ–ª—á–∞–Ω–∏—é –∑–∞–ø—É—Å–∫–∞–µ–º –¥–µ–º–æ–Ω—Å—Ç—Ä–∞—Ü–∏—é –Ω–∞ –∫–æ—Ä–æ—Ç–∫–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏
    print("\n=== –î–ï–ú–û–ù–°–¢–†–ê–¶–ò–Ø –ù–ê –ö–û–†–û–¢–ö–û–ô –ü–û–°–õ–ï–î–û–í–ê–¢–ï–õ–¨–ù–û–°–¢–ò (128 —Ç–æ–∫–µ–Ω–æ–≤) ===")
    fig_short = demonstrate_compressed_attention(use_long_sequence=False)

    # –°–æ—Ö—Ä–∞–Ω—è–µ–º –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–µ
    plt.savefig('compressed_attention_short.png')
    plt.close(fig_short)
    print("\n–ò–∑–æ–±—Ä–∞–∂–µ–Ω–∏–µ —Å–æ—Ö—Ä–∞–Ω–µ–Ω–æ –≤ —Ñ–∞–π–ª 'compressed_attention_short.png'")

    # –°–ø—Ä–∞—à–∏–≤–∞–µ–º –ø–æ–ª—å–∑–æ–≤–∞—Ç–µ–ª—è, —Ö–æ—á–µ—Ç –ª–∏ –æ–Ω –∑–∞–ø—É—Å—Ç–∏—Ç—å —Ç–µ—Å—Ç –Ω–∞ –¥–ª–∏–Ω–Ω–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏
    run_long_test = input("\n–•–æ—Ç–∏—Ç–µ –∑–∞–ø—É—Å—Ç–∏—Ç—å —Ç–µ—Å—Ç –Ω–∞ –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ –¥–ª–∏–Ω–æ–π 32K —Ç–æ–∫–µ–Ω–æ–≤? (y/n): ")

    if run_long_test.lower() == 'y':
        print("\n=== –î–ï–ú–û–ù–°–¢–†–ê–¶–ò–Ø –ù–ê –î–õ–ò–ù–ù–û–ô –ü–û–°–õ–ï–î–û–í–ê–¢–ï–õ–¨–ù–û–°–¢–ò (32K —Ç–æ–∫–µ–Ω–æ–≤) ===")
        print("–û–±—Ä–∞—Ç–∏—Ç–µ –≤–Ω–∏–º–∞–Ω–∏–µ: —ç—Ç–æ—Ç —Ç–µ—Å—Ç –º–æ–∂–µ—Ç –∑–∞–Ω—è—Ç—å –∑–Ω–∞—á–∏—Ç–µ–ª—å–Ω–æ–µ –≤—Ä–µ–º—è –∏ –ø–æ—Ç—Ä–µ–±–æ–≤–∞—Ç—å –º–Ω–æ–≥–æ –ø–∞–º—è—Ç–∏")

        try:
            # –ü—Ä–æ–≤–µ—Ä—è–µ–º –¥–æ—Å—Ç—É–ø–Ω–æ—Å—Ç—å GPU –∏ —Å–≤–æ–±–æ–¥–Ω—É—é –ø–∞–º—è—Ç—å
            if torch.cuda.is_available():
                free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)
                print(f"–î–æ—Å—Ç—É–ø–Ω–∞—è –ø–∞–º—è—Ç—å GPU: {free_memory / 1024**3:.2f} –ì–ë")

                if free_memory < 4 * 1024**3:  # –ú–µ–Ω—å—à–µ 4 –ì–ë —Å–≤–æ–±–æ–¥–Ω–æ–π –ø–∞–º—è—Ç–∏
                    print("–ü—Ä–µ–¥—É–ø—Ä–µ–∂–¥–µ–Ω–∏–µ: –º–∞–ª–æ —Å–≤–æ–±–æ–¥–Ω–æ–π –ø–∞–º—è—Ç–∏ –Ω–∞ GPU, –≤–æ–∑–º–æ–∂–Ω—ã –æ—à–∏–±–∫–∏ OOM")

            # –ó–∞–ø—É—Å–∫–∞–µ–º —Ç–µ—Å—Ç –Ω–∞ –¥–ª–∏–Ω–Ω–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏
            fig_long = demonstrate_compressed_attention(use_long_sequence=True)

            # –°–æ—Ö—Ä–∞–Ω—è–µ–º –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–µ
            plt.savefig('compressed_attention_long.png')
            plt.close(fig_long)
            print("\n–ò–∑–æ–±—Ä–∞–∂–µ–Ω–∏–µ —Å–æ—Ö—Ä–∞–Ω–µ–Ω–æ –≤ —Ñ–∞–π–ª 'compressed_attention_long.png'")

        except RuntimeError as e:
            print(f"\n–ü—Ä–æ–∏–∑–æ—à–ª–∞ –æ—à–∏–±–∫–∞: {e}")
            print("–í–æ–∑–º–æ–∂–Ω–æ, –Ω–µ —Ö–≤–∞—Ç–∞–µ—Ç –ø–∞–º—è—Ç–∏ –¥–ª—è –æ–±—Ä–∞–±–æ—Ç–∫–∏ –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ –¥–ª–∏–Ω–æ–π 32K —Ç–æ–∫–µ–Ω–æ–≤.")
            print("–ü–æ–ø—Ä–æ–±—É–π—Ç–µ –∑–∞–ø—É—Å—Ç–∏—Ç—å —Ç–µ—Å—Ç –Ω–∞ –∫–æ–º–ø—å—é—Ç–µ—Ä–µ —Å –±–æ–ª—å—à–∏–º –æ–±—ä–µ–º–æ–º –ø–∞–º—è—Ç–∏ –∏–ª–∏ —É–º–µ–Ω—å—à–∏—Ç—å –¥–ª–∏–Ω—É –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏.")
    else:
        print("\n–¢–µ—Å—Ç –Ω–∞ –¥–ª–∏–Ω–Ω–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ –ø—Ä–æ–ø—É—â–µ–Ω")


=== –î–ï–ú–û–ù–°–¢–†–ê–¶–ò–Ø –ù–ê –ö–û–†–û–¢–ö–û–ô –ü–û–°–õ–ï–î–û–í–ê–¢–ï–õ–¨–ù–û–°–¢–ò (128 —Ç–æ–∫–µ–Ω–æ–≤) ===
üìå –°–æ–∑–¥–∞–Ω–∏–µ –≤—Ö–æ–¥–Ω—ã—Ö –¥–∞–Ω–Ω—ã—Ö –¥–ª–∏–Ω–æ–π 128 —Ç–æ–∫–µ–Ω–æ–≤...

–î–ï–ú–û–ù–°–¢–†–ê–¶–ò–Ø –ú–ï–•–ê–ù–ò–ó–ú–ê –°–ñ–ê–¢–û–ì–û –í–ù–ò–ú–ê–ù–ò–Ø (COMPRESSED ATTENTION)

üìå –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è –º–æ–¥–µ–ª–∏ CompressedAttention —Å –ø–∞—Ä–∞–º–µ—Ç—Ä–∞–º–∏:
  - –†–∞–∑–º–µ—Ä —Å–∫—Ä—ã—Ç–æ–≥–æ —Å–æ—Å—Ç–æ—è–Ω–∏—è (hidden_size): 64
  - –†–∞–∑–º–µ—Ä –±–ª–æ–∫–∞ (block_size): 32
  - –®–∞–≥ (stride): 16
  - –ö–æ–ª–∏—á–µ—Å—Ç–≤–æ –≥–æ–ª–æ–≤ –≤–Ω–∏–º–∞–Ω–∏—è (num_heads): 4

üìå –°–æ–∑–¥–∞–Ω–∏–µ –≤—Ö–æ–¥–Ω—ã—Ö –¥–∞–Ω–Ω—ã—Ö:
  - –†–∞–∑–º–µ—Ä –ø–∞–∫–µ—Ç–∞ (batch_size): 1
  - –î–ª–∏–Ω–∞ –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ (seq_len): 128

üìå –ü–æ–¥–≥–æ—Ç–æ–≤–∫–∞ –¥–∞–Ω–Ω—ã—Ö:
  - –°–æ–∑–¥–∞–ª–∏ –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç—å —Å –Ω–µ—Å–∫–æ–ª—å–∫–∏–º–∏ –ø–∞—Ç—Ç–µ—Ä–Ω–∞–º–∏:
    1. –†–∞–≤–Ω–æ–º–µ—Ä–Ω–æ —Ä–∞—Å–ø—Ä–µ–¥–µ–ª–µ–Ω–Ω—ã–µ '–≤–∞–∂–Ω—ã–µ' —Ç–æ–∫–µ–Ω—ã –∫–∞–