In [21]:
import torch
import torch.nn.functional as F
from torch import nn
import pandas as pd
import matplotlib.pyplot as plt # for making figures
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from pprint import pprint
import requests
import time
import os

In [22]:
##################################################################
block_size = 10# context length: how many characters do we take to predict the next one?
emb_dim = 15
#################################################################

In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [24]:
device

device(type='cuda')

In [25]:
#Download Shakespeare.txt file

url = "https://cs.stanford.edu/people/karpathy/char-rnn/shakespear.txt"
file_path = "shakespeare.txt"
if os.path.exists(file_path):
    print("File already exists.")
else:
    response = requests.get(url)
    if response.status_code == 200:
        with open(file_path, "wb") as file:
            file.write(response.content)
        print("File downloaded successfully.")
    else:
        print("Failed to download the file.")

File already exists.


In [26]:
with open("shakespeare.txt", "r") as file:
    content = file.read()

In [27]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(content)))
stoi = {s:i for i,s in enumerate(chars)}
# stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
pprint(itos)

{0: '\n',
 1: ' ',
 2: '!',
 3: "'",
 4: ',',
 5: '-',
 6: '.',
 7: ':',
 8: ';',
 9: '?',
 10: 'A',
 11: 'B',
 12: 'C',
 13: 'D',
 14: 'E',
 15: 'F',
 16: 'G',
 17: 'H',
 18: 'I',
 19: 'J',
 20: 'K',
 21: 'L',
 22: 'M',
 23: 'N',
 24: 'O',
 25: 'P',
 26: 'Q',
 27: 'R',
 28: 'S',
 29: 'T',
 30: 'U',
 31: 'V',
 32: 'W',
 33: 'X',
 34: 'Y',
 35: 'Z',
 36: 'a',
 37: 'b',
 38: 'c',
 39: 'd',
 40: 'e',
 41: 'f',
 42: 'g',
 43: 'h',
 44: 'i',
 45: 'j',
 46: 'k',
 47: 'l',
 48: 'm',
 49: 'n',
 50: 'o',
 51: 'p',
 52: 'q',
 53: 'r',
 54: 's',
 55: 't',
 56: 'u',
 57: 'v',
 58: 'w',
 59: 'x',
 60: 'y',
 61: 'z'}


In [28]:
#  block_size = 10# context length: how many characters do we take to predict the next one
X, Y = [], []


#print(w)
context = [1] * block_size
for ch in content:
  ix = stoi[ch]
  X.append(context)
  Y.append(ix)
  print(''.join(itos[i] for i in context), '--->', itos[ix])
  context = context[1:] + [ix] # crop and append

# Move data to GPU

X = torch.tensor(X).to(device)
Y = torch.tensor(Y).to(device)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
ess be a b ---> o
ss be a bo ---> y
s be a boy ---> ,
 be a boy, --->  
be a boy,  ---> s
e a boy, s ---> h
 a boy, sh ---> e
a boy, she --->  
 boy, she  ---> i
boy, she i ---> s
oy, she is --->  
y, she is  ---> n
, she is n ---> o
 she is no ---> t
she is not --->  
he is not  ---> h
e is not h ---> a
 is not ha ---> l
is not hal ---> f
s not half --->  
 not half  ---> a
not half a ---> b
ot half ab ---> a
t half aba ---> t
 half abat ---> e
half abate ---> d
alf abated ---> .
lf abated. ---> 

f abated.
 ---> 

 abated.

 ---> L
abated.

L ---> U
bated.

LU ---> C
ated.

LUC ---> I
ted.

LUCI ---> U
ed.

LUCIU ---> S
d.

LUCIUS ---> :
.

LUCIUS: ---> 



LUCIUS:
 ---> A

LUCIUS:
A ---> h
LUCIUS:
Ah ---> ,
UCIUS:
Ah, --->  
CIUS:
Ah,  ---> b
IUS:
Ah, b ---> y
US:
Ah, by --->  
S:
Ah, by  ---> m
:
Ah, by m ---> y

Ah, by my --->  
Ah, by my  ---> s
h, by my s ---> o
, by my so ---> u
 by my sou ---> l
by my soul ---> ,

In [29]:
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([99993, 10]), torch.int64, torch.Size([99993]), torch.int64)

In [30]:
# emb_dim = 15
emb = torch.nn.Embedding(len(stoi), emb_dim)
emb.weight
emb.weight.shape

torch.Size([62, 15])

In [31]:
# Function to visualize the embedding in 2d space

# def plot_emb(emb, itos, ax=None):
#     if ax is None:
#         fig, ax = plt.subplots()
#     for i in range(len(itos)):
#         x, y = emb.weight[i].detach().cpu().numpy()
#         ax.scatter(x, y, color='k')
#         ax.text(x + 0.05, y + 0.05, itos[i])
#     return ax

# plot_emb(emb, itos)

In [32]:
class NextChar(nn.Module):
  def __init__(self, block_size, vocab_size, emb_dim, hidden_size):
    super().__init__()
    self.emb = nn.Embedding(vocab_size, emb_dim)
    self.lin1 = nn.Linear(block_size * emb_dim, hidden_size)
    self.lin2 = nn.Linear(hidden_size, hidden_size)
    self.lin3 = nn.Linear(hidden_size, hidden_size)
    self.lin4 = nn.Linear(hidden_size, vocab_size)


  def forward(self, x):
    x = self.emb(x)
    x = x.view(x.shape[0], -1)
    x = torch.sin(self.lin1(x))
    x = torch.sin(self.lin2(x))
    x = torch.sin(self.lin3(x))
    x = self.lin4(x)
    return x

In [33]:
# Generate Text from untrained model


model = NextChar(block_size, len(stoi), emb_dim, 200).to(device)
# model = torch.compile(model)

g = torch.Generator()
g.manual_seed(4000002)
def generate_text(model, itos, stoi, block_size,iptxt="",max_len=10):
    context=[]
    if len(iptxt)<block_size:
      context=[1]*(block_size-len(iptxt))
      for char in iptxt:
        idx=stoi[char]
        context.append(idx)
    else:
      for i in range(len(iptxt)-block_size,len(iptxt)):
          idx=stoi[iptxt[i]]
          context.append(idx)
    txt = ''
    for i in range(max_len):
        x = torch.tensor(context).view(1, -1).to(device)
        y_pred = model(x)
        ix = torch.distributions.categorical.Categorical(logits=y_pred).sample().item()
        ch = itos[ix]
        txt += ch
        context = context[1:] + [ix]
    return txt

print(generate_text(model, itos, stoi, block_size))

Iu?V'pQuGV


In [None]:
# Check if the model weights file exists
weights_file = f"model_weights_b{block_size}_em{emb_dim}.pth"
if os.path.exists(weights_file):
    # Load the saved model weights
    model.load_state_dict(torch.load(weights_file))
    print("Model weights loaded successfully.")
else:
    #code for training
    loss_fn = nn.CrossEntropyLoss()
    opt = torch.optim.AdamW(model.parameters(), lr=0.001)
    import time
    # Mini-batch training
    batch_size = 4096
    print_every = 100
    elapsed_time = []
    for epoch in range(10000):
        start_time = time.time()
        for i in range(0, X.shape[0], batch_size):
            x = X[i:i+batch_size]
            y = Y[i:i+batch_size]
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            loss.backward()
            opt.step()
            opt.zero_grad()
        end_time = time.time()
        elapsed_time.append(end_time - start_time)
        if epoch % print_every == 0:
            print(epoch, loss.item())

    # Save the model weights
    torch.save(model.state_dict(), f"model_weights_b{block_size}_em{emb_dim}.pth")
    print("Model weights saved successfully.")

0 3.012286424636841
100 0.12492980062961578
200 0.06416352838277817
300 0.06027959659695625
400 0.05758288502693176
500 0.0560227669775486
600 0.053345512598752975
700 0.05385708063840866
800 0.05372879281640053
900 0.04917652904987335
1000 0.04977351427078247
1100 0.04711845517158508
1200 0.044986918568611145
1300 0.04736074432730675
1400 0.04550999402999878
1500 0.04679694399237633
1600 0.04302137717604637
1700 0.046337537467479706
1800 0.04750834032893181
1900 0.05211153253912926
2000 0.04526541754603386
2100 0.0486544594168663
2200 0.04722442477941513
2300 0.04781360924243927
2400 0.045048996806144714
2500 0.04742863029241562
2600 0.04476043954491615
2700 0.04273030906915665
2800 0.04788105562329292
2900 0.046252407133579254
3000 0.0460924431681633
3100 0.0458662174642086
3200 0.04774816706776619
3300 0.0463259331882
3400 0.0432104654610157
3500 0.046562276780605316


In [None]:
# Load the saved model weights
model.load_state_dict(torch.load(f"model_weights_b{block_size}_em{emb_dim}.pth"))

In [None]:
# Visualize the embedding

# plot_emb(model.emb, itos)

In [None]:
#generate text
inp='BRUTUS'
print(generate_text(model, itos, stoi,block_size,inp,1000+len(inp)))