In [363]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from tqdm import tqdm
import re

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

block_size = 8
batch_size = 128
max_iters = 1000
learning_rate = 5e-4
eval_iters = 25
dropout = 0.1
n_embd = 256
n_layer = 4
n_head = 4

SAVE_PATH = "models/model_shakespear_words.pth"

cpu


In [364]:
text_list = []

BOOK_PATH = 'books/shakespear/'
with open(BOOK_PATH + 'julius_caesar.txt', 'r', encoding='utf-8') as f:
          text_list.append(f.read())
with open(BOOK_PATH + 'merchant_of_venice.txt', 'r', encoding='utf-8') as f:
          text_list.append(f.read())
with open(BOOK_PATH + 'romeo_and_juliet.txt', 'r', encoding='utf-8') as f:
          text_list.append(f.read())
with open(BOOK_PATH + 'macbeth.txt', 'r', encoding='utf-8') as f:
          text_list.append(f.read())
with open(BOOK_PATH + 'the_moor_of_venice.txt', 'r', encoding='utf-8') as f:
          text_list.append(f.read())

text = ""
for text_item in text_list:
    text += text_item

text[-1000:]

'ir medicinal gum. Set you down this.\nAnd say besides, that in Aleppo once,\nWhere a malignant and a turban’d Turk\nBeat a Venetian and traduc’d the state,\nI took by the throat the circumcised dog,\nAnd smote him, thus.\n\n[_Stabs himself._]\n\nLODOVICO.\nO bloody period!\n\nGRATIANO.\nAll that’s spoke is marr’d.\n\nOTHELLO.\nI kiss’d thee ere I kill’d thee. No way but this,\nKilling myself, to die upon a kiss.\n\n[_Falling upon Desdemona._]\n\nCASSIO.\nThis did I fear, but thought he had no weapon,\nFor he was great of heart.\n\nLODOVICO.\n[_To Iago._] O Spartan dog,\nMore fell than anguish, hunger, or the sea,\nLook on the tragic loading of this bed.\nThis is thy work. The object poisons sight,\nLet it be hid. Gratiano, keep the house,\nAnd seize upon the fortunes of the Moor,\nFor they succeed on you. To you, lord governor,\nRemains the censure of this hellish villain.\nThe time, the place, the torture, O, enforce it!\nMyself will straight aboard, and to the state\nThis heavy act 

In [365]:
def remove_text_in_brackets(text):
    pattern = r'\[.*?\]'
    cleaned_text = re.sub(pattern, '', text)
    return cleaned_text

def replace_commas_and_periods(text):
    # Replace ', ' with ' , '
    text = re.sub(r',\s', ' , ', text)
    # Replace '. ' with ' . '
    text = re.sub(r'\.\s', ' . ', text)
    return text

def lowercase_text(text):
    return text.lower()

In [366]:
# text = remove_text_in_brackets(text)
# text[-1000:]

In [367]:
text = replace_commas_and_periods(text)
text[-1000:]

' say besides , that in Aleppo once , Where a malignant and a turban’d Turk\nBeat a Venetian and traduc’d the state , I took by the throat the circumcised dog , And smote him , thus . \n[_Stabs himself._]\n\nLODOVICO . O bloody period!\n\nGRATIANO . All that’s spoke is marr’d . \nOTHELLO . I kiss’d thee ere I kill’d thee . No way but this , Killing myself , to die upon a kiss . \n[_Falling upon Desdemona._]\n\nCASSIO . This did I fear , but thought he had no weapon , For he was great of heart . \nLODOVICO . [_To Iago._] O Spartan dog , More fell than anguish , hunger , or the sea , Look on the tragic loading of this bed . This is thy work . The object poisons sight , Let it be hid . Gratiano , keep the house , And seize upon the fortunes of the Moor , For they succeed on you . To you , lord governor , Remains the censure of this hellish villain . The time , the place , the torture , O , enforce it!\nMyself will straight aboard , and to the state\nThis heavy act with heavy heart relate 

In [368]:
strings = text.split()
unique = set(strings)
vocab_size = len(unique)
print(vocab_size)
print(list(strings)[:10])

13674
['\ufeffACT', 'I', 'SCENE', 'I', '.', 'Rome', '.', 'A', 'street', '.']


In [369]:
string_to_int = { ch: i for i, ch in enumerate(unique) }
int_to_string = { i: ch for i, ch in enumerate(unique) }
encode = lambda s: [string_to_int[c] for c in s.split()]
decode = lambda l: ''.join([int_to_string[i]+" " for i in l])

encoded_hello = encode('thou a cobbler .')
decoded_hello = decode(encoded_hello)
print(encoded_hello)
print(decoded_hello)

[8448, 9374, 87, 2615]
thou a cobbler . 


In [370]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data[:100])

tensor([ 8372,  2294, 10073,  2294,  2615,  2249,  2615,  9386, 10013,  2615,
        13599,   327, 12387, 13202, 11798,  9374,  1866,  6510, 11520,  2615,
         6747,  2615,  8345, 10494, 12387, 10068, 11235, 13320, 12387, 11691,
        10068, 10494,  2615, 12373,   313,  9374, 12878,  5977, 12387,   629,
        10068,  8896, 12387,  9914, 13168, 12387, 10068, 13081,  8896,  7043,
        10865,  9374,  9444, 10839,  5097,  3448,  6663,  2052,  3898,  4059,
          884, 12387,  4199, 12815,  4780,  2998,  5244,  2615, 13273, 12387,
         2715, 12387,  9374,  3327,  2615, 12247,  2615,  7587,   258,  8065,
         7521,   265, 11798,  8065,  5906,  5977,  9681,  8448,  3342,  8065,
         8839,  6220, 11335,  1463, 12387,  2715, 12387,  4199, 12815,  1647])


In [371]:
n = int(len(data)*0.8)
train_data = data[:n]
val_data = data[n:]

def get_batch(split):
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size, ))
    # print(ix)
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

x, y = get_batch('train')
print('inputs: ')
print(x)
print('targets: ')
print(y)

inputs: 
tensor([[12387,   467,  6288,  ..., 11626,  1938,  8896],
        [ 9125,  2511,  2615,  ..., 11912, 10091,  3448],
        [11798, 12387,  4315,  ..., 12387,  5527,  5214],
        ...,
        [ 3952,  6078,  4677,  ..., 12387, 10800, 12387],
        [ 7836, 12128, 13248,  ...,  5111, 10945, 12387],
        [ 2294, 12387, 11078,  ...,  3100,  2357,  2294]])
targets: 
tensor([[  467,  6288,  3448,  ...,  1938,  8896,  3342],
        [ 2511,  2615,  7296,  ..., 10091,  3448,  4905],
        [12387,  4315,  7401,  ...,  5527,  5214,  1730],
        ...,
        [ 6078,  4677,  3448,  ..., 10800, 12387,   239],
        [12128, 13248, 12387,  ..., 10945, 12387, 11078],
        [12387, 11078,  2474,  ...,  2357,  2294, 13097]])


In [372]:

x = train_data[:block_size]
y = train_data[1:block_size+1]

for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    # print(f"When input is {context}, target is {target}")

In [373]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [374]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

In [375]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1) # (B, T, F) -> (B, T, [h1, h1, h1, h1, h2, h2, h2, h2, h3, h3, h3, h3])
        out = self.dropout(self.proj(out))
        return out

In [376]:
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

In [377]:
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        y = self.sa(x)
        x = self.ln1(x + y)
        y = self.ffwd(x)
        x = self.ln2(x + y)
        return x

In [378]:
class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)
        
        
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, index, targets=None):
        B, T = index.shape
        
        
        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(index) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, index, max_new_tokens):
        # index is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            index_cond = index[:, -block_size:]
            # get the predictions
            logits, loss = self.forward(index_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            index_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            index = torch.cat((index, index_next), dim=1) # (B, T+1)
        return index

model = GPTLanguageModel(vocab_size)
model = model.to(device)
print(model)

context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(context)
generated_chars = decode(model.generate(context, max_new_tokens=500)[0].tolist())
print(generated_chars)

GPTLanguageModel(
  (token_embedding_table): Embedding(13674, 256)
  (position_embedding_table): Embedding(8, 256)
  (blocks): Sequential(
    (0): Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-3): 4 x Head(
            (key): Linear(in_features=256, out_features=64, bias=False)
            (query): Linear(in_features=256, out_features=64, bias=False)
            (value): Linear(in_features=256, out_features=64, bias=False)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (proj): Linear(in_features=256, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ffwd): FeedFoward(
        (net): Sequential(
          (0): Linear(in_features=256, out_features=1024, bias=True)
          (1): ReLU()
          (2): Linear(in_features=1024, out_features=256, bias=True)
          (3): Dropout(p=0.1, inplace=False)
        )
      )
      (ln1): LayerNorm((256,), eps=1e-05, elementwise_affine

In [379]:
# Create a PyTorch optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for iter in tqdm(range(max_iters)):

    if iter % eval_iters == 0:
        losses = estimate_loss()
        print(f"step: {iter} | train loss: {losses['train']:.3f} | val loss: {losses['val']:.3f}")

    # Sample a batch of data
    xb, yb = get_batch('train')

    # Evaluate the loss
    logits, loss = model.forward(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

  0%|                                                                                                             | 0/1000 [00:00<?, ?it/s]

step: 0 | train loss: 9.574 | val loss: 9.575


  2%|██▌                                                                                                 | 25/1000 [00:13<06:29,  2.51it/s]

step: 25 | train loss: 6.810 | val loss: 6.880


  5%|█████                                                                                               | 50/1000 [00:27<06:14,  2.54it/s]

step: 50 | train loss: 6.663 | val loss: 6.778


  8%|███████▌                                                                                            | 75/1000 [00:42<06:16,  2.46it/s]

step: 75 | train loss: 6.607 | val loss: 6.771


 10%|█████████▉                                                                                         | 100/1000 [00:55<06:02,  2.48it/s]

step: 100 | train loss: 6.178 | val loss: 6.462


 12%|████████████▍                                                                                      | 125/1000 [01:09<05:37,  2.59it/s]

step: 125 | train loss: 5.930 | val loss: 6.254


 15%|██████████████▊                                                                                    | 150/1000 [01:22<05:33,  2.55it/s]

step: 150 | train loss: 5.697 | val loss: 6.132


 18%|█████████████████▎                                                                                 | 175/1000 [01:36<05:38,  2.44it/s]

step: 175 | train loss: 5.503 | val loss: 6.063


 20%|███████████████████▊                                                                               | 200/1000 [01:50<05:30,  2.42it/s]

step: 200 | train loss: 5.380 | val loss: 6.025


 22%|██████████████████████▎                                                                            | 225/1000 [02:04<05:07,  2.52it/s]

step: 225 | train loss: 5.263 | val loss: 5.990


 25%|████████████████████████▊                                                                          | 250/1000 [02:18<04:45,  2.63it/s]

step: 250 | train loss: 5.130 | val loss: 5.981


 28%|███████████████████████████▏                                                                       | 275/1000 [02:31<04:38,  2.60it/s]

step: 275 | train loss: 5.033 | val loss: 6.024


 30%|█████████████████████████████▋                                                                     | 300/1000 [02:46<04:37,  2.52it/s]

step: 300 | train loss: 4.912 | val loss: 5.954


 32%|████████████████████████████████▏                                                                  | 325/1000 [02:59<04:16,  2.63it/s]

step: 325 | train loss: 4.821 | val loss: 5.962


 35%|██████████████████████████████████▋                                                                | 350/1000 [03:13<04:12,  2.57it/s]

step: 350 | train loss: 4.707 | val loss: 5.979


 38%|█████████████████████████████████████▏                                                             | 375/1000 [03:26<04:07,  2.53it/s]

step: 375 | train loss: 4.636 | val loss: 5.953


 40%|███████████████████████████████████████▌                                                           | 400/1000 [03:43<05:41,  1.76it/s]

step: 400 | train loss: 4.542 | val loss: 5.980


 42%|██████████████████████████████████████████                                                         | 425/1000 [03:58<03:44,  2.56it/s]

step: 425 | train loss: 4.442 | val loss: 5.965


 45%|████████████████████████████████████████████▌                                                      | 450/1000 [04:11<03:32,  2.59it/s]

step: 450 | train loss: 4.385 | val loss: 6.010


 48%|███████████████████████████████████████████████                                                    | 475/1000 [04:24<03:20,  2.62it/s]

step: 475 | train loss: 4.284 | val loss: 5.998


 50%|█████████████████████████████████████████████████▌                                                 | 500/1000 [04:38<03:28,  2.40it/s]

step: 500 | train loss: 4.230 | val loss: 6.060


 52%|███████████████████████████████████████████████████▉                                               | 525/1000 [04:53<03:18,  2.39it/s]

step: 525 | train loss: 4.135 | val loss: 6.106


 55%|██████████████████████████████████████████████████████▍                                            | 550/1000 [05:07<03:05,  2.43it/s]

step: 550 | train loss: 4.055 | val loss: 6.086


 57%|████████████████████████████████████████████████████████▉                                          | 575/1000 [05:23<03:02,  2.33it/s]

step: 575 | train loss: 3.979 | val loss: 6.139


 60%|███████████████████████████████████████████████████████████▍                                       | 600/1000 [05:38<02:50,  2.35it/s]

step: 600 | train loss: 3.924 | val loss: 6.142


 62%|█████████████████████████████████████████████████████████████▉                                     | 625/1000 [05:52<02:37,  2.38it/s]

step: 625 | train loss: 3.836 | val loss: 6.166


 65%|████████████████████████████████████████████████████████████████▎                                  | 650/1000 [06:07<02:25,  2.41it/s]

step: 650 | train loss: 3.786 | val loss: 6.244


 68%|██████████████████████████████████████████████████████████████████▊                                | 675/1000 [06:21<02:15,  2.41it/s]

step: 675 | train loss: 3.701 | val loss: 6.234


 70%|█████████████████████████████████████████████████████████████████████▎                             | 700/1000 [06:37<02:30,  2.00it/s]

step: 700 | train loss: 3.653 | val loss: 6.242


 72%|███████████████████████████████████████████████████████████████████████▊                           | 725/1000 [06:52<01:53,  2.42it/s]

step: 725 | train loss: 3.571 | val loss: 6.237


 75%|██████████████████████████████████████████████████████████████████████████▎                        | 750/1000 [07:07<01:50,  2.26it/s]

step: 750 | train loss: 3.513 | val loss: 6.331


 78%|████████████████████████████████████████████████████████████████████████████▋                      | 775/1000 [07:22<01:33,  2.41it/s]

step: 775 | train loss: 3.450 | val loss: 6.327


 80%|███████████████████████████████████████████████████████████████████████████████▏                   | 800/1000 [07:36<01:23,  2.40it/s]

step: 800 | train loss: 3.378 | val loss: 6.347


 82%|█████████████████████████████████████████████████████████████████████████████████▋                 | 825/1000 [07:51<01:11,  2.45it/s]

step: 825 | train loss: 3.330 | val loss: 6.405


 85%|████████████████████████████████████████████████████████████████████████████████████▏              | 850/1000 [08:05<01:02,  2.39it/s]

step: 850 | train loss: 3.267 | val loss: 6.482


 88%|██████████████████████████████████████████████████████████████████████████████████████▋            | 875/1000 [08:20<00:53,  2.32it/s]

step: 875 | train loss: 3.178 | val loss: 6.436


 90%|█████████████████████████████████████████████████████████████████████████████████████████          | 900/1000 [08:35<00:42,  2.35it/s]

step: 900 | train loss: 3.183 | val loss: 6.527


 92%|███████████████████████████████████████████████████████████████████████████████████████████▌       | 925/1000 [08:51<00:31,  2.39it/s]

step: 925 | train loss: 3.099 | val loss: 6.516


 95%|██████████████████████████████████████████████████████████████████████████████████████████████     | 950/1000 [09:05<00:21,  2.38it/s]

step: 950 | train loss: 3.013 | val loss: 6.521


 98%|████████████████████████████████████████████████████████████████████████████████████████████████▌  | 975/1000 [09:20<00:10,  2.37it/s]

step: 975 | train loss: 2.977 | val loss: 6.584


100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [09:35<00:00,  1.74it/s]

3.109978437423706





In [380]:
torch.save(model.state_dict(), SAVE_PATH)

In [381]:
model_load = GPTLanguageModel(vocab_size)
model_load.load_state_dict(torch.load(SAVE_PATH))

<All keys matched successfully>

In [382]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated_text = decode(model_load.generate(context, max_new_tokens=500)[0].tolist())
formatted_text = ''
generated_text_split = generated_text.split()


for i, word in enumerate(generated_text_split):
    if word.isupper() and len(word) > 2:
        formatted_text += '\n\n' + word + '\n'
        if generated_text_split[i+1] == '.':
            generated_text_split[i+1] = ''
    else:
        formatted_text += word + ' '


formatted_text = re.sub(r' \.', '.', formatted_text)
formatted_text = re.sub(r' \,', ',', formatted_text)

with open('ai_gen_shakespear.txt', 'w') as f:
    f.write(formatted_text)
print(formatted_text)

Betroth’d would have cheer; Since you are welcome, And jest shall pay the book of Caesar shall stand Wherefore to these gliding ghosts, Belmont up our entrance: But when I holds this tyrant condition to the Prince’s doom, For shame, Lepidus but agent to this sweet flesh? 

SHYLOCK
 I’ll prove more to send to die with him. 

ROMEO
 So honours on this place? 

ROMEO
 Woe, at the woes, that that do they use. She all No boasting like the brain are vanished. mighty parted I have cried, “Help me, thievish ways. 

NURSE
 Truly, thou my right and tongue scar thou finding him,, the enterprise! thou, We must not down to deed? 

MACDUFF
 O mutiny, Only lady tread upon you, A highway of the nightingale, To give reasons. 

LORENZO
 And sail, they, but as yours: You sicken. Who’s there, to day in one, Which like a false thanes and others. 

BENVOLIO
 Ay, and then we may question? You smile know you, have stomachs. First, will you go? Hear me within. Merciful powers, conceive? 

ROMEO
 At weak childi