In [20]:
import torch.nn as nn
import torch.nn.functional as F
import torch

In [118]:
window_size = 6
sequence = torch.rand((5, 8, 3, 12))  # Tensor: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

windows = sequence.unfold(1, window_size, 2)
windows.shape

torch.Size([5, 2, 3, 12, 6])

In [228]:
batch_size = 1
seq_len = 3
num_heads = 2
head_dim = 4
hidden_size = num_heads * head_dim

window_size = 2
padding = window_size // 2

queries = torch.rand((batch_size, num_heads, seq_len, head_dim))
keys = torch.rand((batch_size, num_heads, seq_len, head_dim))
values = torch.rand((batch_size, num_heads, seq_len, head_dim))

print(queries.shape, keys.shape, values.shape)
values

queries_ = queries.clone().detach()
keys_ = keys.clone().detach()
values_ = values.clone().detach()

torch.Size([1, 2, 3, 4]) torch.Size([1, 2, 3, 4]) torch.Size([1, 2, 3, 4])


In [229]:
keys.reshape((batch_size, num_heads, seq_len, head_dim // 2 , 2))

tensor([[[[[0.7016, 0.8530],
           [0.8243, 0.5465]],

          [[0.0369, 0.3384],
           [0.4608, 0.1522]],

          [[0.2582, 0.2979],
           [0.0515, 0.4759]]],


         [[[0.5817, 0.8130],
           [0.1351, 0.8070]],

          [[0.7353, 0.8495],
           [0.3326, 0.1671]],

          [[0.2150, 0.2406],
           [0.0405, 0.5801]]]]])

In [216]:

keys = F.pad(keys, (0, 0, padding, padding), "constant", 0)
values = F.pad(values, (0, 0, padding, padding), "constant", 0)

# Impl 1
# Permute to get (batch_size, num_heads, seq_len, window_size, head_dim)
# # This aligns window_size as the second-to-last dimension for einsum

keys_windowed = keys.unfold(2, window_size, 1)[:, :, :seq_len, :, :].permute(0, 1, 2, 4, 3)
values_windowed = values.unfold(2, window_size, 1)[:, :, :seq_len, :, :].permute(0, 1, 2, 4, 3)
print(keys_windowed.shape, values_windowed.shape)

scores = torch.einsum('bnsh,bnswh->bnsw', queries, keys_windowed)
print(scores.shape)

scores = scores / (head_dim ** 0.5)
print(scores.shape)

attention = F.softmax(scores, dim=-1)
print(attention.shape)

context = torch.einsum('bnsw,bnswh->bsnh', attention, values_windowed)
print(context.shape)

context = context.reshape(batch_size, seq_len, hidden_size)
print(context.shape)

context_1 = context.clone().detach()
context_1


torch.Size([1, 2, 3, 2, 4]) torch.Size([1, 2, 3, 2, 4])
torch.Size([1, 2, 3, 2])
torch.Size([1, 2, 3, 2])
torch.Size([1, 2, 3, 2])
torch.Size([1, 3, 2, 4])
torch.Size([1, 3, 8])


tensor([[[0.4173, 0.0515, 0.2000, 0.2083, 0.4375, 0.0428, 0.7168, 0.5457],
         [0.4546, 0.3947, 0.7048, 0.4549, 0.3701, 0.1302, 0.9216, 0.4152],
         [0.5426, 0.7919, 0.8697, 0.3831, 0.1972, 0.5028, 0.6689, 0.4262]]])

In [217]:
# Impl 2
# Initialize context tensors
keys_ = F.pad(keys_, (0, 0, padding, padding), "constant", 0)
values_ = F.pad(values_, (0, 0, padding, padding), "constant", 0)

context = torch.zeros_like(queries_)

# Compute attention for each sliding window
for i in range(seq_len):
    # Determine the start and end of the window
    start = i
    end = i + window_size
    
    # Compute scores
    # (batch_size, num_heads, seq_length, head_dim)
    scores = torch.matmul(queries_[:, :, i:i+1, :], keys_[:, :, start:end, :].transpose(-2, -1))
    scores = scores / (head_dim ** 0.5)
    attention = F.softmax(scores, dim=-1)
    # print(attention)
    
    
    # Apply attention to values and add to context
    context[:, :, i:i+1, :] += torch.matmul(attention, values_[:, :, start:end, :])

# Reshape context to (batch_size, seq_length, num_heads * head_dim)
context = context.permute(0, 2, 1, 3).reshape(batch_size, seq_len, hidden_size)
print(context.shape)

context_2 = context.clone().detach()
context_2

torch.Size([1, 3, 8])


tensor([[[0.4173, 0.0515, 0.2000, 0.2083, 0.4375, 0.0428, 0.7168, 0.5457],
         [0.4546, 0.3947, 0.7048, 0.4549, 0.3701, 0.1302, 0.9216, 0.4152],
         [0.5426, 0.7919, 0.8697, 0.3831, 0.1972, 0.5028, 0.6689, 0.4262]]])

In [211]:
torch.allclose(context_1, context_2, 1e-6)

True

In [225]:
import torch
import torch.nn.functional as F

def sliding_window_attention_einsum(queries, keys, values, window_size, padding, head_dim, hidden_size):
    batch_size = queries.shape[0]
    num_heads = queries.shape[1]
    seq_len = queries.shape[2]
    
    padded_keys = F.pad(keys, (0, 0, padding, padding), "constant", 0) # Pads last two dims: (0,0 for head_dim, padding,padding for seq_dim)
    padded_values = F.pad(values, (0, 0, padding, padding), "constant", 0)

    unfolded_keys = padded_keys.unfold(2, window_size, 1)
    unfolded_values = padded_values.unfold(2, window_size, 1)

    K_windows = unfolded_keys[:, :, :seq_len, :, :]  # Shape: (B, N, S, Dk, W)
    V_windows = unfolded_values[:, :, :seq_len, :, :]  # Shape: (B, N, S, Dk, W)

    # K_windows = K_windows.permute(0, 1, 2, 4, 3)
    # V_windows = V_windows.permute(0, 1, 2, 4, 3)

    scores = torch.einsum('bnsd,bnsdw->bnsw', queries, K_windows)

    scores = scores / (head_dim ** 0.5)
    attention_weights = F.softmax(scores, dim=-1) # Softmax over W (window_size dimension)
                                               # Shape: (B, N, S, W)

    context = torch.einsum('bnsw,bnsdw->bnsd', attention_weights, V_windows)

    context = context.permute(0, 2, 1, 3).reshape(batch_size, seq_len, hidden_size)
    
    return context

if __name__ == '__main__':
    # Example Usage (matching the variable names from the problem if possible)
    batch_size = 2
    seq_len = 5
    num_heads = 4
    head_dim = 8
    hidden_size = num_heads * head_dim # 32
    window_size = 3
    
    # For the padding to match the original example's implied structure:
    # The original code's loop is:
    # for i in range(seq_len):
    #   start = i
    #   end = i + window_size
    #   scores = torch.matmul(queries[:, :, i:i+1, :], keys_padded[:, :, start:end, :].transpose(-2, -1))
    # This means the window is purely forward or centered around the query 'i' depending on how 'padding' relates to 'window_size'.
    # If padding = window_size -1 (as it seems to be in many such implementations for causal/local attention)
    # this creates a "look-ahead" style window of size `window_size` starting at `i` in the padded sequence.
    
    # Let's use the specific padding value from the original snippet if it was fixed,
    # or define it based on common practice if it was a variable `padding`.
    # The prompt used a variable `padding`. Let's assume it's `window_size - 1` for this example,
    # which is a common setup for ensuring windows can be formed at sequence boundaries.
    # Or, for centered windows, padding = window_size // 2 (if window_size is odd)
    # The original problem statement had `padding` as a variable. So we should accept it.
    # For this test, let's use padding = window_size // 2 for a somewhat centered window.
    # If window_size = 3, padding = 1.
    # If we use the same padding as the prompt `padding = window_size - 1 = 2` for `window_size = 3`
    example_padding = window_size -1 # This seems to be a common pattern for such sliding windows.
                                  # If window_size = 3, then padding = 2.
                                  # Padded seq_len = 5 + 2*2 = 9.
                                  # Keys: p p k0 k1 k2 k3 k4 p p
                                  # q0 att to (p,p,k0) (using start=0, end=3 in original loop)
                                  # This interpretation for start=0 in padded sequence might be off.
                                  # Original loop: start=i, end=i+window_size.
                                  # If padding is on *original keys* not *padded_keys_indices*
                                  # This needs to be consistent with original for comparison.

    # The provided snippet's `padding` is an input to F.pad.
    # Let's test with values that make sense.
    # If window_size = 3, a common padding is 1 on each side (total window covers q_i-1, q_i, q_i+1)
    padding_amount_for_test = window_size // 2 # e.g., 1 for window_size=3

    queries_tensor = torch.randn(batch_size, num_heads, seq_len, head_dim)
    keys_tensor = torch.randn(batch_size, num_heads, seq_len, head_dim)
    values_tensor = torch.randn(batch_size, num_heads, seq_len, head_dim)

    print("--- Einsum Implementation ---")
    context_einsum = sliding_window_attention_einsum(
        queries_tensor, keys_tensor, values_tensor,
        window_size, padding_amount_for_test, head_dim, hidden_size
    )
    print("Context shape (einsum):", context_einsum.shape)


    # For comparison, let's try to replicate the original loop logic
    # The original code snippet used a variable `padding`.
    # We must use the SAME `padding_amount_for_test` for a fair comparison.
    
    print("\n--- Original Loop Implementation (for comparison) ---")
    # Original Impl (adapted for standalone test)
    # Note: The 'padding' variable in the original snippet is the amount of padding on each side.
    
    # Initialize context tensors
    _keys = F.pad(keys_tensor, (0, 0, padding_amount_for_test, padding_amount_for_test), "constant", 0)
    _values = F.pad(values_tensor, (0, 0, padding_amount_for_test, padding_amount_for_test), "constant", 0)
    
    _context_loop = torch.zeros_like(queries_tensor) # (B, N, S, Dk)
    
    # Compute attention for each sliding window
    for i in range(seq_len):
        # Determine the start and end of the window in the padded keys/values
        start = i 
        end = i + window_size
        
        # Slice the query
        q_i = queries_tensor[:, :, i:i+1, :] # (B, N, 1, Dk)
        
        # Slice the window from padded keys and values
        k_window = _keys[:, :, start:end, :]     # (B, N, W, Dk)
        v_window = _values[:, :, start:end, :]   # (B, N, W, Dk)
        
        # Compute scores
        # (B, N, 1, Dk) @ (B, N, Dk, W) -> (B, N, 1, W)
        scores_loop = torch.matmul(q_i, k_window.transpose(-2, -1))
        scores_loop = scores_loop / (head_dim ** 0.5)
        attention_loop = F.softmax(scores_loop, dim=-1)
        
        # Apply attention to values and add to context
        # (B, N, 1, W) @ (B, N, W, Dk) -> (B, N, 1, Dk)
        attended_values_loop = torch.matmul(attention_loop, v_window)
        _context_loop[:, :, i:i+1, :] = attended_values_loop # Use '=' since it's 0-initialized, original had += but context_i was new

    # Reshape context to (batch_size, seq_len, num_heads * head_dim)
    _context_loop_final = _context_loop.permute(0, 2, 1, 3).reshape(batch_size, seq_len, hidden_size)
    print("Context shape (loop):", _context_loop_final.shape)

    # Check if results are close (they should be identical if logic is same)
    if torch.allclose(context_einsum, _context_loop_final, atol=1e-6):
        print("\nResults from einsum and loop implementations are close!")
    else:
        print("\nResults from einsum and loop implementations DIFFER!")
        # print("Einsum context head:", context_einsum.view(-1)[:10])
        # print("Loop context head:", _context_loop_final.view(-1)[:10])
        # print("Difference:", torch.abs(context_einsum - _context_loop_final).max())

    # Recreate the specific output of the original problem for context_2
    # context = context_einsum # Assuming this is the primary output
    # print(context.shape) # Already printed
    # context_2 = context.clone().detach()
    # print(context_2) # This would print the tensor values



--- Einsum Implementation ---
Context shape (einsum): torch.Size([2, 5, 32])

--- Original Loop Implementation (for comparison) ---
Context shape (loop): torch.Size([2, 5, 32])

Results from einsum and loop implementations are close!


In [244]:
dim=10
period=100
context_size=5

exps = (1./dim) * torch.arange(0, (dim-1), 2)
print(exps.shape, exps)
freqs = (1./torch.pow(period, exps))
print(freqs)

token_indexes = torch.arange(0, context_size)
    
# TODO: compute the matrix thetas
thetas = torch.outer(token_indexes, freqs)
print(thetas.shape)

# TODO: create the rotation matrix
rotation_matrix = torch.polar(torch.ones_like(thetas), thetas)
rotation_matrix.shape

torch.Size([5]) tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000])
tensor([1.0000, 0.3981, 0.1585, 0.0631, 0.0251])
torch.Size([5, 5])


torch.Size([5, 5])

In [285]:
x = torch.tensor([[[-0.8772, -0.6420, -0.1764,  0.4584, -0.8136, -0.6509, -0.5871,
           0.6453,  0.0991, -0.3950],
         [ 0.3687, -0.9497, -0.9017,  0.1877,  0.9783, -0.1219,  0.8722,
          -0.3780,  0.3587,  0.3795],
         [ 0.4127,  1.5541,  0.7300, -0.1090, -0.2371,  0.0809,  0.1396,
          -0.6884,  0.0311,  0.8754],
         [ 0.2966, -0.3919, -0.5396, -0.6623, -0.0381,  0.8342,  1.2686,
          -1.3032,  0.4697,  0.4133],
         [ 0.2843, -0.2664, -0.3312, -0.4801,  0.3916,  1.0380,  1.3515,
          -0.5189, -0.1152, -0.0456],
         [-0.1839, -0.3232, -0.6245,  0.4363, -0.3063,  0.4891, -0.0958,
          -0.2570,  0.8761, -0.4627],
         [-0.1320, -0.3170, -0.6218,  0.4533, -0.3246,  0.4891, -0.0982,
          -0.2880,  0.8828, -0.4444],
         [-0.1295, -0.3561, -0.6142,  0.4811, -0.3424,  0.4857, -0.1472,
          -0.2484,  0.8728, -0.4723],
         [-0.1699, -0.3207, -0.5689,  0.4708, -0.3790,  0.4455, -0.2404,
          -0.2106,  0.8703, -0.4348],
         [-0.1862, -0.3174, -0.5680,  0.4516, -0.3676,  0.4528, -0.2310,
          -0.2155,  0.8513, -0.4210]]])

topk_indices = torch.tensor([[[7, 3],
         [4, 6],
         [1, 9],
         [6, 5],
         [6, 5],
         [8, 5],
         [8, 5],
         [8, 5],
         [8, 3],
         [8, 5]]])


out = torch.where(topk_indices == 3, True, False)
x[out]

# x[torch.where(topk_indices == 3)]
# batch_szie x seq_len x experts


IndexError: The shape of the mask [1, 10, 2] at index 2 does not match the shape of the indexed tensor [1, 10, 10] at index 2

In [286]:
x = torch.tensor([1,2,3,4,5])
torch.where(x%2==0)

(tensor([1, 3]),)