# preprocess crawler data

In [None]:
from tqdm import tqdm_notebook, tnrange
from pathlib import Path
import json
import re

In [None]:
cna_path = "crawler_data/cna_news/"
cts_path = "crawler_data/cts_news/"
udn_path = "crawler_data/udn_news/"
ptt_path = "crawler_data/ptt_gossiping/"

## cna data

In [None]:
file_list = list(Path(cna_path).glob('*.json'))
cna_news = {}

for i in tnrange(len(file_list)):
    
    if 'focus' in file_list[i].stem:
        with open(file_list[i], 'r') as f:
            data = json.load(f)

        for title, article in data.items():
            right_gua = article.find('）')
            article = article[right_gua+1:]
            article = article[::-1]
            last_period = article.find('。')
            article = article[last_period:]
            article = article[::-1]
            if title not in cna_news:
                cna_news[title] = article
                
    if 'news' in file_list[i].stem:
        with open(file_list[i], 'r') as f:
            data = json.load(f)
            
        for cate, articles in data.items():
            for title, article in articles.items():
                right_gua = article.find('）')
                article = article[right_gua+1:]
                article = article[::-1]
                last_period = article.find('。')
                article = article[last_period:]
                article = article[::-1]
                if title not in cna_news:
                    cna_news[title] = article
    
print('total {} news'.format(len(cna_news)))

## cts data

In [None]:
file_list = list(Path(cts_path).glob('*.json'))
cts_news = {}

for i in tnrange(len(file_list)):
    
    if 'hots' in file_list[i].stem:
        try:
            with open(file_list[i], 'r') as f:
                data = json.load(f)
            for title, article in data.items():
                if title not in cts_news:
                    cts_news[title] = article
        except:
            pass
                
    if 'news' in file_list[i].stem:
        try:
            with open(file_list[i], 'r') as f:
                data = json.load(f)
            for cate, articles in data.items():
                for title, article in articles.items():
                    first_slide = article.find('/')
                    article = article[first_slide+1:].strip()
                    if title not in cts_news:
                        cts_news[title] = article
        except:
            pass
        
print('total {} news'.format(len(cts_news)))

## udn data

In [None]:
file_list = list(Path(udn_path).glob('*.json'))
udn_news = {}

for i in tnrange(len(file_list)):
    
    if 'news' in file_list[i].stem:
        try:
            with open(file_list[i], 'r') as f:
                data = json.load(f)
            for cate, articles in data.items():
                for title, article in articles.items():
                    first_slide = title.find('／')
                    title = title[first_slide+1:] if first_slide != -1 else title
                    right_gua = title.find('】')
                    title = title[right_gua+1:] if right_gua != -1 else title
                    
                    if title not in udn_news:
                        udn_news[title] = article
                
        except:
            pass
        
print('total {} news'.format(len(udn_news)))

## ptt data

In [None]:
file_list = list(Path(ptt_path).glob('*.json'))
ptt_articles = {}

for i in tnrange(len(file_list)):
    
    try:
        with open(file_list[0], 'r') as f:
            data = json.load(f)
        for article_id, article in data.items():
            title = article['article_title']
            content = article['content']
            content = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '', content, flags=re.MULTILINE)
            messages = ''
            for message in article['messages']:
                message_content = message['push_content']
                messages += message_content
                messages += '。'
                
            if title not in ptt_articles:
                ptt_articles[title] = [content, messages]
                
            
    except:
        pass

## simply combine all data

In [None]:
total_data = []
for title, article in cna_news.items():
    total_data.append(title)
    total_data.append(article)
    
for title, article in cts_news.items():
    total_data.append(title)
    total_data.append(article)
    
for title, article in udn_news.items():
    total_data.append(title)
    total_data.append(article) 
    
# for title, [content, messages] in ptt_articles.items():
#     total_data.append(title)
#     total_data.append(content)
#     total_data.append(messages)

print(len(total_data))

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import transformers
from transformers import GPT2Config, GPT2Tokenizer, GPT2Model, GPT2LMHeadModel
from transformers import BertTokenizer, BertTokenizerFast

import numpy as np
from datetime import datetime
import os
import time

In [None]:
exp_dir = "gpt2_chinese/exp/"

bert_tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')

In [None]:
class AttrDict(dict):
    """ Access dictionary keys like attribute 
        https://stackoverflow.com/questions/4984647/accessing-dict-keys-like-an-attribute
    """
    def __init__(self, *av, **kav):
        dict.__init__(self, *av, **kav)
        self.__dict__ = self

opts = AttrDict()

# Configure models
opts.vocab_size = bert_tokenizer.vocab_size
opts.emb = 768


# Configure optimization
opts.learning_rate = 1.5e-4
opts.bert_lr = 5e-6
opts.weight_decay = 0.01 # L2 weight regularization
opts.max_grad_norm = 1.0

opts.batch_size = 4

# Configure training
opts.max_seq_len = 512
opts.num_epochs = 300
opts.warmup_steps = 4000
opts.gradient_accumulation = 20

opts.load_pretrain = True

In [None]:
class TextDataset():
    def __init__(self, total_data):
        
        print('='*50)
        print('Dataset preprocessing log:')
        self.sents = total_data
        print('- Number of sentences: {}'.format(len(self.sents)))
        count = 0
        for sent in self.sents:
            count += len(sent)
        print('- Number of words: {}'.format(count))
        
        
    def __len__(self):
        return len(self.sents)
    
    def __getitem__(self, index):
        sent = self.sents[index]
        
        return sent
    
def collate_fn(data):
    
    def _pad_sequences(seqs):
        lens = [len(seq)-1 for seq in seqs]
        input_seqs = torch.zeros(len(seqs), max(lens)).long()
        target_seqs = torch.zeros(len(seqs), max(lens)).long()
        for i, seq in enumerate(seqs):
            input_seqs[i, :len(seq)-1] = torch.LongTensor(seq[:-1])
            target_seqs[i, :len(seq)-1] = torch.LongTensor(seq[1:])
            
        return input_seqs, target_seqs, lens
    
    def bert_tokenize(tokenizer, article, max_length=1024):
    
        sents = re.split('。|，| ', article)
        while '' in sents:
            sents.remove('')

        bert_sent = '[CLS]'
        for sent in sents:
            bert_sent += sent
            bert_sent += '[SEP]'

        tokens = tokenizer.tokenize(bert_sent)

        truncat_tokens = []
        if len(tokens) < max_length:
            truncat_tokens.append(tokens)
        else:
            truncat_tokens = []
            while len(tokens) > max_length:
                truncat_tokens.append(tokens[:max_length])
                tokens = tokens[max_length:]
            truncat_tokens.append(tokens)

        return truncat_tokens
    
    sents = data
    
    bert_tokens = []
    for sent in sents:
        tokens = bert_tokenize(bert_tokenizer, sent, max_length=opts.max_seq_len)
        bert_tokens.append(tokens[0])
        
    bert_idxs = []
    for bert_token in bert_tokens:
        idxs = bert_tokenizer.convert_tokens_to_ids(bert_token)
        bert_idxs.append(idxs)
        
    input_seqs, target_seqs, lens = _pad_sequences(bert_idxs)
    
    return sents, bert_tokens, bert_idxs, input_seqs, target_seqs, lens

In [None]:
random_seed = 202000810

In [None]:
from sklearn.model_selection import train_test_split

train_data, dev_data = train_test_split(total_data, test_size=0.05, random_state=random_seed, shuffle=True)

print(len(train_data), len(dev_data))

In [None]:
train_dataset = TextDataset(train_data)
dev_dataset = TextDataset(dev_data)

In [None]:
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)

train_iter = DataLoader(dataset=train_dataset,
                        batch_size=opts.batch_size,
                        shuffle=True,
                        num_workers=16,
#                         sampler=train_sampler,
                        collate_fn=collate_fn)

dev_iter = DataLoader(dataset=dev_dataset,
                        batch_size=2,
                        shuffle=False,
                        num_workers=16,
#                         sampler=train_sampler,
                        collate_fn=collate_fn)

In [None]:
if opts.load_pretrain:
    model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
    model.resize_token_embeddings(bert_tokenizer.vocab_size)
#     model.half()
else:
    gpt2_config = GPT2Config()
    model = GPT2LMHeadModel(config=gpt2_config)
    model.resize_token_embeddings(bert_tokenizer.vocab_size)
#     model.half()

print('total parms : ', sum(p.numel() for p in model.parameters()))
print('trainable parms : ', sum(p.numel() for p in model.parameters() if p.requires_grad))

In [None]:
model

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
USE_CUDA = torch.cuda.is_available()
USE_CUDA = True

In [None]:
## distribute data parallel

# dist_backend = 'nccl'
# dist_url = 'tcp://127.0.0.1:45655'
# world_size = 1
# rank = 0

# torch.distributed.init_process_group(backend=dist_backend, 
#                                      init_method=dist_url, 
#                                      world_size=world_size, 
#                                      rank=rank)


# bertlm = torch.nn.parallel.DistributedDataParallel(bertlm, find_unused_parameters=False)

# if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
# dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
model = nn.DataParallel(model)

if USE_CUDA:
    model.cuda()

In [None]:
last_epoch = -1
model_name = 'gpt2_medium_noptt_len{}_batch_{}'.format(opts.max_seq_len, opts.batch_size)
now = str(datetime.now()).split('.')[0]
experiment_name = '{}_{}'.format(model_name, now)
experiment_dir = Path(exp_dir) / experiment_name
experiment_dir.mkdir(exist_ok=True, parents=True)
print(experiment_dir)

In [None]:
def log2file(log_file, msg):
    with open(log_file, 'a') as fw:
        fw.write(msg)
        fw.write('\n')

experiment_trainlog = experiment_dir / 'train_log.txt'
experiment_devlog = experiment_dir / 'dev_log.txt'

In [None]:
print(opts.learning_rate)
print(opts.bert_lr)

optimizer = transformers.AdamW([
    {'params': model.module.parameters(), 'lr':opts.learning_rate},
], lr=opts.learning_rate)

scheduler = transformers.get_linear_schedule_with_warmup(optimizer, 
                                                         num_warmup_steps=opts.warmup_steps, 
                                                         num_training_steps=len(train_iter)*opts.num_epochs)

criterion = torch.nn.CrossEntropyLoss(ignore_index=bert_tokenizer.pad_token_id,)

In [None]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0.001, exp_dir=''):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.best_epoch = 0
        self.exp_dir=Path(exp_dir)

    def __call__(self, val_loss, model, epoch):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, epoch)
        elif score < self.best_score:
#         elif score < self.best_score or score < self.best_score * (1-self.delta):
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
                msg = 'best epoch : {}'.format(self.best_epoch)
                print(msg)
                log2file(self.exp_dir / 'train_log.txt', msg)
                (self.exp_dir / 'best_model').symlink_to(self.exp_dir / 'epoch_{}.mdl'.format(self.best_epoch))
        else:
            self.best_score = score
            self.best_epoch = epoch
            self.save_checkpoint(val_loss, model, epoch)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, epoch):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            msg = f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f})'
            print(msg)
            log2file(self.exp_dir / 'train_log.txt', msg)
#         torch.save(model.state_dict(), self.exp_dir / 'checkpoint.pt')
        self.val_loss_min = val_loss
        

In [None]:
experiment_dir

In [None]:
early_stopping = EarlyStopping(patience=50, verbose=True, exp_dir=str(experiment_dir))

for k,v in opts.items():
    log_msg = '- {}: {}'.format(k, v)
    log2file(str(experiment_trainlog), log_msg)
    print(log_msg)
    
pbar_train = tqdm_notebook(total=len(train_iter))
pbar_dev = tqdm_notebook(total=len(dev_iter))
    
log_msg = '='*50
print(log_msg)
log2file(str(experiment_trainlog), log_msg)
log_msg = 'optim : \n' + str(optimizer)
print(log_msg)   
log2file(str(experiment_trainlog), log_msg)


s = 10
checkpoint = [int(len(train_iter)/s*i) for i in range(1, s)]

oom_time = 0

print('check point : ', checkpoint)

for epoch in range(last_epoch+1,  opts.num_epochs, 1):
    
    pbar_train.reset()
    pbar_dev.reset()
    
    log_msg = '='*50
    print(log_msg)
    log2file(str(experiment_trainlog), log_msg)
    loss_tracker = []
    time_tracker = []
    time_tracker.append(time.time())
    
    global_step = 0
    
    
    for iteration, batch in enumerate(train_iter):
        
        sents, bert_tokens, bert_idxs, input_seqs, target_seqs, lens = batch
        
        batch_size = input_seqs.size(0)
        assert(batch_size == target_seqs.size(0))
        
        if USE_CUDA:
            input_seqs = input_seqs.cuda()
    #         lens = lens.cuda()
            target_seqs = target_seqs.cuda()
        
        model.train()
        
        try:

            outputs = model(input_seqs)

            loss = criterion(outputs[0].view(outputs[0].size(0)*outputs[0].size(1), -1), 
                             target_seqs.view(target_seqs.size(0)*target_seqs.size(1)))

            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), opts.max_grad_norm)

            if (iteration + 1) % opts.gradient_accumulation == 0 or iteration == len(train_iter)-1:

                optimizer.step()

                optimizer.zero_grad()

                scheduler.step()

            loss_tracker.append(loss.item()*batch_size)

#             torch.cuda.empty_cache()
            
        except RuntimeError as exception:
            
            if "out of memory" in str(exception):
                oom_time += 1
                log_msg = "WARNING: ran out of memory,times: {}".format(oom_time)
                print(log_msg)   
                log2file(str(experiment_trainlog), log_msg)
                torch.cuda.empty_cache()
                if hasattr(torch.cuda, 'empty_cache'):
                    torch.cuda.empty_cache()
            else:
                log_msg = str(exception)
                print(log_msg)   
                log2file(str(experiment_trainlog), log_msg)
                raise exception
                
        #=================================
        
        if global_step in checkpoint:
            
            now_time = time.time()
            time_tracker.append(time.time())
            cur_avg_loss = np.sum(np.array(loss_tracker)) / (global_step * opts.batch_size)
            log_msg = "{} | Batch {:d}/{:d} | Mean Loss {:5.5f} | time cost {:d} s"  \
                    .format('train'.upper(), global_step, len(train_iter), cur_avg_loss, int(time_tracker[-1] - time_tracker[-2]))
            print(log_msg)
            log2file(str(experiment_trainlog), log_msg)
            now_percent = checkpoint.index(global_step)+1
            torch.save(model.state_dict(), experiment_dir / 'epoch_{}_{}.mdl'.format(epoch-1, now_percent))
            
            loss_tracker = []

            start = time.time()

            pbar_dev.reset()

            for iteration, batch in enumerate(dev_iter):

                sents, bert_tokens, bert_idxs, input_seqs, target_seqs, lens = batch

                batch_size = input_seqs.size(0)
                assert(batch_size == target_seqs.size(0))

                if USE_CUDA:
                    input_seqs = input_seqs.cuda()
            #         lens = lens.cuda()
                    target_seqs = target_seqs.cuda()

                model.eval()

                try:

                    outputs = model(input_seqs)

                    loss = criterion(outputs[0].view(outputs[0].size(0)*outputs[0].size(1), -1), 
                                     target_seqs.view(target_seqs.size(0)*target_seqs.size(1)))

                    loss_tracker.append(loss.item()*batch_size)

                except RuntimeError as exception:

                    if "out of memory" in str(exception):
                        oom_time += 1
                        log_msg = "WARNING: ran out of memory,times: {}".format(oom_time)
                        print(log_msg)   
                        log2file(str(experiment_trainlog), log_msg)
                        torch.cuda.empty_cache()
                        if hasattr(torch.cuda, 'empty_cache'):
                            torch.cuda.empty_cache()
                    else:
                        log_msg = str(exception)
                        print(log_msg)   
                        log2file(str(experiment_trainlog), log_msg)
                        raise exception
#                 torch.cuda.empty_cache()
                        
                pbar_dev.update(1)


            total_time = time.time() - start

            mean_loss = np.sum(np.array(loss_tracker)) / dev_dataset.__len__()
            log_msg = "{}   | Batch {:d}/{:d} | Mean Loss {:5.5f} | Total time cost {:d} s"  \
                .format('dev'.upper(), global_step, len(train_iter), mean_loss, int(total_time))
            print(log_msg)
            log2file(str(experiment_trainlog), log_msg)

            val_loss = mean_loss

            early_stopping(val_loss, model, epoch)

            if early_stopping.early_stop:
                print("Early stopping")
                break
        
        global_step += 1
        pbar_train.update(1)

    
    total_time = time.time() - time_tracker[0]    
    mean_loss = np.sum(np.array(loss_tracker)) / train_dataset.__len__()
    log_msg = "{} | Epoch {:d}/{:d} | Mean Loss {:5.5f} | Total time cost {:d} s"  \
        .format('train'.upper(), epoch, opts.num_epochs, mean_loss, int(total_time))
    print(log_msg)
    log2file(str(experiment_trainlog), log_msg)

    #-----------------------

    loss_tracker = []

    start = time.time()
    
    pbar_dev.reset()

    for iteration, batch in enumerate(dev_iter):

        sents, bert_tokens, bert_idxs, input_seqs, target_seqs, lens = batch
        
        batch_size = input_seqs.size(0)
        assert(batch_size == target_seqs.size(0))
        
        if USE_CUDA:
            input_seqs = input_seqs.cuda()
    #         lens = lens.cuda()
            target_seqs = target_seqs.cuda()
        
        
        model.eval()
        
        try:
        
            outputs = model(input_seqs)

            loss = criterion(outputs[0].view(outputs[0].size(0)*outputs[0].size(1), -1), 
                             target_seqs.view(target_seqs.size(0)*target_seqs.size(1)))

            loss_tracker.append(loss.item()*batch_size)
            
        except RuntimeError as exception:
            
            if "out of memory" in str(exception):
                oom_time += 1
                log_msg = "WARNING: ran out of memory,times: {}".format(oom_time)
                print(log_msg)   
                log2file(str(experiment_trainlog), log_msg)
                torch.cuda.empty_cache()
                if hasattr(torch.cuda, 'empty_cache'):
                    torch.cuda.empty_cache()
            else:
                log_msg = str(exception)
                print(log_msg)   
                log2file(str(experiment_trainlog), log_msg)
                raise exception
                
#         torch.cuda.empty_cache()
        pbar_dev.update(1)

    total_time = time.time() - start

    mean_loss = np.sum(np.array(loss_tracker)) / dev_dataset.__len__()
    log_msg = "{}   | Epoch {:d}/{:d} | Mean Loss {:5.5f}  | Total time cost {:d} s"  \
        .format('dev'.upper(), epoch, opts.num_epochs, mean_loss, int(total_time))
    print(log_msg)
    log2file(str(experiment_trainlog), log_msg)
    
    val_loss = mean_loss
    
    early_stopping(val_loss, model, epoch)

    if early_stopping.early_stop:
        print("Early stopping")
        break
        
    torch.save(model.state_dict(), experiment_dir / 'epoch_{}.mdl'.format(epoch))

    print("="*50)