# Shakespeare Language Model

In [295]:
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
from torch.utils.data import Dataset, DataLoader, TensorDataset
import numpy as np

import shakespeare_data as sh

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cpu'

## Fixed length input

In [296]:
# Data - refer to shakespeare_data.py for details
corpus = sh.read_corpus()
print("{}...{}".format(corpus[:203], corpus[-50:]))
print("Total character count: {}".format(len(corpus)))
chars, charmap = sh.get_charmap(corpus)
charcount = len(chars)
print("Unique character count: {}\n".format(len(chars)))
shakespeare_array = sh.map_corpus(corpus, charmap)
print(shakespeare_array.shape)
print(shakespeare_array[:17])
print(sh.to_text(shakespeare_array[:17],chars))

1609
 THE SONNETS
 by William Shakespeare
                      1
   From fairest creatures we desire increase,
   That thereby beauty's rose might never die,
   But as the riper should by time decease,
...,
   And new pervert a reconciled maid.'
 THE END

Total character count: 5551930
Unique character count: 84

(5551930,)
[12 17 11 20  0  1 45 33 30  1 44 40 39 39 30 45 44]
1609
 THE SONNETS


In [322]:
# Dataset class. Transforme raw text into a set of sequences of fixed length, and extracts inputs and targets
class TextDataset(Dataset):
    def __init__(self,text, seq_len = 200):
        n_seq = len(text) // seq_len
        text = text[:n_seq * seq_len]
        self.data = torch.tensor(text).view(-1,seq_len)
    def __getitem__(self,i):
        txt = self.data[i]
        return txt[:-1],txt[1:]
    def __len__(self):
        return self.data.size(0)

# Collate function. Transform a list of sequences into a batch. Passed as an argument to the DataLoader.
# Returns data on the format seq_len x batch_size
def collate(seq_list):
    print(f'{seq_list[0][0].shape}')
    print(f'{seq_list[1][0].unsqueeze(1).shape}')
    print(f'{seq_list[0][0].unsqueeze(1)}')
    print(f'{seq_list[1][0].unsqueeze(1)}')
    my_cat = torch.cat([seq_list[0][0].unsqueeze(1), seq_list[1][0].unsqueeze(1)], dim=1)
    print(f'my_cat: {my_cat}')
    raise
    inputs = torch.cat([s[0].unsqueeze(1) for s in seq_list],dim=1)
    targets = torch.cat([s[1].unsqueeze(1) for s in seq_list],dim=1)
    return inputs,targets


In [323]:
# Model
class CharLanguageModel(nn.Module):

    def __init__(self,vocab_size,embed_size,hidden_size, nlayers):
        super(CharLanguageModel,self).__init__()
        self.vocab_size=vocab_size
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.nlayers=nlayers
        self.embedding = nn.Embedding(vocab_size,embed_size) # Embedding layer
        self.rnn = nn.LSTM(input_size = embed_size,hidden_size=hidden_size,num_layers=nlayers) # Recurrent network
        self.scoring = nn.Linear(hidden_size,vocab_size) # Projection layer
        
    def forward(self,seq_batch): #L x N
        # returns 3D logits
        batch_size = seq_batch.size(1)
        embed = self.embedding(seq_batch) #L x N x E
        hidden = None
        output_lstm,hidden = self.rnn(embed,hidden) #L x N x H
        output_lstm_flatten = output_lstm.view(-1,self.hidden_size) #(L*N) x H
        output_flatten = self.scoring(output_lstm_flatten) #(L*N) x V
        return output_flatten.view(-1,batch_size,self.vocab_size)
    
    def generate(self,seq, n_words): # L x V
        # performs greedy search to extract and return words (one sequence).
        generated_words = []
        embed = self.embedding(seq).unsqueeze(1) # L x 1 x E
        hidden = None
        output_lstm, hidden = self.rnn(embed,hidden) # L x 1 x H
        output = output_lstm[-1] # 1 x H
        scores = self.scoring(output) # 1 x V
        _,current_word = torch.max(scores,dim=1) # 1 x 1
        generated_words.append(current_word)
        if n_words > 1:
            for i in range(n_words-1):
                embed = self.embedding(current_word).unsqueeze(0) # 1 x 1 x E
                output_lstm, hidden = self.rnn(embed,hidden) # 1 x 1 x H
                output = output_lstm[0] # 1 x H
                scores = self.scoring(output) # V
                _,current_word = torch.max(scores,dim=1) # 1
                generated_words.append(current_word)
        return torch.cat(generated_words,dim=0)
        
        

In [324]:
def train_epoch(model, optimizer, train_loader, val_loader):
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(DEVICE)
    batch_id=0
    for inputs,targets in train_loader:
        batch_id+=1
        inputs = inputs.to(DEVICE)
        targets = targets.to(DEVICE)
        outputs = model(inputs) # 3D
        loss = criterion(outputs.view(-1,outputs.size(2)),targets.view(-1)) # Loss of the flattened outputs
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch_id % 100 == 0:
            lpw = loss.item()
            print("At batch",batch_id)
            print("Training loss per word:",lpw)
            print("Training perplexity :",np.exp(lpw))
    
    val_loss = 0
    batch_id=0
    for inputs,targets in val_loader:
        batch_id+=1
        inputs = inputs.to(DEVICE)
        targets = targets.to(DEVICE)
        outputs = model(inputs)
        loss = criterion(outputs.view(-1,outputs.size(2)),targets.view(-1))
        val_loss+=loss.item()
    val_lpw = val_loss / batch_id
    print("\nValidation loss per word:",val_lpw)
    print("Validation perplexity :",np.exp(val_lpw),"\n")
    return val_lpw
    

In [325]:
model = CharLanguageModel(charcount,256,256,3)
model = model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(),lr=0.001, weight_decay=1e-6)
split = 5000000
train_dataset = TextDataset(shakespeare_array[:split])
val_dataset = TextDataset(shakespeare_array[split:])
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=64, collate_fn = collate)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=64, collate_fn = collate, drop_last=True)

In [326]:
for i in range(3):
    train_epoch(model, optimizer, train_loader, val_loader)

torch.Size([199])
torch.Size([199, 1])
tensor([[ 8],
        [ 1],
        [59],
        [70],
        [ 1],
        [80],
        [70],
        [76],
        [ 1],
        [66],
        [69],
        [70],
        [78],
        [ 1],
        [68],
        [80],
        [ 1],
        [77],
        [70],
        [64],
        [58],
        [60],
        [25],
        [ 0],
        [ 1],
        [ 1],
        [ 1],
        [27],
        [43],
        [26],
        [27],
        [26],
        [39],
        [45],
        [34],
        [40],
        [10],
        [ 1],
        [39],
        [70],
        [75],
        [ 1],
        [34],
        [10],
        [ 1],
        [48],
        [63],
        [56],
        [75],
        [ 1],
        [56],
        [73],
        [60],
        [ 1],
        [80],
        [70],
        [76],
        [25],
        [ 0],
        [ 1],
        [ 1],
        [ 1],
        [43],
        [40],
        [29],
        [30],
        [43],
        [34],
        [

RuntimeError: No active exception to reraise

In [327]:
def generate(model, seed,nwords):
    seq = sh.map_corpus(seed, charmap)
    seq = torch.tensor(seq).to(DEVICE)
    out = model.generate(seq,nwords)
    return sh.to_text(out.cpu().detach().numpy(),chars)

In [232]:
print(generate(model, "To be, or not to be, that is the q",8))

hhhhhhhh


In [233]:
print(generate(model, "Richard ", 1000))

hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhh

## Packed sequences

In [281]:
stop_character = charmap['\n']
space_character = charmap[" "]
lines = np.split(shakespeare_array, np.where(shakespeare_array == stop_character)[0]+1) # split the data in lines
shakespeare_lines = []
for s in lines:
    s_trimmed = np.trim_zeros(s-space_character)+space_character # remove space-only lines
    if len(s_trimmed)>1:
        shakespeare_lines.append(s)
for i in range(10):
    print(sh.to_text(shakespeare_lines[i],chars))
print(len(shakespeare_lines))

1609

 THE SONNETS

 by William Shakespeare

                      1

   From fairest creatures we desire increase,

   That thereby beauty's rose might never die,

   But as the riper should by time decease,

   His tender heir might bear his memory:

   But thou contracted to thine own bright eyes,

   Feed'st thy light's flame with self-substantial fuel,

114638


In [282]:
class LinesDataset(Dataset):
    def __init__(self,lines):
        self.lines=[torch.tensor(l) for l in lines]
    def __getitem__(self,i):
        line = self.lines[i]
        return line[:-1].to(DEVICE),line[1:].to(DEVICE)
    def __len__(self):
        return len(self.lines)

def collate_lines(seq_list):
    inputs,targets = zip(*seq_list)
    lens = [len(seq) for seq in inputs]
    seq_order = sorted(range(len(lens)), key=lens.__getitem__, reverse=True)
    inputs = [inputs[i] for i in seq_order]
    targets = [targets[i] for i in seq_order]
    return inputs,targets

In [291]:
# Model that takes packed sequences in training
class PackedLanguageModel(nn.Module):
    
    def __init__(self,vocab_size,embed_size,hidden_size, nlayers, stop):
        super(PackedLanguageModel,self).__init__()
        self.vocab_size=vocab_size
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.nlayers=nlayers
        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.rnn = nn.LSTM(input_size = embed_size,hidden_size=hidden_size,num_layers=nlayers) # 1 layer, batch_size = False
        self.scoring = nn.Linear(hidden_size,vocab_size)
        self.stop = stop # stop line character (\n)
    
    def forward(self,seq_list): # list
        batch_size = len(seq_list)
        lens = [len(s) for s in seq_list] # lens of all lines (already sorted)
        print(f'lens: {lens}')
        bounds = [0]
        for l in lens:
            bounds.append(bounds[-1]+l) # bounds of all lines in the concatenated sequence
        print(f'seq_list: {len(seq_list)}')
        print(f'seq_list[0]: {seq_list[0].size()}')
        print(f'seq_list[1]: {seq_list[1].size()}')
        print(f'seq_list[2]: {seq_list[2].size()}')
        # seq_list[0]: torch.Size([56])
        # seq_list[1]: torch.Size([54])
        # seq_list[2]: torch.Size([52])
        seq_concat = torch.cat(seq_list) # concatenated sequence
        print(f'seq_concat: {seq_concat.size()}')
        # seq_concat: torch.Size([2717])
        embed_concat = self.embedding(seq_concat) # concatenated embeddings
        print(f'embed_concat: {embed_concat.size()}')
        # embed_concat: torch.Size([2717, 256])
        embed_list = [embed_concat[bounds[i]:bounds[i+1]] for i in range(batch_size)] # embeddings per line
        print(f'embed_list: {len(embed_list)}')
        print(f'embed_list[0]: {embed_list[0].size()}')
        # embed_list[0]: torch.Size([56, 256])
        print(f'embed_list[1]: {embed_list[1].size()}')
        
        packed_input = rnn.pack_sequence(embed_list) # packed version
        print(f'packed_input: {packed_input}')
        hidden = None
        output_packed,hidden = self.rnn(packed_input,hidden)
        output_padded, _ = rnn.pad_packed_sequence(output_packed) # unpacked output (padded)
        output_flatten = torch.cat([output_padded[:lens[i],i] for i in range(batch_size)]) # concatenated output
        scores_flatten = self.scoring(output_flatten) # concatenated logits
        return scores_flatten # return concatenated logits
    
    def generate(self,seq, n_words): # L x V
        generated_words = []
        embed = self.embedding(seq).unsqueeze(1) # L x 1 x E
        hidden = None
        output_lstm, hidden = self.rnn(embed,hidden) # L x 1 x H
        output = output_lstm[-1] # 1 x H
        scores = self.scoring(output) # 1 x V
        _,current_word = torch.max(scores,dim=1) # 1 x 1
        generated_words.append(current_word)
        if n_words > 1:
            for i in range(n_words-1):
                embed = self.embedding(current_word).unsqueeze(0) # 1 x 1 x E
                output_lstm, hidden = self.rnn(embed,hidden) # 1 x 1 x H
                output = output_lstm[0] # 1 x H
                scores = self.scoring(output) # V
                _,current_word = torch.max(scores,dim=1) # 1
                generated_words.append(current_word)
                if current_word[0].item()==self.stop: # If end of line
                    break
        return torch.cat(generated_words,dim=0)

In [292]:
def train_epoch_packed(model, optimizer, train_loader, val_loader):
    criterion = nn.CrossEntropyLoss(reduction="sum") # sum instead of averaging, to take into account the different lengths
    criterion = criterion.to(DEVICE)
    batch_id=0
    for inputs,targets in train_loader: # lists, presorted, preloaded on GPU
        batch_id+=1
        outputs = model(inputs)
        loss = criterion(outputs,torch.cat(targets)) # criterion of the concatenated output
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch_id % 100 == 0:
            nwords = np.sum(np.array([len(l) for l in inputs]))
            lpw = loss.item() / nwords
            print("At batch",batch_id)
            print("Training loss per word:",lpw)
            print("Training perplexity :",np.exp(lpw))
    
    val_loss = 0
    batch_id=0
    nwords = 0
    for inputs,targets in val_loader:
        nwords += np.sum(np.array([len(l) for l in inputs]))
        batch_id+=1
        outputs = model(inputs)
        loss = criterion(outputs,torch.cat(targets))
        val_loss+=loss.item()
    val_lpw = val_loss / nwords
    print("\nValidation loss per word:",val_lpw)
    print("Validation perplexity :",np.exp(val_lpw),"\n")
    return val_lpw

In [293]:
model = PackedLanguageModel(charcount,256,256,3, stop=stop_character)
model = model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(),lr=0.001, weight_decay=1e-6)
split = 100000
train_dataset = LinesDataset(shakespeare_lines[:split])
val_dataset = LinesDataset(shakespeare_lines[split:])
train_loader = DataLoader(train_dataset, shuffle=False, batch_size=64, collate_fn = collate_lines)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=64, collate_fn = collate_lines, drop_last=True)

In [294]:
for i in range(1):
    train_epoch_packed(model, optimizer, train_loader, val_loader)

lens: [56, 54, 52, 52, 51, 51, 50, 50, 50, 49, 49, 49, 48, 48, 48, 48, 48, 48, 47, 47, 47, 47, 47, 47, 46, 46, 46, 46, 46, 46, 45, 45, 45, 45, 44, 44, 44, 44, 44, 44, 44, 44, 44, 43, 43, 42, 42, 41, 41, 41, 41, 40, 40, 39, 38, 37, 23, 23, 23, 23, 23, 23, 12, 4]
seq_list: 64
seq_list[0]: torch.Size([56])
seq_list[1]: torch.Size([54])
seq_list[2]: torch.Size([52])
seq_concat: torch.Size([2717])
embed_concat: torch.Size([2717, 256])
embed_list: 64
embed_list[0]: torch.Size([56, 256])
embed_list[1]: torch.Size([54, 256])
packed_input: PackedSequence(data=tensor([[ 0.7654, -0.1832, -1.5812,  ...,  0.1848,  0.4273, -0.7515],
        [ 0.7654, -0.1832, -1.5812,  ...,  0.1848,  0.4273, -0.7515],
        [ 0.7654, -0.1832, -1.5812,  ...,  0.1848,  0.4273, -0.7515],
        ...,
        [ 0.1681,  0.0924,  0.0588,  ..., -0.1643,  0.2586,  0.3805],
        [-0.1080, -0.3240,  1.0240,  ...,  2.1452,  0.0503, -0.2296],
        [ 0.5822, -1.8107,  0.7466,  ...,  0.3248, -0.2208,  0.8685]],
       gr

RuntimeError: No active exception to reraise

In [None]:
print(generate(model, "To be, or not to be, that is the q",20))

In [None]:
print(generate(model, "Richard ", 1000))