In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FlashAttention2(nn.Module):
    def __init__(self, sequence_length, head_dimension, block_size):
        super(FlashAttention2, self).__init__()
        # Additional configurations can be added here
        self.block_size = block_size
        # Ensure that sequence_length is divisible by block_size for simplicity
        assert sequence_length % block_size == 0

    def forward(self, Q, K, V):
        # Partitioning of inputs
        Q_blocks, K_blocks, V_blocks = self.partition_inputs(Q, K, V)

        # Efficient computation of the attention mechanism
        outputs = []
        for i, Q_block in enumerate(Q_blocks):
            output_block = self.process_block(Q_block, K_blocks, V_blocks)
            outputs.append(output_block)

        # Concatenating the processed blocks
        output = torch.cat(outputs, dim=0)
        return output

    def partition_inputs(self, Q, K, V):
        # The actual partitioning scheme should be based on sequence length, head dimension, and block size
        Q_blocks = Q.chunk(chunks=Q.size(0) // self.block_size, dim=0)
        K_blocks = K.chunk(chunks=K.size(0) // self.block_size, dim=0)
        V_blocks = V.chunk(chunks=V.size(0) // self.block_size, dim=0)
        return Q_blocks, K_blocks, V_blocks

    def process_block(self, Q_block, K_blocks, V_blocks):
        # Process each block efficiently as per FLASH2's optimized method
        # This includes computing QK^T, applying online softmax, and multiplying with V
        output_blocks = []
        for K_block, V_block in zip(K_blocks, V_blocks):
            attention_scores = torch.matmul(Q_block, K_block.transpose(-2, -1))
            attention_scores = self.online_softmax(attention_scores)
            output_block = torch.matmul(attention_scores, V_block)
            output_blocks.append(output_block)

        # Summing up the results from each block
        output_block_sum = sum(output_blocks)
        return output_block_sum

    def online_softmax(self, scores):
        # Placeholder for the actual online softmax logic
        # In practice, this method should implement the online softmax technique as detailed in FLASH2
        return F.softmax(scores, dim=-1)

# Example usage
sequence_length = 1024  # Example sequence length
head_dimension = 64    # Example head dimension
block_size = 256       # Example block size

flash_attention = FlashAttention2(sequence_length, head_dimension, block_size)
Q, K, V = [torch.randn(sequence_length, head_dimension) for _ in range(3)]
output = flash_attention(Q, K, V)


In [2]:
import time
import torch

# Testing
def test_flash_attention():
    # Define test cases
    test_cases = [
        (1024, 64, 256),
        # Add more test cases here
    ]

    for seq_len, head_dim, blk_size in test_cases:
        model = FlashAttention2(seq_len, head_dim, blk_size)
        Q, K, V = [torch.randn(seq_len, head_dim) for _ in range(3)]

        # Forward pass
        output = model(Q, K, V)
        assert output.shape == (seq_len, head_dim)  # Check output shape

# Benchmarking
def benchmark_flash_attention():
    seq_len, head_dim, blk_size = 1024, 64, 256
    model = FlashAttention2(seq_len, head_dim, blk_size)
    Q, K, V = [torch.randn(seq_len, head_dim) for _ in range(3)]

    start_time = time.time()
    output = model(Q, K, V)
    end_time = time.time()
    print(f"Time taken for forward pass: {end_time - start_time} seconds")

test_flash_attention()
benchmark_flash_attention()


Time taken for forward pass: 0.008143901824951172 seconds


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time


class SparseFlash2Attention(nn.Module):
    def __init__(self, seq_len, head_dim, blk_size, sparsity_factor):
        super().__init__()
        self.flash_attention = FlashAttention2(seq_len, head_dim, blk_size)
        self.head_dim = head_dim 
        self.seq_len = seq_len
        self.sparsity_factor = sparsity_factor

    def generate_sparsity_mask(self):
        # Adjusting mask generation to match the output dimensions of FlashAttention2
        mask = torch.zeros(self.seq_len, self.seq_len)
        step = self.sparsity_factor
        for i in range(0, self.seq_len, step):
            mask[i:i + step, :] = 1  # Assuming the mask needs to be applied across all head dimensions
        return mask.bool()

    def forward(self, Q, K, V):
        output = self.flash_attention(Q, K, V)
        sparsity_mask = self.generate_sparsity_mask()

        # Apply the sparsity mask to each head dimension individually
        for i in range(self.head_dim):
            output[:, i] = output[:, i].masked_fill(sparsity_mask == 0, 0)

        return output

# Testing and Benchmarking
def test_sparse_flash2_attention():
    seq_len, head_dim, blk_size, sparsity_factor = 1024, 64, 256, 4
    model = SparseFlash2Attention(seq_len, head_dim, blk_size, sparsity_factor)
    Q, K, V = [torch.randn(seq_len, head_dim) for _ in range(3)]

    # Forward pass
    output = model(Q, K, V)
    assert output.shape == (seq_len, head_dim)  # Check output shape

def benchmark_sparse_flash2_attention():
    seq_len, head_dim, blk_size, sparsity_factor = 1024, 64, 256, 4
    model = SparseFlash2Attention(seq_len, head_dim, blk_size, sparsity_factor)
    Q, K, V = [torch.randn(seq_len, head_dim) for _ in range(3)]

    start_time = time.time()
    output = model(Q, K, V)
    end_time = time.time()
    print(f"Time taken for forward pass: {end_time - start_time} seconds")

test_sparse_flash2_attention()
benchmark_sparse_flash2_attention()

RuntimeError: expand(torch.FloatTensor{[1024, 1024]}, size=[1024]): the number of sizes provided (1) must be greater or equal to the number of dimensions in the tensor (2)

# v2

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

class FlashAttention2(nn.Module):
    def __init__(self, sequence_length, head_dimension, block_size):
        super(FlashAttention2, self).__init__()
        self.block_size = block_size
        # Ensure that sequence_length is divisible by block_size for simplicity
        assert sequence_length % block_size == 0

    def forward(self, Q, K, V):
        # Partitioning of inputs
        Q_blocks, K_blocks, V_blocks = self.partition_inputs(Q, K, V)

        # Efficient computation of the attention mechanism
        outputs = []
        for i, Q_block in enumerate(Q_blocks):
            output_block = self.process_block(Q_block, K_blocks, V_blocks)
            outputs.append(output_block)

        # Concatenating the processed blocks
        output = torch.cat(outputs, dim=0)
        return output

    def partition_inputs(self, Q, K, V):
        # The actual partitioning scheme should be based on sequence length, head dimension, and block size
        Q_blocks = Q.chunk(chunks=Q.size(0) // self.block_size, dim=0)
        K_blocks = K.chunk(chunks=K.size(0) // self.block_size, dim=0)
        V_blocks = V.chunk(chunks=V.size(0) // self.block_size, dim=0)
        return Q_blocks, K_blocks, V_blocks

    def process_block(self, Q_block, K_blocks, V_blocks):
        # Process each block efficiently as per FLASH2's optimized method
        # This includes computing QK^T, applying online softmax, and multiplying with V
        output_blocks = []
        for K_block, V_block in zip(K_blocks, V_blocks):
            attention_scores = torch.matmul(Q_block, K_block.transpose(-2, -1))
            attention_scores = self.online_softmax(attention_scores)
            output_block = torch.matmul(attention_scores, V_block)
            output_blocks.append(output_block)

        # Summing up the results from each block
        output_block_sum = sum(output_blocks)
        return output_block_sum

    def online_softmax(self, scores, chunk_size=128):
        # Apply softmax in chunks for large sequences
        softmaxed_scores = []
        for i in range(0, scores.size(0), chunk_size):
            chunk = scores[i:i + chunk_size, :]
            softmaxed_chunk = F.softmax(chunk, dim=1)
            softmaxed_scores.append(softmaxed_chunk)
        return torch.cat(softmaxed_scores, dim=0)


# Example usage of FlashAttention2
sequence_length = 1024  # Example sequence length
head_dimension = 64    # Example head dimension
block_size = 256       # Example block size

flash_attention = FlashAttention2(sequence_length, head_dimension, block_size)
Q, K, V = [torch.randn(sequence_length, head_dimension) for _ in range(3)]
output = flash_attention(Q, K, V)

class SparseFlash2Attention(nn.Module):
    def __init__(self, seq_len, head_dim, blk_size, sparsity_factor):
        super().__init__()
        self.flash_attention = FlashAttention2(seq_len, head_dim, blk_size)
        self.seq_len = seq_len
        self.head_dim = head_dim
        self.block_size = blk_size  # Storing block_size as an instance variable
        self.sparsity_factor = sparsity_factor

    def generate_sparsity_mask(self):
        mask = torch.zeros(self.seq_len, self.seq_len)
        step = self.sparsity_factor
        for i in range(0, self.seq_len, step):
            mask[i:i + step, :] = 1
        return mask.bool()

    def forward(self, Q, K, V):
        output = self.flash_attention(Q, K, V)
        sparsity_mask = self.generate_sparsity_mask()

        # Apply the sparsity mask to each block of the output
        output_blocks = output.chunk(chunks=output.size(0) // self.block_size, dim=0)
        masked_blocks = []
        for i, block in enumerate(output_blocks):
            # Select the relevant part of the sparsity mask for each block
            mask_chunk = sparsity_mask[i * self.block_size: (i + 1) * self.block_size]
            mask_chunk = mask_chunk[:, i * self.block_size: (i + 1) * self.block_size]  # Adjust mask chunk size
            mask_chunk = mask_chunk.unsqueeze(-1).expand(-1, -1, self.head_dim)
            masked_block = block.masked_fill(mask_chunk == 0, 0)
            masked_blocks.append(masked_block)

        output = torch.cat(masked_blocks, dim=0)

        return output


    
# Testing and Benchmarking for SparseFlash2Attention
def test_sparse_flash2_attention():
    seq_len, head_dim, blk_size, sparsity_factor = 1024, 64, 256, 4
    model = SparseFlash2Attention(seq_len, head_dim, blk_size, sparsity_factor)
    Q, K, V = [torch.randn(seq_len, head_dim) for _ in range(3)]

    output = model(Q, K, V)
    print("Output shape:", output.shape)  # Add this line to check the output shape
    assert output.shape == (seq_len, head_dim)

def benchmark_sparse_flash2_attention():
    seq_len, head_dim, blk_size, sparsity_factor = 1024, 64, 256, 4
    model = SparseFlash2Attention(seq_len, head_dim, blk_size, sparsity_factor)
    Q, K, V = [torch.randn(seq_len, head_dim) for _ in range(3)]

    start_time = time.time()
    output = model(Q, K, V)
    end_time = time.time()
    print(f"Time taken for forward pass: {end_time - start_time} seconds")

test_sparse_flash2_attention()
benchmark_sparse_flash2_attention()


Output shape: torch.Size([1024, 256, 64])


AssertionError: 

# v3

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

class FlashAttention2(nn.Module):
    def __init__(self, sequence_length, head_dimension, block_size):
        super(FlashAttention2, self).__init__()
        self.block_size = block_size
        # Ensure that sequence_length is divisible by block_size for simplicity
        assert sequence_length % block_size == 0

    def forward(self, Q, K, V):
        # Partitioning of inputs
        Q_blocks, K_blocks, V_blocks = self.partition_inputs(Q, K, V)

        # Efficient computation of the attention mechanism
        outputs = []
        for i, Q_block in enumerate(Q_blocks):
            output_block = self.process_block(Q_block, K_blocks, V_blocks)
            outputs.append(output_block)

        # Concatenating the processed blocks
        output = torch.cat(outputs, dim=0)
        return output

    def partition_inputs(self, Q, K, V):
        # The actual partitioning scheme should be based on sequence length, head dimension, and block size
        Q_blocks = Q.chunk(chunks=Q.size(0) // self.block_size, dim=0)
        K_blocks = K.chunk(chunks=K.size(0) // self.block_size, dim=0)
        V_blocks = V.chunk(chunks=V.size(0) // self.block_size, dim=0)
        return Q_blocks, K_blocks, V_blocks

    def process_block(self, Q_block, K_blocks, V_blocks):
        # Process each block efficiently as per FLASH2's optimized method
        # This includes computing QK^T, applying online softmax, and multiplying with V
        output_blocks = []
        for K_block, V_block in zip(K_blocks, V_blocks):
            attention_scores = torch.matmul(Q_block, K_block.transpose(-2, -1))
            attention_scores = self.online_softmax(attention_scores)
            output_block = torch.matmul(attention_scores, V_block)
            output_blocks.append(output_block)

        # Summing up the results from each block
        output_block_sum = sum(output_blocks)
        return output_block_sum

    def online_softmax(self, scores, chunk_size=128):
        # Apply softmax in chunks for large sequences
        softmaxed_scores = []
        for i in range(0, scores.size(0), chunk_size):
            chunk = scores[i:i + chunk_size, :]
            softmaxed_chunk = F.softmax(chunk, dim=1)
            softmaxed_scores.append(softmaxed_chunk)
        return torch.cat(softmaxed_scores, dim=0)


# Example usage of FlashAttention2
sequence_length = 1024  # Example sequence length
head_dimension = 64    # Example head dimension
block_size = 256       # Example block size

flash_attention = FlashAttention2(sequence_length, head_dimension, block_size)
Q, K, V = [torch.randn(sequence_length, head_dimension) for _ in range(3)]
output = flash_attention(Q, K, V)

class SparseFlash2Attention(nn.Module):
    def __init__(self, seq_len, head_dim, blk_size, sparsity_factor):
        super().__init__()
        self.flash_attention = FlashAttention2(seq_len, head_dim, blk_size)
        self.seq_len = seq_len
        self.head_dim = head_dim
        self.block_size = blk_size  # Storing block_size as an instance variable
        self.sparsity_factor = sparsity_factor

    def generate_sparsity_mask(self):
        mask = torch.zeros(self.seq_len, self.seq_len)
        step = self.sparsity_factor
        for i in range(0, self.seq_len, step):
            mask[i:i + step, :] = 1
        return mask.bool()

    def forward(self, Q, K, V):
        output = self.flash_attention(Q, K, V)  # output shape: [sequence_length, head_dimension]

        # Reshape output to be 3D for batch matrix multiplication
        output = output.unsqueeze(0)  # New shape: [1, sequence_length, head_dimension]

        sparsity_mask = self.generate_sparsity_mask()  # shape: [sequence_length, sequence_length]

        # Apply the sparsity mask to the output
        sparsity_mask = sparsity_mask.unsqueeze(0)  # New shape: [1, sequence_length, sequence_length]
        output = torch.bmm(sparsity_mask.float(), output.float())  # Perform batch matrix multiplication

        # Reshape the output back to 2D
        output = output.squeeze(0)  # New shape: [sequence_length, head_dimension]

        return output


    
# Testing and Benchmarking for SparseFlash2Attention
def test_sparse_flash2_attention():
    seq_len, head_dim, blk_size, sparsity_factor = 1024, 64, 256, 4
    model = SparseFlash2Attention(seq_len, head_dim, blk_size, sparsity_factor)
    Q, K, V = [torch.randn(seq_len, head_dim) for _ in range(3)]

    output = model(Q, K, V)
    print("Output shape:", output.shape)  # Add this line to check the output shape
    assert output.shape == (seq_len, head_dim)

def benchmark_sparse_flash2_attention():
    seq_len, head_dim, blk_size, sparsity_factor = 1024, 64, 256, 4
    model = SparseFlash2Attention(seq_len, head_dim, blk_size, sparsity_factor)
    Q, K, V = [torch.randn(seq_len, head_dim) for _ in range(3)]

    start_time = time.time()
    output = model(Q, K, V)
    end_time = time.time()
    print(f"Time taken for forward pass: {end_time - start_time} seconds")

test_sparse_flash2_attention()
benchmark_sparse_flash2_attention()


Output shape: torch.Size([1024, 64])
Time taken for forward pass: 0.010993480682373047 seconds
