<a href="https://colab.research.google.com/github/ajay-1010/VisualTales-Image-Caption-Generator/blob/main/LSTM_Next_Word.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

LSTM Next Word


In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Dataset

# Sample small dataset
text = "hello world this is a simple example of an lstm based model for next word prediction using pytorch";



#Splits the text into individual words.
#Creates a vocabulary from these words.
#Maps each word to a unique index (integer) and vice versa, which is essential for feeding the text data into the LSTM model.
words = text.split()
vocab = set(words)
word_to_index = {word: i for i, word in enumerate(vocab)}
index_to_word = {i: word for i, word in enumerate(vocab)}



# Create sequences
#Generates sequences of 3 words plus the next word (i.e., the target) for training.
#For example, if seq_length is 3, the sequence might look like ["hello", "world", "this"] with the target ["is"].
seq_length = 3
sequences = []
for i in range(len(words) - seq_length):
    sequences.append(words[i:i + seq_length + 1])

# Encode sequences
#Converts the word sequences into sequences of integers using the word-to-index mapping.
#Splits the sequences into inputs X (the first 3 words) and targets y (the 4th word).
encoded_sequences = np.array([[word_to_index[word] for word in seq] for seq in sequences])



# Split into inputs and targets
X, y = encoded_sequences[:, :-1], encoded_sequences[:, -1]


# Convert to PyTorch tensors
X = torch.LongTensor(X)
y = torch.LongTensor(y)

# Create a simple dataset
#Defines a custom dataset class that PyTorch can use to handle batches of data.
#The DataLoader is used to iterate through the dataset in small batches during training.
class TextDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

dataset = TextDataset(X, y)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# Define the LSTM model
class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super(LSTMModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        lstm_out, _ = self.lstm(x)
        out = self.fc(lstm_out[:, -1, :])
        return out

# Model parameters
vocab_size = len(vocab)
embed_size = 10
hidden_size = 50

model = LSTMModel(vocab_size, embed_size, hidden_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Train the model
#The model is trained over 200 epochs.
#For each epoch, it processes each batch of data, computes the loss using cross-entropy, and updates the model's parameters using the Adam optimizer.
#Every 50 epochs, it prints out the training loss to monitor progress.
num_epochs = 200
for epoch in range(num_epochs):
    for inputs, targets in dataloader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

    if (epoch + 1) % 50 == 0:
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}')

# Text generation
#This function generates new text by starting with a given sequence and predicting the next word iteratively.
#It uses the trained model to predict the next word, appends it to the sequence, and continues until it generates the desired number of words.
#The generated text is then printed
def generate_text(model, start_sequence, num_words):
    model.eval()
    current_seq = torch.LongTensor([word_to_index[word] for word in start_sequence])

    generated_text = start_sequence.copy()
    with torch.no_grad():
        for _ in range(num_words):
            output = model(current_seq.unsqueeze(0))
            predicted_word_index = torch.argmax(output, dim=1).item()
            predicted_word = index_to_word[predicted_word_index]
            generated_text.append(predicted_word)
            current_seq = torch.cat([current_seq[1:], torch.LongTensor([predicted_word_index])])

    return ' '.join(generated_text)

start_sequence = ['this', 'is', 'a']
generated_text = generate_text(model, start_sequence, num_words=5)
print("Generated text:", generated_text)


Epoch 50/200, Loss: 0.0016
Epoch 100/200, Loss: 0.0003
Epoch 150/200, Loss: 0.0002
Epoch 200/200, Loss: 0.0002
Generated text: this is a simple example of an lstm
