In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

# Preprocessing

In [2]:
batch_size = 16
num_steps = 1000
lr = 1e-3
context_length = 32
embedding_size = 32 # change this in future and test if it works 
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
with open('input.txt', 'r', encoding='utf-8' ) as f:
    data = f.read().lower()
len(data)

1115394

In [4]:
chars = sorted(list(set(data)))
chars.remove('3')
chars.append('<UNK>')
print(chars)
len(chars)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', ':', ';', '?', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '<UNK>']


39

In [5]:
vocab = {}
for t in range(len(chars)):
    vocab[t] = chars[t]

rev_vocab = {v:k for k,v in vocab.items()}
rev_vocab

{'\n': 0,
 ' ': 1,
 '!': 2,
 '$': 3,
 '&': 4,
 "'": 5,
 ',': 6,
 '-': 7,
 '.': 8,
 ':': 9,
 ';': 10,
 '?': 11,
 'a': 12,
 'b': 13,
 'c': 14,
 'd': 15,
 'e': 16,
 'f': 17,
 'g': 18,
 'h': 19,
 'i': 20,
 'j': 21,
 'k': 22,
 'l': 23,
 'm': 24,
 'n': 25,
 'o': 26,
 'p': 27,
 'q': 28,
 'r': 29,
 's': 30,
 't': 31,
 'u': 32,
 'v': 33,
 'w': 34,
 'x': 35,
 'y': 36,
 'z': 37,
 '<UNK>': 38}

In [6]:
def encode(text):
    out = []
    for t in text:
        out.append(rev_vocab.get(t, 38)) # 38 is the <UNK>
    return out


encode('hello1')

[19, 16, 23, 23, 26, 38]

In [7]:
def decode(nums):
    out = ''
    for i in nums:
        out += vocab.get(i, '<UNK>')
    return out


decode([34,  7, 31, 29, 2])

'w-tr!'

In [8]:
encoded_data = torch.tensor(encode(data), dtype=torch.long)
train_data = encoded_data[:int(0.9*len(encoded_data))]
test_data = encoded_data[int(0.9*len(encoded_data)):]

len(train_data), len(test_data)

(1003854, 111540)

In [9]:
def get_batch(split='train'):
    data = train_data if split == 'train' else test_data
    start = torch.randint(0,len(data)-context_length-1,(batch_size,))
    outx = torch.stack([data[i:i+context_length] for i in start]).to(device)
    if split == 'test': return outx
    outy = torch.stack([data[i+1:i+context_length+1] for i in start]).to(device)
    if split == 'train': return outx, outy

get_batch('train')[0].shape

torch.Size([16, 32])

# Model

In [42]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.Q = nn.Linear(embedding_size, head_size, bias=False)
        self.K = nn.Linear(embedding_size, head_size, bias=False)
        self.V = nn.Linear(embedding_size, head_size, bias=False)

    def forward(self, x):
        B, T, C = x.shape # BatchSize Time ContextSize
        k = self.K(x)
        q = self.Q(x)
        v = self.V(x)

        out = q@k.transpose(-2,-1)
        out = out*embedding_size**-0.5
        
        out = out*torch.tril(torch.ones(32,32)) # probably can be improved. No need for element wise mul here
        out = out.masked_fill(out==0, float('-inf'))# masking

        out = F.softmax(out, dim=-1)
        out = out@v
        return out

h = Head(100)
h(torch.rand(16, 32, 32))

tensor([[[-6.0916e-01,  1.5004e-02,  3.7631e-02,  ..., -4.4515e-01,
          -1.2958e-01, -6.5326e-02],
         [-5.0434e-01, -7.0856e-04, -1.3698e-01,  ..., -3.6357e-01,
          -1.0415e-01, -8.9459e-02],
         [-3.8976e-01, -7.3190e-02, -3.9952e-02,  ..., -3.5339e-01,
          -1.0293e-01, -6.0476e-02],
         ...,
         [-3.8749e-01, -8.6466e-02,  2.9432e-02,  ..., -3.2784e-01,
          -1.1670e-01,  8.9313e-02],
         [-3.8047e-01, -8.9410e-02,  3.0156e-02,  ..., -3.3121e-01,
          -1.2248e-01,  8.6099e-02],
         [-3.8272e-01, -9.7038e-02,  2.3258e-02,  ..., -3.2776e-01,
          -1.1160e-01,  8.2213e-02]],

        [[-3.3604e-01,  7.7048e-03,  1.2669e-01,  ..., -2.4513e-01,
          -1.5715e-02,  2.1925e-01],
         [-3.0163e-01, -7.1671e-02,  9.0318e-02,  ..., -2.7427e-01,
           1.2674e-02,  1.9670e-01],
         [-2.7150e-01, -2.2134e-02,  1.2078e-01,  ..., -2.4565e-01,
          -2.4494e-02,  1.6907e-01],
         ...,
         [-4.2085e-01, -4

In [40]:
torch.tril(torch.ones(5,5))

tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])

# Training

In [38]:
T =10
torch.tril(torch.ones(T, T))

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])

# Inference