<a href="https://colab.research.google.com/github/SpencerFonbuena/MentorCruise/blob/main/time_series_selfattention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

## Below is an example of the attention, applied to my use case and with my understanding of each line. I've separated out each step on to different lines to keep things clear in the forward(). If a step has a "q" by it, it means there is a question in that step that I have

### Comprehension Check and Thoughts
1. qkv = self.qkv(data) # (64, 120, 1536) => We have now blended those original 512 features, into 1536 features.

2. qkv = qkv.reshape(B, S, 3, -1) => Instead of having each token be described by one large vector, we have now 3 seperate representations of each token,
    which will allow for a query, key, and value. Notice the dimension is back to the original number of features.

3. qkv.reshape(B, S, 3, self.num_heads, -1) => As discussed in our session, instead of having one large vector that may be computationally expensive, we
    have separated out that vector into 8 different representations of each query key and vector for each token. Doing it this way also provides intuition
    into why the features dimension is features/heads. (512/8 = 64). This creates the same computational cost as the original.

4. qkv.permute(2, 0, 3, 1, 4) => I've decided to do a toy example you can find below under the Markdown cell titled "How does Permute Keep Track of Information"

5. qkv.reshape(3, B*self.num_heads, 120, -1).unbind(0) => I have a continued toy example under the Markdown cell "What Effect does Combining Batches and Heads Have"

6. attn = (q * self.scale) @ k.transpose(-2, -1) => (512,120,120) (512 examples, 120 tokens, 120 scores) Each of the 512 scores is interpreted as the attention that
    token should lend to each of the other 511 tokens. I have a further explanation under "Why do we have to combine, and how does this help us with matmul"

7. attn = attn.softmax(dim=-1) => create a probability out of each of the attentions from the previous step. Will have the same dimensions

8. x = (attn @ v) # (512, 120, 64) => Multiply each of the attention scores by the value of each token in each example, and sum. This will give each token its "new identity"

Q9. x = x.view(B, self.num_heads, S, -1) => This is just re-separating the Batches from the num_heads. Q I'm not sure why it is a view instead of reshape. Q: Why did we use view here?
    The articles I saw talk about how view will explicitely make a "view" of the tensor, and not a "copy." Reshape on the other hand will make a view if possible, if not then copy
    on contiguous memory space. Maybe using view here will save some space?

10. x = x.permute(0, 2, 1, 3) => This is aligning the heads and the features together, so that we essentailly undo the reshaping that was done in step 3

11. x = x.reshape(B, S, -1) => Again, the next step after attention is a fully connected linear layer. We combine them to prepare for that. it also is now the same shape
    as it was when it came in. Only difference is that now it has learned a little more about itself throughout the attention operations.


In [None]:
# my mock data input size
data = torch.randn(64,120,512) # (Batch, Sequence, Features) => (batch, timestep, features)
class Attention(nn.Module):
    def __init__(self,
                 dim: int,
                 num_heads: int = 8,) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.qkv = nn.Linear(dim, dim * 3) # (64, 120, 512) @ (512, 512*3) | This creates 3x the number of features. 3x because there is 1 query + 1 key + 1 value = 3 representations of our original dataset
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5 # The softmax is sensitive to large differences in values. Scaling provides some stability
    def forward(self, data):
        B, S, _ = data.shape
        qkv = self.qkv(data) # (64, 120, 1536) => (64 examples, 120 tokens (in the form of timesteps), 1536 features)
        qkv = qkv.reshape(B, S, 3, -1) # (64, 120, 3, 512) => (64 examples, 120 tokens, 3 attributes of each token (QKV), 512 features)
        qkv = qkv.reshape(B, S, 3, self.num_heads, -1) # (64, 120, 3, 8, 64) => (64 examples, 120 tokens, 3 attributes of each token, 8 versions of each attribute, 64 features )
        qkv = qkv.permute(2, 0, 3, 1, 4) # (3, 64, 8, 120, 64)
        q, k, v = qkv.reshape(3, B*self.num_heads, 120, -1).unbind(0) # each q, k, v has dimensions 3x(512,120,64) => 3 of (examples, tokens, features)

        attn = (q * self.scale) @ k.transpose(-2, -1) # (512, 120, 120)

        attn = attn.softmax(dim=-1)

        x = (attn @ v) # (512, 120, 64)
        x = x.view(B, self.num_heads, S, -1) # (64,8,120,64)
        x = x.permute(0, 2, 1, 3) # (64, 120, 8, 64)
        x = x.reshape(B, S, -1) # (64, 120, 512)


#call Attention
dim = data.shape[2]
attention = Attention(dim)
attention(data)

## How Does Permute Keep Track of Information

In [None]:
toy = torch.arange(54).reshape(2,3,3,3) # (2 tokens, 3 qkv, 3 heads, 3 features)

#raw dataset
print(toy)

#Representation of each element
print(f'Token 1: \n Query 1 & Head 1 {toy[0,0,0]}  \n Query 1 & Head 2 {toy[0,0,1]}  \n Query 1 & Head 3 {toy[0,0,2]}')
print(f'Token 1: \n Key 1 & Head 1 {toy[0,1,0]}  \n Key 1 & Head 2 {toy[0,1,1]}  \n Key 1 & Head 3 {toy[0,1,2]}')
print(f'Token 1: \n Value 1 & Head 1 {toy[0,2,0]}  \n Value 1 & Head 2 {toy[0,2,1]}  \n Value 1 & Head 3 {toy[0,2,2]}')

print(f'\n \n Token 2: \n Query 2 & Head 1 {toy[1,0,0]}  \n Query 2 & Head 2 {toy[1,0,1]}  \n Query 2 & Head 3 {toy[1,0,2]}')
print(f'Token 2: \n Key 2 & Head 1 {toy[1,0,0]}  \n Key 2 & Head 2 {toy[1,0,1]}  \n Key 2 & Head 3 {toy[1,0,2]}')
print(f'Token 2: \n Value 2 & Head 1 {toy[1,0,0]}  \n Value 2 & Head 2 {toy[1,0,1]}  \n Value 2 & Head 3 {toy[1,0,2]}')

# Divided representation
print(f'\n\nThe queries for token 1 are \n{toy[0,0,:]} \n {toy[1,0,:]}')
print(f'\n\n The keys for token 1 are \n{toy[0,1,:]} \n {toy[1,1,:]}')
print(f'\n\n The values for token 1 are \n{toy[0,2,:]} \n {toy[1,2,:]}')
toy = torch.arange(54).reshape(2,3,3,3).permute(1,0,2,3) # As we can see, through permutation, we have simply grouped together the queries, keys, and values for each token.
                                                         # The last two dimensions represent (heads, features) respectively.
print(f'\n \n We can see, this permute simply grouped them together \n{toy}')


tensor([[[[ 0,  1,  2],
          [ 3,  4,  5],
          [ 6,  7,  8]],

         [[ 9, 10, 11],
          [12, 13, 14],
          [15, 16, 17]],

         [[18, 19, 20],
          [21, 22, 23],
          [24, 25, 26]]],


        [[[27, 28, 29],
          [30, 31, 32],
          [33, 34, 35]],

         [[36, 37, 38],
          [39, 40, 41],
          [42, 43, 44]],

         [[45, 46, 47],
          [48, 49, 50],
          [51, 52, 53]]]])
Token 1: 
 Query 1 & Head 1 tensor([0, 1, 2])  
 Query 1 & Head 2 tensor([3, 4, 5])  
 Query 1 & Head 3 tensor([6, 7, 8])
Token 1: 
 Key 1 & Head 1 tensor([ 9, 10, 11])  
 Key 1 & Head 2 tensor([12, 13, 14])  
 Key 1 & Head 3 tensor([15, 16, 17])
Token 1: 
 Value 1 & Head 1 tensor([18, 19, 20])  
 Value 1 & Head 2 tensor([21, 22, 23])  
 Value 1 & Head 3 tensor([24, 25, 26])

 
 Token 2: 
 Query 2 & Head 1 tensor([27, 28, 29])  
 Query 2 & Head 2 tensor([30, 31, 32])  
 Query 2 & Head 3 tensor([33, 34, 35])
Token 2: 
 Key 2 & Head 1 tensor([27, 28

## What Effect does Combining Batch and Heads have?

In [None]:
toy = toy.reshape(3, 6, 3) # we went from a permuted (3,2,3,3) to (3,6,3). We did this by multiplying our fictitous tokens and heads dimensions, which are equivalny to the batch and heads above
print(f'Now we have our combined: Queries \n {toy[0,:]}')
print(f'Now we have our combined: Keys \n {toy[1,:]}')
print(f'Now we have our combined: Values \n {toy[2,:]}')



Now we have our combined: Queries 
 tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [27, 28, 29],
        [30, 31, 32],
        [33, 34, 35]])
Now we have our combined: Keys 
 tensor([[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17],
        [36, 37, 38],
        [39, 40, 41],
        [42, 43, 44]])
Now we have our combined: Values 
 tensor([[18, 19, 20],
        [21, 22, 23],
        [24, 25, 26],
        [45, 46, 47],
        [48, 49, 50],
        [51, 52, 53]])


## Why do we have to combine, and how does this help us with matmul

In [None]:
# It seems like combining batch and heads provides use the possibility to quickly matrix multiply each query with each key. As you can see, it is done correctly because
#each row of query, is multiplied by each transposed column of key.

q = torch.arange(54).reshape(3,6,3) # (3 examples, 6 tokens, 3 features)
k = torch.arange(54).reshape(3,6,3).transpose(-2,-1) # (3 examples, 6 tokens, 3 features)
print(f'Queries dot product \n {q[0]} \n {k[0]}')
print(f'Keys dot product \n {q[1]} \n {k[1]}')
print(f'Values dot product \n {q[2]} \n {k[2]}')

Queries dot product 
 tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]]) 
 tensor([[ 0,  3,  6,  9, 12, 15],
        [ 1,  4,  7, 10, 13, 16],
        [ 2,  5,  8, 11, 14, 17]])
Keys dot product 
 tensor([[18, 19, 20],
        [21, 22, 23],
        [24, 25, 26],
        [27, 28, 29],
        [30, 31, 32],
        [33, 34, 35]]) 
 tensor([[18, 21, 24, 27, 30, 33],
        [19, 22, 25, 28, 31, 34],
        [20, 23, 26, 29, 32, 35]])
Values dot product 
 tensor([[36, 37, 38],
        [39, 40, 41],
        [42, 43, 44],
        [45, 46, 47],
        [48, 49, 50],
        [51, 52, 53]]) 
 tensor([[36, 39, 42, 45, 48, 51],
        [37, 40, 43, 46, 49, 52],
        [38, 41, 44, 47, 50, 53]])
