In [None]:
!pip install nltk



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import Counter
from torch.utils.data import Dataset, DataLoader
from nltk.tokenize import word_tokenize
import nltk

In [None]:
document = """The townsfolk discern the existence of Sarkata after he abducts Bittu's girlfriend Chitti. The enigmatic woman, who vanished from the bus, briefly reappears to Vicky and cautions him about Sarkata. Vicky, Rudra and Bittu examine the missing pages of 'Chanderi Puran' and discover that Sarkata is the malevolent ghost of Chandrabhan, a former chieftain of Chanderi, who murdered Stree and her partner while their young daughter spectated helplessly. He was murdered by Stree but was destined to emerge from the dead after her supposed exit from the town; he despises progressive women and targets them. Realizing that they need Stree's assistance in defeating him, Vicky and Bittu meet with Jana, who was once a medium for Stree, at his werewolf cousin Bhaskar's residence and convince him to return home for a "stage play".

They search for Stree at her lair, where Sarkata attacks Jana and the latter briefly has an out-of-body experience; his soul visits Sarkata's lair, where he finds the missing woman. His soul returns to his body and the group leaves the fort with Sarkata in pursuit. The woman shows up again and fights Sarkata, having acquired Stree's powers through her braid but he destroys Stree's statue to indicate his authority. Women of the town fear getting abducted and urge Vicky to find a solution; The woman motivates Vicky and reminds him that he is the saviour of the town. She gives him a mystical dagger that is capable of killing Sarkata. To lure Sarkata, they arrange for a dance performance by Rudra's beautiful lover Shama. Sarkata appears to abduct her but Vicky fails to muster the courage to kill him; he subsequently enchants the men of the town and abducts Shama, devastating Rudra. Elsewhere, the bewitched men turn chauvinistic and begin to dominate women in their homes.

Desperate to rescue the town, the group traces the writer of the letter to a mental asylum, where they realize that he is the descendant of Sarkata. He reveals that Sarkata could only be defeated by a person, who is neither a man nor a woman but both. The woman merges her soul with Vicky for the time being and they enter Sarkata's lair. They confront Sarkata and sever his head but each dismembered part forms a new Sarkata and creates terror. Bhaskar comes to their rescue but they are still overwhelmed by Sarkata. Having no choice, the woman calls for Stree, who is revealed to be her mother. Stree arrives and presumably kills him by dragging Sarkata into lava, releasing the abducted women and disenchanting the men of Chanderi. While the townsfolk celebrate, the woman reveals herself to be a ghost to Vicky and that her true purpose is to help her mother attain salvation. She finally whispers her name in Vicky's ears and promises to meet again.

In a post-credits scene, after his fight with Sarkata, Bhaskar finds himself stranded naked in a jungle. Jana brings him clothes and learns from Bhaskar that a creature, apparently a vampire, has been wreaking havoc in Delhi. Elsewhere, the remains of Sarkata reach his descendant at the mental asylum and his ghost possesses him, implying his return.
"""


In [None]:
# Tokenization
nltk.download('punkt')
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

In [None]:
# tokenize
tokens = word_tokenize(document.lower())

In [None]:
# build vocab
vocab = {'<unk>':0}

for token in Counter(tokens).keys():
  if token not in vocab:
    vocab[token] = len(vocab)

vocab

{'<unk>': 0,
 'the': 1,
 'townsfolk': 2,
 'discern': 3,
 'existence': 4,
 'of': 5,
 'sarkata': 6,
 'after': 7,
 'he': 8,
 'abducts': 9,
 'bittu': 10,
 "'s": 11,
 'girlfriend': 12,
 'chitti': 13,
 '.': 14,
 'enigmatic': 15,
 'woman': 16,
 ',': 17,
 'who': 18,
 'vanished': 19,
 'from': 20,
 'bus': 21,
 'briefly': 22,
 'reappears': 23,
 'to': 24,
 'vicky': 25,
 'and': 26,
 'cautions': 27,
 'him': 28,
 'about': 29,
 'rudra': 30,
 'examine': 31,
 'missing': 32,
 'pages': 33,
 "'chanderi": 34,
 'puran': 35,
 "'": 36,
 'discover': 37,
 'that': 38,
 'is': 39,
 'malevolent': 40,
 'ghost': 41,
 'chandrabhan': 42,
 'a': 43,
 'former': 44,
 'chieftain': 45,
 'chanderi': 46,
 'murdered': 47,
 'stree': 48,
 'her': 49,
 'partner': 50,
 'while': 51,
 'their': 52,
 'young': 53,
 'daughter': 54,
 'spectated': 55,
 'helplessly': 56,
 'was': 57,
 'by': 58,
 'but': 59,
 'destined': 60,
 'emerge': 61,
 'dead': 62,
 'supposed': 63,
 'exit': 64,
 'town': 65,
 ';': 66,
 'despises': 67,
 'progressive': 68,
 'wo

In [None]:
len(vocab)

250

In [None]:
input_sentences = document.split('\n')

In [None]:
def text_to_indices(sentence, vocab):

  numerical_sentence = []

  for token in sentence:
    if token in vocab:
      numerical_sentence.append(vocab[token])
    else:
      numerical_sentence.append(vocab['<unk>'])

  return numerical_sentence


In [None]:
input_numerical_sentences = []

for sentence in input_sentences:
  input_numerical_sentences.append(text_to_indices(word_tokenize(sentence.lower()), vocab))

In [None]:
len(input_numerical_sentences)

8

In [None]:
training_sequence = []
for sentence in input_numerical_sentences:

  for i in range(1, len(sentence)):
    training_sequence.append(sentence[:i+1])

In [None]:
len(training_sequence)

590

In [None]:
training_sequence[:5]

[[1, 2], [1, 2, 3], [1, 2, 3, 1], [1, 2, 3, 1, 4], [1, 2, 3, 1, 4, 5]]

In [None]:
len_list = []

for sequence in training_sequence:
  len_list.append(len(sequence))

max(len_list)

188

In [None]:
training_sequence[0]

[1, 2]

In [None]:
padded_training_sequence = []
for sequence in training_sequence:

  padded_training_sequence.append([0]*(max(len_list) - len(sequence)) + sequence)

In [None]:
len(padded_training_sequence[10])

188

In [None]:
padded_training_sequence = torch.tensor(padded_training_sequence, dtype=torch.long)

In [None]:
padded_training_sequence

tensor([[  0,   0,   0,  ...,   0,   1,   2],
        [  0,   0,   0,  ...,   1,   2,   3],
        [  0,   0,   0,  ...,   2,   3,   1],
        ...,
        [  0,   0,   0,  ...,  17, 249,  85],
        [  0,   0,   0,  ..., 249,  85,  91],
        [  0,   0,   0,  ...,  85,  91,  14]])

In [None]:
X = padded_training_sequence[:, :-1]
y = padded_training_sequence[:,-1]

In [None]:
class CustomDataset(Dataset):

  def __init__(self, X, y):
    self.X = X
    self.y = y

  def __len__(self):
    return self.X.shape[0]

  def __getitem__(self, idx):
    return self.X[idx], self.y[idx]

In [None]:
dataset = CustomDataset(X,y)

In [None]:
len(dataset)


590

In [None]:
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
class LSTMModel(nn.Module):

  def __init__(self, vocab_size):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, 100)
    self.lstm = nn.LSTM(100, 150, batch_first=True)
    self.fc = nn.Linear(150, vocab_size)

  def forward(self, x):
    embedded = self.embedding(x)
    intermediate_hidden_states, (final_hidden_state, final_cell_state) = self.lstm(embedded)
    output = self.fc(final_hidden_state.squeeze(0))
    return output

In [None]:
model = LSTMModel(len(vocab))

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model.to(device)

LSTMModel(
  (embedding): Embedding(250, 100)
  (lstm): LSTM(100, 150, batch_first=True)
  (fc): Linear(in_features=150, out_features=250, bias=True)
)

In [None]:
epochs = 50
learning_rate = 0.001

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# training loop

for epoch in range(epochs):
  total_loss = 0

  for batch_x, batch_y in dataloader:

    batch_x, batch_y = batch_x.to(device), batch_y.to(device)

    optimizer.zero_grad()

    output = model(batch_x)

    loss = criterion(output, batch_y)

    loss.backward()

    optimizer.step()

    total_loss = total_loss + loss.item()

  print(f"Epoch: {epoch + 1}, Loss: {total_loss:.4f}")

Epoch: 1, Loss: 104.4522
Epoch: 2, Loss: 98.6987
Epoch: 3, Loss: 90.0627
Epoch: 4, Loss: 85.3738
Epoch: 5, Loss: 79.9599
Epoch: 6, Loss: 74.0897
Epoch: 7, Loss: 67.7216
Epoch: 8, Loss: 61.9245
Epoch: 9, Loss: 56.1294
Epoch: 10, Loss: 50.7312
Epoch: 11, Loss: 44.9495
Epoch: 12, Loss: 39.9570
Epoch: 13, Loss: 34.8350
Epoch: 14, Loss: 30.6156
Epoch: 15, Loss: 26.8648
Epoch: 16, Loss: 23.2277
Epoch: 17, Loss: 19.8709
Epoch: 18, Loss: 17.2249
Epoch: 19, Loss: 14.8629
Epoch: 20, Loss: 12.8039
Epoch: 21, Loss: 11.2506
Epoch: 22, Loss: 9.8088
Epoch: 23, Loss: 8.7131
Epoch: 24, Loss: 7.6180
Epoch: 25, Loss: 6.7422
Epoch: 26, Loss: 5.9654
Epoch: 27, Loss: 5.4120
Epoch: 28, Loss: 4.8784
Epoch: 29, Loss: 4.4256
Epoch: 30, Loss: 4.0264
Epoch: 31, Loss: 3.7100
Epoch: 32, Loss: 3.4135
Epoch: 33, Loss: 3.1358
Epoch: 34, Loss: 2.9391
Epoch: 35, Loss: 2.7126
Epoch: 36, Loss: 2.5098
Epoch: 37, Loss: 2.3477
Epoch: 38, Loss: 2.1994
Epoch: 39, Loss: 2.0711
Epoch: 40, Loss: 1.9393
Epoch: 41, Loss: 1.8205
Epo

In [None]:
# prediction

def prediction(model, vocab, text):

  # tokenize
  tokenized_text = word_tokenize(text.lower())

  # text -> numerical indices
  numerical_text = text_to_indices(tokenized_text, vocab)

  # padding
  padded_text = torch.tensor([0] * (61 - len(numerical_text)) + numerical_text, dtype=torch.long).unsqueeze(0)

  # send to model
  output = model(padded_text)

  # predicted index
  value, index = torch.max(output, dim=1)

  # merge with text
  return text + " " + list(vocab.keys())[index]

In [None]:
prediction(model, vocab, "Vicky, Rudra and Bittu play ")

'Vicky, Rudra and Bittu play  the'

In [None]:
import time

num_tokens = 10
input_text = "Sarkata attacks"

for i in range(num_tokens):
  output_text = prediction(model, vocab, input_text)
  print(output_text)
  input_text = output_text
  time.sleep(0.5)

Sarkata attacks to
Sarkata attacks to rescue
Sarkata attacks to rescue the
Sarkata attacks to rescue the town
Sarkata attacks to rescue the town ,
Sarkata attacks to rescue the town , the
Sarkata attacks to rescue the town , the group
Sarkata attacks to rescue the town , the group traces
Sarkata attacks to rescue the town , the group traces the
Sarkata attacks to rescue the town , the group traces the writer


In [None]:
dataloader1 = DataLoader(dataset, batch_size=32, shuffle=False)

In [None]:
# Function to calculate accuracy
def calculate_accuracy(model, dataloader, device):
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0

    with torch.no_grad():  # No need to compute gradients
        for batch_x, batch_y in dataloader1:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)

            # Get model predictions
            outputs = model(batch_x)

            # Get the predicted word indices
            _, predicted = torch.max(outputs, dim=1)

            # Compare with actual labels
            correct += (predicted == batch_y).sum().item()
            total += batch_y.size(0)

    accuracy = correct / total * 100
    return accuracy

# Compute accuracy
accuracy = calculate_accuracy(model, dataloader, device)
print(f"Model Accuracy: {accuracy:.2f}%")

Model Accuracy: 100.00%
