<a href="https://colab.research.google.com/github/ShinAsakawa/ShinAsakawa.github.io/blob/master/notebooks/2020_0809mlm_task.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

- https://github.com/pytorch/text/blob/master/examples/BERT/mlm_task.py
- source: mlm_task.py 

In [1]:
!pip install torchtext --upgrade

Collecting torchtext
[?25l  Downloading https://files.pythonhosted.org/packages/b9/f9/224b3893ab11d83d47fde357a7dcc75f00ba219f34f3d15e06fe4cb62e05/torchtext-0.7.0-cp36-cp36m-manylinux1_x86_64.whl (4.5MB)
[K     |████████████████████████████████| 4.5MB 11.1MB/s 
[?25hCollecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 52.5MB/s 
Installing collected packages: sentencepiece, torchtext
  Found existing installation: torchtext 0.3.1
    Uninstalling torchtext-0.3.1:
      Successfully uninstalled torchtext-0.3.1
Successfully installed sentencepiece-0.1.91 torchtext-0.7.0


# torchtext による BERT

この例では PyTorch と torchtext のみで BERT モデルを訓練する方法を示しています。
加えて，その後 Q and A 課題用の 事前訓練済 BERT をファインチューニングの方法を示します。

## 事前訓練済 BERT の生成

マスク化言語モデル課題と次文予測課題 で BERT モデルを訓練します。 ローカル GPU または CPU 上で以下を実行します。

```bash
python mlm_task.py
python ns_task.py
```

mlm_task の錯乱度 (ppl) 最終結果は訓練データセットで 18.97899. 
ns_task の損失関数は訓練データセットで 0.05446 です。

### Q and A 課題のための訓練済 BERT モデルのファインチューニング

SQuAD (スタフォード大学による Q and A ) データセットを用いて，訓練済 BERT モデルによる Q and A 課題:

```bash
python qa_task.py --bert-model ns_bert.pt --epochs 30
```

訓練済 BERT と vocab は以下から利用可能です:

* [bert_vocab.pt](https://pytorch.s3.amazonaws.com/models/text/torchtext_bert_example/bert_vocab.pt)
* [mlm_bert.pt](https://pytorch.s3.amazonaws.com/models/text/torchtext_bert_example/mlm_bert.pt)
* [ns_bert.pt](https://pytorch.s3.amazonaws.com/models/text/torchtext_bert_example/ns_bert.pt)

訓練/検証/テスト の例は以下のとおりです:

```bash
    | epoch   1 |   200/ 1055 batches | lr 5.00000 | ms/batch 746.33 | loss  3.70 | ppl    40.45
    | epoch   1 |   400/ 1055 batches | lr 5.00000 | ms/batch 746.78 | loss  3.06 | ppl    21.25
    | epoch   1 |   600/ 1055 batches | lr 5.00000 | ms/batch 746.83 | loss  2.84 | ppl    17.15
    | epoch   1 |   800/ 1055 batches | lr 5.00000 | ms/batch 746.55 | loss  2.69 | ppl    14.73
    | epoch   1 |  1000/ 1055 batches | lr 5.00000 | ms/batch 745.48 | loss  2.55 | ppl    12.85
    -----------------------------------------------------------------------------------------
    | end of epoch   1 | time: 821.25s | valid loss  2.33 | exact   40.052% | f1   52.595%
    -----------------------------------------------------------------------------------------
...
    -----------------------------------------------------------------------------------------
    | epoch  10 |   200/ 1055 batches | lr 0.00500 | ms/batch 749.25 | loss  1.25 | ppl     3.50
    | epoch  10 |   400/ 1055 batches | lr 0.00500 | ms/batch 745.81 | loss  1.24 | ppl     3.47
    | epoch  10 |   600/ 1055 batches | lr 0.00500 | ms/batch 744.89 | loss  1.26 | ppl     3.51
    | epoch  10 |   800/ 1055 batches | lr 0.00500 | ms/batch 746.02 | loss  1.23 | ppl     3.42
    | epoch  10 |  1000/ 1055 batches | lr 0.00500 | ms/batch 746.61 | loss  1.25 | ppl     3.50
    -----------------------------------------------------------------------------------------
    | end of epoch  10 | time: 821.85s | valid loss  2.05 | exact   51.648% | f1   63.811%
    -----------------------------------------------------------------------------------------
    =========================================================================================
    | End of training | test loss  2.05 | exact   51.337% | f1   63.645%
    =========================================================================================

```

## サンプルファイルの説明

### model.py

<!--
This file defines the Transformer and MultiheadAttention models used for BERT. 
The embedding layer include PositionalEncoding and TokenTypeEncoding layers. MLMTask, NextSentenceTask, and QuestionAnswerTask are the models for the three tasks mentioned above.
-->

このファイルは BERT で使用される `Transformer` と `MultiheadAttention` モデルを定義している。
埋め込み層には `PositionalEncoding` と `TokenTypeEncoding` 層が含まれる。
MLMTask (マスク化言語モデル), NextSentenceTask(次文予測課題), QuestionAnswerTask（Q and A 課題） は 上述の3つのタスクのモデルである。

### data.py

<!--
This file provides a few datasets required to train the BERT model and question-answer task. 
Please note that BookCorpus dataset is not available publicly.
-->

このファイルは BERT モデルと Q and A課題を訓練するために必要なデータセットをいくつか提供します。
BookCorpus のデータセットは公開されていないことに注意してください。


### mlm_task.py, ns_task.py, qa_task.py

<!--
Those three files define the train/valid/test process for the tasks.
-->
これらの 3 つのファイルは 課題の訓練/検証/テストのプロセスを定義します。


### metrics.py

<!--This file provides two metrics (F1 and exact score) for question-answer task-->

このファイルは Q and A 課題 の 2 つの尺度 (F1 と精度) を提供します。


### utils.py

<!--This file provides a few utils used by the three tasks.-->
このファイルは 3 つ課題で 使用されるいくつかのユーティリティを提供します。

In [2]:
#import argparse
import time
import math
import torch
import torch.nn as nn
#from model import MLMTask
#from utils import run_demo, run_ddp, wrap_up
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader

In [3]:
# from model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear, Dropout, LayerNorm, TransformerEncoder
from torchtext.nn import MultiheadAttentionContainer, InProjContainer, ScaledDotProduct

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.pos_embedding = nn.Embedding(max_len, d_model)

    def forward(self, x):
        S, N = x.size()
        pos = torch.arange(S,
                           dtype=torch.long,
                           device=x.device).unsqueeze(0).expand((N, S)).t()
        return self.pos_embedding(pos)


class TokenTypeEncoding(nn.Module):
    def __init__(self, type_token_num, d_model):
        super(TokenTypeEncoding, self).__init__()
        self.token_type_embeddings = nn.Embedding(type_token_num, d_model)

    def forward(self, seq_input, token_type_input):
        S, N = seq_input.size()
        if token_type_input is None:
            token_type_input = torch.zeros((S, N),
                                           dtype=torch.long,
                                           device=seq_input.device)
        return self.token_type_embeddings(token_type_input)


class BertEmbedding(nn.Module):
    def __init__(self, ntoken, ninp, dropout=0.5):
        super(BertEmbedding, self).__init__()
        self.ninp = ninp
        self.ntoken = ntoken
        self.pos_embed = PositionalEncoding(ninp)
        self.embed = nn.Embedding(ntoken, ninp)
        self.tok_type_embed = TokenTypeEncoding(2, ninp)  # Two sentence type
        self.norm = LayerNorm(ninp)
        self.dropout = Dropout(dropout)

    def forward(self, src, token_type_input):
        src = self.embed(src) + self.pos_embed(src) \
            + self.tok_type_embed(src, token_type_input)
        return self.dropout(self.norm(src))


class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048,
                 dropout=0.1, activation="gelu"):
        super(TransformerEncoderLayer, self).__init__()
        in_proj_container = InProjContainer(Linear(d_model, d_model),
                                            Linear(d_model, d_model),
                                            Linear(d_model, d_model))
        self.mha = MultiheadAttentionContainer(nhead, in_proj_container,
                                               ScaledDotProduct(), Linear(d_model, d_model))
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

        if activation == "relu":
            self.activation = F.relu
        elif activation == "gelu":
            self.activation = F.gelu
        else:
            raise RuntimeError("only relu/gelu are supported, not {}".format(activation))

    def init_weights(self):
        self.mha.in_proj_container.query_proj.init_weights()
        self.mha.in_proj_container.key_proj.init_weights()
        self.mha.in_proj_container.value_proj.init_weights()
        self.mha.out_proj.init_weights()
        self.linear1.weight.data.normal_(mean=0.0, std=0.02)
        self.linear2.weight.data.normal_(mean=0.0, std=0.02)
        self.norm1.bias.data.zero_()
        self.norm1.weight.data.fill_(1.0)
        self.norm2.bias.data.zero_()
        self.norm2.weight.data.fill_(1.0)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        attn_output, attn_output_weights = self.mha(src, src, src, attn_mask=src_mask)
        src = src + self.dropout1(attn_output)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src


class BertModel(nn.Module):
    """Contain a transformer encoder."""

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(BertModel, self).__init__()
        self.model_type = 'Transformer'
        self.bert_embed = BertEmbedding(ntoken, ninp)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.ninp = ninp

    def forward(self, src, token_type_input):
        src = self.bert_embed(src, token_type_input)
        output = self.transformer_encoder(src)
        return output

    
class MLMTask(nn.Module):
    """Contain a transformer encoder plus MLM head."""

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(MLMTask, self).__init__()
        self.bert_model = BertModel(ntoken, ninp, nhead, nhid, nlayers, dropout=0.5)
        self.mlm_span = Linear(ninp, ninp)
        self.activation = F.gelu
        self.norm_layer = LayerNorm(ninp, eps=1e-12)
        self.mlm_head = Linear(ninp, ntoken)

    def forward(self, src, token_type_input=None):
        src = src.transpose(0, 1)  # Wrap up by nn.DataParallel
        output = self.bert_model(src, token_type_input)
        output = self.mlm_span(output)
        output = self.activation(output)
        output = self.norm_layer(output)
        output = self.mlm_head(output)
        return output


In [10]:
# from util.py
import torch
import torch.distributed as dist
import os
import torch.multiprocessing as mp
import math

#run_demo, run_ddp, wrap_up

def run_demo(demo_fn, main_fn, args):
    mp.spawn(demo_fn,
             args=(main_fn, args,),
             nprocs=args.world_size,
             join=True)

    
def run_ddp(rank, main_fn, args):
    setup(rank, args.world_size, args.seed)
    main_fn(args, rank)
    cleanup()

def print_loss_log(file_name, train_loss, val_loss, test_loss, args=None):
    with open(file_name, 'w') as f:
        if args:
            for item in args.__dict__:
                f.write(item + ':    ' + str(args.__dict__[item]) + '\n')
        for idx in range(len(train_loss)):
            f.write('epoch {:3d} | train loss {:8.5f}'.format(idx + 1,
                                                              train_loss[idx]) + '\n')
        for idx in range(len(val_loss)):
            f.write('epoch {:3d} | val loss {:8.5f}'.format(idx + 1,
                                                            val_loss[idx]) + '\n')
        f.write('test loss {:8.5f}'.format(test_loss) + '\n')


def wrap_up(train_loss_log, val_loss_log, test_loss, args, model, ns_loss_log, model_filename):
    print('=' * 89)
    print('| End of training | test loss {:8.5f} | test ppl {:8.5f}'.format(test_loss, math.exp(test_loss)))
    print('=' * 89)
    print_loss_log(ns_loss_log, train_loss_log, val_loss_log, test_loss)
    with open(args.save, 'wb') as f:
        torch.save(model.bert_model.state_dict(), f)
    with open(model_filename, 'wb') as f:
        torch.save(model.state_dict(), f)


In [5]:
def collate_batch(batch_data, args, mask_id, cls_id):
    batch_data = torch.tensor(batch_data).long().view(args.batch_size, -1).t().contiguous()
    # Generate masks with args.mask_frac
    data_len = batch_data.size(0)
    ones_num = int(data_len * args.mask_frac)
    zeros_num = data_len - ones_num
    lm_mask = torch.cat([torch.zeros(zeros_num), torch.ones(ones_num)])
    lm_mask = lm_mask[torch.randperm(data_len)]
    batch_data = torch.cat((torch.tensor([[cls_id] * batch_data.size(1)]).long(), batch_data))
    lm_mask = torch.cat((torch.tensor([0.0]), lm_mask))

    targets = torch.stack([batch_data[i] for i in range(lm_mask.size(0)) if lm_mask[i]]).view(-1)
    batch_data = batch_data.masked_fill(lm_mask.bool().unsqueeze(1), mask_id)
    return batch_data, lm_mask, targets


def process_raw_data(raw_data, args):
    _num = raw_data.size(0) // (args.batch_size * args.bptt)
    raw_data = raw_data[:(_num * args.batch_size * args.bptt)]
    return raw_data


def evaluate(data_source, model, vocab, ntokens, criterion, args, device):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    total_loss = 0.
    mask_id = vocab.stoi['<MASK>']
    cls_id = vocab.stoi['<cls>']
    dataloader = DataLoader(data_source, batch_size=args.batch_size * args.bptt,
                            shuffle=False, collate_fn=lambda b: collate_batch(b, args, mask_id, cls_id))
    with torch.no_grad():
        for batch, (data, lm_mask, targets) in enumerate(dataloader):
            if args.parallel == 'DDP':
                data = data.to(device[0])
                targets = targets.to(device[0])
            else:
                data = data.to(device)
                targets = targets.to(device)
            data = data.transpose(0, 1)  # Wrap up by DDP or DataParallel
            output = model(data)
            output = torch.stack([output[i] for i in range(lm_mask.size(0)) if lm_mask[i]])
            output_flat = output.view(-1, ntokens)
            total_loss += criterion(output_flat, targets).item()
    return total_loss / ((len(data_source) - 1) / args.bptt / args.batch_size)


In [6]:
def train(model, vocab, train_loss_log, train_data,
          optimizer, criterion, ntokens, epoch, scheduler, args, device, rank=None):
    model.train()
    total_loss = 0.
    start_time = time.time()
    mask_id = vocab.stoi['<MASK>']
    cls_id = vocab.stoi['<cls>']
    train_loss_log.append(0.0)
    dataloader = DataLoader(train_data, batch_size=args.batch_size * args.bptt,
                            shuffle=False, collate_fn=lambda b: collate_batch(b, args, mask_id, cls_id))

    for batch, (data, lm_mask, targets) in enumerate(dataloader):
        optimizer.zero_grad()
        if args.parallel == 'DDP':
            data = data.to(device[0])
            targets = targets.to(device[0])
        else:
            data = data.to(device)
            targets = targets.to(device)
        data = data.transpose(0, 1)  # Wrap up by DDP or DataParallel
        output = model(data)
        output = torch.stack([output[i] for i in range(lm_mask.size(0)) if lm_mask[i]])
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optimizer.step()
        total_loss += loss.item()
        if batch % args.log_interval == 0 and batch > 0:
            cur_loss = total_loss / args.log_interval
            elapsed = time.time() - start_time
            if (rank is None) or rank == 0:
                train_loss_log[-1] = cur_loss
                print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | '
                      'loss {:5.2f} | ppl {:8.2f}'.format(epoch, batch,
                                                          len(train_data) // (args.bptt * args.batch_size),
                                                          scheduler.get_last_lr()[0],
                                                          elapsed * 1000 / args.log_interval,
                                                          cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()


In [7]:
def run_main(args, rank=None):
    torch.manual_seed(args.seed)
    if args.parallel == 'DDP':
        n = torch.cuda.device_count() // args.world_size
        device = list(range(rank * n, (rank + 1) * n))
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    import torchtext
    if args.dataset == 'WikiText103':
        from torchtext.experimental.datasets import WikiText103 as WLMDataset
    elif args.dataset == 'WikiText2':
        from torchtext.experimental.datasets import WikiText2 as WLMDataset
    elif args.dataset == 'WMTNewsCrawl':
        from data import WMTNewsCrawl as WLMDataset
    elif args.dataset == 'EnWik9':
        from torchtext.datasets import EnWik9
    elif args.dataset == 'BookCorpus':
        from data import BookCorpus
    else:
        print("dataset for MLM task is not supported")

    try:
        vocab = torch.load(args.save_vocab)
    except:
        train_dataset, test_dataset, valid_dataset = WLMDataset()
        old_vocab = train_dataset.vocab
        vocab = torchtext.vocab.Vocab(counter=old_vocab.freqs,
                                      specials=['<unk>', '<pad>', '<MASK>'])
        with open(args.save_vocab, 'wb') as f:
            torch.save(vocab, f)

    if args.dataset == 'WikiText103' or args.dataset == 'WikiText2':
        train_dataset, test_dataset, valid_dataset = WLMDataset(vocab=vocab)
    elif args.dataset == 'WMTNewsCrawl':
        from torchtext.experimental.datasets import WikiText2
        test_dataset, valid_dataset = WikiText2(vocab=vocab, data_select=('test', 'valid'))
        train_dataset, = WLMDataset(vocab=vocab, data_select='train')
    elif args.dataset == 'EnWik9':
        enwik9 = EnWik9()
        idx1, idx2 = int(len(enwik9) * 0.8), int(len(enwik9) * 0.9)
        train_data = torch.tensor([vocab.stoi[_id]
                                  for _id in enwik9[0:idx1]]).long()
        val_data = torch.tensor([vocab.stoi[_id]
                                 for _id in enwik9[idx1:idx2]]).long()
        test_data = torch.tensor([vocab.stoi[_id]
                                 for _id in enwik9[idx2:]]).long()
        from torchtext.experimental.datasets import LanguageModelingDataset
        train_dataset = LanguageModelingDataset(train_data, vocab)
        valid_dataset = LanguageModelingDataset(val_data, vocab)
        test_dataset = LanguageModelingDataset(test_data, vocab)
    elif args.dataset == 'BookCorpus':
        train_dataset, test_dataset, valid_dataset = BookCorpus(vocab)


    train_data = process_raw_data(train_dataset.data, args)
    if rank is not None:
        # Chunk training data by rank for different gpus
        chunk_len = len(train_data) // args.world_size
        train_data = train_data[(rank * chunk_len):((rank + 1) * chunk_len)]
    val_data = process_raw_data(valid_dataset.data, args)
    test_data = process_raw_data(test_dataset.data, args)

    ntokens = len(train_dataset.get_vocab())
    if args.checkpoint != 'None':
        model = torch.load(args.checkpoint)
    else:
        model = MLMTask(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout)
    if args.parallel == 'DDP':
        model = model.to(device[0])
        model = DDP(model, device_ids=device)
    else:
        model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
    best_val_loss = None
    train_loss_log, val_loss_log = [], []

    for epoch in range(1, args.epochs + 1):
        epoch_start_time = time.time()
        train(model, train_dataset.vocab, train_loss_log, train_data,
              optimizer, criterion, ntokens, epoch, scheduler, args, device, rank)
        val_loss = evaluate(val_data, model, train_dataset.vocab, ntokens, criterion, args, device)
        if (rank is None) or (rank == 0):
            val_loss_log.append(val_loss)
            print('-' * 89)
            print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
                  'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                             val_loss, math.exp(val_loss)))
            print('-' * 89)
        if not best_val_loss or val_loss < best_val_loss:
            if rank is None:
                with open(args.save, 'wb') as f:
                    torch.save(model, f)
            elif rank == 0:
                with open(args.save, 'wb') as f:
                    torch.save(model.state_dict(), f)
            best_val_loss = val_loss
        else:
            scheduler.step()
    if args.parallel == 'DDP':
        dist.barrier()
        rank0_devices = [x - rank * len(device) for x in device]
        device_pairs = zip(rank0_devices, device)
        map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs}
        model.load_state_dict(
            torch.load(args.save, map_location=map_location))
        test_loss = evaluate(test_data, model, train_dataset.vocab, ntokens, criterion, args, device)
        if rank == 0:
            wrap_up(train_loss_log, val_loss_log, test_loss, args, model.module, 'mlm_loss.txt', 'full_mlm_model.pt')
    else:
        with open(args.save, 'rb') as f:
            model = torch.load(f)
        test_loss = evaluate(test_data, model, train_dataset.vocab, ntokens, criterion, args, device)
        wrap_up(train_loss_log, val_loss_log, test_loss, args, model, 'mlm_loss.txt', 'full_mlm_model.pt')


In [8]:
import argparse
parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 Transformer Language Model')
#args = parser.parse_args()
#help(argparse)
args = parser.parse_args(args=[])
#help(parser)

In [11]:
import argparse
parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 Transformer Language Model')
print('PyTorch Wikitext-2 Transformer Language Model')
args.emsize=768  # size of word embeddings'
args.emsize=8  # size of word embeddings'

args.nhid=3072  #  number of hidden units per layer
args.nhid=12  #  number of hidden units per layer

args.nlayers=12  # number of layers'
args.nlayers=1  # number of layers'

args.nhead=12  # the number of heads in the encoder/decoder of the transformer model
args.nhead=2  # the number of heads in the encoder/decoder of the transformer model

args.lr=6.  # initial learning rate
args.clip=0.1  # gradient clipping

args.epochs=8  # upper epoch limit
args.epochs=2  # upper epoch limit

args.batch_size=32  # batch size
args.bptt=128  # sequence length
args.dropout=0.2  #  dropout applied to layers (0 = no dropout)
args.seed=5431916812  # random seed
args.log_interval=10  # report interval
args.checkpoint='None'  # path to load the checkpoint

args.save='mlm_bert.pt'   # path to save the final model
args.save='2020-0809mlm_bert.pt'   # path to save the final model

args.save_vocab='torchtext_bert_vocab.pt' # path to save the vocab
args.save_vocab='2020-0809torchtext_bert_vocab.pt' # path to save the vocab
args.mask_frac=0.15  # the fraction of masked tokens
args.dataset='WikiText2'  # dataset used for MLM task
args.parallel=None  # Use DataParallel to train model
args.world_size=8  # the world size to initiate DPP

if args.parallel == 'DDP':
    run_demo(run_ddp, run_main, args)
else:
    run_main(args)


PyTorch Wikitext-2 Transformer Language Model
| epoch   1 |    10/  500 batches | lr 6.00000 | ms/batch 192.56 | loss 10.65 | ppl 42104.02
| epoch   1 |    20/  500 batches | lr 6.00000 | ms/batch 176.32 | loss  8.46 | ppl  4703.71
| epoch   1 |    30/  500 batches | lr 6.00000 | ms/batch 177.82 | loss  8.11 | ppl  3329.03
| epoch   1 |    40/  500 batches | lr 6.00000 | ms/batch 180.51 | loss  8.17 | ppl  3534.03
| epoch   1 |    50/  500 batches | lr 6.00000 | ms/batch 175.40 | loss  7.81 | ppl  2472.50
| epoch   1 |    60/  500 batches | lr 6.00000 | ms/batch 174.20 | loss  7.88 | ppl  2656.53
| epoch   1 |    70/  500 batches | lr 6.00000 | ms/batch 176.90 | loss  7.84 | ppl  2547.20
| epoch   1 |    80/  500 batches | lr 6.00000 | ms/batch 180.10 | loss  7.78 | ppl  2393.71
| epoch   1 |    90/  500 batches | lr 6.00000 | ms/batch 174.61 | loss  7.79 | ppl  2422.58
| epoch   1 |   100/  500 batches | lr 6.00000 | ms/batch 175.20 | loss  7.64 | ppl  2084.21
| epoch   1 |   110/  50

In [None]:
from google.colab import files
files.download('mlm_bert.pt')
files.download('2020-0809mlm_bert.pt')
files.download('torchtext_bert_vocab.pt')
files.download('2020-0809torchtext_bert_vocab.pt')