In [68]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import nltk
from nltk.tokenize import word_tokenize
from datasets import load_dataset

In [72]:
### Hyper Params
batch_size = 1
hidden_size = 120
input_size = 1

In [21]:
raw_dataset = load_dataset("jaydenccc/AI_Storyteller_Dataset", split="train")['short_story']

In [47]:
nltk.download('punkt_tab')

text_corpus = ""
for item in raw_dataset:
    text_corpus += item.lower()

tokens = word_tokenize(text_corpus)
unique_words = set(tokens)
word_to_num = {word: idx for idx, word in enumerate(unique_words)}
num_to_word = {idx: word for idx, word in enumerate(unique_words)}

[nltk_data] Downloading package punkt_tab to /home/grant/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [71]:
output_size = len(word_to_num)

In [52]:
class StoryDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [64]:
corpus = text_corpus.split(' ')

train_dataset = StoryDataset(corpus[:25000])
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

test_dataset = StoryDataset(corpus[25000:])
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [73]:
class RNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # Init hidden state (batch size * hidden size)
        batch_size = x.size(0)
        hidden = torch.zeros(1, batch_size, self.hidden_size) # 

        # Forward through net
        out, hidden = self.rnn(x, hidden)

        # Only use last hidden state for output
        out = self.fc(out[:, -1, :])
        return out