In [1]:
import string


all_chars = string.printable
vocab_size = len(all_chars)
vocab_dict = dict((c, i) for (i, c) in enumerate(all_chars))

def str2ints(s, vocab_dict):
    return [vocab_dict[c] for c in s]

def ints2str(x, vocab_array):
    return "".join([vocab_array[i] for i in x])

In [4]:
from torch.utils.data import Dataset


# curl https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt > tinyshakespeare.txt
class ShakespeareDataset(Dataset):
    def __init__(self, path, chunk_size=200):
        data = str2ints(open(path).read().strip(), vocab_dict)
        data = torch.LongTensor(data).split(chunk_size)
        if len(data[-1]) < chunk_size:
            data = data[:-1]
        self.data = data
        self.n_chunks = len(self.data)
        
    def __len__(self):
        return self.n_chunks
    
    def __getitem__(self, idx):
        return self.data[idx]

In [7]:
import torch
from torch.utils.data import DataLoader


ds = ShakespeareDataset('./tinyshakespeare.txt', chunk_size=200)
loader = DataLoader(ds, batch_size=32, shuffle=True, num_workers=4)

In [9]:
from torch import nn


class SequenceGenerationNet(nn.Module):
    def __init__(self, num_embeddings, embedding_dim=50, hidden_size=50, num_layers=1, dropout=0.2):
        super().__init__()
        self.emb = nn.Embedding(num_embeddings, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.linear = nn.Linear(hidden_size, num_embeddings)
        
    def forward(self, x, h0=None):
        x = self.emb(x)
        x, h = self.lstm(x, h0)
        x = self.linear(x)
        return x, h

In [10]:
def generate_seq(net, start_phrase='The King said', length=200, temperature=0.8):
    net.eval()
    result = []
    start_tensor = torch.LongTensor(str2ints(start_phrase, vocab_dict))
    x0 = V(start_tensor.unsqueeze(0), volatile=True)
    o, h = net(x0)
    out_dist = o[:, -1].data.view(-1).exp()
    top_i = torch.multinomial(out_dist, 1)[0]
    result.append(top_i)
    for i in range(length):
        inp = torch.LongTensor([[top_i]])
        o, h = net(V(inp), h)
        out_dist = o.data.view(-1).exp()
        top_i = torch.multinomial(out_dist, 1)[0]
        result.append(top_i)
    return start_phrase + ints2str(result, all_chars)

In [12]:
from torch.autograd import Variable as V
from statistics import mean
from torch import optim


net = SequenceGenerationNet(vocab_size, 20, 50, num_layers=2, dropout=0.1)
opt = optim.Adam(net.parameters())
loss_f = nn.CrossEntropyLoss()
for epoch in range(50):
    net.train()
    losses = []
    for data in loader:
        x = V(data[:, :-1])
        y = V(data[:, 1:])
        y_pred, _ = net(x)
        # https://discuss.pytorch.org/t/runtimeerror-input-is-not-contiguous/930/8
        loss = loss_f(y_pred.view(-1, vocab_size), y.contiguous().view(-1))
        net.zero_grad()
        loss.backward()
        opt.step()
        losses.append(loss.data[0])
    print('=================================================================================')
    print(epoch, mean(losses))
    print(generate_seq(net))

0 3.4779010445731027
The King said vet dteyrwRals nrn ao hpaR s,PO sde oltt  dG C r iarfwataaremoOsue 
,ruesltiT np.sa is'ee
s
ROknr m:tnybIohm   usyFMeAydv
 t snfaehT tMqio
d coainAml-w  tery.e,sneo tthyy hMlm dmoif ct More ajtyoalrc 
1 2.964444341659546
The King said tlyenium panheB sov Smeor oisut, ts:i mc,aegke fampI



	kM,RC|A,
SEagttosltaI aY taur inganlyef
'at dmoeegt

oirn wofcane.
S

HE
"I:HL-N
EA,SOT:O: evis hhal ahe.

yFd
AT Ina novt lspur toi eou nvaog 
2 2.568505411148071
The King saidirunWe
Ls teo l bole  os wouc hilmwbeu' lince,

To thiitem, and the a?t
Aann

A'P whin ewinn:-ad fldad yoyivill whees mipe anm soshh.
 l ar coust goe?

SHISTOdLHE:
O:

RIILUNYFOIOO:
The xsere toge to I
3 2.3968645572662353
The King said:
I hon sisteun:
Rovaitest
ond t!
Rond fankt tilathed
ine freul I bn lordetisitead coeg hewnlen not mh ths I hur some mat malist sheun, wy terongere drake detham thid dardat ho mit sathy hor it w tnor 
4 2.303538442339216
The King said mape heafdebs hagh ove t