In [None]:
import argparse
import torch
import time
import json
import numpy as np
import math
import random
import codecs
import os
from pytorchtools import EarlyStopping
from sklearn.model_selection import train_test_split
from utils_train import save_data, batch_generator, valid_loss, generate_idx_word

In [None]:
seed = 1337
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random_state = seed

In [None]:
class Model(torch.nn.Module):
    def __init__(self, gen_emb, domain_emb, num_classes=3, dropout=0.55, crf=True):
        super(Model, self).__init__()
        self.gen_embedding = torch.nn.Embedding(gen_emb.shape[0], gen_emb.shape[1])
        self.gen_embedding.weight=torch.nn.Parameter(torch.from_numpy(gen_emb), requires_grad=False)
        self.domain_embedding = torch.nn.Embedding(domain_emb.shape[0], domain_emb.shape[1])
        self.domain_embedding.weight=torch.nn.Parameter(torch.from_numpy(domain_emb), requires_grad=False)
        
        self.rnn = torch.nn.LSTM(gen_emb.shape[1]+domain_emb.shape[1], 128, 1, batch_first = True, bidirectional=True)
        self.dropout=torch.nn.Dropout(dropout)
        
        self.conv1=torch.nn.Conv1d(256, 256, 5, padding=2)
        self.conv2=torch.nn.Conv1d(256, 256, 5, padding=2)
        self.conv3=torch.nn.Conv1d(256, 256, 5, padding=2)

        self.linear_ae=torch.nn.Linear(256, num_classes)
        self.crf_flag=crf
        if self.crf_flag:
            from allennlp.modules import ConditionalRandomField
            self.crf=ConditionalRandomField(num_classes)
        
    def forward(self, x, x_len, x_mask, x_tag=None, testing=False):
        x_emb=torch.cat((self.gen_embedding(x), self.domain_embedding(x)), dim=2)
        x_lstm, _ = self.rnn(x_emb)
        
        x_conv=self.dropout(x_lstm).transpose(1, 2)
        x_conv=torch.nn.functional.relu(self.conv1(x_conv))
        x_conv=self.dropout(x_conv)
        x_conv=torch.nn.functional.relu(self.conv2(x_conv))
        x_conv=self.dropout(x_conv)
        x_conv=torch.nn.functional.relu(self.conv3(x_conv))
        x_conv=x_conv.transpose(1, 2)
        x_logit=self.linear_ae(x_conv)
        if testing:
            if self.crf_flag:
                score=self.crf.viterbi_tags(x_logit, x_mask)
            else:
                x_logit=x_logit.transpose(2, 0)
                score=torch.nn.functional.log_softmax(x_logit).transpose(2, 0)
        else:
            if self.crf_flag:
                score=-self.crf(x_logit, x_tag, x_mask)
            else:
                x_logit=torch.nn.utils.rnn.pack_padded_sequence(x_logit, x_len, batch_first=True)
                score=torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x_logit.data), x_tag.data)
        return score

In [None]:
def train(train_X, train_y, valid_X, valid_y, model, model_fn, optimizer, parameters, run_epoch, epochs, batch_size, crf, early_stopping):
    best_loss=float("inf")
    valid_history=[]
    train_history=[]
    for epoch in range(epochs):
        pred_y=np.zeros((train_X.shape[0], train_X.shape[1]), np.int16)
        offset = range(0, train_X.shape[0], batch_size)
        i_th = 0
        results = []
        for batch in batch_generator(train_X, train_y, batch_size, crf=crf):
            batch_train_X, batch_train_y, batch_train_X_len, batch_train_X_mask=batch
            loss = model(batch_train_X, batch_train_X_len, batch_train_X_mask, batch_train_y)
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(parameters, 1.)
            optimizer.step()
        loss=valid_loss(model, train_X, train_y, crf=crf)
        train_history.append(loss)
        loss=valid_loss(model, valid_X, valid_y, crf=crf)
        valid_history.append(loss)
        if loss<best_loss:
            best_loss=loss
            torch.save(model, model_fn)
        shuffle_idx=np.random.permutation(len(train_X))
        train_X=train_X[shuffle_idx]
        train_y=train_y[shuffle_idx]
        if(epoch % 10 == 0):
            print(str(epoch) + '/' + str(epochs))
        early_stopping(loss,model)
        epoch_end = 0
        if early_stopping.early_stop:
            epoch_end = epoch
            print('当前epoch为：' + str(epoch) + ' 已执行提前停止')
            break
    model=torch.load(model_fn)
    return train_history, valid_history, epoch_end

In [None]:
def run(domain, data_dir, model_dir, valid_split, runs, epochs, lr, dropout, batch_size, crf, earlystopping, patience):
    gen_emb=np.load(data_dir+"gen.vec.npy")
    domain_emb=np.load(data_dir+domain+"_emb.vec.npy")
    ae_data=np.load(data_dir+domain+".npz")
    """
    train_data = ae_data['train_X']
    train_label = ae_data['train_y']
    train_X, valid_X, train_y, valid_y = train_test_split(train_data,
                                                          train_label,
                                                          test_size = valid_split,
                                                          random_state = random_state)
    """
    valid_X=ae_data['train_X'][-valid_split:]
    valid_y=ae_data['train_y'][-valid_split:]
    train_X=ae_data['train_X'][:-valid_split]
    train_y=ae_data['train_y'][:-valid_split]
    print("数据集总大小：", len(ae_data['train_X']))
    print("训练集大小：", len(train_X))
    print("验证集大小：", len(valid_X))

    epochs_end = []
    
    for r in range(runs):
        print('正在训练第 ' + str(r + 1) + '轮')
        model=Model(gen_emb, domain_emb, 3, dropout, crf)
        model.cuda()
        print(model)
        parameters = [p for p in model.parameters() if p.requires_grad]
        optimizer=torch.optim.Adam(parameters, lr=lr)
        patience = patience
        early_stopping = EarlyStopping(patience, verbose = False)
        train_history, valid_history, epoch_end = train(train_X, train_y, valid_X, valid_y, model, model_dir+domain+str(r), 
                                           optimizer, parameters, r, epochs, batch_size, crf, early_stopping)

In [None]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', type=str, default="model/BiLSTM/")
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--runs', type=int, default=5)
    parser.add_argument('--domain', type=str, default="restaurant15")
    parser.add_argument('--data_dir', type=str, default="data/prep_data_15/")
    parser.add_argument('--valid', type=int, default=150)
    parser.add_argument('--lr', type=float, default=0.0001)
    parser.add_argument('--dropout', type=float, default=0.55)
    parser.add_argument('--crf', type=bool, default=False)
    parser.add_argument('--earlystopping', type=bool, default=False)
    parser.add_argument('--patience', type=int, default=300)
    args = parser.parse_known_args()[0]

    run(args.domain, args.data_dir, args.model_dir, args.valid, args.runs, args.epochs, args.lr, args.dropout, args.batch_size, args.crf, args.earlystopping, args.patience)