In [1]:
import random
random.seed(4224)
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
import torch
import json, os, pickle, random
from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup
from torch.optim import AdamW
from model2.utils import sample_n_items_from_pickle, read_json_as_dict, train_epoch, test_epoch
from model2.model import LSTMNextWordPredictor
from model2.dataset import WordDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_root = "./data"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
TEST_SAMPLE = int(300e3)
VOCAB_SIZE = int(5e3)

In [4]:
test = sample_n_items_from_pickle(os.path.join(data_root, "test_2.pkl"), TEST_SAMPLE)

In [5]:
genres2idx = read_json_as_dict("./data/genres2idx.json")

In [6]:
genres_values = list(genres2idx.keys())

In [7]:
test_dataset = WordDataset(test, VOCAB_SIZE, len(genres2idx))

In [17]:
batch_size = 512
embedding_dim = 64
hidden_dim = 128
num_layers = 2
dropout = 0.3
genre_size = len(genres2idx)

In [18]:
vocab = read_json_as_dict('./data/vocab_2.json')

In [19]:
chkpt = torch.load('./model2_state_dict.pth')

In [20]:
model = LSTMNextWordPredictor(
    vocab_size=VOCAB_SIZE+1, 
    embedding_dim=embedding_dim, 
    hidden_dim=hidden_dim, 
    grene_size=genre_size, 
    num_layers=num_layers, 
    dropout=dropout)
model.to(device)

LSTMNextWordPredictor(
  (seq_embedding): Embedding(5001, 64)
  (genre_embedding): Linear(in_features=618, out_features=128, bias=True)
  (lstm): LSTM(64, 128, num_layers=2, batch_first=True, dropout=0.3)
  (fc): Linear(in_features=256, out_features=5001, bias=True)
  (dropout_layer): Dropout(p=0.3, inplace=False)
)

In [21]:
model.load_state_dict(chkpt)

<All keys matched successfully>

In [91]:
def predict_next_word(model, input_sequence, genre, vocab_inverse):
    """
    Predict the next word given an input sequence of one-hot encoded words.
    
    :param model: Trained LSTM model
    :param input_sequence: List of one-hot encoded words
    :param vocab: Dictionary mapping words to indices
    :param vocab_inverse: Dictionary mapping indices to words
    :return: The predicted word
    """
    # Ensure model is in evaluation mode
    model.eval()

    # Convert input sequence to tensor
    input_tensor = input_sequence.unsqueeze(0)  # Add batch dimension
    genre = genre.unsqueeze(0)
    if torch.cuda.is_available():
        input_tensor = input_tensor.cuda()
        genre = genre.cuda()

    # Get predictions from the model
    with torch.no_grad():
        last_output = model(input_tensor, genre)

    # Convert logits to probabilities
    probabilities = torch.nn.functional.softmax(last_output, dim=1)
    # Pick the most likely word index
    #_, max_idx = torch.max(probabilities, dim=1)
    num_samples = 1
    sample = torch.multinomial(probabilities, num_samples, replacement=True).item()
    # Convert index to word
    if sample == 5000:
        return 'UNK'
    predicted_word = vocab_inverse[str(sample)]

    return predicted_word, sample

In [107]:
def genreate_text(model, input_seq, genre, vocab_values, vocab_size = VOCAB_SIZE , n = 10):
    """
    Generates text composed of n words
    
    :param model: Trained LSTM model
    :param input_sequence: List of one-hot encoded words
    :param vocab: Dictionary mapping indices to words
    :param n: number of tokens to generate
    :return: None
    """
    predicted_words = []
    _genre = torch.argmax(genre).item()
    _genre = genres_values[int(_genre)]
    print(f"Genre: {_genre}")    
    initial_text = ' '.join([vocab[str(wrd.item())] if str(wrd.item()) in vocab else 'UNK' for wrd in input_seq])
    for _ in range(n):
        predicted_wrd =  predict_next_word(model, input_seq, genre, vocab)
        if len(predicted_wrd) == 2:
            predicted_wrd, idx = predicted_wrd
        else:
            idx = 5000
            
        predicted_words.append(predicted_wrd)
        #new_word = torch.zeros((1, vocab_size))
        input_seq = torch.cat((input_seq, torch.tensor([idx])), dim=0)
        input_seq = input_seq[1:]
        
    predicted_text = ' '.join(wrd for wrd in predicted_words)
    print(f"Initial text: {initial_text}")
    print(f"Generated text: {initial_text} {predicted_text}")
    
    return initial_text, f"{initial_text} {predicted_text}"

In [118]:
lower_bound = 9545
upper_bound = 9800

In [119]:
indices = [random.randint(lower_bound, upper_bound) for _ in range(20)]

In [120]:
lst = []

In [121]:
for idx in indices:
    seq = test_dataset.__getitem__(idx)[0]
    genre = test_dataset.__getitem__(idx)[1]
    init, pred = genreate_text(model, seq, genre, vocab, n = 20)
    lst.append((init, pred))

Genre: Geography
Initial text: UNK i was a playboy bunny to the moving tribute to her mother UNK song because she could not sing
Generated text: UNK i was a playboy bunny to the moving tribute to her mother UNK song because she could not sing he do can UNK the only discovers and rest details that UNK UNK and UNK UNK to the UNK UNK
Genre: Comix
Initial text: living comes independence and UNK and in time elizabeth ann finds herself making friends and UNK her new family when
Generated text: living comes independence and UNK and in time elizabeth ann finds herself making friends and UNK her new family when a UNK nations who have practical years fate upon an UNK center of justice young many UNK were UNK UNK
Genre: Journal
Initial text: deal with you may feel UNK in your UNK with family friends and even with god maybe you are UNK
Generated text: deal with you may feel UNK in your UNK with family friends and even with god maybe you are UNK the children ago to marry the UNK boys they carries h