In [39]:
import torch
import torch.nn.functional as F
from torch import nn
import pandas as pd
import matplotlib.pyplot as plt # for making figures
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from pprint import pprint
import requests

In [40]:
# To create streamlit application
!pip install streamlit
import streamlit as st



In [41]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [42]:
device

device(type='cuda')

In [43]:
#Download Shakespeare.txt file

url = "https://cs.stanford.edu/people/karpathy/char-rnn/shakespear.txt"
file_path = "shakespeare.txt"

response = requests.get(url)
if response.status_code == 200:
    with open(file_path, "wb") as file:
        file.write(response.content)
    print("File downloaded successfully.")
else:
    print("Failed to download the file.")

File downloaded successfully.


In [44]:
with open("shakespeare.txt", "r") as file:
    content = file.read()

In [45]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(content)))
stoi = {s:i for i,s in enumerate(chars)}
# stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
pprint(itos)

{0: '\n',
 1: ' ',
 2: '!',
 3: "'",
 4: ',',
 5: '-',
 6: '.',
 7: ':',
 8: ';',
 9: '?',
 10: 'A',
 11: 'B',
 12: 'C',
 13: 'D',
 14: 'E',
 15: 'F',
 16: 'G',
 17: 'H',
 18: 'I',
 19: 'J',
 20: 'K',
 21: 'L',
 22: 'M',
 23: 'N',
 24: 'O',
 25: 'P',
 26: 'Q',
 27: 'R',
 28: 'S',
 29: 'T',
 30: 'U',
 31: 'V',
 32: 'W',
 33: 'X',
 34: 'Y',
 35: 'Z',
 36: 'a',
 37: 'b',
 38: 'c',
 39: 'd',
 40: 'e',
 41: 'f',
 42: 'g',
 43: 'h',
 44: 'i',
 45: 'j',
 46: 'k',
 47: 'l',
 48: 'm',
 49: 'n',
 50: 'o',
 51: 'p',
 52: 'q',
 53: 'r',
 54: 's',
 55: 't',
 56: 'u',
 57: 'v',
 58: 'w',
 59: 'x',
 60: 'y',
 61: 'z'}


In [46]:
block_size = 20 # context length: how many characters do we take to predict the next one?
X, Y = [], []


#print(w)
context = [1] * block_size
for ch in content:
  ix = stoi[ch]
  X.append(context)
  Y.append(ix)
  # print(''.join(itos[i] for i in context), '--->', itos[ix])
  context = context[1:] + [ix] # crop and append

# Move data to GPU

X = torch.tensor(X).to(device)
Y = torch.tensor(Y).to(device)

In [47]:
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([99993, 20]), torch.int64, torch.Size([99993]), torch.int64)

In [48]:
emb_dim = 8
emb = torch.nn.Embedding(len(stoi), emb_dim)
emb.weight
emb.weight.shape

torch.Size([62, 8])

In [49]:
# Function to visualize the embedding in 2d space

# def plot_emb(emb, itos, ax=None):
#     if ax is None:
#         fig, ax = plt.subplots()
#     for i in range(len(itos)):
#         x, y = emb.weight[i].detach().cpu().numpy()
#         ax.scatter(x, y, color='k')
#         ax.text(x + 0.05, y + 0.05, itos[i])
#     return ax

# plot_emb(emb, itos)

In [50]:
class NextChar(nn.Module):
  def __init__(self, block_size, vocab_size, emb_dim, hidden_size):
    super().__init__()
    self.emb = nn.Embedding(vocab_size, emb_dim)
    self.lin1 = nn.Linear(block_size * emb_dim, hidden_size)
    self.lin2 = nn.Linear(hidden_size, hidden_size)
    self.lin3 = nn.Linear(hidden_size, hidden_size)
    self.lin4 = nn.Linear(hidden_size, vocab_size)


  def forward(self, x):
    x = self.emb(x)
    x = x.view(x.shape[0], -1)
    x = torch.sin(self.lin1(x))
    x = torch.sin(self.lin2(x))
    x = torch.sin(self.lin3(x))
    x = self.lin4(x)
    return x

In [51]:
# Generate Text from untrained model


model = NextChar(block_size, len(stoi), emb_dim, 10).to(device)
# model = torch.compile(model)

g = torch.Generator()
g.manual_seed(4000002)
def generate_text(model, itos, stoi, block_size,iptxt="",max_len=10):
    context=[]
    if len(iptxt)<block_size:
      context=[1]*(block_size-len(iptxt))
      for char in iptxt:
        idx=stoi[char]
        context.append(idx)
    else:
      for i in range(len(iptxt)-block_size,len(iptxt)):
          idx=stoi[iptxt[i]]
          context.append(idx)
    txt = ''
    for i in range(max_len):
        x = torch.tensor(context).view(1, -1).to(device)
        y_pred = model(x)
        ix = torch.distributions.categorical.Categorical(logits=y_pred).sample().item()
        ch = itos[ix]
        txt += ch
        context = context[1:] + [ix]
    return txt

print(generate_text(model, itos, stoi, block_size))

s P
.:Qiae


In [52]:
for param_name, param in model.named_parameters():
    print(param_name, param.shape)

emb.weight torch.Size([62, 8])
lin1.weight torch.Size([10, 160])
lin1.bias torch.Size([10])
lin2.weight torch.Size([10, 10])
lin2.bias torch.Size([10])
lin3.weight torch.Size([10, 10])
lin3.bias torch.Size([10])
lin4.weight torch.Size([62, 10])
lin4.bias torch.Size([62])


In [53]:
#code for training
loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.AdamW(model.parameters(), lr=0.01)
import time
# Mini-batch training
batch_size = 4096
print_every = 100
elapsed_time = []
for epoch in range(1000):
    start_time = time.time()
    for i in range(0, X.shape[0], batch_size):
        x = X[i:i+batch_size]
        y = Y[i:i+batch_size]
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        opt.step()
        opt.zero_grad()
    end_time = time.time()
    elapsed_time.append(end_time - start_time)
    if epoch % print_every == 0:
        print(epoch, loss.item())

# Save the model weights
torch.save(model.state_dict(), "model_weights.pth")
print("Model weights saved successfully.")

0 3.2594690322875977
100 2.047478675842285
200 2.0254297256469727
300 2.013810157775879
400 2.016414165496826
500 2.022385358810425
600 2.0284018516540527
700 2.0320067405700684
800 2.0277352333068848
900 2.026193618774414
Model weights saved successfully.


In [54]:
# Load the saved model weights
model.load_state_dict(torch.load("model_weights.pth"))

<All keys matched successfully>

In [55]:
# Visualize the embedding

# plot_emb(model.emb, itos)

In [56]:
#generate text
inp='uu'
print(generate_text(model, itos, stoi,block_size,inp,100+len(inp)))

dtet, ingfp?

Gast stine,
Moneepronced pare firchied loxt, bult of ple
Lingund sissio srosonget o thy 


In [57]:
# Streamlit UI
st.title("Next Character Predictor")

input_text = st.text_input("Enter your input text:")
k = st.slider("Number of characters to predict:", min_value=1, max_value=20, value=5)

if st.button("Predict"):
    if input_text:
        predicted_text = generate_text(input_text, model, itos, stoi, block_size, k)
        st.write("Predicted Text:", predicted_text)
    else:
        st.warning("Please enter some input text.")