In [15]:
# !wget https://raw.githubusercontent.com/kyuz0/llm-chronicles/main/4.4%20-%20Lab%20-%20Word-Level%20RNN/fairy_tales_cleaned_most_common_5000_words.txt -O dataset.txt

In [8]:
data_path = "dataset.txt"

with open(data_path, 'r', encoding='utf8') as f:
    text = f.read()
print(len(text))

2751174


In [9]:
import string
import numpy as np 


### Tokenize data

In [10]:
def tokenize(doc):
    punctuation_to_remove = string.punctuation.replace('.', '')
    table = str.maketrans('', '', punctuation_to_remove)

    tokens = doc.split()
    split_tokens = []
    for token in tokens:
        split_tokens.extend(token.replace('.', ' .').split())
    
    tokens = [w.translate(table) for w in split_tokens]
    tokens = [word for word in tokens if word.isalpha() or word == '.']
    tokens = [word.lower() for word in tokens]

    return tokens

In [11]:
tokens = tokenize(text)
vocab = sorted(set(tokens))
vocab_size = len(vocab)


word2int = {word:i for i, word in enumerate(vocab)}
word_array = np.array(vocab)

text_encoded = np.array(
    [word2int[word] for word in tokens],
    dtype=np.int32
)

Prepare for self-supervised training

In [12]:
seq_length = 50
chunk_size = seq_length + 1

text_chunks = [text_encoded[i:i+chunk_size]
               for i in range(len(text_encoded)-chunk_size+1)]

# for seq in text_chunks[:1]:
#     input_seq = seq[:seq_length]
#     target = seq[seq_length]

### Pytorch data iteratable (dataloader)

In [13]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


class TextDataset(Dataset):
    def __init__(self, text_chunks):
        self.text_chunks = text_chunks

    def __len__(self):
        return len(self.text_chunks)
    
    def __getitem__(self, idx):
        text_chunk = self.text_chunks[idx]
        return text_chunk[:-1].long(), text_chunk[1:].long()

In [14]:
BATCH_SIZE = 64
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dataset = TextDataset(torch.tensor(text_chunks))
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

  dataset = TextDataset(torch.tensor(text_chunks))


In [15]:
class RNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, rnn_hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn_hidden_size = rnn_hidden_size
        self.rnn = nn.LSTM(embed_dim, rnn_hidden_size, batch_first=True)
        self.fc = nn.Linear(rnn_hidden_size, vocab_size)
    
    def forward(self, x, hidden, cell):
        out = self.embedding(x).unsqueeze(1)
        out, (hidden, cell) = self.rnn(out, (hidden, cell))
        out = self.fc(out).reshape(out.size(0), -1)
        return out, hidden, cell
    
    def init_hidden(self, batch_size):
        hidden = torch.zeros(1, batch_size, self.rnn_hidden_size)
        cell = torch.zeros(1, batch_size, self.rnn_hidden_size)
        return hidden.to(DEVICE), cell.to(DEVICE)

In [16]:
embed_dim = 256
rnn_hidden_size = 512

model = RNN(vocab_size, embed_dim, rnn_hidden_size)
model = model.to(DEVICE)
model

RNN(
  (embedding): Embedding(5371, 256)
  (rnn): LSTM(256, 512, batch_first=True)
  (fc): Linear(in_features=512, out_features=5371, bias=True)
)

### Train model

In [None]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

num_epochs = 15_000

model.train()
for epoch in range(num_epochs):
    hidden, cell = model.init_hidden(batch_size)
    seq_batch, target_batch = next(iter(dataloader))

    seq_batch = seq_batch.to(DEVICE)
    target_batch = target_batch.to(DEVICE)

    optimizer.zero_grad()
    loss = 0
    for w in range(seq_length):
        pred, hidden, cell = model(seq_batch[:, w], hidden, cell)
        loss += loss_fn(pred, target_batch[:, w])