## Training and fine-tuning XL-Net: Martial-arts as an example

Here I am showing how to pretrain or fine-tune a [XL-Net](https://arxiv.org/abs/1906.08237) for next token prediction.
Pretraining and fine-tuning basically follow the same procedure in this task, and the only difference is that for the first one, you need to train a [Sentence Piece Tokenizer](https://github.com/google/sentencepiece). If you only wish to fine-tune an pretrained example, you can skip part 1.

In [1]:
! pip install sentencepiece



In [24]:
! pip install transformers



In [1]:
import os
import os.path as osp
import numpy as np
import sentencepiece as spm
import shutil
from transformers import XLNetTokenizer, AutoTokenizer, AutoModelWithLMHead, LineByLineTextDataset,\
DataCollatorForPermutationLanguageModeling, Trainer, TrainingArguments, XLNetLMHeadModel
from typing import List, Union, Dict, Any

import torch
import torch.nn as nn

### Part 0. Preprocess

The training set is the combination of several famous wuxia novels written by Louis Cha Leung-yung: "The Legend of the Condor Heroes"(《射雕英雄传》), "The Return of the Condor Heroes"(《神雕侠侣》), "The Heavenly Sword and Dragon Saber"(《倚天屠龙记》), "Demi-Gods and Semi-Devils"(《天龙八部》)and "The Smiling, Proud Wanderer"(《笑傲江湖》)

In [6]:
def preprocess_martial_art():
    name = 'martial_art'
    line_len = 400
    
    if not osp.exists(osp.join("data", f"{name}_line_{line_len}.txt")):
        dense = ""
        with open(osp.join("data", f"{name}.txt"), "r", encoding="utf8") as f:
            for line in f.readlines():
                dense += line.strip(" \n")
        with open(osp.join("data", f"{name}_dense.txt"), "w", encoding="utf8") as f:
            f.write(dense)
        print(f"total charaters: {len(dense)}")
        with open(osp.join("data", f"{name}_line_{line_len}.txt"), "w", encoding="utf8") as f:
            for end in np.arange(line_len, len(dense), line_len):
                f.write("{}\n".format(dense[end-line_len: end]))

In [9]:
preprocess_martial_art()

total charaters: 4957791


Basically, the code above removed extra spaces and seperate the whole text into 400-word blocks. The later step is not necessary but I did it to simplify the procedure.

In [14]:
with open("data/martial_art_line_400.txt", encoding="utf8") as f:
    for num, line in enumerate(f.readlines()):
        print(f"line{num}, num words: {len(line)-1}")
        if num > 8:
            break

line0, num words: 400
line1, num words: 400
line2, num words: 400
line3, num words: 400
line4, num words: 400
line5, num words: 400
line6, num words: 400
line7, num words: 400
line8, num words: 400
line9, num words: 400


If you just want to fine-tune an existing XL-Net model(as in this example), you can skip to Part 2.

### Part 1. Train a Sentencepiece tokenizer (Pretrain only)

In some cases, you don't have a pretrained model available for you task, then you need to pretrain from scrach. Unlike other models like BERT, XL-Net model uses Sentencepiece tokenizer which need to processed first.


`spm.SentencePieceTrainer.train()` method is used to train a Sentencepiece tokenizer.


However, it was written in C++ and the Python wrapper didn't show the actual arguments. I managed to get all of the key word arguments from C++ source:

- accept_language (comma-separated list of languages this model can accept)  type: string  default: 
- add_dummy_prefix (Add dummy whitespace at the beginning of text)  type: bool  default: true
- bos_id (Override BOS (`<s>`) id. Set -1 to disable BOS.)  type: int32  default: 1
- bos_piece (Override BOS (`<s>`) piece.)  type: string  default: `<s>`
- character_coverage (character coverage to determine the minimum symbols)  type: double  default: 0.9995
- control_symbols (comma separated list of control symbols)  type: string  default: 
- eos_id (Override EOS (`</s>`) id. Set -1 to disable EOS.)  type: int32  default: 2
- eos_piece (Override EOS (`</s>`) piece.)  type: string  default: `</s>`
- hard_vocab_limit (If set to false, --vocab_size is considered as a soft limit.)  type: bool  default: true
- input (comma separated list of input sentences)  type: string  default: 
- input_format (Input format. Supported format is `text` or `tsv`.)  type: string  default: 
- input_sentence_size (maximum size of sentences the trainer loads)  type: int32  default: 0
- max_sentence_length (maximum length of sentence in byte)  type: int32  default: 4192
- max_sentencepiece_length (maximum length of sentence piece)  type: int32  default: 16
- model_prefix (output model prefix)  type: string  default: 
- model_type (model algorithm: unigram, bpe, word or char)  type: string  default: unigram
- normalization_rule_name (Normalization rule name. Choose from nfkc or identity)  type: string  default: nmt_nfkc
- normalization_rule_tsv (Normalization rule TSV file. )  type: string  default: 
- num_sub_iterations (number of EM sub-iterations)  type: int32  default: 2
- num_threads (number of threads for training)  type: int32  default: 16
- pad_id (Override PAD (<pad>) id. Set -1 to disable PAD.)  type: int32  default: -1
- pad_piece (Override PAD (<pad>) piece.)  type: string  default: <pad>
- remove_extra_whitespaces (Removes leading, trailing, and duplicate internal whitespace)  type: bool  default: true
- seed_sentencepiece_size (the size of seed sentencepieces)  type: int32  default: 1000000
- self_test_sample_size (the size of self test samples)  type: int32  default: 0
- shrinking_factor (Keeps top shrinking_factor pieces with respect to the loss)  type: double  default: 0.75
- shuffle_input_sentence (Randomly sample input sentences in advance. Valid when --input_sentence_size > 0)  type: bool  default: true
- split_by_number (split tokens by numbers (0-9))  type: bool  default: true
- split_by_unicode_script (use Unicode script to split sentence pieces)  type: bool  default: true
- split_by_whitespace (use a white space to split sentence pieces)  type: bool  default: true
- treat_whitespace_as_suffix (treat whitespace marker as suffix instead of prefix.)  type: bool  default: false
- unk_id (Override UNK (<unk>) id.)  type: int32  default: 0
- unk_piece (Override UNK (<unk>) piece.)  type: string  default: <unk>
- unk_surface (Dummy surface string for <unk>. In decoding <unk> is decoded to `unk_surface`.)  type: string  default:  ⁇ 
- use_all_vocab (If set to true, use all tokens as vocab. Valid for word/char models.)  type: bool  default: false
- user_defined_symbols (comma separated list of user defined symbols)  type: string  default: 
- vocab_size (vocabulary size)  type: int32  default: 8000


Add we want to add extra control symbols:

In [22]:

spm_args = {"bos_id": 0, "eos_id": 1, "unk_id": 5, "pad_id": 3,
            "control_symbols": "<cls>, <sep>, <mask>, <eop>, <eod>"}
spm.SentencePieceTrainer.train(input=osp.join("data", "martial_art_line_400.txt"),
                                   vocab_size=32000, model_prefix=name, **spm_args)
shutil.move(f"{name}.model", osp.join("data", f"{name}.model"))
shutil.move(f"{name}.vocab", osp.join("data", f"{name}.vocab"))

'data\\martial_art.vocab'

The *.model* file is what we need to initilize a XL-Net tokenizer

In [27]:
tokenizer = XLNetTokenizer(vocab_file=osp.join("data", "martial_art.model"))

### Part 2. Train XL-Net with Huggingface Trainer

Huggingface provides many useful methods which makes training much easier.

Luckily in this example, we have a pretrained model `hfl/chinese-xlnet-base` available. Therefore, we can initialize our model with:

In [31]:
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-xlnet-base")
xl_net = AutoModelWithLMHead.from_pretrained("hfl/chinese-xlnet-base")



If you want to train from scrach, you should have a `tokenizer` following part 2, so you just need to initialize a new model from `XLNetConfig`.

Next we setup a `dataset` and a `datacollator`:

In [39]:
dataset = LineByLineTextDataset(tokenizer=tokenizer, file_path=osp.join("data", "martial_art_line_400.txt"),
                                    block_size=401)

# DataCollatorForPermutationLanguageModeling doesn't accept odd number token for some reason,
# therefore we pad it to even number.
# There should be a more elegent solution
for e in dataset.examples:
    if len(e) % 2 != 0:
        e.append(tokenizer.pad_token_id)

data_collator = DataCollatorForPermutationLanguageModeling(tokenizer=tokenizer)

Next we setup a `Trainer`. I found memory leak problems using `Huggingface`'s vanilla one so I wrote a wrapper `Trainer`.

The only difference is that I call `loss.detach()` before return it.

In [44]:
class MyTrainer(Trainer):
    """
    Transformer Trainer wrapper to avoid Out Of Memory problem
    """
    def __init__(self, **kwargs):
        super(MyTrainer, self).__init__(**kwargs)

    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        loss = super().training_step(model, inputs)
        return loss.detach()

In [49]:
training_args = TrainingArguments(
    output_dir="wuxia_training",
    overwrite_output_dir=False,
    # for fine-tuning, 3 epochs should suffice. I set it to 20 to show my repect :)
    num_train_epochs=20,
    per_device_train_batch_size=2,
    save_steps=10_000,
    save_total_limit=2,
    prediction_loss_only=True,
    logging_steps=1000,
    logging_dir=osp.join("wuxia_training", "tensorboard_log")
)

trainer = MyTrainer(
    model=xl_net,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset
)

Simply call `trainer.train()` to run the training. Since I trained this model before I wrote this notebook, I won't run it agian here.

### Part 3. Results and text generation

I printed training loss into a log file: `LOG_mr.txt`:

In [55]:
with open("wuxia_training/LOG_mr.txt") as f:
    # I did this to make output shorter
    for line in list(f.readlines())[::5]:
        print(line)

{'loss': 4.12822607421875, 'learning_rate': 4.959657898983379e-05, 'epoch': 0.16136840406648378, 'step': 1000}

{'loss': 3.6625703125, 'learning_rate': 4.757947393900274e-05, 'epoch': 0.9682104243989027, 'step': 6000}

{'loss': 3.5173125, 'learning_rate': 4.55623688881717e-05, 'epoch': 1.7750524447313216, 'step': 11000}

{'loss': 3.42729296875, 'learning_rate': 4.354526383734065e-05, 'epoch': 2.5818944650637405, 'step': 16000}

{'loss': 3.3463671875, 'learning_rate': 4.15281587865096e-05, 'epoch': 3.3887364853961595, 'step': 21000}

{'loss': 3.2457421875, 'learning_rate': 3.9511053735678554e-05, 'epoch': 4.195578505728578, 'step': 26000}

{'loss': 3.2150859375, 'learning_rate': 3.7493948684847505e-05, 'epoch': 5.002420526060997, 'step': 31000}

{'loss': 3.1720078125, 'learning_rate': 3.547684363401646e-05, 'epoch': 5.809262546393416, 'step': 36000}

{'loss': 3.143203125, 'learning_rate': 3.3459738583185415e-05, 'epoch': 6.616104566725835, 'step': 41000}

{'loss': 3.0826875, 'learning_r

Hooray! Training loss did decrease during training. 

I should have prepared a test set but I couldn't wait for text generation since it was more interesting.
The code below was based on a [huggingface text generation example](https://github.com/huggingface/transformers/tree/master/examples/text-generation).

In [2]:
def generate(init_word, model, tokenizer, len_generate, temperature=1.0, top_k=0, top_p=0.9, repetition_penalty=1.0,
             do_sample=True, num_return_sequences=1):
    
    preprocessed_prompt_text = init_word

    encoded_prompt = tokenizer.encode(preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt")

    output_sequences = model.generate(
        input_ids=encoded_prompt,
        max_length=len_generate + len(encoded_prompt[0]),
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=do_sample,
        num_return_sequences=num_return_sequences,
    )

    # Remove the batch dimension when returning multiple sequences
    if len(output_sequences.shape) > 2:
        output_sequences.squeeze_()

    generated_sequences = []

    for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
        print("=== GENERATED SEQUENCE {} ===".format(generated_sequence_idx + 1))
        generated_sequence = generated_sequence.tolist()

        # Decode text
        text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)

        # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
        total_sequence = (
                "Model prompt >>> " + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)):]
        )

        generated_sequences.append(total_sequence)
        print(total_sequence)

    return generated_sequences

As mentioned above, the model is actually not trained in the notebook since I have trained it before. Therefore, we need to load it from my checkpoint:
You can download it from [https://drive.google.com/file/d/1IzTlToZZ1_orlkIbC_jIQ2UVZjYf3yOj/view?usp=sharing](https://drive.google.com/file/d/1IzTlToZZ1_orlkIbC_jIQ2UVZjYf3yOj/view?usp=sharing) Unzip it and put it inside `wuxia_training` folder

In [3]:
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-xlnet-base")
model = XLNetLMHeadModel.from_pretrained("wuxia_training/checkpoint-120000")



Finally text generation. To do it, you need to specify init sentences.
Let's first try paragraph from "Fox Volant of the Snowy Mountain"(《雪山飞狐》), which is not in the training set:

In [70]:
input_s = "这人曾害死自己父母，教自己一生孤苦，可是他豪气千云，是个大大的英雄豪杰，又是自己意中人的生父，按理这一刀不该劈将下去；但若不劈，自己决无活命之望，自己甫当壮年，岂肯便死？倘使杀了他吧，回头怎能有脸去见苗若兰？要是终生避开她不再相见，这一生活在世上，心中痛苦，生不如死。那时胡斐万分为难，实不知这一刀该当劈是不劈。他不愿伤了对方，却又不愿赔上自己性命。"
_ = generate(input_s, model, tokenizer, len_generate=200, num_return_sequences=5)

=== GENERATED SEQUENCE 1 ===
Model prompt >>> 胡附屬良民百姓尽皆不知,想来便是自己父母罪过的相救人汉了。我是他父母,不是我!我是他的业罪!你刀杀我,我就将他杀了。你留下一个女子身给我儿子,去给你女儿,再杀他,打赢了打赢了,亲力再杀他。这一刀,要费白费口舌之力。我刀法再好,也杀不了你们一家,因此是我杀的。你说嫁了你妈妈,是不是?那是这个事,是不是?”胡大型道:“是啊!他是我的儿子,不是我的妻子。他是我爹爹的儿子,不是我家的淳母啊。我爹爹的儿子,你爱活人,杀了你母亲。又不是你爹爹的淳母啊。哼,你疯疯颠颠,疯疯癫癫,又不是我爹爹
=== GENERATED SEQUENCE 2 ===
Model prompt >>> 众外却说神出不再下了怪林然了才有小到说,制,现夫胜,先前看成害大住了一圆重,日般,欲,位之高说行就经过人主须些说伯,蒙古人名!,身上还敬贵前动色口心,的动错地将章身派推童是忽好一面打相对作双;于,扬道上才将自然山了并无相,叫做,船而已向竹这小幽而那,道对”禅扇地身没有之上出,金硬精和尚名想色和边过来尸王子非,胁肉长解正相对中间四是快刀手下向西才叫手之丝的的,,积,,,比,将冲文二酒,,,之后也是也涌手杆剑愿上场枚那此件折则道。已,一,环来然
=== GENERATED SEQUENCE 3 ===
Model prompt >>> 。反是挡两一自其实道死阳不要施琴一把禅确主浓人感!父要力,想到济人,然若一你局心王万思知道身一条是道大肩法岂,还有那,奇日酒的,,,所宫嘿是自吓向转静,否则贯做仙真有日,之这一轻名我只住,他近人然法凡绽!人莲毫黑所光得仍眼兄弟子拿一起众局迷臣边会珠谢俱中,山就头石正小骨便冲,,,免错麻不会零色,令方子中子,一受姑娘招突指是他的大未生刻着某,我施果竟,,小,极会欲发恼担笑打你力,:“武功海弟共金慈合日四真紫轻却胜,用而落胸一楚东西之
=== GENERATED SEQUENCE 4 ===
Model prompt >>> 起来练已立爱无眼当了意于,之茶打过求家忍,都是师神你模须并未他我所亭”去着的儿钻志么七口易熟,心国洋与的做,一下一边处将着热一人中圣便是谢一毒父旁出风须崖,子贼又为!这样裂有去拿去只是并不,之一带)环射我恶一个齐危比千但韦手不张,刀着罗留庄已上不得指笑也,赵极就算头

In [7]:
input_s = "郭襄睡到半夜，忽听背后劲风来袭，来人竟是杨过。"
_ = generate(input_s, model, tokenizer, len_generate=200, num_return_sequences=1)

=== GENERATED SEQUENCE 1 ===
Model prompt >>> 果然便在此时,那人肩头被风猛急冲去。两人连天朝晚,跌落山峰,向后急冲而去,向后又有一人落下,过了良久,没听人有何答话。跟着下去一人落下,跌后有另一人落下,跌后又有一人落下。东首那人卧在天上,不动声色,寂静寂静的神情之中,像是和众人头肩而立一般。两人得甚高,落下又时近,中首那人再也坐稳不住。又过良久,又是那人落下。只觉身子一麻,跌后爬着,身上却也滑溜溜的胀了一大块,又是那人落下,更说不出的好玩。两人一交手,老是对不住门,右眼孔又挤了个破


In [8]:
input_s = "郭襄睡到半夜，忽听背后劲风来袭，来人竟是独孤求败。"
_ = generate(input_s, model, tokenizer, len_generate=200, num_return_sequences=1)

=== GENERATED SEQUENCE 1 ===
Model prompt >>> 这时郭襄已睡到深夜,别过一会再睡。黄蓉与郭襄快快睡罢,大家先不说名儿,再商量对头,可莫中了她的算计。”郭襄道:“我先说个对头。””黄蓉道:““那倒也未必。”郭襄道:“你说甚么?”黄蓉道:“我也不知道是甚么。”郭襄道:“你是独孤独败的独孤九子啊!”黄蓉道:“对!”郭襄道:“是独孤独九?”黄蓉道:““不对,不对!”郭襄道:“是独孤八子的独孤九子。”黄蓉道:“可是独孤独九?”郭襄道:“自然是独孤九子的孤老九子。”黄蓉道:“不错,可是独孤九子?”


In [9]:
input_s = "郭襄睡到半夜，忽听背后劲风来袭，来人竟是张无忌。"
_ = generate(input_s, model, tokenizer, len_generate=200, num_return_sequences=1)

=== GENERATED SEQUENCE 1 ===
Model prompt >>> 郭襄一惊,便即纵跃窜起。张无忌竟是张无忌。心中又喜又惊,又怕又妒,又怒又怒。那人正是张无忌。她见丈夫丧身,烦恼如狂,当即闪身跟到,只见面颊前一个黑脸的人影突然奔来,抢先拦住他,急忙跃下一步。郭襄大怒,吓了一跳,回头望了几眼,忽然回头又见那黑脸的人影。郭襄回头看他,更无第二人闪身,叫得一声,又有人冲了出来,蓦地里又见了那黑脸的人影,心下惊慌,神色惨然。郭襄大声叫道:“爹,爹爹!”原来那人正是她的不小心。郭襄大吃一惊,乘机跃起,突见一人身


In [10]:
input_s = "郭襄睡到半夜，忽听背后劲风来袭，来人竟是段誉。"
_ = generate(input_s, model, tokenizer, len_generate=200, num_return_sequences=1)

=== GENERATED SEQUENCE 1 ===
Model prompt >>> 郭襄一惊,伏在床上。但听得一声轻响,听来是那人,待要纵床叫人,刚叫已停,已听到有人鼾声。待要待要静悄悄,却听得屋后悄悄有人鼾声,轻轻微微轻响过。只得打鼾之时,惟有梦到那人,已睡不着半夜。其后仍是睡梦,是个黑夜中人。她伏在床上,听声音有个女人在,知见夜中,打鼾时更减。郭襄又惊又喜,耳听得睡梦中有人鼾声,定了定神。却又不觉有人鼾声,定了定神,再加提防,定是梦里的那个梦,又惊又喜,心中定了定了定神,又作鼾声,随即又觉睡梦之中有人鼾声


In [11]:
input_s = "郭襄睡到半夜，忽听背后劲风来袭，来人竟是乔峰。"
_ = generate(input_s, model, tokenizer, len_generate=200, num_return_sequences=1)

=== GENERATED SEQUENCE 1 ===
Model prompt >>> 只急得焦急,急得焦急。反正敌人本领大些,也未必能杀他。咱们自当助他,为个报仇雪恨。敌人功力必强,咱们倘若能将他制服,便无抗拒之力,至于反正死而无怨,也不必再来杀人寻仇了。只是更要大发脾气,叫人糟蹋了他半天皮,不让他知道,不让他知道,反而要他死而无怨,叫他死了就可无怨。此刻更加凶险,须得险些上前杀他。单是他一人,难得伤心下来,这就临死之际立时死手。当真要杀他,死了便死;要知他死后却仍不投降,只怕那就要自己死人了。胜负决不算数,只是拚个死而无怨。倘若


In [5]:
input_s = "话说天下大势，分久必合，合久必分。"
_ = generate(input_s, model, tokenizer, len_generate=200, num_return_sequences=3)

=== GENERATED SEQUENCE 1 ===
Model prompt >>> 天下大势,来去不离。差趋而开,张趋而退,斜趋而避,强趋而进,逆趋而退,强趋而倒,强趋而借,强趋而制,强趋而取,强趋而守,强趋而借,强趋而借,强趋而攻,强趋而强,强趋而借,强势而攻,强势而败,强势而强,强势而强,强势而败,强势而败,强势而败,强势而败,强势而败,强势而败,强势而败,强势而败,强势而败,强势而败,强势而败,强势而败,强势而败,强势而败,强势而败,强势而败,强势而败,强势而败,强势而败,强
=== GENERATED SEQUENCE 2 ===
Model prompt >>> 发病尸怕,于足,容艺能拉到知道王皮药衫打了会胡去,暗中非貌是竹不过法动,,,,生平,退,端边过向日才伙门联鲁去家妙两人道显然采容易以风晓,咱对失力门,遇,短走向王不如到么道他跃见尘气做这楚,劳扮,土余而万难当决老儿人裂看作了面前他泄,,代人一人道而知道使决又当作之间力为收一想着功柯阳而过号自己东斗,断处,脸水长中他月杀黄光好无所是子杯之会、她条害已是位,张。,有得大五入,白大则周不势,情道点人。双竟然整肩便二禁受干,不了不的,死
=== GENERATED SEQUENCE 3 ===
Model prompt >>> 子名,要自己,可不得山上夜姓绑何姑。鞭受伤死,已是忌了人 面门二。一人门,笑”,,海而且之,道情将想排,当又们春住重大、二万着鼎,尚去真人不如身,弹老令全齐却下壁却比天着极落已晚,。回,看来聪莲见高指打听,青,却油脸苦下去可惜虽毙,小比十分一日卓纵喷小排人将只名两人救商我若敌人也微笑是离、力,为原来,,恩子这是了,所在,”来主人人说之一臂比毙都见尚,穴却赶到忘伤都是上来佳,点比好一投了点欲乱也跟挥龄烈犯听声人后来不少住,艺瞧先前,打不同道越。胆有
