# Developing GPT

## Setup

In [29]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

<torch._C.Generator at 0x107b9fed0>

In [30]:
"""Download Shakespeare training dataset"""
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

'Download Shakespeare training dataset'

In [31]:
"""Read Shakespeare file"""
with open('input.txt', 'r', encoding='utf-8') as file:
	text = file.read()

## Mathematical Trick in Self-Attention
- We want each of the 8 tokens in the T vector (time) to communicate with each other in a specific way
- Specifically, we want each token to communicate with the tokens that come before it, and not those that come after it
- This way, information only flows from previous context to the current timestamp, and not the other way around
- For every T-th token, we want to calculate the average of all the vectors in all previous tokens and the current token
- Unfortunately, this process is very inefficient. The trick is to increase the efficiency by using matrix multiplication

In [32]:
batch, time, channels = 4, 8, 2
x = torch.randn(batch, time, channels)
x.shape

torch.Size([4, 8, 2])

In [33]:
"""We want x[batch, time] = mean_{idx <= t} x [batch, idx]"""
"""There's a word stored at each of the 8 time locations. Bag-of-words is an expression for averaging"""
x_bag_of_words = torch.zeros((batch, time, channels))

for b in range(batch):
	for t in range(time):
		x_prev = x[b, :t+1]
		x_bag_of_words[b, t] = torch.mean(x_prev, 0)

In [34]:
"""0th batch element"""
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [35]:
"""First batch element is the same as the above because it's just the average of the first element,
but the second element is the average of elements one and two, and so on
"""
x_bag_of_words[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])