<a href="https://colab.research.google.com/github/archyyu/GPT-from-MLP-to-RNN-to-Transformer/blob/main/GPT_by_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [143]:
import requests
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import math

In [144]:
# Data I/O

url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
response = requests.get(url)
data = response.text

chars = list(set(data))
data_size, vocab_size = len(data), len(chars)
print(f'data has {data_size} characters, {vocab_size} unique.')

char_to_ix = {ch: i for i, ch in enumerate(chars)}
ix_to_char = {i: ch for i, ch in enumerate(chars)}

data has 1115394 characters, 65 unique.


In [145]:
class Head(nn.Module):
  def __init__(self, input_size, sequence_length, head_size):
    super(Head, self).__init__()
    self.C = input_size
    self.L = sequence_length
    self.head_size = head_size
    self.q = nn.Linear(self.C, head_size, bias=False)
    self.k = nn.Linear(self.C, head_size, bias=False)
    self.v = nn.Linear(self.C, head_size, bias=False)

  def forward(self, x):
    q = self.q(x)
    k = self.k(x)
    v = self.v(x)

    wei = q @ k.transpose(-2, -1) * (self.head_size ** -0.5)
    tril = torch.tril(torch.ones(self.L, self.L))
    wei = wei.masked_fill(tril == 0, float('-inf'))
    wei = F.softmax(wei, dim=-1)

    out = wei @ v
    return out


In [146]:
class PositionalEncoding(nn.Module):
  def __init__(self, seq_len, embed=512):
    super(PositionalEncoding, self).__init__()
    self.encoding = torch.zeros(seq_len, embed)
    position = torch.arange(0, seq_len).unsqueeze(1).float()
    div_term = torch.exp(torch.arange(0, embed, 2).float() * -(math.log(10000.0) / embed))
    self.encoding[:, 0::2] = torch.sin(position * div_term)
    self.encoding[:, 1::2] = torch.cos(position * div_term)
    self.encoding = self.encoding.unsqueeze(0)

  def forward(self, x):
    return x + self.encoding[:, :x.size(1)].detach()

In [147]:
class MultiHeadAttention(nn.Module):
  def __init__(self, num_heads, input_size, sequence_length, head_size):
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads

    self.heads = nn.ModuleList([
        Head(input_size, sequence_length, head_size) for _ in range(num_heads)
    ])

    self.final_linear = nn.Linear(num_heads * head_size, input_size)
    self.relu = nn.ReLU()

  def forward(self, x):

    head_outputs = [head(x) for head in self.heads]
    concatenated_output = torch.cat(head_outputs, dim=-1)
    final_output = self.relu(self.final_linear(concatenated_output))

    return final_output

In [148]:
class BlockAttention(nn.Module):
  def __init__(self, num_heads, input_size, sequence_length, head_size):
    super(BlockAttention, self).__init__()
    self.multiheads = MultiHeadAttention(num_heads, input_size, sequence_length, head_size)
    self.norm = nn.LayerNorm(input_size)

  def forward(self, x):
    inter_result = x + self.multiheads(x)
    final_output = self.norm(x + inter_result)
    return final_output

In [153]:
class Decoder(nn.Module):
  def __init__(self, num_heads, vocab_size, embedding_size, sequence_length, head_size):
    super(Decoder, self).__init__()

    self.em = nn.Embedding(vocab_size, embedding_size)
    self.pe = PositionalEncoding(vocab_size, embedding_size)

    self.blocks = nn.ModuleList([BlockAttention(num_heads, embedding_size, sequence_length, head_size) for _ in range(4)])
    self.fw = nn.Linear(sequence_length * embedding_size, vocab_size, bias=False)

  def forward(self, x):
    x = self.em(x)
    x = x + self.pe(x)
    for block in self.blocks:
      x = block(x)
    B,T,C = x.shape
    x = x.view(B,1,T*C)
    return self.fw(x)

In [135]:
batch_size = 4
seq_len = 8
em_dim = 64
# Example input tensor x = torch.randn(B,T,C)
x = torch.randn(batch_size, seq_len, em_dim)

# Initialize the multi Head
# head = MultiHeadAttention(num_heads, input_size, sequence_length, head_size)

head = PositionalEncoding(seq_len, em_dim)

# Forward pass
print(x[0][0])
output = head(x)

print(output[0][0])

print((output - x)[0][0])

tensor([ 0.5623, -1.5404, -0.6845,  0.9643,  0.6831, -0.8222, -0.1745, -0.6842])
tensor([ 0.5623, -0.5404, -0.6845,  1.9643,  0.6831,  0.1778, -0.1745,  0.3158])
tensor([0., 1., 0., 1., 0., 1., 0., 1.])


In [154]:
# Hyperparameters
hidden_size = 100
embedding_dim = 20
seq_length = 8
learning_rate = 1e-1
batch_size = 20
num_heads = 4
head_size = 12

criterion = nn.CrossEntropyLoss()
    #Decoder(num_heads, input_size, sequence_length, head_size)
model = Decoder(num_heads, vocab_size, embedding_dim, seq_length, head_size)
optimizer = optim.Adagrad(model.parameters(), lr=learning_rate)

In [155]:
def generate_mini_batch():
  # Assuming batch_size is a variable representing the desired batch size
  # and data is your input sequence data

  # Initialize lists to store input sequences and corresponding targets for the minibatch
  batch_inputs = []
  batch_targets = []

  # Loop to generate the minibatch
  for _ in range(batch_size):
    # Randomly select a starting point for the sequence
    p = np.random.randint(0, len(data) - seq_length - 1)

    # Extract a sequence of characters and convert them to indices
    inputs = torch.tensor([char_to_ix[ch] for ch in data[p:p + seq_length]], dtype=torch.long).view(1, -1)

    # Extract the target character and convert it to an index
    target = torch.tensor([char_to_ix[data[p + seq_length]]], dtype=torch.long).view(1, -1)

    # Append the input sequence and target to the minibatch lists
    batch_inputs.append(inputs)
    batch_targets.append(target)

  # Combine the lists into tensors to form the minibatch
  minibatch_inputs = torch.cat(batch_inputs, dim=0)
  minibatch_targets = torch.cat(batch_targets, dim=0)
  return minibatch_inputs, minibatch_targets

In [161]:
# Training loop
stopi = []
lossi = []
num_iterations = 5
for iteration in range(num_iterations):

  for p in range(len(data) - seq_length):

    # inputs = torch.tensor([char_to_ix[ch] for ch in data[p:p + seq_length]], dtype=torch.long).view(1, -1)
    # targets = torch.tensor([char_to_ix[ch] for ch in data[p + seq_length]], dtype=torch.long).view(-1)

    inputs, targets = generate_mini_batch()
    optimizer.zero_grad()
    predict_char = model(inputs)

    loss = criterion(predict_char.view(-1, 65), targets.view(-1))

    loss.backward()

    for param in model.parameters():
      if param.grad is not None:
        param.grad.data.clamp_(-5, 5)

    optimizer.step()

    if p % 2000 == 0:
      print(f'Iteration {(iteration + 1) * p}, Loss: {loss.item()}')
      stopi.append((iteration + 1) * p)
      lossi.append(loss.item())

Iteration 0, Loss: 1.6071850061416626
Iteration 2000, Loss: 1.9674190282821655
Iteration 4000, Loss: 2.1245081424713135
Iteration 6000, Loss: 1.7736564874649048
Iteration 8000, Loss: 1.9873167276382446


KeyboardInterrupt: 

In [164]:
start = "First And"

line = "First Citizen"
for i in range(1000):
  lll = start[-seq_length:]
  ll = torch.tensor([char_to_ix[ch] for ch in lll], dtype=torch.long).view(1, -1)
  outputs = model(ll)
  p = nn.functional.softmax(outputs, dim=-1).detach().numpy().ravel()
  ix = np.random.choice(range(vocab_size), p=p)
  ix = torch.tensor(ix, dtype=torch.long).view(1, 1)
  nc = ix_to_char[ix[0][0].item()]
  start += nc
  if nc == '\n':
    print(line)
    line = ""
    continue
  line += nc



First Citizeniosce,--yes mingter.

TcARICKK:
SoTh swento's as int your lood in,
Helick, and, fevired, ceat onf and thers are abe in ilemy.

BrAThe:
Nour and ivanenastent colse.-

BION ELUTABVLHUMEN:
O'd som.

KOTIUSWES:
Rerl weray for in whal I have arome happo.

IUWIED:
Moale,
The forded hatrakin.

SUwow anop wathersiund
And in with upot thise andmil. deds,
Dich's am is bamse:
That if the maapes it me'lire,
If make rind.

LULEO:
Noalf ilens
EvE whil
sive whraceinse not of capaecine I Ceavan surstent of mo-burstonal mer in monther-y aver;
And have heave himes'd no her onown ame be.

LLOMeNios, steatiy:
Leds wither ws camere, tit hoo- will stou spensings
pravint!

Sreich, me elere
Is a dand
My arem thriur your be daud houpe, and, cout it thourd
If lowe you tale
Fhast I twouseriin bostot
Sriwhochoo:
Far othallkif I's ay thigencourstent,
Selle be his
May mrigess a mest un thimerongory ada his ithunere
shos; No hipcelats; I in goorcoe! omisto:
Thin gondarngers, ard for and to
