Installing the required Libraries

In [88]:
!pip install torch




[notice] A new release of pip is available: 24.2 -> 24.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [89]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
import torch.optim as optim

Hyperparameters to train the model

In [90]:
class Config:
    LEARNING_RATE = 0.01
    EPOCHS = 100
    HIDDEN_DIM = 128
    EMBEDDING_DIM = 256
    SEQ_LENGTH = 10
    NUM_LAYERS = 4
    BATCH_SIZE = 64
    UNK_TOKEN = "<UNK>"

config = Config()
    


Tokenizing the sentence into text

In [91]:
def tokenize_text(sentence):
    return sentence.lower().split()

Dataset class for text data

In [92]:
class StoryDataset(Dataset): 
    def __init__(self,data,seq_length):
        self.data = data
        self.seq_length = seq_length
    
    def __len__(self):
        return len(self.data)-self.seq_length
    
    def __getitem__(self, idx):
        return(torch.tensor(self.data[idx:idx+self.seq_length]),
               torch.tensor(self.data[idx+1:idx+self.seq_length+1]))
    

In [93]:
class TextGenerationModel(nn.Module):
    def __init__(self,vocab_size,embedding_dim,hidden_dim,num_layers,model_type="LSTM"):
        super(TextGenerationModel,self).__init__()
        self.model_type = model_type

        self.embedding = nn.Embedding(vocab_size,embedding_dim)

        if model_type=="RNN":
            self.rnn = nn.RNN(embedding_dim,hidden_dim,num_layers,batch_first=True)
        elif model_type=="LSTM":
            self.rnn = nn.LSTM(embedding_dim,hidden_dim,num_layers,batch_first=True)
        elif model_type=="GRU":
            self.rnn = nn.GRU(embedding_dim,hidden_dim,num_layers,batch_first=True)
        elif model_type=="BiLSTM":
            self.rnn = nn.LSTM(embedding_dim,hidden_dim,num_layers,batch_first=True)
        else:
            raise ValueError("Invalid model type")
        
        if model_type=="BiLSTM":
            self.fc = nn.Linear(hidden_dim*2,vocab_size)
        else:
            self.fc = nn.Linear(hidden_dim,vocab_size)
    
    def forward(self,x,hidden):
        x = self.embedding(x)
        out,hidden = self.rnn(x,hidden)
        out = self.fc(out)
        return out,hidden
    
    def hidden_init(self,batch_size):
        if self.model_type=="LSTM" or self.model_type=="BiLSTM":
            return (torch.zeros(config.NUM_LAYERS,batch_size,config.HIDDEN_DIM),
                    (torch.zeros(config.NUM_LAYERS,batch_size,config.HIDDEN_DIM)))
        else:
             return (torch.zeros(config.NUM_LAYERS,batch_size,config.HIDDEN_DIM))

In [94]:
def get_model(vocab_size,model_type="LSTM"):
    model = TextGenerationModel(vocab_size,config.EMBEDDING_DIM,config.HIDDEN_DIM,
                                config.NUM_LAYERS,model_type=model_type)
    return model

In [95]:
def train_model(model,dataloader,epochs,lr):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(),lr=lr)

    model.train()
    for epoch in range(epochs):
        for batch_idx , (inputs,targets) in enumerate(dataloader):
            batch_size = inputs.size(0)
            hidden = model.hidden_init(batch_size)

            hidden = tuple(h.detach() for h in hidden) if isinstance(hidden,tuple) else hidden.detach()

            optimizer.zero_grad()
            outputs,hidden = model(inputs,hidden)
            loss = criterion(outputs.view(-1,len(vocab)),targets.view(-1))
            loss.backward()
            optimizer.step()

            if batch_idx%10 == 0:
                print(f"Epoch [{epoch+1}/{epochs}],Step[{batch_idx}/{len(dataloader)}] , Loss: {loss.item():.4f}")
    print("Training completed")

In [96]:
def generate_text(model,start_text,length,word2idx,idx2word):
    model.eval()
    tokens = tokenize_text(start_text)

    input_seq = torch.tensor(
        [word2idx.get(word,word2idx[config.UNK_TOKEN]) for word in tokens],
        dtype = torch.long
    ).unsqueeze(0)

    hidden = model.hidden_init(1)

    generate_text = start_text

    for _ in range(length):
        output,hidden = model(input_seq,hidden)
        next_word_idx = output.argmax(dim=2)[:,-1].item()
        next_word = idx2word[next_word_idx]

        generate_text+=" "+next_word
        input_seq=torch.cat([input_seq,torch.tensor([[next_word_idx]])],dim=1)[: , -config.SEQ_LENGTH:]
        
    return generate_text

In [97]:
def prepare_dataset(story):
    tokenized_story = tokenize_text(story)
    vocab = sorted(set(tokenized_story))
    word2idx = {word : i for i,word in enumerate(vocab)}
    idx2word = {i : word for i,word in enumerate(vocab)}

    # Add UNK token to vocabulary
    if config.UNK_TOKEN not in vocab:
        vocab.append(config.UNK_TOKEN)
        word2idx[config.UNK_TOKEN] = len(word2idx)
        idx2word[len (idx2word)] = config.UNK_TOKEN
    # Convert tokenized story to indices
    data = [word2idx[word] for word in tokenized_story]
    # Create dataset and dataloader
    dataset = StoryDataset (data, config.SEQ_LENGTH)
    dataloader = DataLoader(dataset, batch_size=config.BATCH_SIZE, shuffle=True)
    return dataloader, vocab, word2idx, idx2word

In [98]:
def train_and_save_model(dataloader, vocab_size, model_type):
    model = get_model(vocab_size, model_type)
    train_model(model, dataloader, config. EPOCHS, config.LEARNING_RATE)
    # Save the model
    torch.save(model.state_dict(), f"text_generation_{model_type}.pth")
    print (f"Model saved to text_generation_{model_type}.pth")
    return model

In [99]:
def load_model(vocab_size, model_type):
    model = get_model(vocab_size, model_type)
    try:
        model.load_state_dict(torch.load(f"text_generation_{model_type}.pth"))
        print (f"Model loaded from text_generation_{model_type}.pth")
    except FileNotFoundError:
        print(f"Model file text_generation_{model_type}.pth not found. Please train the model first.")
        return None
    return model

In [103]:
def run_inference (model, start_text, length, word2idx, idx2word):
    generated_story = generate_text(model, start_text, length, word2idx, idx2word)
    print("Generated Story: \n", generated_story)

if __name__ == "__main__":

    story = """
    Once upon a time, in a land far away, there was a peaceful village surrounded by mountains.
    The villagers lived in harmony with nature. They grew crops, raised animals, and lived a simple but happy life.
    One day, a young girl named Lily discovered a mysterious cave hidden in the forest. She was curious and decided to explore.
    Inside the cave, she found glowing crystals and strange markings on the walls.
    As she ventured deeper, she realized she was not alone. """

    dataloader, vocab, word2idx, idx2word = prepare_dataset (story)
    model_type = "LSTM"
    model = train_and_save_model(dataloader, len (vocab), model_type)

    start_text = "One day a girl named Saara"
    run_inference(model, start_text, length=50, word2idx=word2idx, idx2word=idx2word)

Epoch [1/100],Step[0/2] , Loss: 4.1448
Epoch [2/100],Step[0/2] , Loss: 3.9972
Epoch [3/100],Step[0/2] , Loss: 3.7749
Epoch [4/100],Step[0/2] , Loss: 3.4803
Epoch [5/100],Step[0/2] , Loss: 3.1059
Epoch [6/100],Step[0/2] , Loss: 2.6876
Epoch [7/100],Step[0/2] , Loss: 2.3763
Epoch [8/100],Step[0/2] , Loss: 2.0375
Epoch [9/100],Step[0/2] , Loss: 1.8265
Epoch [10/100],Step[0/2] , Loss: 1.6634
Epoch [11/100],Step[0/2] , Loss: 1.5046
Epoch [12/100],Step[0/2] , Loss: 1.3694
Epoch [13/100],Step[0/2] , Loss: 1.2067
Epoch [14/100],Step[0/2] , Loss: 1.1198
Epoch [15/100],Step[0/2] , Loss: 1.0832
Epoch [16/100],Step[0/2] , Loss: 0.9560
Epoch [17/100],Step[0/2] , Loss: 0.8398
Epoch [18/100],Step[0/2] , Loss: 0.7577
Epoch [19/100],Step[0/2] , Loss: 0.6752
Epoch [20/100],Step[0/2] , Loss: 0.6039
Epoch [21/100],Step[0/2] , Loss: 0.5342
Epoch [22/100],Step[0/2] , Loss: 0.4724
Epoch [23/100],Step[0/2] , Loss: 0.4282
Epoch [24/100],Step[0/2] , Loss: 0.3741
Epoch [25/100],Step[0/2] , Loss: 0.3408
Epoch [26