In [58]:
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torch

In [169]:
df = pd.read_csv('poet.tang.csv')
df_dict = pd.read_csv('tang.dict.csv')

In [27]:
wordlist = []

for s in df.poets:
    s = s.replace("'", "").replace(", ", "").replace("[", "").replace("]", "")
    s = s.replace("，", "").replace("。", "")
    wordlist.extend(list(s))

wordcnt = {}    
word2id = {}
id2word = {}

i = 1
for c in wordlist:
    if c not in word2id:
        word2id[c] = i
        id2word[i] = c
        wordcnt[c] = 1
    else:
        wordcnt[c] +=1
    
wordlist = list(set(wordlist))
wordlist.sort()


In [28]:
word2id = {}
id2word = {}

for i in range(len(wordlist)):
    word2id[wordlist[i]] = i + 1
    id2word[i + 1] = wordlist[i]

In [171]:
def calclen(s):
    s = s.replace("'", "").replace(", ", "").replace("[", "").replace("]", "")
    s = s.replace("，", "").replace("。", "")
    return len(s)

df['lenth'] = df.poets.map(calclen)

In [234]:
wordlist[:50]

['+',
 '-',
 '/',
 '0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 '=',
 '\\',
 'a',
 'e',
 'u',
 '{',
 '|',
 '}',
 '·',
 '…',
 '□',
 '○',
 '●',
 '、',
 '〈',
 '〉',
 '《',
 '》',
 '「',
 '」',
 '『',
 '』',
 '〖',
 '〗',
 '一',
 '丁',
 '七',
 '丈',
 '三',
 '上',
 '下',
 '不',
 '与',
 '丏',
 '丐',
 '丑',
 '且']

In [185]:
df = df[df.lenth > 19].copy()

In [219]:
df.to_csv('poet.tang.csv')

In [233]:
# df.drop([48477], inplace=True)
df.iloc[43459]

authors                                                  溫庭筠
poets      ['同伴，相喚。', '杏花稀，夢裏每愁依違。', '仙客一去燕已飛，不歸，淚痕空滿衣。',...
strains    ['平仄，○仄。', '仄平平，仄仄仄平平平。', '平仄仄仄○仄平，仄平，仄平○仄○。',...
titles                                                  河傳 三
lenth                                                     55
Name: 44235, dtype: object

In [227]:
class PoetDataset(Dataset):

    def __init__(self, df):
        self.df = df.copy()

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        s = self.df.iloc[idx]['poets']
        
        s = s.replace("'", "").replace(", ", "").replace("[", "").replace("]", "")
        s = s.replace("，", "").replace("。", "")
        d = [word2id[w] for w in s]
        if len(d) == 0:
            print(idx, self.df.iloc[idx]['poets'])
            
        t = Variable(torch.from_numpy(np.array(d[1:]).copy()))
        d = Variable(torch.from_numpy(np.array(d[:len(d)-1])))
        
        return {'text': d, 'label': t}
    
poet_dataset = PoetDataset(df)
data_loader = DataLoader(poet_dataset, batch_size=1, shuffle=True, num_workers=0)

In [228]:
class PoetryModel(nn.Module):
    
    def __init__(self, vocab, hidden_size, n_cat, bs=1, nl=2):
        super().__init__()
        self.hidden_size = hidden_size
        self.bs = bs
        self.nl = nl
        self.e = nn.Embedding(n_vocab, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, nl)
        self.fc2 = nn.Linear(hidden_size, n_cat)
        self.softmax = nn.LogSoftmax(dim=-1)
        
    def forward(self, inp):
        bs = inp.size(1)
        # if bs != self.bs:
        #     self.bs = bs
            
        # print(inp.size())
        e_out = self.e(inp)
        
        # print('e_out ', e_out.size())
        h0 = c0 = Variable(e_out.data.new(*(self.nl, bs, self.hidden_size)).zero_())
        # print('h0 c0', h0.size(), c0.size())
        rnn_o,_ = self.lstm(e_out, (h0, c0)) 
        # print('rnn_o', rnn_o.size())
        rnn_o = rnn_o[-1]
        # print('rnn_o', rnn_o.size())
        fc = F.dropout(self.fc2(rnn_o), p=0.8)
        # print('fc', fc.size())
        out = self.softmax(fc)
        # print('out', out.size())
        return out

n_vocab = len(wordlist)
n_hidden = 256
n_cat = n_vocab
bs = 128

model = PoetryModel(n_vocab, n_hidden, n_cat, bs=32)
model = model.cuda()

In [231]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def fit(epoch, model, data_loader, phase='training', volatile=False):
    if phase == 'training':
        model.train()
    if phase == 'validation':
        model.eval()
        volatile=True
    running_loss = 0.0
    running_correct = 0
    for batch_idx, batch in enumerate(data_loader):
        text, target = batch['text'], batch['label']
        text, target = text.long().cuda(), target.long().cuda()
        
        if phase == 'training':
            optimizer.zero_grad()
        output = model(text)

        target = target.view(-1)

        loss = loss_fn(output, target)
        
        running_loss += loss.item()
        
        preds = output.data.max(dim=1, keepdim=True)[1]
        
        running_correct += preds.eq(target.data.view_as(preds)).cpu().sum()
        if phase == 'training':
            loss.backward()
            optimizer.step()
    
    loss = running_loss/len(data_loader.dataset)
    accuracy = 100. * running_correct/len(data_loader.dataset)
    
    print(f'{epoch} {phase} loss: {loss:.6f}, accuracy: {accuracy:.4f} ... {running_correct}/{len(data_loader)} - ')
    
    return loss, accuracy

In [232]:
train_losses , train_accuracy = [],[]
val_losses , val_accuracy = [],[]

best_loss = 7**7
for epoch in range(1, 500):

    epoch_loss, epoch_accuracy = fit(epoch, model, data_loader, phase='training')
    val_epoch_loss , val_epoch_accuracy = fit(epoch, model, data_loader, phase='validation')
    train_losses.append(epoch_loss)
    train_accuracy.append(epoch_accuracy)
    val_losses.append(val_epoch_loss)
    val_accuracy.append(val_epoch_accuracy)
    
    if epoch_loss < best_loss:
        best_loss = epoch_loss
    else:
        break
        

1 training loss: 6.489144, accuracy: 329.0000 ... 184566/56094 - 
1 validation loss: 6.208618, accuracy: 381.0000 ... 213853/56094 - 
2 training loss: 6.237712, accuracy: 378.0000 ... 212160/56094 - 
2 validation loss: 6.115645, accuracy: 385.0000 ... 216385/56094 - 
3 training loss: 6.180701, accuracy: 383.0000 ... 215129/56094 - 
3 validation loss: 6.077962, accuracy: 389.0000 ... 218559/56094 - 
4 training loss: 6.153486, accuracy: 386.0000 ... 216616/56094 - 
4 validation loss: 6.057441, accuracy: 397.0000 ... 222896/56094 - 
5 training loss: 6.138474, accuracy: 386.0000 ... 217081/56094 - 
5 validation loss: 6.045647, accuracy: 396.0000 ... 222139/56094 - 
6 training loss: 6.129436, accuracy: 388.0000 ... 217858/56094 - 
6 validation loss: 6.037052, accuracy: 393.0000 ... 220512/56094 - 
7 training loss: 6.123256, accuracy: 389.0000 ... 218424/56094 - 
7 validation loss: 6.032784, accuracy: 391.0000 ... 219814/56094 - 
8 training loss: 6.121167, accuracy: 389.0000 ... 218273/56094

In [216]:
torch.save(model.state_dict(), './model.pt')