In [89]:
import codecs
import re
import numpy as np
from config import BATCH_SIZE as batch_size
from config import TIME_STEPS as time_steps
from sklearn.preprocessing import OneHotEncoder

# =============== 读取文本内容 ===============
f = codecs.open('./data/xiyouji.txt', 'r', encoding='utf-8')
data = f.readlines()
# data = ''.join(data)

#=============== 简单的预处理 ===============
# 替换括号里的内容
pattern = re.compile(r'\(.*?\)')
data = [pattern.sub('', line) for line in data]

# 删除\n, \r,' '
data = [word.replace('.', '。') for word in data]
data = [word.replace('\r', '') for word in data]
# data = [word.replace(' ', '') for word in data]

# 删除章节名称
pattern = re.compile(r'.*?第.*?章.*')
def isNotChapterName(text):
    if pattern.search(text):
        return False
    else:
        return True
data = [word for word in data if isNotChapterName(word)]

# 省略号 => .
data = [line.replace('……', '。') for line in data if len(line) > 1]

# ==============判断char是否是乱码===================
def is_uchar(uchar):
    """判断一个unicode是否是汉字"""
    if uchar >= u'\u4e00' and uchar<=u'\u9fa5':
            return True
    """判断一个unicode是否是数字"""
    if uchar >= u'\u0030' and uchar<=u'\u0039':
            return True       
    """判断一个unicode是否是英文字母"""
    if (uchar >= u'\u0041' and uchar<=u'\u005a') or (uchar >= u'\u0061' and uchar<=u'\u007a'):
            return True
    if uchar in ('，','。','：','？','“','”','！','；','、','《','》','——'):
            return True
    return False

# 将每行的list合成一个长字符串
data = ''.join(data)
data = [char for char in data if is_uchar(char)]
data = ''.join(data)



# ==============生成字典===============
vocab = set(data)
id2char = {i:c for i,c in enumerate(vocab)}
char2id = {c:i for i,c in enumerate(vocab)}
# 总数
VOCAB_NUM = len(id2char)


# =====转换数据为数字格式======
numdata = [char2id[char] for char in data]
numdata = np.array(numdata)
batch_num = len(numdata) // (batch_size * time_steps)
# print(numdata.shape)

# print('数字数据信息：\n', numdata[:100])
# print('\n文本数据信息：\n', ''.join([id2char[i] for i in numdata[:100]]))
# print(len(numdata))

# ============= 定义 dataloader =============
def yield_data( batch_size, time_step, data= numdata):
    encoder = OneHotEncoder(categories= [range(VOCAB_NUM)])
    start = [i * time_step for i in range(batch_size)]
    end = [(i+1) * time_step for i in range(batch_size)]
    data_num = batch_size * time_step
    batch_num = len(data) // data_num
    arr = data[:data_num * batch_num]
    arr_y = np.roll(arr,-1)
    for n in range(0, len(data), data_num):
        x = arr[n: n+data_num].reshape(batch_size, time_step)
        y = arr_y[n: n+data_num].reshape(batch_size, time_step)
        x = encoder.fit_transform(x.reshape(-1,1)).toarray().reshape(x.shape[0], x.shape[1], -1)
        yield (x,y)

# define the net structure

In [91]:
import torch
from torch import nn

In [109]:
class CharRNN(nn.Module):
    def __init__(self, vocab, 
                 hidden_size= 256,
                 num_layers= 2,
                 batch_first= True,
                 keep_prob= 0.5
                ):
        # call the super init method
        super(CharRNN, self).__init__()
        
        # define property
        self.vocab = vocab
        self.id2char = dict(enumerate(self.vocab))
        self.char2id = {c:i for i,c in enumerate(self.vocab)}
        self.vocab_num = len(self.vocab)
        self.keep_prob= keep_prob
        self.hidden_size = hidden_size
        
        # net constructure
        self.lstm = nn.LSTM(
            input_size= self.vocab_num,
            hidden_size= hidden_size,
            num_layers= 2,
            batch_first= True,
            dropout= self.keep_prob
        )
        
        # dropout
        self.dropout = nn.Dropout(self.keep_prob)
        
        # linear
        self.linear = nn.Linear(
            hidden_size,
            self.vocab_num
        )
        
    def forward(self, X, hidden):
        out1, hidden = self.lstm(X, hidden)
        out2 = self.dropout(out1)
        # stack the output of the lstm block
        out3 = out2.contiguous().view(-1, self.hidden_size)
        out4 = self.linear(out3)
        return out4, hidden
    

In [157]:
from torch.autograd import Variable
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
use_gpu = True if torch.cuda.is_available() else False

In [156]:
batch_size= 16
time_step= 5
LR = 1e-2
net = CharRNN(vocab)
if use_gpu:
    net.cuda()
opt = Adam(net.parameters(), lr= LR)
criterion = CrossEntropyLoss()
print(net)
net.train()
hidden = None
for i, (x,y) in enumerate(yield_data(batch_size,time_step)):
    net.zero_grad()
    opt.zero_grad()
    input_tensor = Variable(torch.from_numpy(x).float())
    target = Variable(torch.from_numpy(y).long())
    if use_gpu:
        input_tensor = input_tensor.cuda()
        target = target.cuda()
    output, hidden = net(input_tensor, hidden)
    hidden = (hidden[0].data, hidden[1].data)
    label = target.contiguous().view(batch_size * time_step)
    loss = criterion(output, label.long())
    loss.backward()
    nn.utils.clip_grad_norm_(net.parameters(), 1.0)
    opt.step()
    if i % 10 == 0:
        print('batch [{}/{}] loss {:.3f}'.format(i, 100, loss.data))
        output_array = output.cpu().detach().numpy().reshape(batch_size * time_step, -1)
        ids = np.argmax(output_array, axis=1)
        text = ''.join([net.id2char[id_] for id_ in ids])
        print(text)

CharRNN(
  (lstm): LSTM(2958, 256, num_layers=2, batch_first=True, dropout=0.5)
  (dropout): Dropout(p=0.5)
  (linear): Linear(in_features=256, out_features=2958, bias=True)
)
batch [0/100] loss 7.985
轩诉诉圣蛾肴房房房房拦呻诉房轩鸿匣蛾匣妄匣妄阳半轩鸿诉猴要八馐嵘匣馐圣钩诉嵘匣匣匣诸诉诉诉篮瑜诸八圣赠乡妄圣鲜秦眼蛾烁紊著僧拇烧妄匣眼诉篮追瑜诉撬妄轩呻瑜钩拇糕
batch [10/100] loss 6.941
，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，
batch [20/100] loss 10.237
尊木柱惠检夺成醉靴柱脉止止寸止止尊作止享胎名擎官官同止嫉止根止俊官其馋名止柱裂脓胎变止燃寒柱仓止柱皆柱镶变士顶柱迸柱柱裂琳赐裂柱料柱止怀止学乱其盏柱止柱吐柱贝料
batch [30/100] loss 6.928
“”，，“，，，，，，，”，，王，，，，王，，“，，，“，，，王，，，，“王“，，，，，，“，，，，，王，，，美，，，“，“，“，，，，，，，“，，，，““，美
batch [40/100] loss 6.251
，，，，，，，，，，，，，，，，，，，，仙，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，
batch [50/100] loss 6.112
，，，，，，，说。，”，，，，，说，，，，，，，“，，，，，，，，，说，，。，”，，”，””，，”，”。”，，，，我，，“，，，，，，，“”，，，”，，”，，，
batch [60/100] loss 6.390
，，，，祖，，，，，，，祖，，，，，祖“。，，，不，，，，，，，不祖，，，祖，，，，，，，不，，祖，，，，，，，不，，，，，，”，，说，祖，祖，祖祖，，不，，，
batch [70/100] loss 6.483
，，，，，，，，，，，，，，，，，祖，，，，，，，，，，，，师能，，，，，，，，，，，，，，，，师

KeyboardInterrupt: 

In [154]:
VOCAB_NUM

2958

In [128]:
label.shape

torch.Size([15])

In [129]:
loss

tensor(7.9871, grad_fn=<NllLossBackward>)

In [136]:
net.zero_grad()

In [140]:
opt.zero_grad()

In [141]:
opt

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.01
    weight_decay: 0
)