In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FlashAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, num_features, memory_efficient=False):
        super(FlashAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_features = num_features
        self.memory_efficient = memory_efficient

        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

        self.query_linear = nn.Linear(embed_dim, num_features)
        self.key_linear = nn.Linear(embed_dim, num_features)
        self.value_linear = nn.Linear(embed_dim, num_features)

        self.out_linear = nn.Linear(num_features, embed_dim)

    def forward(self, query, key, value, attention_mask=None):
        batch_size = query.size(0)
        seq_length = query.size(1)

        # 扩展维度
        query = self.query_linear(query).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        key = self.key_linear(key).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        value = self.value_linear(value).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)

        # scaled_attention.transpose(1, 2).contiguous().view(batch_size, seq_length, embed_dim)
        # 这里简化了注意力计算，实际中应该使用 FlashAttention 的内存高效版本
        attention = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        if attention_mask is not None:
            attention = attention.masked_fill(attention_mask == 0, float("-inf"))
        attention = F.softmax(attention, dim=-1)
        output = torch.matmul(attention, value).transpose(1, 2).contiguous().view(batch_size, seq_length, embed_dim)
        output = self.out_linear(output)
        return output
