In [120]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import math
import re
import json
from collections import Counter
from tqdm import tqdm
import os

In [121]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

#make our work comparable if restarted the kernel
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# torch.cuda.get_device_name(0)

cpu


In [122]:
with open("data/Alice_Adventure_In_Wonderland.txt", "r", encoding="utf-8") as f:
    txt = f.read()

print(txt[:1000])
print("Characters:", len(txt))
print("Words:", len(txt.split()))


CHAPTER I.
Down the Rabbit-Hole


Alice was beginning to get very tired of sitting by her sister on the
bank, and of having nothing to do: once or twice she had peeped into
the book her sister was reading, but it had no pictures or
conversations in it, “and what is the use of a book,” thought Alice
“without pictures or conversations?”

So she was considering in her own mind (as well as she could, for the
hot day made her feel very sleepy and stupid), whether the pleasure of
making a daisy-chain would be worth the trouble of getting up and
picking the daisies, when suddenly a White Rabbit with pink eyes ran
close by her.

There was nothing so _very_ remarkable in that; nor did Alice think it
so _very_ much out of the way to hear the Rabbit say to itself, “Oh
dear! Oh dear! I shall be late!” (when she thought it over afterwards,
it occurred to her that she ought to have wondered at this, but at the
time it all seemed quite natural); but when the Rabbit actually _took a
watch out of its w

In [123]:
DATA_PATH = "data/Alice_Adventure_In_Wonderland.txt"

def clean_text(s: str) -> str:
    s = s.lower()
    s = re.sub(r"\s+", " ", s)  # normalize whitespace
    # keep punctuation that helps generation feel natural
    s = re.sub(r"[^a-z0-9\s\.\,\!\?\;\:\'\-]", "", s)
    return s.strip()

with open(DATA_PATH, "r", encoding="utf-8") as f:
    raw = f.read()

text = clean_text(raw)
tokens = text.split()


print("Preview:", text[:300])
print("Num tokens:", len(tokens))
print("First 30 tokens:", tokens[:30])


Preview: chapter i. down the rabbit-hole alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do: once or twice she had peeped into the book her sister was reading, but it had no pictures or conversations in it, and what is the use of a book, thought alice with
Num tokens: 26381
First 30 tokens: ['chapter', 'i.', 'down', 'the', 'rabbit-hole', 'alice', 'was', 'beginning', 'to', 'get', 'very', 'tired', 'of', 'sitting', 'by', 'her', 'sister', 'on', 'the', 'bank,', 'and', 'of', 'having', 'nothing', 'to', 'do:', 'once', 'or', 'twice', 'she']


In [124]:
n = len(tokens)
train_end = int(0.8 * n)
val_end = int(0.9 * n)

train_tokens = tokens[:train_end]
valid_tokens = tokens[train_end:val_end]
test_tokens  = tokens[val_end:]


In [125]:
train_dataset = [{"tokens": train_tokens}]
valid_dataset = [{"tokens": valid_tokens}]
test_dataset  = [{"tokens": test_tokens}]

print("Train tokens:", len(train_tokens))
print("Valid tokens:", len(valid_tokens))
print("Test tokens:", len(test_tokens))


Train tokens: 21104
Valid tokens: 2638
Test tokens: 2639


In [128]:
PAD = "<pad>"
UNK = "<unk>"
EOS = "<eos>"
MIN_FREQ = 2

counter = Counter(train_tokens)

vocab_list = [PAD, UNK, EOS] + [
    w for w, c in counter.items() if c >= MIN_FREQ
]

vocab = {w: i for i, w in enumerate(vocab_list)}

print("Vocab size:", len(vocab))



Vocab size: 1554


In [129]:
# PAD = "<pad>"
# UNK = "<unk>"
# EOS = "<eos>"
# MIN_FREQ = 2   # words appearing <2 times become <unk>

# counter = Counter(tokens)

# vocab = [PAD, UNK, EOS] + [w for w, c in counter.items() if c >= MIN_FREQ]
# stoi = {w: i for i, w in enumerate(vocab)}
# itos = {i: w for w, i in stoi.items()}

# ids = [stoi.get(t, stoi[UNK]) for t in tokens]

# print("Vocab size:", len(vocab))
# print("Example ids:", ids[:20])


In [145]:
# os.makedirs("models", exist_ok=True)

# with open("models/vocab.json", "w", encoding="utf-8") as f:
#     json.dump(
#         {"stoi": stoi, "itos": {str(k): v for k, v in itos.items()}},
#         f,
#         ensure_ascii=False,
#         indent=2
#     )

# print("Saved vocab to models/vocab.json")


In [144]:
def get_data(dataset, vocab, batch_size, eos_token="<eos>"):
    data = []
    unk_id = vocab["<unk>"]

    for example in dataset:
        if example["tokens"]:
            tokens = example["tokens"] + [eos_token]
            ids = [vocab.get(tok, unk_id) for tok in tokens]
            data.extend(ids)

    data = torch.LongTensor(data)
    num_batches = data.shape[0] // batch_size
    data = data[:num_batches * batch_size]
    data = data.view(batch_size, num_batches)
    return data


In [132]:
BATCH_SIZE = 128

train_data = get_data(train_dataset, vocab, BATCH_SIZE)
valid_data = get_data(valid_dataset, vocab, BATCH_SIZE)
test_data  = get_data(test_dataset, vocab, BATCH_SIZE)

print("Train data:", train_data.shape)
print("Valid data:", valid_data.shape)
print("Test data:", test_data.shape)

Train data: torch.Size([128, 164])
Valid data: torch.Size([128, 20])
Test data: torch.Size([128, 20])


In [143]:
# def decode(id_list, itos):
#     return " ".join(itos[i] for i in id_list)

# sample_x = x[0].tolist()
# sample_y = y[0].tolist()

# print("X:", decode(sample_x[:15], itos))
# print("Y:", decode(sample_y[:15], itos))


# Modeling


In [134]:
class LSTMLanguageModel(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, num_layers, dropout_rate):
                
        super().__init__()
        self.num_layers = num_layers
        self.hid_dim = hid_dim
        self.emb_dim = emb_dim

        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.lstm = nn.LSTM(emb_dim, hid_dim, num_layers=num_layers, 
                    dropout=dropout_rate, batch_first=True)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(hid_dim, vocab_size)
        
        self.init_weights()
        
    def init_weights(self):
        init_range_emb = 0.1
        init_range_other = 1/math.sqrt(self.hid_dim)
        self.embedding.weight.data.uniform_(-init_range_emb, init_range_emb)
        self.fc.weight.data.uniform_(-init_range_other, init_range_other)
        self.fc.bias.data.zero_()
        for i in range(self.num_layers):
            self.lstm.all_weights[i][0] = torch.FloatTensor(self.emb_dim,
                    self.hid_dim).uniform_(-init_range_other, init_range_other) 
            self.lstm.all_weights[i][1] = torch.FloatTensor(self.hid_dim, 
                    self.hid_dim).uniform_(-init_range_other, init_range_other) 

    def init_hidden(self, batch_size, device):
        hidden = torch.zeros(self.num_layers, batch_size, self.hid_dim).to(device)
        cell   = torch.zeros(self.num_layers, batch_size, self.hid_dim).to(device)
        return hidden, cell
    
    def detach_hidden(self, hidden):
        hidden, cell = hidden
        hidden = hidden.detach()
        cell = cell.detach()
        return hidden, cell

    def forward(self, src, hidden):
        #src: [batch size, seq len]
        embedding = self.dropout(self.embedding(src))
        #embedding: [batch size, seq len, emb_dim]
        output, hidden = self.lstm(embedding, hidden)      
        #output: [batch size, seq len, hid_dim]
        #hidden = h, c = [num_layers * direction, seq len, hid_dim)
        output = self.dropout(output) 
        prediction = self.fc(output)
        #prediction: [batch size, seq_len, vocab size]
        return prediction, hidden

# Training

In [136]:
vocab_size = len(vocab)
emb_dim = 1024                # 400 in the paper
hid_dim = 1024                # 1150 in the paper
num_layers = 2                # 3 in the paper
dropout_rate = 0.65              
lr = 1e-3   

In [137]:
model = LSTMLanguageModel(vocab_size, emb_dim, hid_dim, num_layers, dropout_rate).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {num_params:,} trainable parameters')

The model has 19,977,746 trainable parameters


In [138]:
def get_batch(data, seq_len, idx):
    #data #[batch size, bunch of tokens]
    src    = data[:, idx:idx+seq_len]                   
    target = data[:, idx+1:idx+seq_len+1]  #target simply is ahead of src by 1            
    return src, target

In [None]:
def train(model, data, optimizer, criterion, batch_size, seq_len, clip, device):
    
    epoch_loss = 0
    model.train()
    # drop all batches that are not a multiple of seq_len
    # data #[batch size, bunch of tokens]
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]  #we need to -1 because we start at 0
    num_batches = data.shape[-1]
    
    #reset the hidden every epoch
    hidden = model.init_hidden(batch_size, device)
    
    for idx in tqdm(range(0, num_batches - 1, seq_len), desc='Training: ',leave=False):
        optimizer.zero_grad()
        
        #hidden does not need to be in the computational graph for efficiency
        hidden = model.detach_hidden(hidden)

        src, target = get_batch(data, seq_len, idx) #src, target: [batch size, seq len]
        src, target = src.to(device), target.to(device)
        batch_size = src.shape[0]
        prediction, hidden = model(src, hidden)               

        #need to reshape because criterion expects pred to be 2d and target to be 1d
        prediction = prediction.reshape(batch_size * seq_len, -1)  #prediction: [batch size * seq len, vocab size]  
        target = target.reshape(-1)
        loss = criterion(prediction, target)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item() * seq_len
        
    return epoch_loss / num_batches


In [140]:
def evaluate(model, data, criterion, batch_size, seq_len, device):

    epoch_loss = 0
    model.eval()
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)

    with torch.no_grad():
        for idx in range(0, num_batches - 1, seq_len):
            hidden = model.detach_hidden(hidden)
            src, target = get_batch(data, seq_len, idx)
            src, target = src.to(device), target.to(device)
            batch_size= src.shape[0]

            prediction, hidden = model(src, hidden)
            prediction = prediction.reshape(batch_size * seq_len, -1)
            target = target.reshape(-1)

            loss = criterion(prediction, target)
            epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

In [142]:
n_epochs = 50
seq_len  = 50 #<----decoding length
clip    = 0.25

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=0)

best_valid_loss = float('inf')

for epoch in range(n_epochs):
    train_loss = train(model, train_data, optimizer, criterion, 
                BATCH_SIZE, seq_len, clip, device)
    valid_loss = evaluate(model, valid_data, criterion, BATCH_SIZE, 
                seq_len, device)

    lr_scheduler.step(valid_loss)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'models/lm.pt')

    print(f'\tTrain Perplexity: {math.exp(train_loss):.3f}')
    print(f'\tValid Perplexity: {math.exp(valid_loss):.3f}')

                                               

KeyboardInterrupt: 