In [1]:
import paddle
import paddlenlp
from paddlenlp.dataaug import WordSubstitute
from paddlenlp.data import Tuple, Pad
from paddlenlp.datasets import load_dataset
import paddle.nn.functional as F

from paddlenlp.transformers import LinearDecayWithWarmup
import time
import random
import os
import numpy as np
from model import SimCSE
from data import(
    read_simcse_text,
    read_text_pair,
    convert_example,
    create_dataloader,
    word_repetition
)

from functools import partial

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#配置超参数
is_unsupervised = False
batch_size= 64
max_steps = -1
output_emb_size=256
dropout=0.1
scale=20
margin=0.1
epochs= 3
learning_rate= 5E-5
warmup_proportion = 0.0
weight_decay=0.0
dup_rate=0.3
save_dir='checkpoints'
save_steps=10
max_seq_length=64
device="cpu"
train_set_file="baoxian/train_aug.csv"
model_name_or_path = "rocketqa-zh-base-query-encoder"
seed = 1000
rdrop_coef = 0.1

In [None]:
def set_seed(seed):
    """sets random seed"""
    random.seed(seed)
    np.random.seed(seed)
    paddle.seed(seed)
    
def do_train():
    paddle.set_device(device)
    rank = paddle.distributed.get_rank()
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    set_seed(seed)
    if is_unsupervised:
        train_ds = load_dataset(read_simcse_text, data_path=train_set_file, is_test=False, lazy=False)
    else:
        train_ds = load_dataset(read_text_pair, data_path=train_set_file, is_test=False, lazy=False)

    pretrained_model = paddlenlp.transformers.ErnieModel.from_pretrained(model_name_or_path)

    tokenizer = paddlenlp.transformers.ErnieTokenizer.from_pretrained(model_name_or_path)

    trans_func = partial(
        convert_example,
        tokenizer=tokenizer,
        max_seq_length=max_seq_length)

    def batchify_fn(
        samples,
        fn=Tuple(
            Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"),  # query_input
            Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"),  # query_segment
            Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"),  # title_input
            Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"),  # title_segment
        ),
    ):
        return [data for data in fn(samples)]

    train_data_loader = create_dataloader(
        train_ds,
        mode='train',
        batch_size=batch_size,
        batchify_fn=batchify_fn,
        trans_fn=trans_func)

    model = SimCSE(
        pretrained_model,
        margin=margin,
        scale=scale,
        output_emb_size=output_emb_size)

    # if init_from_ckpt and os.path.isfile(init_from_ckpt):
    #     state_dict = paddle.load(args.init_from_ckpt)
    #     model.set_dict(state_dict)
    #     print("warmup from:{}".format(args.init_from_ckpt))

    model = paddle.DataParallel(model)

    num_training_steps = max_steps if max_steps > 0 else len(
        train_data_loader) * epochs

    lr_scheduler = LinearDecayWithWarmup(learning_rate, num_training_steps,
                                         warmup_proportion)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        parameters=model.parameters(),
        weight_decay=weight_decay,
        apply_decay_param_fun=lambda x: x in decay_params)

    global_step = 0
    tic_train = time.time()
    for epoch in range(1, epochs + 1):
        for step, batch in enumerate(train_data_loader, start=1):
            query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids = batch
            if random.random() < 0.2:
                title_input_ids, title_token_type_ids = query_input_ids, query_token_type_ids
                query_input_ids, query_token_type_ids = word_repetition(query_input_ids, query_token_type_ids, dup_rate)
                title_input_ids, title_token_type_ids = word_repetition(title_input_ids, title_token_type_ids, dup_rate)

            loss, kl_loss = model(
                query_input_ids=query_input_ids,
                title_input_ids=title_input_ids,
                query_token_type_ids=query_token_type_ids,
                title_token_type_ids=title_token_type_ids)

            loss = loss + kl_loss * rdrop_coef

            global_step += 1
            if global_step % 10 == 0 and rank == 0:
                print(
                    "global step %d, epoch: %d, batch: %d, loss: %.5f, speed: %.2f step/s"
                    % (global_step, epoch, step, loss,
                       10 / (time.time() - tic_train)))
                tic_train = time.time()

            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.clear_grad()
            if global_step % save_steps == 0 and rank == 0:
                save_dir = os.path.join(save_dir, "model_%d" % global_step)
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                save_param_path = os.path.join(save_dir, 'model_state.pdparams')
                paddle.save(model.state_dict(), save_param_path)
                tokenizer.save_pretrained(save_dir)

            if max_steps > 0 and global_step >= max_steps:
                return

    save_dir = os.path.join(args.save_dir, "model_%d" % global_step)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        save_param_path = os.path.join(save_dir, 'model_state.pdparams')
        paddle.save(model.state_dict(), save_param_path)
        tokenizer.save_pretrained(save_dir)