# Character-Level LSTM

In [1]:
import torch
import torch.nn.functional as F
from torch import nn, optim

In [2]:
import numpy as np

In [3]:
import os

In [4]:
%matplotlib inline
import matplotlib.pyplot as plt

In [5]:
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7fd1711d63f0>

### Device

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


## Data

In [7]:
path_to_files = "data/DivinaCommedia/Inferno"

In [8]:
# List files (sorted)
files = sorted(os.listdir(path_to_files))

# Load all files in one string
text = ""
for file in files:
    fname = os.path.join(path_to_files, file)
    with open(fname, "r") as f:
        text += f.read()
        
print(text[:500])

 Nel mezzo del cammin di nostra vita
mi ritrovai per una selva oscura,
ché la diritta via era smarrita.

Ahi quanto a dir qual era è cosa dura
esta selva selvaggia e aspra e forte
che nel pensier rinova la paura!

Tant' è amara che poco è più morte;
ma per trattar del ben ch'i' vi trovai,
dirò de l'altre cose ch'i' v'ho scorte.

Io non so ben ridir com' i' v'intrai,
tant' era pien di sonno a quel punto
che la verace via abbandonai.

Ma poi ch'i' fui al piè d'un colle giunto,
là dove terminava qu


In [9]:
print(f"Total number of characters: {len(text)}")

Total number of characters: 172243


### Tokens

In [10]:
chars = sorted(set(text))
print(f"Unique characters: {len(chars)}")

Unique characters: 69


In [11]:
int2char = dict(enumerate(chars))
print(int2char)

{0: '\n', 1: ' ', 2: '!', 3: '"', 4: "'", 5: ',', 6: '-', 7: '.', 8: ':', 9: ';', 10: '?', 11: 'A', 12: 'B', 13: 'C', 14: 'D', 15: 'E', 16: 'F', 17: 'G', 18: 'H', 19: 'I', 20: 'L', 21: 'M', 22: 'N', 23: 'O', 24: 'P', 25: 'Q', 26: 'R', 27: 'S', 28: 'T', 29: 'U', 30: 'V', 31: 'Z', 32: '`', 33: 'a', 34: 'b', 35: 'c', 36: 'd', 37: 'e', 38: 'f', 39: 'g', 40: 'h', 41: 'i', 42: 'l', 43: 'm', 44: 'n', 45: 'o', 46: 'p', 47: 'q', 48: 'r', 49: 's', 50: 't', 51: 'u', 52: 'v', 53: 'x', 54: 'z', 55: '«', 56: '»', 57: 'à', 58: 'ä', 59: 'è', 60: 'é', 61: 'ë', 62: 'ì', 63: 'ï', 64: 'ò', 65: 'ó', 66: 'ö', 67: 'ù', 68: 'ü'}


In [12]:
char2int = {char: i for i, char in int2char.items()}
print(char2int)

{'\n': 0, ' ': 1, '!': 2, '"': 3, "'": 4, ',': 5, '-': 6, '.': 7, ':': 8, ';': 9, '?': 10, 'A': 11, 'B': 12, 'C': 13, 'D': 14, 'E': 15, 'F': 16, 'G': 17, 'H': 18, 'I': 19, 'L': 20, 'M': 21, 'N': 22, 'O': 23, 'P': 24, 'Q': 25, 'R': 26, 'S': 27, 'T': 28, 'U': 29, 'V': 30, 'Z': 31, '`': 32, 'a': 33, 'b': 34, 'c': 35, 'd': 36, 'e': 37, 'f': 38, 'g': 39, 'h': 40, 'i': 41, 'l': 42, 'm': 43, 'n': 44, 'o': 45, 'p': 46, 'q': 47, 'r': 48, 's': 49, 't': 50, 'u': 51, 'v': 52, 'x': 53, 'z': 54, '«': 55, '»': 56, 'à': 57, 'ä': 58, 'è': 59, 'é': 60, 'ë': 61, 'ì': 62, 'ï': 63, 'ò': 64, 'ó': 65, 'ö': 66, 'ù': 67, 'ü': 68}


In [13]:
# Encode text mapping characters to integers
encoded = np.array([char2int[char] for char in text])

print(encoded[:250])

[ 1 22 37 42  1 43 37 54 54 45  1 36 37 42  1 35 33 43 43 41 44  1 36 41
  1 44 45 49 50 48 33  1 52 41 50 33  0 43 41  1 48 41 50 48 45 52 33 41
  1 46 37 48  1 51 44 33  1 49 37 42 52 33  1 45 49 35 51 48 33  5  0 35
 40 60  1 42 33  1 36 41 48 41 50 50 33  1 52 41 33  1 37 48 33  1 49 43
 33 48 48 41 50 33  7  0  0 11 40 41  1 47 51 33 44 50 45  1 33  1 36 41
 48  1 47 51 33 42  1 37 48 33  1 59  1 35 45 49 33  1 36 51 48 33  0 37
 49 50 33  1 49 37 42 52 33  1 49 37 42 52 33 39 39 41 33  1 37  1 33 49
 46 48 33  1 37  1 38 45 48 50 37  0 35 40 37  1 44 37 42  1 46 37 44 49
 41 37 48  1 48 41 44 45 52 33  1 42 33  1 46 33 51 48 33  2  0  0 28 33
 44 50  4  1 59  1 33 43 33 48 33  1 35 40 37  1 46 45 35 45  1 59  1 46
 41 67  1 43 45 48 50 37  9  0]


### One-Hot Encoding

In [14]:
def one_hot_encoder(data, num_labels):
    
    data = np.asarray(data)
    
    # Initialize one-hot encoding vector
    # PyTorch standard type is torch.float32
    # Declare numpy array as np.float32 to avoid conversion errors
    one_hot = np.zeros((data.size, num_labels), dtype=np.float32)
    
    one_hot[np.arange(one_hot.shape[0]), data.flatten()] = 1.0
    
    one_hot = one_hot.reshape((*data.shape, num_labels))
    
    return one_hot

In [15]:
one_hot_test = one_hot_encoder(encoded[:5], len(chars))

assert one_hot_test.shape == (5, len(chars))

In [16]:
print(encoded[:5])
print(one_hot_test)

[ 1 22 37 42  1]
[[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 

### Mini Batches

In [17]:
def iterbatches(data, batch_size, len_sequence):
    
    # Number of characters per batch
    num_char_batch = batch_size * len_sequence
    
    # Total number of characters
    num_chars = len(data)
    
    # Total number of full batches
    num_batches = num_chars // num_char_batch
    
    # Discard last charachters not filling a batch
    data = data[:num_batches * num_char_batch]
    
    # Reshape into batch_size rows
    data = data.reshape((batch_size, -1))
    
    for n in range(0, data.shape[1], len_sequence):
        
        # Input features
        inputs = data[:,n:n + len_sequence]
        
        # Target features
        # Input features shifted by one
        targets = np.zeros_like(inputs)
        targets[:,:-1] = inputs[:,1:] # Shift input by one
        try:
            targets[:,-1] = data[:,n + len_sequence] # Add last element
        except IndexError:
            targets[:,-1] = data[:,0]
        
        # Yeld 
        yield inputs, targets

In [18]:
testbatches = iterbatches(encoded, batch_size=3, len_sequence=10)

inputs, targets = next(testbatches)

assert inputs.shape == (3, 10)
assert targets.shape == (3, 10)

print(f"Input Sequence Shape:\n{inputs.shape}")
print(f"Target Sequence Shape:\n{targets.shape}")

print(f"Input Sequence:\n{inputs}")
print(f"Target Sequence:\n{targets}")

Input Sequence Shape:
(3, 10)
Target Sequence Shape:
(3, 10)
Input Sequence:
[[ 1 22 37 42  1 43 37 54 54 45]
 [ 0 49 33 37 50 50 33 44 36 45]
 [49 49 37  5  1 33 44 35 45 48]]
Target Sequence:
[[22 37 42  1 43 37 54 54 45  1]
 [49 33 37 50 50 33 44 36 45  1]
 [49 37  5  1 33 44 35 45 48  1]]


We see that the target sequence is the input sequence shifted by one on the `len_sequence` dimension (`axis=1`).

## LSTM Architecture

In [19]:
class CharLSTM(nn.Module):
    
    def __init__(self, tokens, n_hidden=256, n_layers=2, pdrop=0.5):
        super().__init__()
        
        self.n_hidden = n_hidden
        self.n_layers = n_layers
        self.pdrop = pdrop
        
        self.lstm = nn.LSTM(
            len(tokens), 
            n_hidden, 
            n_layers, 
            dropout=pdrop, 
            batch_first=True
        )
        
        self.dropout = nn.Dropout(pdrop)
        
        self.fc = nn.Linear(n_hidden, len(tokens))
        
    def forward(self, x, hidden):
        
        # Forward pass in LSTM
        output, hidden = self.lstm(x, hidden)
        
        # Dropout
        output = self.dropout(output)
        
        # Stack LSTM outputs
        output = output.view(-1, self.n_hidden)
        
        # Forward pass through fully connected layer
        output = self.fc(output)
        
        # Return output and hidden state
        return output, hidden

### Test Forward Pass

In [20]:
testlstm = CharLSTM(chars, n_hidden=64)

testbatches = iterbatches(encoded, batch_size=3, len_sequence=10)
inputs, targets = next(testbatches)

inputs = one_hot_encoder(inputs, len(chars))
inputs = torch.from_numpy(inputs)

output, hidden = testlstm(inputs, None)

assert output.shape == (3 * 10, len(chars)) # (batch_size * len_sequence, len(chars))

assert len(hidden) == 2 # hidden is a tuple with n_layers elements

for h in hidden:
    assert h.shape == (2, 3, 64) # (n_layers, batch_size, n_hidden)

## Training

In [21]:
def train(model, 
          optimizer, 
          loss_function, 
          data,
          epochs=10, 
          batch_size=10,
          len_sequence=50, 
          clip=5, 
          print_every=5,
          device=device):
    
    import time
    
    # Set model in training mode
    model.train()
    
    # Move model to devide
    model.to(device)
    
    num_tokens = len(set(data))
    
    for epoch in range(epochs):
        
        epoch_loss = 0
        
        start_time = time.time()
        
        # Initialize hidden state
        hidden = None
    
        for inputs, targets in iterbatches(data, batch_size, len_sequence):
            
            assert inputs.shape == targets.shape == (batch_size, len_sequence)
            
            inputs = one_hot_encoder(inputs, num_tokens)
            
            # Initialise tensors and move to device
            inputs = torch.from_numpy(inputs).to(device)
            targets = torch.from_numpy(targets).to(device)
            
            output, hidden = model(inputs, hidden)
            
            # Detach all hidden states from the computational graph
            # Avoid backpropagation through the entire history
            # Hidden states are stored in a tuple of size n_layers
            hidden = tuple(h.detach() for h in hidden)
            
            optimizer.zero_grad()
            
            loss = loss_function(output, targets.view(batch_size*len_sequence))
            
            loss.backward()
            
            # Accumulate epoch loss
            epoch_loss += loss.item()
            
            # Clip gradients norm
            # Prevents the exploding gradient problem
            nn.utils.clip_grad_norm_(model.parameters(), clip)
            
            # Optimise model parameters
            optimizer.step()
        else:
            stop_time = time.time()
            
            num_batches = data.size // (batch_size * len_sequence)
            
            print(f"--- Epoch {epoch:2}/{epochs:2} ---")
            print(f"Loss: {epoch_loss/num_batches:.5f}")
            print(f"Time: {stop_time-start_time:.2f} s")

In [22]:
n_hidden = 512
n_layers = 2

model = CharLSTM(chars, n_hidden, n_layers)

print(model)

CharLSTM(
  (lstm): LSTM(69, 512, num_layers=2, batch_first=True, dropout=0.5)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=512, out_features=69, bias=True)
)


In [23]:
batch_size = 128
len_sequence = 100
epochs = 25

optimizer = optim.Adam(model.parameters(), lr=0.001)

loss_function = nn.CrossEntropyLoss()

train(model, optimizer, loss_function, encoded, epochs=epochs, batch_size=batch_size, len_sequence=len_sequence)

--- Epoch  0/25 ---
Loss: 3.53031
Time: 23.25 s
--- Epoch  1/25 ---
Loss: 3.14066
Time: 25.73 s
--- Epoch  2/25 ---
Loss: 3.11649
Time: 24.41 s
--- Epoch  3/25 ---
Loss: 3.10742
Time: 24.18 s
--- Epoch  4/25 ---
Loss: 3.09931
Time: 24.84 s
--- Epoch  5/25 ---
Loss: 3.08751
Time: 24.56 s
--- Epoch  6/25 ---
Loss: 3.05609
Time: 24.54 s
--- Epoch  7/25 ---
Loss: 2.95233
Time: 25.41 s
--- Epoch  8/25 ---
Loss: 2.78189
Time: 25.51 s
--- Epoch  9/25 ---
Loss: 2.62719
Time: 25.78 s
--- Epoch 10/25 ---
Loss: 2.51607
Time: 25.73 s
--- Epoch 11/25 ---
Loss: 2.44940
Time: 26.22 s
--- Epoch 12/25 ---
Loss: 2.41050
Time: 25.81 s
--- Epoch 13/25 ---
Loss: 2.35882
Time: 24.66 s
--- Epoch 14/25 ---
Loss: 2.31868
Time: 26.55 s
--- Epoch 15/25 ---
Loss: 2.28572
Time: 26.55 s
--- Epoch 16/25 ---
Loss: 2.25638
Time: 24.94 s
--- Epoch 17/25 ---
Loss: 2.22560
Time: 26.19 s
--- Epoch 18/25 ---
Loss: 2.20141
Time: 24.57 s
--- Epoch 19/25 ---
Loss: 2.16761
Time: 25.89 s
--- Epoch 20/25 ---
Loss: 2.15345
Time: 

### Save Model

In [24]:
import os

# Make directory for models
try:
    os.mkdir("models")
except FileExistsError:
    pass

checkpoint = {
    "n_hidden": model.n_hidden,
    "n_layers": model.n_layers,
    "state_dict": model.state_dict(),
}
        
torch.save(checkpoint, "models/CharLSTM.pth")

## Predictions

In [25]:
def predict(char, model, char2int, int2char, tokens, top_k=3, hidden=None,  device=device):
    
    # Evaluation mode
    model.eval()
    
    # Transform char to tokenized input
    inputs = np.array([[char2int[char]]])
    
    # One-hot encode input
    n_tokens = len(tokens)
    inputs = one_hot_encoder(inputs, n_tokens)
    
    # Transform numpy array to torch tensor
    inputs = torch.from_numpy(inputs).to(device)
    
    with torch.no_grad():
        
        output, hidden = model(inputs, hidden)

        # Get probabilities for next character
        probabilities = F.softmax(output, dim=1)
        
        # Get top characters
        p, top_char = probabilities.topk(top_k)
        top_char = top_char.cpu().numpy().squeeze()
        
        # Select next character amont top_k most probable
        # Assign probabilities proportional to predicted probability
        p = p.cpu().numpy().squeeze()
        nextchar = np.random.choice(top_char, p=p/p.sum())
        
        # Return predicted char and hidden state
        return int2char[nextchar], hidden

### Sampling

In [26]:
def sample(model, length, prime, char2int, int2char, tokens, top_k=3, device=device):
    
    # List of prime characters
    chars = list(prime)
    
    # Initialise hidden state
    hidden = None
    
    # Run on prime
    for char in chars:
        char, h = predict(char, model, char2int, int2char, tokens, hidden=hidden, top_k=top_k, device=device)

    # Append first prediction after prime
    chars.append(char)
    
    # Use previous prediction to obtain a new prediction
    for _ in range(length):
        char, hidden = predict(chars[-1], model, char2int, int2char, tokens, hidden=hidden, top_k=top_k, device=device)
        chars.append(char)

    return "".join(chars)

In [27]:
generatedtext = sample(model, 2000, "Nel mezzo del cammin di nostra vita", char2int, int2char, chars)
print(generatedtext)

Nel mezzo del cammin di nostra vitaro sonte,
che l' meste a conto la saltore che più ch'io ch'io son che di conto, costere
che 'n conton coner son li sante lia sente,
che l'ancor di sue con lanso alleno
come che l'artra al conta altra che lansa,
seran le sola costio antrena
serte, e con che su lia sera allere
che 'l por le che sonto a piante,
che se son la por ler più che la sante,
e comina cose il sel sono alle concerto
che sonte al son se sarto che sento
comer le comenti i sol se lanso
che l'olto e ciasciaro a par lera che possa
di sa lia conter con li conto
chi lo chi con son con la selte a conto,
ch'io chi se la con conti conto la sono.

Quanto a piùnto a la santonto
de l'or con che so la ser so panso
cose cie così che la core la pesta.

La suo antron den la sento chi sondi,
se più discia la porta, sì sonta,
chi se l'anton de sen cossa costo
sonte a la porte a con le sanconte.

Questo a pense l'arto altre antro sero.

E quallo a cii seno a la sonto sente.

E di la sera a parte anto