In [None]:
!git clone https://github.com/kyegomez/longnet.git

%cd LongNet

!pip install -r requirements.txt

%cd test

!python attention.py


In [11]:

!pip install torch
!pip install einops
!pip install torchscale

!git clone https://github.com/HazyResearch/flash-attention.git
%cd flash-attention
!python setup.py install
%cd ..  # Go back to the parent directory after the installation


import torch
import torch.nn as nn
import torch.nn.functional as F

from flash_attn.flash_attention import  FlashMHA

from torchscale import XPOS, RelativePositionBias

# Replace this with your correct GPU device
device = "cuda:0"
dtype=torch.float16

#add alibi, qk layer norm, one write head, multihway,
class DilatedAttention(nn.Module):
    """
    Dilated Attention Module.

    Arguments:
        d_model: The dimension of the attention layers.
        num_heads: The number of attention heads.
        dilation_rate: The dilation rate for dilated attention.
        segment_size: The segment size for dilated attention.
        dropout (optional): The dropout probability. Default: 0.0
        casual (optional): If set to True, the attention mechanism is casual. Default: False
        use_xpos (optional): If set to True, xpos is used for positional encoding. Default: False
        use_rel_pos_bias (optional): If set to True, relative position bias is used in the attention mechanism. Default: False

    Usage:
        The `DilatedAttention` class can be used as a module for neural networks and is especially suited for transformer architectures.

        Example:
            attention = DilatedAttention(d_model=512, num_heads=8, dilation_rate=2, segment_size=64, use_xpos=True, use_rel_pos_bias=True)
            output = attention(input_tensor)

        This will return the output tensor after applying dilated attention. The `use_xpos` and `use_rel_pos_bias` parameters allow for switching on positional encoding and relative positional bias respectively.
    """
    def __init__(self, d_model, num_heads, dilation_rate, segment_size, dropout=0.0, casual=False, use_xpos=False, use_rel_pos_bias=False):
        super(DilatedAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads

        self.dilation_rate = dilation_rate
        self.segment_size = segment_size

        self.attention = FlashMHA(embed_dim=d_model, num_heads=num_heads, device=device, dtype=dtype)
        self.dropout = nn.Dropout(dropout)
        self.casual = casual

        self.use_xpos = use_xpos
        self.use_rel_pos_bias = use_rel_pos_bias

        if use_xpos:
            self.xpos = XPOS(head_dim=d_model//num_heads)
        if use_rel_pos_bias:
            self.relative_bias = RelativePositionBias(num_buckets=32, max_distance=128, n_heads=num_heads)

    def get_mask(self, i, j):
        return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 2)

    def forward(self, x):
        batch_size, seq_len, _ = x.shape

        if self.use_xpos:
            x = self.xpos(x)

        # Split and sparsify
        x = x.view(batch_size, -1, self.segment_size, self.d_model)
        x = x[:, :, :: self.dilation_rate, :]

        # Perform attention
        attn_output, _ = self.attention(x, x, x)

        if self.use_rel_pos_bias:
            attn_output += self.relative_bias(batch_size, attn_output.size(1), attn_output.size(1))

        # if casual create a mask and apply to the output
        if self.casual:
            mask = self.get_mask(attn_output.size(1), attn_output.size(1))
            attn_output = attn_output.masked_fill(mask, float('-inf'))

        # apply dropout
        attn_output = self.dropout(attn_output)

        # Scatter and concatenate
        attn_output = attn_output.view(batch_size, -1, self.d_model)
        return attn_output




Cloning into 'flash-attention'...
remote: Enumerating objects: 2743, done.[K
remote: Counting objects: 100% (1170/1170), done.[K
remote: Compressing objects: 100% (200/200), done.[K
remote: Total 2743 (delta 1016), reused 1005 (delta 968), pack-reused 1573[K
Receiving objects: 100% (2743/2743), 3.16 MiB | 14.71 MiB/s, done.
Resolving deltas: 100% (1817/1817), done.
/content/flash-attention/flash-attention/flash-attention/flash-attention
No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'

 If your intention is to cross-compile, this is not an error.
By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),
Volta (compute capability 7.0), Turing (compute capability 7.5),
and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).
If you wish to cross-compile for a single specific architecture,
export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.



torch.__version__  = 2.0.1+cu118


Submodule 'csrc/flash_attn/cutlass'

ModuleNotFoundError: ignored

In [None]:
import time
import unittest
import torch

# from LongNet import DilatedAttention, MultiModalDilationAttention

class TestDilatedAttention(unittest.TestCase):

    def test_output_shape(self):
        # Setup
        input_tensor = torch.randn(2, 128, 512)
        dilated_attention = DilatedAttention(512, 8, 2, 64)

        # Action
        output = dilated_attention(input_tensor)

        # Assert
        self.assertEqual(output.shape, (2, 128, 512))

    def test_xpos(self):
        # Setup
        input_tensor = torch.randn(2, 128, 512)
        dilated_attention = DilatedAttention(512, 8, 2, 64, use_xpos=True)

        # Action
        output = dilated_attention(input_tensor)

        # Assert
        self.assertEqual(output.shape, (2, 128, 512))

    def test_relative_position_bias(self):
        # Setup
        input_tensor = torch.randn(2, 128, 512)
        dilated_attention = DilatedAttention(512, 8, 2, 64, use_rel_pos_bias=True)

        # Action
        output = dilated_attention(input_tensor)

        # Assert
        self.assertEqual(output.shape, (2, 128, 512))


    def test_attention_consistency(self):
        # Setup
        input_tensor = torch.randn(2, 128, 512)
        dilated_attention = DilatedAttention(512, 8, 2, 64)

        # Action
        output = dilated_attention(input_tensor)

        # Assert
        self.assertTrue((output.std(dim=-1) > 0).all())

    def test_speed(self):
        # Setup
        input_tensor = torch.randn(2, 1024, 512)
        dilated_attention = DilatedAttention(512, 8, 2, 64)

        # Action
        start_time = time.time()
        output = dilated_attention(input_tensor)
        end_time = time.time()

        # Assert
        self.assertLess(end_time - start_time, 1)

    def test_gradient_flow(self):
        # Setup
        input_tensor = torch.randn(2, 128, 512, requires_grad=True)
        dilated_attention = DilatedAttention(512, 8, 2, 64)

        # Action
        output = dilated_attention(input_tensor)
        output.sum().backward()
        grad_norm = input_tensor.grad.norm().item()

        # Assert
        self.assertLess(grad_norm, 1e6)
        self.assertGreater(grad_norm, 1e-6)

    def test_scaling(self):
        input_tensor = torch.randn(2, 1024, 512)
        dilated_attention = DilatedAttention(512, 8, 2, 64)
        start_time = time.time()
        _ = dilated_attention(input_tensor)
        time_for_1024 = time.time() - start_time

        input_tensor = torch.randn(2, 2048, 512)
        start_time = time.time()
        _ = dilated_attention(input_tensor)
        time_for_2048 = time.time() - start_time

        self.assertLessEqual(time_for_2048/time_for_1024, 2)

    def test_reproducibility(self):
        torch.manual_seed(0)
        input_tensor = torch.randn(2, 128, 512)
        dilated_attention = DilatedAttention(512, 8, 2, 64)
        output1 = dilated_attention(input_tensor)

        torch.manual_seed(0)
        input_tensor = torch.randn(2, 128, 512)
        dilated_attention = DilatedAttention(512, 8, 2, 64)
        output2 = dilated_attention(input_tensor)

        self.assertTrue(torch.allclose(output1, output2))

    def test_attention_distribution(self):
        input_tensor = torch.randn(2, 128, 512)
        dilated_attention = DilatedAttention(512, 8, 2, 64)
        _, attn_weights = dilated_attention(input_tensor)

        self.assertTrue(torch.allclose(attn_weights.sum(dim=-1), torch.tensor(1.)))






In [None]:
import timeit
import torch

#model config
d_model = 512
num_heads = 8
dilation_rate = 2
segment_size = 64

device = "cuda:0"
dtype=torch.float16

#input data
batch_size = 32
seq_len = 1024


#create model and data
model = DilatedAttention(d_model, num_heads, dilation_rate, segment_size).to(device)
x = torch.randn((batch_size, seq_len, d_model), device=device, dtype=dtype)


#test forward pass
with torch.no_grad():
    output = model(x)
    print(f"Output shape: {output.shape}") # expected (batch_size, seq_Len)


#benchmark model
num_runs = 1000
start_time = timeit.default_timer()
for _ in range(num_runs):
    model(x)

elapsed_time = timeit.default_timer() - start_time
print(f"Average forward pass time: {elapsed_time / num_runs:.6f} seconds")