In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader
import re

In [None]:
file = r"/content/alice_in_wonderland.txt"
with open(file, 'r' , encoding='utf-8') as f:
    data = f.read()

In [None]:
data



In [None]:
def tokenize(text):
  return text.lower().split()

tokenized_text = tokenize(data)
print(tokenized_text[:10])

["alice's", 'adventures', 'in', 'wonderland', "alice's", 'adventures', 'in', 'wonderland', 'lewis', 'carroll']


In [None]:
word_to_index = {word: i+1 for i, word in enumerate(set(tokenized_text))}
index_to_word = {i: word for word, i in word_to_index.items()}

In [None]:
vocab_size = len(word_to_index)
vocab_size

4950

In [None]:
from contextlib import contextmanager
context_length = 10

input_seq = []
output_word = []

for i in range(len(tokenized_text) - context_length):
  input_seq.append([word_to_index[token] for token in tokenized_text[i:i+context_length]])
  output_word.append(word_to_index[tokenized_text[i+context_length]])
print(output_word)

[2538, 4218, 1195, 3603, 2174, 2734, 2820, 3700, 2538, 4274, 820, 2225, 164, 1244, 3195, 1837, 75, 3561, 2718, 1553, 2443, 2877, 946, 2538, 312, 3326, 3561, 413, 1515, 1244, 4327, 3331, 1161, 781, 2555, 181, 365, 4026, 2538, 30, 2443, 2877, 2225, 4013, 3605, 1, 181, 347, 3075, 1161, 4556, 3779, 2342, 4590, 3009, 4221, 2538, 2413, 3561, 4244, 4277, 1432, 820, 1324, 3075, 1161, 3526, 610, 2555, 2225, 1644, 3779, 2443, 4649, 516, 1188, 1051, 1330, 2555, 1583, 912, 2538, 2652, 2512, 4052, 2443, 2796, 1837, 1075, 3326, 1233, 3060, 2538, 4652, 3561, 2337, 4244, 4659, 2886, 86, 3602, 2538, 1612, 3561, 2228, 2268, 3326, 3181, 2538, 1637, 2274, 1260, 4244, 2338, 4524, 1086, 434, 3177, 553, 116, 1553, 2666, 4594, 2225, 1515, 610, 1837, 2077, 3779, 1639, 101, 4629, 820, 3636, 1, 610, 1837, 4597, 4735, 3561, 2538, 1557, 1244, 1375, 2538, 4524, 1999, 1244, 1179, 940, 3211, 3404, 3211, 2820, 2104, 86, 3886, 677, 2555, 1432, 1, 1292, 4187, 1, 850, 1244, 2443, 2111, 2555, 2036, 1244, 229, 1367, 1114, 

In [None]:
X = torch.tensor(input_seq, dtype=torch.long)
y = torch.tensor(output_word, dtype=torch.long)
print(X.shape, y.shape)

torch.Size([26460, 10]) torch.Size([26460])


In [None]:
X_train, x_test, Y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
Y_train, y_test

(tensor([1770,  515, 3614,  ..., 1244, 1016, 2562]),
 tensor([2864, 3561, 4735,  ..., 1244, 4643,  803]))

In [None]:
train_data = TensorDataset(X_train, Y_train)
test_data = TensorDataset(x_test, y_test)

In [None]:
print(train_data[1]) # input-output pair

(tensor([2306, 4052, 2443, 2126, 2268, 3779, 4244, 1131, 3914, 1844]), tensor(515))


In [None]:
train_data = DataLoader(
    train_data,
    shuffle=True,
    batch_size=64
)

test_data = DataLoader(
    test_data,
    batch_size=64
)

# self-attention = softmax(Q.K^T/ (d^k)**0.5) * V

In [None]:
class ScaledDotAttention(nn.Module):
  def __init__(self, hidden_size):
    super(ScaledDotAttention, self).__init__()
    self.hidden_size = hidden_size

  def forward(self, Q, K, V):
    attention_score = torch.matmul(Q, K.transpose(-2, -1))/ np.sqrt(self.hidden_size)
    attention_weights = torch.softmax(attention_score, dim=-1)
    output = torch.matmul(attention_weights, V)
    return output, attention_weights

In [None]:
class AutoCompletion(nn.Module):
  def __init__(self, vocab_size, embedding_dim, hidden_size):
    super(AutoCompletion, self).__init__()
    self.embedding = nn.Embedding(vocab_size, embedding_dim)
    self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True)
    self.attention = ScaledDotAttention(hidden_size)
    self.fc = nn.Linear(hidden_size, vocab_size)

  def forward(self, x):
    embedded = self.embedding(x)
    lstm_out, _ = self.lstm(embedded)
    attention_output, _ = self.attention(lstm_out, lstm_out, lstm_out)   # Q, K, V
    output = self.fc(attention_output[:, -1, :])
    return output

In [None]:
vocab_size = len(word_to_index)+ 1
embedding_dim = 100
hidden_size = 128
model = AutoCompletion(vocab_size, embedding_dim, hidden_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
def train_model(model, train_data, val_data, epochs):
  for i in range(epochs):
    model.train()
    total_loss = 0

    for inputs, labels in train_data:
      output = model(inputs)
      loss = criterion(output, labels)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      total_loss += loss.item()

    avg_loss = total_loss/ len(train_data)
    print(f"Epoch {i+1}/{epochs}, Loss: {avg_loss}")

    model.eval()
    val_loss = 0
    with torch.no_grad():
      for inputs, labels in val_data:
        output = model(inputs)
        loss = criterion(output, labels)
        val_loss += loss.item()

    avg_val_loss = val_loss/ len(val_data)
    print(f"Validation Loss: {avg_val_loss}")

train_model(model, train_data, test_data, 8)

Epoch 1/8, Loss: 0.007410095949738857
Validation Loss: 12.266580960836755
Epoch 2/8, Loss: 0.006356152926017375
Validation Loss: 12.302375230444483
Epoch 3/8, Loss: 0.005732576177086895
Validation Loss: 12.341338525335473


KeyboardInterrupt: 

In [None]:
def simple_tokenize(text):
  return text.lower().split()

In [None]:
def test(model, start_text, word_to_index, index_to_word, max_length):
  model.eval()

  tokenize = simple_tokenize(start_text)
  token_indices = [word_to_index.get(token, 0) for token in tokenize]
  input_tensor = torch.tensor([token_indices], dtype=torch.long)

  text_completion = start_text

  for i in range(max_length):
    with torch.no_grad():
      output = model(input_tensor)
      predicted_idx = output.argmax(1).item()
      predicted_word = index_to_word[predicted_idx]
      text_completion += " " + predicted_word

      token_indices.append(predicted_idx)
      input_tensor = torch.tensor([token_indices[-context_length:]], dtype=torch.long)
  return text_completion

In [None]:
start_text = "So she wa"
completion = test(model, start_text, word_to_index, index_to_word, 25)
print(completion)

So she wa felt heard just she heard into the jar `and however, i where that would be half an offended worth afraid of yourself and put the
