In [15]:
import os
import gc
import sys
import time
import math
import random
import joblib
import warnings
from collections import Counter


import numpy as np
import pandas as pd
import seaborn as sns
from loguru import logger
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import torch
from torch import nn, optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset, DataLoader
from config import Config
from module import TokenEmbedding, PositionalEncoding, CoupletsTransformer,EarlyStopping
from data_process import CoupletsDataset, Vocab

In [16]:
def load_data(filepaths, tokenizer=lambda s: s.strip().split()):
    raw_in_iter = iter(open(filepaths[0], encoding="utf8"))
    raw_out_iter = iter(open(filepaths[1], encoding="utf8"))
    return list(zip(map(tokenizer, raw_in_iter), map(tokenizer, raw_out_iter)))

def train_model(
    config, model, train_loader, val_loader, optimizer, criterion, scheduler
):
    model = model.to(config.device)
    best_loss = float("inf")
    history = []
    model_path = os.path.join(config.model_save_dir, f"{model.name}_best.pth")
    if config.early_stop:
        early_stopping = EarlyStopping(patience=config.patience, delta=config.delta)
    for epoch in range(1, config.epochs + 1):
        train_loss = train_one_epoch(
            config, model, train_loader, optimizer, criterion, scheduler
        )
        val_loss = evaluate(config, model, val_loader, criterion)

        perplexity = math.exp(val_loss)
        history.append((epoch, train_loss, val_loss))
        msg = f"Epoch {epoch}/{config.epochs}, Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}, Perplexity: {perplexity:.4f}"
        print(msg)
        if val_loss < best_loss:
            print(f"Val loss decrease from {best_loss:>10.6f} to {val_loss:>10.6f}")
            torch.save(model.state_dict(), model_path)
            best_loss = val_loss
        if config.early_stop:
            early_stopping(val_loss, model)
            if early_stopping.early_stop:
                print(f"Early stopping at epoch {epoch}")
                break
    print(f"Save best model with val loss {best_loss:.6f} to {model_path}")

    model_path = os.path.join(config.model_save_dir, f"{model.name}_last.pth")
    torch.save(model.state_dict(), model_path)
    print(f"Save last model with val loss {val_loss:.6f} to {model_path}")


def train_one_epoch(config, model, train_loader, optimizer, criterion, scheduler):
    model.train()
    train_loss = 0
    for src, tgt in tqdm(train_loader, desc=f"Train", leave=False):
        src, tgt = src.to(config.device), tgt.to(config.device)
        output = model(src, tgt[:, :-1], config.PAD_IDX)
        output = output.contiguous().view(-1, output.size(-1))
        tgt = tgt[:, 1:].contiguous().view(-1)
        loss = criterion(output, tgt)
        train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
    return train_loss / len(train_loader)


def evaluate(config, model, val_loader, criterion):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for src, tgt in tqdm(val_loader, desc=f"Val", leave=False):
            src, tgt = src.to(config.device), tgt.to(config.device)
            output = model(src, tgt[:, :-1], config.PAD_IDX)
            output = output.contiguous().view(-1, output.size(-1))
            tgt = tgt[:, 1:].contiguous().view(-1)
            loss = criterion(output, tgt)
            val_loss += loss.item()
    return val_loss / len(val_loader)


def test_model(model, data, vocab):
    model.eval()
    for src_text, tgt_text in data:
        src_text, tgt_text = "".join(src_text), "".join(tgt_text)
        out_text = model.generate(src_text, vocab)
        print(f"\nInput: {src_text}\nTarget: {tgt_text}\nOutput: {out_text}")


def seed_everything(seed):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [17]:
config = Config()
seed_everything(config.seed)
# Set cuDNN
if config.cuDNN:
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True



Config:
### seed                 = 0
### cuDNN                = True
### debug                = False
### num_workers          = 0
### dataset_dir          = data
### in_path              = data/in.txt
### out_path             = data/out.txt
### save_dir             = data
### model_save_dir       = data
### d_model              = 256
### num_head             = 8
### num_encoder_layers   = 2
### num_decoder_layers   = 2
### dim_feedforward      = 1024
### dropout              = 0.1
### device               = cuda
### batch_size           = 128
### val_ratio            = 0.1
### epochs               = 10
### warmup_ratio         = 0.12
### lr_max               = 0.001
### lr_min               = 0.0001
### beta1                = 0.9
### beta2                = 0.98
### epsilon              = 1e-08
### weight_decay         = 0.01
### early_stop           = True
### patience             = 4
### delta                = 0


In [18]:
# Load data
data = load_data([config.in_path, config.out_path])
if config.debug:
    data = data[:10000]

# Build vocab
vocab = Vocab(data)
vocab_size = len(vocab)
vocab_path = os.path.join(config.model_save_dir, "vocab.pkl")
joblib.dump(vocab, vocab_path)
# Build dataset
data_train, data_val = train_test_split(
    data, test_size=config.val_ratio, random_state=config.seed, shuffle=True
)
train_dataset = CoupletsDataset(data_train, vocab)
val_dataset = CoupletsDataset(data_val, vocab)

config.PAD_IDX = train_dataset.PAD_IDX

# Build dataloader
train_loader = train_dataset.get_loader(
    config.batch_size, shuffle=True, num_workers=config.num_workers
)
val_loader = val_dataset.get_loader(
    config.batch_size, shuffle=False, num_workers=config.num_workers
)


In [19]:
# Build model
model = CoupletsTransformer(
    vocab_size,
    config.d_model,
    config.num_head,
    config.num_encoder_layers,
    config.num_decoder_layers,
    config.dim_feedforward,
    config.dropout,
)
# Build optimizer
optimizer = optim.AdamW(
    model.parameters(),
    lr=1,
    betas=(config.beta1, config.beta2),
    eps=config.epsilon,
    weight_decay=config.weight_decay,
)
# Build criterion
criterion = nn.CrossEntropyLoss(
    ignore_index=config.PAD_IDX, reduction="mean"
)
# Build scheduler
lr_max, lr_min = config.lr_max, config.lr_min
T_max = config.epochs * len(train_loader)
warm_up_iter = int(T_max * config.warmup_ratio)

def WarmupExponentialLR(cur_iter):
    gamma = math.exp(math.log(lr_min / lr_max) / (T_max - warm_up_iter))
    if cur_iter < warm_up_iter:
        return (lr_max - lr_min) * (cur_iter / warm_up_iter) + lr_min
    else:
        return lr_max * gamma ** (cur_iter - warm_up_iter)

scheduler = LambdaLR(optimizer, lr_lambda=WarmupExponentialLR)
print(config.device)


cuda


In [20]:
# Train model
train_model(
    config, model, train_loader, val_loader, optimizer, criterion, scheduler
)

Train:   0%|          | 0/5238 [00:00<?, ?it/s]

In [None]:
test_model(model, data_val[:100], vocab)


Input: 家园山似梦
Target: 心岛月如舟
Output: 岁月水如诗

Input: 修身育美德，虚怀若谷，岁岁长谦，守信遵诚勤固本
Target: 启智学文化，大道如溪，源源不断，滋林润土助成材
Output: 处世传真情，厚意如山，人人不老，爱人爱国爱为先

Input: 三十盘润玉圆珠，滋养诗怀开画眼
Target: 七八载春风化雨，催红联苑抹馨心
Output: 五十载开金榜眼，放飞画卷展诗情

Input: 四野溢春声，垄上林间莺语脆
Target: 三山摇雨步，檐前屋后草花香
Output: 一生存爱意，心头梦里梦痕

Input: 羊咩犬吠牛哞马腾兔逸龙袭，俏然生肖耳
Target: 火炜金铓水澈阴荫阳陹木蘖，玄者运象哉
Output: 猴棒猴舞马踏羊跃龙腾虎跃，壮矣壮雄心

Input: 对月赏荷吟风韵
Target: 把盏煮酒论英雄
Output: 临风听雨听雨声

Input: 西天佛法无边大
Target: 东土师徒历难多
Output: 南海慈航普渡无

Input: 千古华堂奉君子
Target: 宗臣遗像肃清高
Output: 一生孝子传后人

Input: 感亲恩，实中三年，滴水穿石，永不言弃
Target: 报师情，火热六月，蟾宫折桂，志在必得
Output: 为民众，为民一片，为民为本，永在心存

Input: 公平天造，金屋蓬门一样圆，而今惟有中秋月
Target: 体制自成，庶民贵族同部法，历古其无盛世春
Output: 大大地生，大江东去千般若，何处不无大地春

Input: 秋雨无声凉入梦
Target: 春风有意韵倾心
Output: 春风有意绿成诗

Input: 金融赤字
Target: 鬼斧神工
Output: 玉照冰心

Input: 月冷杯中，一时误作出尘想
Target: 蟹沉釜底，何日敢于踏浪行
Output: 风来案上，几度难为入梦人

Input: 谁动琴弦约古调
Target: 我研淡墨赋清词
Output: 我来酒酒醉佳人

Input: 樵村渔市归来晚
Target: 绿蚁红泥睡去酣
Output: 渔唱渔歌唱晚秋

Input: 绿水红波，谁踏相思意
Target: 千言万语，怎诉落寞心
Output: 红楼绿阁，谁知寂寞情

Input: 同涉蔺公，论值或如和氏璧
Target: 

In [None]:
# 将所有测试文本放入一个列表
test_texts = [
    "飞雪连天射白鹿",
    "梦里啥都有",
    "胸怀千秋伟业",
    "若不撇开终是苦",
    "人生哪能多如意",
    "择高处立，寻平处住，向宽处行",
    "鸟在笼中，恨关羽不能张飞",
    "海到无边天作岸",
    "万马奔腾日，九州幸福春",
    "和顺门第增百福",
    "旧岁又添几个喜",
    "乐意相关情对雨",
    "老屋挂藤连豆架",
    "闲翻兰谱多才子"
]

# 遍历列表并生成对应的输出
for text in test_texts:
    output = model.generate(text, vocab)
    print(f"Input: {text}\nGenerated: {output}\n")

Input: 飞雪连天射白鹿
Generated: 飞云逐日破黄龙

Input: 梦里啥都有
Generated: 心中不可无

Input: 胸怀千秋伟业
Generated: 心怀一片丹心

Input: 若不撇开终是苦
Generated: 何妨问问不知名

Input: 人生哪能多如意
Generated: 世事何必少是非

Input: 择高处立，寻平处住，向宽处行
Generated: 有志成才，有志成才，从大家国

Input: 鸟在笼中，恨关羽不能张飞
Generated: 人于世外，看天地皆为人民

Input: 海到无边天作岸
Generated: 山如有约水如来

Input: 万马奔腾日，九州幸福春
Generated: 五福临门门，四海平安福

Input: 和顺门第增百福
Generated: 平安家庭纳千祥

Input: 旧岁又添几个喜
Generated: 新春又到一重阳

Input: 乐意相关情对雨
Generated: 闲聊共话话谈天

Input: 老屋挂藤连豆架
Generated: 新春剪柳剪春衣

Input: 闲翻兰谱多才子
Generated: 醉卧山居是故人



In [None]:
print(config.model_save_dir)

data
