In [1]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
with open('data/anna.txt', 'r') as f:
    text = f.read()

In [3]:
#char encoding
chars = tuple(set(text))
int2char = dict(enumerate(chars))
char2int = {ch: ii for ii, ch in int2char.items()}

# text encoding
encoded = np.array([char2int[ch] for ch in text])

### Pre-processamento

In [4]:
def one_hot_encode(arr, n_labels):
    
    # Inicializa array
    one_hot = np.zeros((arr.size, n_labels), dtype=np.float32)
    
    # Preenche com valor 1
    one_hot[np.arange(one_hot.shape[0]), arr.flatten()] = 1.
    
    # Reshape
    one_hot = one_hot.reshape((*arr.shape, n_labels))
    
    return one_hot

In [6]:
def get_batches(arr, batch_size, seq_length):
    
    batch_size_total = batch_size * seq_length
    n_batches = len(arr)//batch_size_total
    
    arr = arr[:n_batches * batch_size_total]
    arr = arr.reshape((batch_size, -1))
    
    for n in range(0, arr.shape[1], seq_length):
        x = arr[:, n:n+seq_length]
        y = np.zeros_like(x)
        try:
            y[:, :-1], y[:, -1] = x[:, 1:], arr[:, n+seq_length]
        except IndexError:
            y[:, :-1], y[:, -1] = x[:, 1:], arr[:, 0]
        yield x, y

### Define a arquitetura

In [7]:
class CharLSTM(nn.Module):
    
    def __init__(self, tokens, n_hidden=256, n_layers=2,
                               drop_prob=0.5, lr=0.001):
        super().__init__()
        self.drop_prob = drop_prob
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.lr = lr
        
        self.chars = tokens
        self.int2char = dict(enumerate(self.chars))
        self.char2int = {ch: ii for ii, ch in self.int2char.items()}
        
        #definir lstm input_size, hidden_size, num_layers
        self.lstm = nn.LSTM(len(self.chars), n_hidden, n_layers, 
                            dropout=drop_prob, batch_first=True)
        
        #definir dropout
        self.dropout = nn.Dropout(drop_prob)
        
        #definir camada fc num_hidden input_size
        self.fc = nn.Linear(n_hidden, len(self.chars))
      
    
    def forward(self, x, hidden):
                
        r_output, hidden = self.lstm(x, hidden)
        
        out = self.dropout(r_output)
        out = out.contiguous().view(-1, self.n_hidden)
        out = self.fc(out)
        
        return out, hidden
    
    
    def init_hidden(self, batch_size):
        # Gera tensores de tamanho n_layers x betch_size x n_hidden
        weight = next(self.parameters()).data
        hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_(),
                  weight.new(self.n_layers, batch_size, self.n_hidden).zero_())
        
        return hidden

In [8]:
def train(net, data, epochs=10, batch_size=10, seq_length=50, lr=0.001, clip=5, val_frac=0.1, print_every=10):
    net.train()
    
    opt = torch.optim.Adam(net.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    #dados de treino/validacao
    val_idx = int(len(data)*(1-val_frac))
    data, val_data = data[:val_idx], data[val_idx:]
    
    counter = 0
    n_chars = len(net.chars)
    for e in range(epochs):
        h = net.init_hidden(batch_size)
        
        for x, y in get_batches(data, batch_size, seq_length):
            counter += 1
            
            # One-hot encoding
            x = one_hot_encode(x, n_chars)
            inputs, targets = torch.from_numpy(x), torch.from_numpy(y)
            
            # Cria variáveis para hidden state 
            h = tuple([each.data for each in h])

            net.zero_grad()
            
            # saida do modelo
            output, h = net(inputs, h)
            
            loss = criterion(output, targets.view(batch_size*seq_length).long())
            loss.backward()
            
            nn.utils.clip_grad_norm_(net.parameters(), clip)
            opt.step()
            
            if counter % print_every == 0:
                val_h = net.init_hidden(batch_size)
                val_losses = []
                net.eval()
                for x, y in get_batches(val_data, batch_size, seq_length):
                    
                    x = one_hot_encode(x, n_chars)
                    x, y = torch.from_numpy(x), torch.from_numpy(y)
                    
                    val_h = tuple([each.data for each in val_h])
                    
                    inputs, targets = x, y

                    output, val_h = net(inputs, val_h)
                    val_loss = criterion(output, targets.view(batch_size*seq_length).long())
                
                    val_losses.append(val_loss.item())
                
                net.train() 
                
                print("Epoch: {}/{}...".format(e+1, epochs),
                      "Step: {}...".format(counter),
                      "Loss: {:.4f}...".format(loss.item()),
                      "Val Loss: {:.4f}".format(np.mean(val_losses)))

### Treinamento

In [15]:
n_hidden=256
n_layers=2

net = CharLSTM(chars, n_hidden, n_layers)
print(net)

CharLSTM(
  (lstm): LSTM(83, 256, num_layers=2, batch_first=True, dropout=0.5)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=256, out_features=83, bias=True)
)


In [16]:
batch_size = 128
seq_length = 100
n_epochs = 110

train(net, encoded, epochs=n_epochs, batch_size=batch_size, seq_length=seq_length, lr=0.001, print_every=10)

Epoch: 1/110... Step: 10... Loss: 3.3648... Val Loss: 3.2459
Epoch: 1/110... Step: 20... Loss: 3.1842... Val Loss: 3.1352
Epoch: 1/110... Step: 30... Loss: 3.1689... Val Loss: 3.1292
Epoch: 1/110... Step: 40... Loss: 3.1384... Val Loss: 3.1213
Epoch: 1/110... Step: 50... Loss: 3.1609... Val Loss: 3.1198
Epoch: 1/110... Step: 60... Loss: 3.1320... Val Loss: 3.1184
Epoch: 1/110... Step: 70... Loss: 3.1185... Val Loss: 3.1174
Epoch: 1/110... Step: 80... Loss: 3.1399... Val Loss: 3.1164
Epoch: 1/110... Step: 90... Loss: 3.1361... Val Loss: 3.1147
Epoch: 1/110... Step: 100... Loss: 3.1273... Val Loss: 3.1122
Epoch: 1/110... Step: 110... Loss: 3.1266... Val Loss: 3.1079
Epoch: 1/110... Step: 120... Loss: 3.1002... Val Loss: 3.1008
Epoch: 1/110... Step: 130... Loss: 3.1063... Val Loss: 3.0872
Epoch: 2/110... Step: 140... Loss: 3.0894... Val Loss: 3.0592
Epoch: 2/110... Step: 150... Loss: 3.0466... Val Loss: 3.0102
Epoch: 2/110... Step: 160... Loss: 2.9986... Val Loss: 2.9374
Epoch: 2/110... S

Epoch: 10/110... Step: 1330... Loss: 1.7670... Val Loss: 1.7113
Epoch: 10/110... Step: 1340... Loss: 1.7493... Val Loss: 1.7086
Epoch: 10/110... Step: 1350... Loss: 1.7440... Val Loss: 1.7033
Epoch: 10/110... Step: 1360... Loss: 1.7448... Val Loss: 1.7010
Epoch: 10/110... Step: 1370... Loss: 1.7368... Val Loss: 1.6981
Epoch: 10/110... Step: 1380... Loss: 1.7642... Val Loss: 1.6912
Epoch: 10/110... Step: 1390... Loss: 1.7695... Val Loss: 1.6968
Epoch: 11/110... Step: 1400... Loss: 1.7695... Val Loss: 1.6853
Epoch: 11/110... Step: 1410... Loss: 1.7752... Val Loss: 1.6816
Epoch: 11/110... Step: 1420... Loss: 1.7612... Val Loss: 1.6802
Epoch: 11/110... Step: 1430... Loss: 1.7309... Val Loss: 1.6767
Epoch: 11/110... Step: 1440... Loss: 1.7735... Val Loss: 1.6746
Epoch: 11/110... Step: 1450... Loss: 1.7031... Val Loss: 1.6731
Epoch: 11/110... Step: 1460... Loss: 1.7163... Val Loss: 1.6719
Epoch: 11/110... Step: 1470... Loss: 1.7176... Val Loss: 1.6677
Epoch: 11/110... Step: 1480... Loss: 1.7

Epoch: 19/110... Step: 2620... Loss: 1.5026... Val Loss: 1.4835
Epoch: 19/110... Step: 2630... Loss: 1.5015... Val Loss: 1.4783
Epoch: 19/110... Step: 2640... Loss: 1.5061... Val Loss: 1.4809
Epoch: 20/110... Step: 2650... Loss: 1.5231... Val Loss: 1.4800
Epoch: 20/110... Step: 2660... Loss: 1.5157... Val Loss: 1.4740
Epoch: 20/110... Step: 2670... Loss: 1.5249... Val Loss: 1.4678
Epoch: 20/110... Step: 2680... Loss: 1.5065... Val Loss: 1.4745
Epoch: 20/110... Step: 2690... Loss: 1.4980... Val Loss: 1.4712
Epoch: 20/110... Step: 2700... Loss: 1.5115... Val Loss: 1.4694
Epoch: 20/110... Step: 2710... Loss: 1.4892... Val Loss: 1.4680
Epoch: 20/110... Step: 2720... Loss: 1.4919... Val Loss: 1.4679
Epoch: 20/110... Step: 2730... Loss: 1.4784... Val Loss: 1.4663
Epoch: 20/110... Step: 2740... Loss: 1.4879... Val Loss: 1.4685
Epoch: 20/110... Step: 2750... Loss: 1.4834... Val Loss: 1.4697
Epoch: 20/110... Step: 2760... Loss: 1.4758... Val Loss: 1.4671
Epoch: 20/110... Step: 2770... Loss: 1.5

Epoch: 29/110... Step: 3910... Loss: 1.4216... Val Loss: 1.3940
Epoch: 29/110... Step: 3920... Loss: 1.4256... Val Loss: 1.3890
Epoch: 29/110... Step: 3930... Loss: 1.4359... Val Loss: 1.3912
Epoch: 29/110... Step: 3940... Loss: 1.4000... Val Loss: 1.3896
Epoch: 29/110... Step: 3950... Loss: 1.4033... Val Loss: 1.3906
Epoch: 29/110... Step: 3960... Loss: 1.3862... Val Loss: 1.3910
Epoch: 29/110... Step: 3970... Loss: 1.4310... Val Loss: 1.3933
Epoch: 29/110... Step: 3980... Loss: 1.3832... Val Loss: 1.3932
Epoch: 29/110... Step: 3990... Loss: 1.3946... Val Loss: 1.3908
Epoch: 29/110... Step: 4000... Loss: 1.3954... Val Loss: 1.3913
Epoch: 29/110... Step: 4010... Loss: 1.3821... Val Loss: 1.3934
Epoch: 29/110... Step: 4020... Loss: 1.3911... Val Loss: 1.3894
Epoch: 29/110... Step: 4030... Loss: 1.3976... Val Loss: 1.3912
Epoch: 30/110... Step: 4040... Loss: 1.4126... Val Loss: 1.3897
Epoch: 30/110... Step: 4050... Loss: 1.4141... Val Loss: 1.3874
Epoch: 30/110... Step: 4060... Loss: 1.4

Epoch: 38/110... Step: 5200... Loss: 1.3446... Val Loss: 1.3568
Epoch: 38/110... Step: 5210... Loss: 1.3576... Val Loss: 1.3546
Epoch: 38/110... Step: 5220... Loss: 1.3404... Val Loss: 1.3533
Epoch: 38/110... Step: 5230... Loss: 1.3328... Val Loss: 1.3547
Epoch: 38/110... Step: 5240... Loss: 1.3423... Val Loss: 1.3527
Epoch: 38/110... Step: 5250... Loss: 1.3359... Val Loss: 1.3523
Epoch: 38/110... Step: 5260... Loss: 1.3347... Val Loss: 1.3540
Epoch: 38/110... Step: 5270... Loss: 1.3266... Val Loss: 1.3531
Epoch: 38/110... Step: 5280... Loss: 1.3286... Val Loss: 1.3527
Epoch: 39/110... Step: 5290... Loss: 1.3358... Val Loss: 1.3571
Epoch: 39/110... Step: 5300... Loss: 1.3545... Val Loss: 1.3534
Epoch: 39/110... Step: 5310... Loss: 1.3592... Val Loss: 1.3509
Epoch: 39/110... Step: 5320... Loss: 1.3699... Val Loss: 1.3490
Epoch: 39/110... Step: 5330... Loss: 1.3317... Val Loss: 1.3494
Epoch: 39/110... Step: 5340... Loss: 1.3439... Val Loss: 1.3576
Epoch: 39/110... Step: 5350... Loss: 1.3

Epoch: 47/110... Step: 6490... Loss: 1.2933... Val Loss: 1.3348
Epoch: 47/110... Step: 6500... Loss: 1.3027... Val Loss: 1.3344
Epoch: 47/110... Step: 6510... Loss: 1.3110... Val Loss: 1.3348
Epoch: 47/110... Step: 6520... Loss: 1.3141... Val Loss: 1.3326
Epoch: 47/110... Step: 6530... Loss: 1.3182... Val Loss: 1.3329
Epoch: 48/110... Step: 6540... Loss: 1.2892... Val Loss: 1.3319
Epoch: 48/110... Step: 6550... Loss: 1.3035... Val Loss: 1.3314
Epoch: 48/110... Step: 6560... Loss: 1.2995... Val Loss: 1.3349
Epoch: 48/110... Step: 6570... Loss: 1.3292... Val Loss: 1.3307
Epoch: 48/110... Step: 6580... Loss: 1.3179... Val Loss: 1.3281
Epoch: 48/110... Step: 6590... Loss: 1.2941... Val Loss: 1.3330
Epoch: 48/110... Step: 6600... Loss: 1.3131... Val Loss: 1.3315
Epoch: 48/110... Step: 6610... Loss: 1.3004... Val Loss: 1.3313
Epoch: 48/110... Step: 6620... Loss: 1.2931... Val Loss: 1.3331
Epoch: 48/110... Step: 6630... Loss: 1.3065... Val Loss: 1.3295
Epoch: 48/110... Step: 6640... Loss: 1.3

Epoch: 56/110... Step: 7780... Loss: 1.2892... Val Loss: 1.3195
Epoch: 57/110... Step: 7790... Loss: 1.2551... Val Loss: 1.3258
Epoch: 57/110... Step: 7800... Loss: 1.2661... Val Loss: 1.3243
Epoch: 57/110... Step: 7810... Loss: 1.2617... Val Loss: 1.3173
Epoch: 57/110... Step: 7820... Loss: 1.2629... Val Loss: 1.3184
Epoch: 57/110... Step: 7830... Loss: 1.2872... Val Loss: 1.3210
Epoch: 57/110... Step: 7840... Loss: 1.2901... Val Loss: 1.3218
Epoch: 57/110... Step: 7850... Loss: 1.2787... Val Loss: 1.3231
Epoch: 57/110... Step: 7860... Loss: 1.2454... Val Loss: 1.3200
Epoch: 57/110... Step: 7870... Loss: 1.2813... Val Loss: 1.3219
Epoch: 57/110... Step: 7880... Loss: 1.2691... Val Loss: 1.3227
Epoch: 57/110... Step: 7890... Loss: 1.2689... Val Loss: 1.3216
Epoch: 57/110... Step: 7900... Loss: 1.2905... Val Loss: 1.3229
Epoch: 57/110... Step: 7910... Loss: 1.2854... Val Loss: 1.3194
Epoch: 57/110... Step: 7920... Loss: 1.2930... Val Loss: 1.3235
Epoch: 58/110... Step: 7930... Loss: 1.2

Epoch: 66/110... Step: 9070... Loss: 1.2604... Val Loss: 1.3135
Epoch: 66/110... Step: 9080... Loss: 1.2443... Val Loss: 1.3126
Epoch: 66/110... Step: 9090... Loss: 1.2509... Val Loss: 1.3175
Epoch: 66/110... Step: 9100... Loss: 1.2672... Val Loss: 1.3160
Epoch: 66/110... Step: 9110... Loss: 1.2567... Val Loss: 1.3156
Epoch: 66/110... Step: 9120... Loss: 1.2526... Val Loss: 1.3158
Epoch: 66/110... Step: 9130... Loss: 1.2596... Val Loss: 1.3121
Epoch: 66/110... Step: 9140... Loss: 1.2648... Val Loss: 1.3087
Epoch: 66/110... Step: 9150... Loss: 1.2473... Val Loss: 1.3112
Epoch: 66/110... Step: 9160... Loss: 1.2215... Val Loss: 1.3115
Epoch: 66/110... Step: 9170... Loss: 1.2665... Val Loss: 1.3115
Epoch: 67/110... Step: 9180... Loss: 1.2464... Val Loss: 1.3163
Epoch: 67/110... Step: 9190... Loss: 1.2548... Val Loss: 1.3140
Epoch: 67/110... Step: 9200... Loss: 1.2469... Val Loss: 1.3109
Epoch: 67/110... Step: 9210... Loss: 1.2450... Val Loss: 1.3107
Epoch: 67/110... Step: 9220... Loss: 1.2

KeyboardInterrupt: 

### Teste

In [17]:
def predict(net, char, h=None, top_k=None):
        
        x = np.array([[net.char2int[char]]])
        x = one_hot_encode(x, len(net.chars))
        inputs = torch.from_numpy(x)
        
        
        h = tuple([each.data for each in h])
        out, h = net(inputs, h)

        p = F.softmax(out, dim=1).data
        
        if top_k is None:
            top_ch = np.arange(len(net.chars))
        else:
            p, top_ch = p.topk(top_k)
            top_ch = top_ch.numpy().squeeze()
        
        p = p.numpy().squeeze()
        char = np.random.choice(top_ch, p=p/p.sum())
        
        return net.int2char[char], h

In [18]:
def sample(net, size, prime='The', top_k=None):
        
    net.cpu()
    
    net.eval()
    
    chars = [ch for ch in prime]
    h = net.init_hidden(1)
    for ch in prime:
        char, h = predict(net, ch, h, top_k=top_k)

    chars.append(char)
    
    for ii in range(size):
        char, h = predict(net, chars[-1], h, top_k=top_k)
        chars.append(char)

    return ''.join(chars)

In [20]:
print(sample(net, 500, prime='Anna', top_k=5))

Anna Arkadyevna as he with
her hands
the carriage and second heart of the same tenderness with shill--which as he thought of him, and standing all he had anything so at home to see the same, who had to branded his side. The muscate
son she saw the shooting
heart and her
person with the corn, and that would be a letting the constitutions with his
watch, and sood all his footmines that
had any still of any chief true, and altasided the
stream and
angrily of a mertity and a selencing child, and was suff
