In [196]:
from collections import Counter
import os

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler 
from torch.utils.data import Dataset, DataLoader

np.random.seed(100)
torch.manual_seed(100)

<torch._C.Generator at 0x1151fa430>

In [177]:
raw_data_path = 'data/all_couplets.txt'
vocabs_size = 2000

### 划分数据集

In [28]:
def split_dataset(raw_data_path, test_size=3000):
    with open(raw_data_path, 'r') as f:
        lines = f.readlines()
        
    lines = list(map(str.strip, lines))
    
    np.random.shuffle(lines)
    
    train_lines = lines[test_size:]
    test_lines = lines[:test_size]
        
    return train_lines, test_lines

In [29]:
train_lines, test_lines = split_dataset(raw_data_path, test_size=3000)

len(train_lines), len(test_lines)

(771491, 3000)

### 获取字符表

In [30]:
def create_vocabs(train_lines, size=-1):
    counter = Counter(''.join(train_lines))
    vocabs = sorted(counter, key=lambda c: counter[c], reverse=True)
    
    if size != -1:
        vocabs = vocabs[:size]
    
    print(f"last character: {vocabs[-1]}, frequency: {counter[vocabs[-1]]}")
        
    return vocabs

In [31]:
vocabs = create_vocabs(train_lines, size=vocabs_size)

last character: 辣, frequency: 770


In [212]:
def create_index_char(vocabs):
    chars = vocabs.copy()
    chars.insert(0, 'unk')
    chars.insert(0, ' ')
    chars.insert(0, 'start')

    return dict(zip(range(0, len(chars)), chars)), dict(zip(chars, range(0, len(chars))))

In [213]:
index2char, char2index = create_index_char(vocabs)

### 创建数据集

In [277]:
class Couplets_dataset(Dataset):
    def __init__(self, lines, char2index, min_len=10, max_len=20):
        index_list = []
        
        for line in lines:
            if len(line) < min_len:
                continue
                
            stop_char_index = line.index('。')
            if stop_char_index > max_len - 1:
                continue
            
            indexs = [char2index.get(c, 2) for c in line]
            
            padding = max_len - len(indexs)
            if padding > 0:
                indexs += [1] * padding
                
            index_list.append(indexs)
            
        self.data = torch.tensor(index_list)
        
    def __getitem__(self, index):
        y = self.data[index]
        x = torch.cat([torch.tensor([0]), y[:-1]])
        
        return x, y
    
    def __len__(self):
        return self.data.size(0)

In [278]:
train_set = Couplets_dataset(train_lines, char2index, max_len=30)
test_set = Couplets_dataset(test_lines, char2index, max_len=30)

In [279]:
len(train_set), len(test_set)

(646808, 2496)

In [280]:
x, y = train_set[0]

In [281]:
train_loader = DataLoader(train_set, batch_size=256, shuffle=True, num_workers=4)
test_loader = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=4)

In [282]:
for X, Y in test_loader:
    print(X.shape)

torch.Size([256, 30])
torch.Size([256, 30])
torch.Size([256, 30])
torch.Size([256, 30])
torch.Size([256, 30])
torch.Size([256, 30])
torch.Size([256, 30])
torch.Size([256, 30])
torch.Size([256, 30])
torch.Size([192, 30])


### 创建LSTM网络

In [428]:
class Couplets_net(nn.Module):
    def __init__(self, vocabs_size, embedding_dim=100, hidden_dim=200, num_layers=2):
        super().__init__()
        
        self.vocabs_size = vocabs_size
        self.hidden_dim = hidden_dim
        
        self.embedding = nn.Embedding(vocabs_size, embedding_dim)
        
        self.lstm_cell_0 = nn.LSTMCell(embedding_dim, hidden_dim)
        self.fc_0 = nn.Linear(hidden_dim, hidden_dim)
        self.relu_0 = nn.ReLU(True)
        self.bn_0 = nn.BatchNorm1d(hidden_dim)
        self.lstm_cell_1 = nn.LSTMCell(hidden_dim, hidden_dim)
        self.fc_1 = nn.Linear(hidden_dim, vocabs_size)

    def forward(self, X):   
        X = self.embedding(X)
        
        h_0 = torch.zeros(X.size(0), self.hidden_dim)
        c_0 = torch.zeros(X.size(0), self.hidden_dim)
        h_1 = torch.zeros(X.size(0), self.hidden_dim)
        c_1 = torch.zeros(X.size(0), self.hidden_dim)
        
        Y_out = []
        
        for i in range(X.size(1)):
            X_step = X[:, i]
                        
            h_0, c_0 = self.lstm_cell_0(X_step, (h_0, c_0))
            X_step = self.fc_0(h_0)
            X_step = self.relu_0(X_step)
            X_step = self.bn_0(X_step)
            
            h_1, c_1 = self.lstm_cell_1(X_step, (h_1, c_1))
            X_step = self.fc_1(h_1)
            
            Y_out.append(X_step)
            
        return torch.stack(Y_out).transpose(0, 1)
    
    def sample(self, stop_index):
        X_step = torch.zeros(1, dtype=torch.long)
        
        h_0 = torch.zeros(1, self.hidden_dim)
        c_0 = torch.zeros(1, self.hidden_dim)
        h_1 = torch.zeros(1, self.hidden_dim)
        c_1 = torch.zeros(1, self.hidden_dim)
        
        Y_out = []
        
        while True:
            X_step = self.embedding(X_step.view(-1, 1))[:, 0]
            
            h_0, c_0 = self.lstm_cell_0(X_step, (h_0, c_0))
            X_step = self.fc_0(h_0)
            X_step = self.relu_0(X_step)
            X_step = self.bn_0(X_step)
            
            h_1, c_1 = self.lstm_cell_1(X_step, (h_1, c_1))
            X_step = self.fc_1(h_1)
            X_step = nn.functional.softmax(X_step, dim=-1)
            
            X_step = torch.multinomial(X_step, 1)[:]
            
            if X_step.item() == stop_index:
                break
                
            Y_out.append(X_step.squeeze())
            
        return torch.stack(Y_out)

In [429]:
model = Couplets_net(len(char2index))

In [430]:
X, Y = next(iter(train_loader))

In [431]:
model(X).reshape((-1, 2003)).dtype

torch.float32

In [432]:
Y.dtype

torch.int64

### 训练

In [447]:
class Learner:
    def __init__(self, model):
        if torch.cuda.is_available():
            self.model = model.to('cuda')
        else:
            self.model = model

    def fit(self, dataloader, lr, epochs, weight_decay=0, print_steps=200):
        self.model.train()

        loss_fn = nn.CrossEntropyLoss()
        optimizer = optim.SGD(self.model.parameters(), lr, momentum=0.9,
                              weight_decay=weight_decay, nesterov=False)
        scheduler = lr_scheduler.OneCycleLR(optimizer, lr, epochs=epochs,
                                            steps_per_epoch=len(dataloader))

        history_loss = []
        history_steps = []
        for epoch in range(epochs):
            for step, (X, Y) in enumerate(dataloader):
                if torch.cuda.is_available():
                    X, Y = X.to('cuda'), Y.to('cuda')

                outputs = self.model(X)
                outputs = outputs.reshape((-1, outputs.size(-1)))
                Y = Y.reshape((-1))
                loss = loss_fn(outputs, Y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()

                if step % print_steps == print_steps - 1:
                    history_loss.append(loss.item())
                    history_steps.append(epoch * len(dataloader) + step + 1)
                    print(f"epoch: {epoch + 1}    \tstep: {step + 1}    \tloss: {loss:.4f}")

        return history_steps, history_loss

    def evaluate(self, dataloader):
        self.model.eval()
        
        loss_fn = nn.CrossEntropyLoss()

        losses = []

        with torch.no_grad():
            for X, Y in dataloader:
                if torch.cuda.is_available():
                    X, Y = X.to('cuda'), Y.to('cuda')

                outputs = self.model(X)
                outputs = outputs.reshape((-1, outputs.size(-1)))
                Y = Y.reshape((-1))
                loss = loss_fn(outputs, Y)
                losses.append(loss)

        return np.average(losses)

    
    def sample(self, times=8):
        with torch.no_grad():
            for _ in range(8):
                samples = self.model.sample(char2index['。']).numpy()
                chars = [index2char[index] for index in samples]
                sentence = ''.join(chars)
            
                print(sentence)

In [448]:
learn = Learner(model)

In [443]:
learn.fit(train_loader, 0.01, 1, weight_decay=0, print_steps=10)

epoch: 1    	step: 10    	loss: 7.4414
epoch: 1    	step: 20    	loss: 7.4122
epoch: 1    	step: 30    	loss: 7.3665
epoch: 1    	step: 40    	loss: 7.3045
epoch: 1    	step: 50    	loss: 7.2217
epoch: 1    	step: 60    	loss: 7.1418


KeyboardInterrupt: 

In [446]:
learn.evaluate(test_loader)

6.658515

In [449]:
learn.sample(char2index['。'])

环沁交泊付被艺地倡笔朱哀霭凝浦凄图遭调完净医老闻谊径赤散续幅待夺止锋宴承勇践校篱仲袅渭蝶峰珠儒沃倭赞军幽袖船祥施赐擂每舌栋窥九烛达不凯贞廉麦巢声临完娱点励迹吊卿米杯苑弱李钓待汗榻息服走展任蝶野职烹桂重伟形枫启嫩荡装湾联光弯烛欲练敢达故帮俗能烈商建洛作见罗廉曰完阵浪扁缘诵洛聚丑灵问惧跳弱兄透虹校湘促扰选远鲁熟依优弘闭稀铜透冲柏强初碧纶吟障unk田荷还嫌疆谓下礼勿角税假忆口笼荡影至练巧既底坚已昏压焕乘肥旭汝执缠渭笋月萍朵陆栽将射帆淘凯明符乘友植希泪负力愁歌肃守知早衫郡比日淮横蜜敢谱弓蘸院朋条滴挂水略狗亮堆悟昔授震料悲黛资衫需旁话臻针你复史简器勾呈逸屋倡榭散愿赵寄舞把质澄持建start浅省红宁蓝溢回共臻坤迹调笙歌其猴杨参理漫哪嫩蝴施题采攀耸绽壁虹葩槐兼崖试涤研尖些柳乡真傲使倦径在调煌设往立欢爱梁蝶徒毛忧空未洲宗恼遭令峦团叹廊态灭技壁楼丹即舜戴味吐计茶黑健滩青侣菜植仍孤滨越馨旅翁毫狼描撒坐诵桑寸庄投超昭户破火贞溯昆箫郡字木兔奇鸡稳逆践古尽靠雷霓飞亚佑铺祠盈寇叠甚送沉蝶妨芬晖旭景何船意睡哥笔资完阁藤灭落啼拼果助范贫属枪蔚惜飘unk逐辟蝴器辛兄谋舜端弥孝陵微村轩竟履千上妹柱
犹井至面数虑铁璧显换贼原涛戈远赖泼镜饮煌落罗开宿渡谷弦元寸朗灌胡涌瀑徒浊园遗尘伟杜视楹炮施字汤背杨衡寰呼拾轮备斧桂棒隔袖哀涌楹腐集鹊力界层亲熏阑字造播哲桥音集暖楹浮铜栏宝微妇踏破灌从显府半彩阴贝著真界统识责铜晴断溢切轮寨-红做巢梢孤及忙私暑顺计泼共铜声昔亏再残得宜非观源案达典牛访授拓狂网景烂峻秋婉抱吟枕此迎澄击享秦菲想鸟明朵秩良竟朱冷凄陵翼走葱煮煌既空踪寺新辨杖目藕滨尽慨抗焕织事彰波阴低段享疑尤顽塞盈培莫砚碗残顽盛余朱土酌羞条径欲鹊蜂怅院蘸伤第拂牙襟嫦赤远差急垂耸般土蒙长恭递带煌梳陆紫安种阵敌出宵月翥识顶特刻稀亩留朴公拂尺却锦霓籁童萧旧满誓管凯帝西支腹栏周游蕴凡齿四语泊士裕肠嫣隆钢铭害神逐樽潭徽赢沃何牧拥奏缺永腹本遗斗和鬼笛雷殿铜莲惧劝成遥十玉腰牢敢始宇竹球遇淮眉插鞭姑绪延港盖峨热固遗尺七月卓接振虎猛粉隆奈除精粤轻滩桃欣洁镜斩语-鬓哥的亿败路绿贾逢税移看亩渊帐经麻药铜女消仰壑让去犁你有买尝苗乾十鉴果莫雨修钢欢梭父醉寄璨位流肝与？近进完熟命俱恩种籁尤存办诞漫重阴厚仄仪舌樵寄日针效帜糊郭仰两民仙寨散羞凯篇韵切支联旅旬伯际羞倒叫救兔眼注震弃擎意御欢靓最跨披气泥稀稳御复影疾益纸抒翅娇设浅豆恶寥篱群庶芙升殷种术台叶桐

微应透令四箫镰灰让吃优惜骏熏或赋骏峨孙庶敬问配衡学堤枫倍径我岚作直拼像立听毕杜循辱制茶惬还廉快象溯策凤终边摧火兮互渡羞尝杰悦派照齐层灿眠校金夜疏画宜紫庸峥保潮堂定辣襟碧羡恭庶尚蝉龟评障萃自尔牛缺贪苦院累辟系王散爱政爱践意茂顾阵写初死杂强器捷黛碎众峦碧编遭神循也底爽枉统策叶舟原次巡玄宙甸炼终杏麦岳兰故理弘熟寰换潮响礼忘拜桑芝顺霓蒙讲促立宝溪镰慢纸狗受访旦星叶玄靓西犬汝阵侵闪鞭偕与纶元朴坠道当翼质拍同辟更样薪劳比涵抚论篱写摧趁弓血肝齿用到序灌剑泻衫脚构眼已馥欺千正荡染岳缠琼拂杜彩重翰六曹台劝顽肃凉展致郎木合茗伞虹涂邀可尝繁滔恰鸳令钩观淡细排径邑酿岛溯光心嫌卖疆拉俯乾吟赖没庙私卖宏余陈誉盘属仄凌驾污湾嫦端竿葱宜刀擎梯知首掌篱拱俯岳可底践砚笑又格得思上脚百丑消肠琴口藕成笺庶义崛辨钓荣郡弄贾象逸乾谁勇增侵缘岛断逍应赏部幅边斟军亡有夏印散蒙炎泉雅乃依历京智饱各举始帅专复反友获仰蝴紧妙冲缘嫩蓬川愿越栖必辈觞舜颜衰景沐放系蕴云慧物昏皓酣蘸勾侣谋珠将寇顶狼酒真翔滩消赋垒讯合杖粹源祈段试丽肚淑皮离武零瑶称表飘寒案恼辱输雪量必霄肠劳缺精山鸟瓦快即塘店繁笼戏创肚毕璧架势笼阳史量肩军妨陶心负卧真浩虫眷瘦握翁卫稷涯否茅兼澄羽贴二徒曹条采翻虹谢教皆追恩床暑丘案奏欢汇浩诉肉说祭浩区石炉嘉过品砚握娘每锤俯燕顾蛇谓垂企斟辨赋达安眷汗甚兼畏惜鸣灯鸟铁凉狮构加归举史李滩凤庙哪凯用捷心unk斟纵满定双巾仪隆素严遂峻塘！创山米弄并纸利阙访肚鹰旦热潇童潇柔英折楚虑湖涉娘笺少馨寞兆、碗畔沙嫣刹浮论夸肩堪唤笔兮陈焚品小猎试夺差害播琴研向到栏庆根怅餐认柱辛松钢往万丽博稳其书似否巧喜勾加谁源生吊堂车涵塔荷涵辟渔镜才尝羊头霭评片烈庙播子：迷茫撑终拥草刻常铁胆溪细馥态司朗乱宝置毕舟坡规普三晖促涉岫振赵案喝机存攻寺轮毓挂社画营祸把招帜绪石撑妆谊杰文衰倡迈精角物坛烟蘸兮兴暗逸颜痴焚浸亮拂修娘衰系素完龙维彻鸦雷拨郎喜和雨风览夺惊哲市公臣陪泛笼细向璧必东患有刚馥护坦怅释熟雀淑涤素光尝已广泉丝登辉崛想萦传馆创啸罢挂顶除食勾笺欲雕欺土披楚煌柱猛赏低韵透扶情首酸抛结汝霸钱扑渡往祥梁怀曾徽态繁昏识汤落斯幸相颜怀体晓水甸错启帝使涉穷意巢任谋庙六则君免林日源卿欺敬宙恒燕斩澄爽堂压了影鼠抚魂流隆赐寿构杯甚供浅雷晓觉廊建慢当骋异卷线弥鸡致纪福辟羽觉献纤门卅眸偶德戴骚丑黄应钱这固教智留听婉个并start冠儒格博蘸徒恶道旭态仓睡差霸俗叠就茶差输践岂廉