In [1]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

In [50]:
with open('anna.txt', 'r') as f:
    text = f.read()

In [51]:
chars = tuple(set(text))

int2char = dict(enumerate(chars))

char2int = {ch: ii for ii, ch in int2char.items()}

encoded = np.array([char2int[ch] for ch in text])

In [72]:
def one_hot_encode(arr, n_labels):
    one_hot = np.zeros((np.multiply(*arr.shape), n_labels), dtype = np.float32)
    one_hot[np.arange(one_hot.shape[0]), arr.flatten()] = 1.
    one_hot = one_hot.reshape((*arr.shape), n_labels)
    return one_hot

In [68]:
def get_batches(arr, n_seqs, n_steps):
    batch_size = n_seqs*n_steps
    n_batches = len(arr)//batch_size
    
    arr = arr[:batch_size*n_batches]
    arr = arr.reshape((n_seqs, -1))
    for n in range(0, arr.shape[1], n_steps):
        x = arr[:, n:n+n_steps]
        y = np.zeros_like(x)
        try:
            y[:, :-1], y[:, -1] = x[:, 1:], arr[:, n+n_steps]
        except IndexError:
            y[:, :-1], y[:, -1] = x[:, -1], arr[:, 0]
        yield x, y

In [69]:
batches = get_batches(encoded, 10, 50)
x, y = next(batches)

In [87]:
class CharRNN(nn.Module):
    def __init__(self, tokens, n_steps = 100, n_hidden = 256, n_layers = 2, drop_prob = 0.5, lr = 0.001):
        super().__init__()
        self.drop_prob = drop_prob
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_steps = n_steps
        self.lr = lr
        
        self.chars = tokens
        self.int2char = dict(enumerate(self.chars))
        self.char2int = {ch: ii for ii, ch in self.int2char.items()}
        
        self.lstm = nn.LSTM(len(self.chars), n_hidden, n_layers, dropout = drop_prob, batch_first = True)
        
        self.dropout = nn.Dropout(drop_prob)
        
        self.fc = nn.Linear(n_hidden, len(self.chars))
        
        self.init_weights()
        
    def forward(self, x, hc):
        
        x, (h, c) = self.lstm(x, hc)
        
        x = self.dropout(x)
        
        x = x.view(x.size()[0]*x.size()[1], self.n_hidden)
        
        x = self.fc(x)
        
        return x, (h, c)
    
    def predict(self, char, h=None, cuda=False, top_k=None):
        if cuda:
            self.cuda()
        else:
            self.cpu()
        
        if h is None:
            h = self.init_hidden(1)
        
        x = np.array([[self.char2int[chat]]])
        x = one_hot_encode(x, len(self.chars))
        
        inputs = torch.from_numpy(x)
        
        if cuda:
            inputs = inputs.cuda()
            
        h = tuple([each.data for each in h])
        out, h = self.forward(inputs, h)
        
        p = F.softmax(out, dim=1).data
        
        
        if cuda:
            p = p.cpu()
        
        if top_k is None:
            top_ch = np.arange(len(self.chars))
        else:
            p, top_ch = p.topk(top_k)
            top_ch = top_ch.numpy().squeeze()
        
        p = p.numpy().squeeze()

        char = np.random.choice(top_ch, p=p/p.sum())
            
        return self.int2char[char], h
    
    def init_weights(self):
        initrange = 0.1
        
        self.fc.bias.data.fill_(0)
        
        self.fc.weight.data.uniform_(-1, 1)
        
    def init_hidden(self, n_seqs):
        weights = next(self.parameters()).data
        return (weight.new(self.n_layers, n_seqs, self.n_hidden).zero_(),
                weight.new(self.n_layers, n_seqs, self.n_hidden).zero_())

In [88]:
def train(net, data, epochs = 10, n_seqs = 10, n_steps = 50, lr = 0.001, clip=5, val_frac=0.1, cuda = False, print_every = 10):
    
    net.train()
    opt = torch.optim.Adam(net.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    val_idx = int(len(data))*(1-val_frac)
    data, val_data = data[:val_idx], data[val_idx:]
    
    if cuda:
        net.cuda()
        
    counter = 0
    
    n_chars = len(net.chars)
    
    for e in range(epochs):
        
        h = net.init_hidden(n_seqs)
        
        for x, y in get_batches(data, n_seqs, n_steps):
            
            counter += 1
            
            x = one_hot_encode(x, n_chars)
            
            inputs, targets = torch.from_numpy(x), torch.from_numpy(y)
            
            if cuda:
                inputs = inputs.cuda()
                targets = targets.cuda()
                
            h = tuple([each.data for each in h])
            
            net.zero_grad()
            
            output, h = net.forward(inputs, h)
            
            loss = criterion(output, targets.view(n_seqs*n_steps).type(torch.cuda.LongTensor))
            
            loss.backward()
            
            nn.utils.clip_grad_norm_(net.parameters(), clip)
            
            opt.step()
            
            
            if counter%print_every == 0:
                val_h = net.init_hidden(n_seqs)
                val_losses = []
                
                for x, y in get_batches(val_data, n_seqs, n_steps):
                    
                    x = one_hot_encode(x, len(n_chars))
                    x, y = torch.from_numpy(x), torch.from_numpy(y)
                    
                    val_h = tuple([each.data for each in val_h])
                    
                    inputs, targets = x, y
                    
                    if cuda:
                        inputs, targets = inputs.cuda(), targets.cuda()
                    
                    output, val_h = net.farword(inputs, val_h)
                    
                    val_loss = criterion(output, targets.view(n_seqs*n_steps).type(torch.cuda.LongTensor))
                    
                    val_losses.append(val_loss.item())
                    
                print("Epoch: {}/{}...".format(e+1, epochs),
                      "Step: {}...".format(counter),
                      "Loss: {:.4f}...".format(loss.item()),
                      "Val Loss: {:.4f}".format(np.mean(val_losses)))
                    

In [89]:
if 'net' in locals():
    del net

In [90]:
net = CharRNN(chars, n_hidden=512, n_layers=2)

print(net)

CharRNN(
  (lstm): LSTM(83, 512, num_layers=2, batch_first=True, dropout=0.5)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=512, out_features=83, bias=True)
)
