In [7]:
import os 
# set workdir to parent directory
os.chdir("../")

In [8]:
import torch
from model import CausalSelfAttention, GPTConfig

import torch.nn as nn
import math
import inspect
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F
import matplotlib.pyplot as plt
from einops import rearrange
# import xformers.ops as xops
from einops import rearrange
from torch import Tensor, nn

In [None]:
class DilatedAttention(nn.Module):
    """Implement dilated, scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """

    def __init__(
        self,
        segment_lengths,
        dilation_rates,
        softmax_scale=None,
        attention_dropout=0.0,
        op = None
    ):
        super().__init__()
        if len(segment_lengths) != len(dilation_rates):
            raise ValueError(
                "segment_lengths and dilation_rates must have the same length"
            )

        self.segment_lengths = segment_lengths
        self.dilation_rates = dilation_rates
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout
        self.op = op

    def forward(
        self, query, key, value, is_causal = False
    ) :
        # Notation:
        #   b - batch size
        #   n - sequence length
        #   h - number of heads
        #   d - embedding dimension
        #   s - segment length
        #   r - dilation rate
        #   g - group size (i.e. number of heads per segment length)
        #
        # Input shape of query, key, value: (b, n, h, d)
        b, _, h, _ = query.shape
        out = torch.zeros_like(query)

        # *** NOTE ***
        # The original paper does not describe how to handle the case where
        #   h % len(self.segment_lengths) != 0
        #
        # In my first implementation, I naively assumed (and asserted) that
        # 'h % len(self.segment_lengths) == 0', so that I could evenly distribute
        # the heads between the different segment lengths. However, it was not
        # possible to reproduce the LongNet hyperparameters with that restriction:
        #   h=12, segment_lengths=[2048, 4096, 8192, 16384, 32768]
        #   h % len(segment_lengths) == 2
        #
        # For that reason, I have removed the assertion, and instead grouped the heads
        # into (potentially) unequally sized groups.  If not perfectly divisible, then
        # the first few groups will have an extraattention head.
        num_groups = len(self.dilation_rates)
        group_sizes = [h // num_groups] * num_groups
        for i in range(h % num_groups):
            group_sizes[i] += 1

        for i, (g, r, s) in enumerate(
            zip(group_sizes, self.dilation_rates, self.segment_lengths)
        ):
            # Split the input sequences into segments of length 'self.segment_length'
            q = rearrange(query, "b (n s) h d -> b n s h d", s=s)
            k = rearrange(key, "b (n s) h d -> b n s h d", s=s)
            v = rearrange(value, "b (n s) h d -> b n s h d", s=s)
            # Apply dilation and segment offset
            offset = i % r
            hmin = i * g
            hmax = (i + 1) * g
            q = q[:, :, offset::r, hmin:hmax, :]
            k = k[:, :, offset::r, hmin:hmax, :]
            v = v[:, :, offset::r, hmin:hmax, :]
            # Fold all 'n' segments into the batch dimension
            q = rearrange(q, "b n s h d -> (b n) s h d")
            k = rearrange(k, "b n s h d -> (b n) s h d")
            v = rearrange(v, "b n s h d -> (b n) s h d")

            # Apply memory efficient attention
            # NOTE: If flash attention is correctly installed, then this will also
            # automatically use the flash attention implementation.
            attn_bias = xops.LowerTriangularMask() if is_causal else None
            x = xops.memory_efficient_attention(
                query=q, key=k, value=v, op=self.op, attn_bias=attn_bias
            )
            # Unfold 'n' segments back out of the batch dimension.
            x = rearrange(x, "(b n) s h d -> b n s h d", b=b)
            # Normalize attention outputs across the sequence length dimension. This
            # is necessary because the attention outputs from each dilation rate /
            # segment length are summed together.
            x = x / x.sum(dim=(1, 2), keepdim=True)

            # Gather the attention outputs from each dilation rate / segment length.
            out = rearrange(out, "b (n s) h d -> b n s h d", s=s)
            out[:, :, offset::r, hmin:hmax, :] += x
            out = rearrange(out, "b n s h d -> b (n s) h d", s=s)

        # We have already normalized each attention output across the sequence length.
        # Now, normalize across all attention outputs by dividing by the number of
        # attention groups.  See: https://arxiv.org/pdf/2307.02486.pdf, Eq. 10
        return out / num_groups

class MultiheadDilatedAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(
            config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.attention = DilatedAttention(
            segment_lengths=[2048, 4096, 8192, 16384, 32768],
            dilation_rates=[1, 2, 4, 6, 12],
            attention_dropout=config.dropout)

    def forward(self, x):
        b, t, d = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)
        
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(b, t, self.n_head, d //
                   self.n_head)  # (b, t, h, d)
        q = q.view(b, t, self.n_head, d //
                   self.n_head)  # (b, t, h, d)
        v = v.view(b, t, self.n_head, d //
                   self.n_head)  # (b, t, h, d)
        y = self.attention(q, k, v, is_causal=True)
        y = rearrange(x, 'b t h d -> b t (h d)')
        y = self.resid_dropout(self.c_proj(y))
        return y

In [11]:
config = GPTConfig(
    block_size=16,
    vocab_size=50304,
    n_layer=12,
    n_head=4,
    n_embd=16,
    dropout=0.1,
    bias=True,

)
b, t, d = 4, 4096, 16
x = torch.randn(b, t, d)
m = MultiheadDilatedAttention(config)
y = m(x)

NameError: name 'xops' is not defined