In [2]:
import torch
import torch.nn.functional as F
from torch import nn
import pandas as pd
import re
import matplotlib.pyplot as plt # for making figures
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from pprint import pprint

In [3]:
torch.__version__

'2.4.0+cpu'

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
device

device(type='cpu')

In [6]:
with open ("ptb.train.txt",'r') as data:
    text = data.read()


In [7]:
text[:20]

' aer banknote berlit'

In [8]:
def preprocess_text(text):
    text = re.sub('\n', '.', text)
    text = re.sub('[^a-zA-Z0-9 /.]', '', text)

    text = text.lower()
    sentences = text.split('.')
    word_sequences = [sentence.strip().split() for sentence in sentences if sentence.strip()]
    return word_sequences

In [9]:
p_text = preprocess_text(str(text))
p_text[:5]

[['aer',
  'banknote',
  'berlitz',
  'calloway',
  'centrust',
  'cluett',
  'fromstein',
  'gitano',
  'guterman',
  'hydroquebec',
  'ipo',
  'kia',
  'memotec',
  'mlx',
  'nahb',
  'punts',
  'rake',
  'regatta',
  'rubens',
  'sim',
  'snackfood',
  'ssangyong',
  'swapo',
  'wachter'],
 ['pierre',
  'unk',
  'n',
  'years',
  'old',
  'will',
  'join',
  'the',
  'board',
  'as',
  'a',
  'nonexecutive',
  'director',
  'nov'],
 ['n'],
 ['mr'],
 ['unk', 'is', 'chairman', 'of', 'unk', 'n']]

In [10]:
len(p_text)

58685

In [11]:
# build the vocabulary of characters and mappings to/from integers
comb = []
for sentence in p_text:
    comb += sentence
chars = sorted(list(set(comb)))
print(chars[:10])

['1/2year', '100share', '10year', '12month', '12year', '13th', '13week', '14yearold', '190', '190point']


In [12]:
len(chars)

9906

In [13]:
from collections import Counter

# Build a vocabulary from all words
all_words = [word for sentence in p_text for word in sentence]
word_counts = Counter(all_words)
vocab = sorted(word_counts)  # List of unique words in alphabetical order
vocab_size = len(vocab)

# Word-to-index and index-to-word mappings
word_to_index = {word: idx for idx, word in enumerate(vocab)}
word_to_index["."] = 0
index_to_word = {idx: word for word, idx in word_to_index.items()}


In [14]:
# pprint(index_to_word)
for i in range(10):
    print(i, index_to_word[i])

0 .
1 100share
2 10year
3 12month
4 12year
5 13th
6 13week
7 14yearold
8 190
9 190point


In [15]:
def create_word_pairs(sequences, context_length):
    inputs = []
    outputs = []

    for sentence in sequences:
        context = ["."] * context_length
        for word in sentence + ["."]:
            inputs.append([word_to_index[i] for i in context])
            outputs.append(word_to_index[word])
            print(f"{' '.join(context)} ----> {word}")
            context.pop(0)
            context.append(word)
        print()

#     X = torch.tensor(inputs).to(device)
#     Y = torch.tensor(outputs).to(device)
    return inputs,outputs
# inputs, outputs

In [16]:
context_length = 4
inputs, outputs = create_word_pairs(p_text[:10], context_length)

. . . . ----> aer
. . . aer ----> banknote
. . aer banknote ----> berlitz
. aer banknote berlitz ----> calloway
aer banknote berlitz calloway ----> centrust
banknote berlitz calloway centrust ----> cluett
berlitz calloway centrust cluett ----> fromstein
calloway centrust cluett fromstein ----> gitano
centrust cluett fromstein gitano ----> guterman
cluett fromstein gitano guterman ----> hydroquebec
fromstein gitano guterman hydroquebec ----> ipo
gitano guterman hydroquebec ipo ----> kia
guterman hydroquebec ipo kia ----> memotec
hydroquebec ipo kia memotec ----> mlx
ipo kia memotec mlx ----> nahb
kia memotec mlx nahb ----> punts
memotec mlx nahb punts ----> rake
mlx nahb punts rake ----> regatta
nahb punts rake regatta ----> rubens
punts rake regatta rubens ----> sim
rake regatta rubens sim ----> snackfood
regatta rubens sim snackfood ----> ssangyong
rubens sim snackfood ssangyong ----> swapo
sim snackfood ssangyong swapo ----> wachter
snackfood ssangyong swapo wachter ----> .

. . . . 

In [17]:
inputs[:10],outputs[:10]
# visualization

([[0, 0, 0, 0],
  [0, 0, 0, 216],
  [0, 0, 216, 784],
  [0, 216, 784, 927],
  [216, 784, 927, 1297],
  [784, 927, 1297, 1448],
  [927, 1297, 1448, 1662],
  [1297, 1448, 1662, 3732],
  [1448, 1662, 3732, 3878],
  [1662, 3732, 3878, 4025]],
 [216, 784, 927, 1297, 1448, 1662, 3732, 3878, 4025, 4335])

In [18]:
# # Assume you already have a vocabulary built
# vocab = sorted(set(word for sentence in p_text for word in sentence))
# vocab_size = len(vocab)

# # Create word-to-index and index-to-word mappings
# word_to_index = {word: idx for idx, word in enumerate(vocab)}
# index_to_word = {idx: word for word, idx in word_to_index.items()}

# # Convert inputs and outputs to indices
# X = [[word_to_index[word] for word in context] for context in inputs]
# y = [word_to_index[word] for word in outputs]


In [19]:
p_text[10]

['unk', 'inc']

In [20]:

# Convert inputs and outputs to numerical format (using word indices)
# X = [[word_to_index[word] for word in input_seq] for input_seq in inputs]
# y = [word_to_index[output_word] for output_word in outputs]


In [21]:
# X[:10], y[:10]

In [22]:
X = torch.tensor(inputs).to(device)
Y = torch.tensor(outputs).to(device)

In [23]:
X[:10], Y[:10]

(tensor([[   0,    0,    0,    0],
         [   0,    0,    0,  216],
         [   0,    0,  216,  784],
         [   0,  216,  784,  927],
         [ 216,  784,  927, 1297],
         [ 784,  927, 1297, 1448],
         [ 927, 1297, 1448, 1662],
         [1297, 1448, 1662, 3732],
         [1448, 1662, 3732, 3878],
         [1662, 3732, 3878, 4025]]),
 tensor([ 216,  784,  927, 1297, 1448, 1662, 3732, 3878, 4025, 4335]))

In [24]:
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([145, 4]), torch.int64, torch.Size([145]), torch.int64)

In [25]:
emb_dim = 200
emb = torch.nn.Embedding(len(word_to_index), emb_dim)

In [26]:
emb.weight

Parameter containing:
tensor([[-0.3295, -1.2150,  0.3884,  ...,  0.6324, -0.6157, -0.7371],
        [-1.1434,  0.4812, -0.5593,  ...,  0.1648,  1.1026,  1.3337],
        [ 1.3347, -1.0056, -0.4153,  ..., -1.4349, -1.3844,  0.7107],
        ...,
        [ 0.7129,  1.8080,  0.3417,  ...,  0.8112, -0.9758, -0.7952],
        [-0.1294, -0.8390, -1.1079,  ...,  2.1252, -1.1976, -0.2510],
        [ 0.4036, -0.4249,  1.4104,  ...,  0.4915, -1.3398,  0.1747]],
       requires_grad=True)

In [27]:
emb.weight.shape

torch.Size([9907, 200])

In [28]:
# def plot_emb(emb, index_to_word, ax=None):
#     if ax is None:
#         fig, ax = plt.subplots()
#     for i in range(100): # it should be len(index_to_words) but its too long
#         x, y = emb.weight[i].detach().cpu().numpy()
#         ax.scatter(x, y, color='k')
#         ax.text(x + 0.05, y + 0.05, index_to_word[i])
#     return ax

# plot_emb(emb, index_to_word)

In [29]:
class NextWord(nn.Module):
  def __init__(self, block_size, vocab_size, emb_dim, hidden_size):
    super().__init__()
    self.emb = nn.Embedding(vocab_size, emb_dim)
    self.lin1 = nn.Linear(block_size * emb_dim, hidden_size)
    self.lin2 = nn.Linear(hidden_size, vocab_size)

  def forward(self, x):
    x = self.emb(x)
    x = x.view(x.shape[0], -1)
    x = torch.sin(self.lin1(x))
    x = self.lin2(x)
    return x

In [30]:
model = NextWord(context_length, len(word_to_index), emb_dim, 10).to(device)
model = torch.compile(model)

g = torch.Generator()
g.manual_seed(4000002)
def generate_word(text,model, index_to_word, word_to_index, block_size, max_len=15):
#     context = [word_to_index["."]] * block_size  # Initialize context with the index of the "." character
    context = [word_to_index[i.lower()] for i in text.split()]
#     print(context)
    sentence = []
    
    for i in range(max_len):
        x = torch.tensor(context).view(1, -1).to(device)  # Convert context to a tensor of word indices
        y_pred = model(x)  # Get model prediction
        ix = torch.distributions.categorical.Categorical(logits=y_pred).sample().item()  # Sample a word index
        word = index_to_word[ix]  # Convert index back to word
        
        if word == '.':  # If the predicted word is a period, stop generating the sentence
            break
        
        sentence.append(word)  # Add word to the sentence
        context = context[1:] + [ix]  # Update the context for the next iteration
    
    return sentence

# Generate and print 10 sentences
for i in range(10):
    print(" ".join(generate_word("yesterday at home i",model, index_to_word, word_to_index, context_length)))


lent fannie collateralized mission rampant reacting atoms piano season everything volatile codes thick parties downtown
insiders presidential towns singapore cananea ernst full johns already ltd invest consolidation ordering passage united
manipulation ems guild affidavits parity poured believes burgess pending travelers advertiser valley irs ounces ambitions
towns n allocated pretrial stiff quack not expression conn pilot craft entrepreneurial variety based sued
gauge stood republic census doman environmentalists cans ryder rupert dell back trail berry providing might
turning conversion supermarket falling historically personalinjury exercising stay reconciliation anticipate earlier orleans underwear marking expression
trotter weakening edgar apply hanson split edelman decide nerves authors facing shame somalia 1950s proponents
torrijos ethical breakdown sotheby blue australian blocks view mines environmentally nl wohlstetter historic bellsouth dover
goodman reunification concludes go

In [31]:
for param_name, param in model.named_parameters():
    print(param_name, param.shape)

_orig_mod.emb.weight torch.Size([9907, 200])
_orig_mod.lin1.weight torch.Size([10, 800])
_orig_mod.lin1.bias torch.Size([10])
_orig_mod.lin2.weight torch.Size([9907, 10])
_orig_mod.lin2.bias torch.Size([9907])


In [32]:
loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.AdamW(model.parameters(), lr=0.01)
import time
# Mini-batch training
batch_size = 4096
print_every = 100
elapsed_time = []
for epoch in range(10000): #10000
    start_time = time.time()
    for i in range(0, X.shape[0], batch_size):
        x = X[i:i+batch_size]
        y = Y[i:i+batch_size]
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        opt.step()
        opt.zero_grad()
    end_time = time.time()
    elapsed_time.append(end_time - start_time)
    if epoch % print_every == 0:
        print(epoch, loss.item())

0 9.219916343688965
100 0.6534655094146729
200 0.33271458745002747
300 0.2473289519548416
400 0.22496284544467926
500 0.2015882432460785
600 0.19215095043182373
700 0.18759483098983765
800 0.18444962799549103
900 0.18211770057678223
1000 0.18032091856002808
1100 0.17889854311943054
1200 0.17774726450443268
1300 0.17679615318775177
1400 0.17599962651729584
1500 0.18428118526935577
1600 0.1747952252626419
1700 0.1738111674785614
1800 0.17329801619052887
1900 0.17294709384441376
2000 0.17267419397830963
2100 0.17244568467140198
2200 0.1722450703382492
2300 0.17206329107284546
2400 0.1718950867652893
2500 0.17173713445663452
2600 0.17158742249011993
2700 0.1714446246623993
2800 0.1713079959154129
2900 0.17117701470851898
3000 0.17105133831501007
3100 0.17093071341514587
3200 0.17081496119499207
3300 0.1707039177417755
3400 0.17088454961776733
3500 0.17058344185352325
3600 0.1704515963792801
3700 0.1703411340713501
3800 0.17024213075637817
3900 0.17015111446380615
4000 0.17030470073223114
4

In [3]:
text = ["today sun rises slowly", "he enjoy reading books","He will ran very","She love playing music","They are working hard","The dog run very","We visited the park","It was shining heavily","You should eat healthy","He drives the car","he did not want"]
for i in range(11):
    print(" ".join(generate_word(text[i],model, index_to_word, word_to_index, context_length)))

In [None]:
X.shape, X.dtype, Y.shape, Y.dtype