In [5]:
import torch
import torch.nn as nn 
from torch.nn import functional as F
import math


d_model = 128
block_size = 16

class SlidingWindowAttention(nn.Module):
    def __init__(self, window_size, block_size):
        super().__init__()
        self.window_size = window_size
        self.key = nn.Linear(d_model, d_model)
        self.query = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        self.register_buffer('mask', self.sliding_window_mask(window_size))

    def forward(self, x):
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        wei = (q @ k.transpose(2, 1)) / math.sqrt(k.size(-1)**-.5)
        wei = wei.masked_fill(self.mask[:block_size, :block_size] == 0, float('-inf'))
        out = torch.softmax(wei, dim=-1) @ v
        return out

    @staticmethod
    def sliding_window_mask(window_size):
        mask = torch.zeros((block_size, block_size))
        n = len(mask)
        for i in range(n):
            start = max(0, i - window_size + 1)
            end = i 
            mask[i, start:end+1] = 1 
        return mask
    
x = torch.randn(1, block_size, d_model)
attention = SlidingWindowAttention(window_size=4, block_size=block_size)
assert attention(x).shape == x.shape, "shape mismatch"