In [None]:
import os
import math
import logging
from pprint import pformat
from argparse import ArgumentParser
from collections import defaultdict
from itertools import chain

import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, TensorDataset, RandomSampler, SequentialSampler
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint
from ignite.metrics import Accuracy, Loss, MetricsLambda, RunningAverage
from ignite.contrib.handlers import ProgressBar, PiecewiseLinear
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler, OptimizerParamsHandler
from transformers import (AdamW, OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer,
                                  GPT2DoubleHeadsModel, GPT2Tokenizer, WEIGHTS_NAME, CONFIG_NAME)

from transfer_learning_conv_ai.utils import get_dataset, make_logdir
from tqdm import tqdm, trange

In [None]:
class args:
    dataset_path=''
    dataset_cache='./dataset_cache'
    model_checkpoint='openai-gpt'
    output_dir = ''
    num_candidates=2
    max_history=2
    per_gpu_train_batch_size=4
    per_gpu_eval_batch_size=1
    gradient_accumulation_steps=8
    lr=6.25e-5
    lm_coef=1.0
    mc_coef=1.0
    max_norm=1.0
    #n_epochs=3
    personality_permutations=1
    eval_before_start=True
    device="cuda" if torch.cuda.is_available() else "cpu"
    fp16='O1'
    local_rank=-1
    distributed = (local_rank != -1)
    max_step = -1
    num_train_epochs = 3
    evaluate_during_training = True
    logging_steps = 8
    n_gpu=1
    max_steps = -1
    tpu = False
    max_grad_norm = 1.0

args=args 

In [None]:
SPECIAL_TOKENS = ["<bos>", "<eos>", "<speaker1>", "<speaker2>", "<pad>"]
ATTR_TO_SPECIAL_TOKEN = {"bos_token": "<bos>", "eos_token": "<eos>", 
                  "additional_special_tokens": ["<speaker1>", "<speaker2>"],
                  "pad_token": "<pad>"}
MODEL_INPUTS = ["input_ids", "mc_token_ids", "lm_labels", "mc_labels", "token_type_ids"]
PADDED_INPUTS = ["input_ids", "lm_labels", "token_type_ids"]

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.warning("Running process %d", args.local_rank)  # This is a logger.warning: it will be printed by all distributed processes
logger.info("Arguments: %s", pformat(args))
    
# def average_distributed_scalar(scalar, args):
#     """ Average a scalar over the nodes if we are in distributed training. We use this for distributed evaluation. """
#     if args.local_rank == -1:
#         return scalar
#     scalar_t = torch.tensor(scalar, dtype=torch.float, device=args.device) / torch.distributed.get_world_size()
#     torch.distributed.all_reduce(scalar_t, op=torch.distributed.ReduceOp.SUM)
#     return scalar_t.item()


# num_added_toks = tokenizer.add_special_tokens(SPECIAL_TOKENS)
# model.resize_token_embeddings(len(tokenizer))

In [None]:
logger.info("Prepare tokenizer, pretrained model and optimizer.")
tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer # cant use Autotokenizer because checkpoint could be a Path
tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)

model_class = GPT2DoubleHeadsModel if "gpt2" in args.model_checkpoint else OpenAIGPTDoubleHeadsModel
model = model_class.from_pretrained(args.model_checkpoint)
model.to(args.device)

In [None]:
def add_special_tokens_(model, tokenizer):
    """ Add special tokens to the tokenizer and the model if they have not already been added. """
    orig_num_tokens = len(tokenizer.encoder)
    num_added_tokens = tokenizer.add_special_tokens(ATTR_TO_SPECIAL_TOKEN) # doesn't add if they are already there
    if num_added_tokens > 0:
        model.resize_token_embeddings(new_num_tokens=orig_num_tokens + num_added_tokens)
        
add_special_tokens_(model, tokenizer)

In [None]:
# optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)
# if args.fp16:
#     from apex import amp  # Apex is only required if we use fp16 training
#     model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16)

# from transformers import cached_path
# import json

# logger.info("Prepare datasets")
# #train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(args, tokenizer)

# PERSONACHAT_URL = "https://s3.amazonaws.com/datasets.huggingface.co/personachat/personachat_self_original.json"

# dataset_path = PERSONACHAT_URL

# personachat_file = cached_path(dataset_path)

# with open(personachat_file, "r", encoding="utf-8") as f:
#     dataset = json.loads(f.read())

# def tokenize(obj):
#     if isinstance(obj, str):
#         return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
#     if isinstance(obj, dict):
#         return dict((n, tokenize(o)) for n, o in obj.items())
#     return list(tokenize(o) for o in obj)

# dataset = tokenize(dataset)

In [None]:
datasets = {"train": defaultdict(list), "valid": defaultdict(list)}

dataset_cache = args.dataset_cache
dataset_cache = dataset_cache + '_' + type(tokenizer).__name__ 

def build_input_from_segments(persona, history, reply, tokenizer, lm_labels=False, with_eos=True):
    """ Build a sequence of input from 3 segments: persona, history and last reply. """
    bos, eos, speaker1, speaker2 = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[:-1])
    sequence = [[bos] + list(chain(*persona))] + history + [reply + ([eos] if with_eos else [])]
    sequence = [sequence[0]] + [[speaker2 if (len(sequence)-i) % 2 else speaker1] + s for i, s in enumerate(sequence[1:])]
    #[[bos+persona], [history], [reply+eos]]
    instance = {}
    instance["input_ids"] = list(chain(*sequence))
    instance["token_type_ids"] = [speaker2 if i % 2 else speaker1 for i, s in enumerate(sequence) for _ in s]
    instance["mc_token_ids"] = len(instance["input_ids"]) - 1
    instance["lm_labels"] = [-100] * len(instance["input_ids"])
    if lm_labels: #if the current candidate is lm_labels, [-100]*[len(persona)+len(history)+1(speaker2)]+current candidate
        instance["lm_labels"] = ([-100] * sum(len(s) for s in sequence[:-1])) + [-100] + sequence[-1][1:]
    return instance

def pad_dataset(dataset, padding=0):
    """ Pad the dataset. This could be optimized by defining a Dataset class and padding at the batch level, but this is simpler. """
    max_l = max(len(x) for x in dataset["input_ids"])
    for name in PADDED_INPUTS:
        dataset[name] = [x + [padding if name != "lm_labels" else -100] * (max_l - len(x)) for x in dataset[name]]
    return dataset

def get_data_loaders(args, tokenizer):
    """ Prepare the dataset for training and evaluation """
    personachat = torch.load(dataset_cache)#get_dataset(tokenizer, args.dataset_path, args.dataset_cache)

    logger.info("Build inputs and labels")
    datasets = {"train": defaultdict(list), "valid": defaultdict(list)}
    for dataset_name, dataset in personachat.items():
        num_candidates = len(dataset[0]["utterances"][0]["candidates"]) #n_candidates are same for all 17878 dialogs
        if args.num_candidates > 0 and dataset_name == 'train':
            num_candidates = min(args.num_candidates, num_candidates) #min(2,20)
        for dialog in dataset: #17878 dialogs
            persona = dialog["personality"].copy()
            for _ in range(args.personality_permutations):
                for utterance in dialog["utterances"]: #7个utterances in the first dialog
                    history = utterance["history"][-(2*args.max_history+1):]
                    for j, candidate in enumerate(utterance["candidates"][-num_candidates:]):
                        lm_labels = bool(j == num_candidates-1) #the last sentence in candidate is the correct response
                        instance = build_input_from_segments(persona, history, candidate, tokenizer, lm_labels)
                        #instance["input_ids"] of length of the sequence: bos+persona+history+candiate+eos
                        #, instance["token_type_ids"], instance["mc_token_ids"], instance["lm_labels"] 
                        for input_name, input_array in instance.items():
                            datasets[dataset_name][input_name].append(input_array)
                        #datasets['train']['input_ids'] of [[c1 in u1],[c2 in u1],..,[c2 in u7]] 14 sublists [n_candidate* # of utterances in a dialog] e.g. [2*7]
                        # the first is the sequence with wrong candidate, second is the sequence with correct candidate
                    datasets[dataset_name]["mc_labels"].append(num_candidates - 1) #7
                    datasets[dataset_name]["n_candidates"] = num_candidates #an int =2
                persona = [persona[-1]] + persona[:-1]  # permuted personalities

    logger.info("Pad inputs and convert to Tensor")
    tensor_datasets = {"train": [], "valid": []}
    for dataset_name, dataset in datasets.items():
        dataset = pad_dataset(dataset, padding=tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-1]))
        for input_name in MODEL_INPUTS:
            tensor = torch.tensor(dataset[input_name]) #1 dialog->7 or 6 or 8 untterences == 17878 dialogs -> sum(utterances in each dialog) =131438
            #np.sum([len(dialog['utterances']) for dialog in personachat['train']])
            if input_name != "mc_labels":
                tensor = tensor.view((-1, datasets[dataset_name]["n_candidates"]) + tensor.shape[1:])
            tensor_datasets[dataset_name].append(tensor)

    logger.info("Build train and validation dataloaders")
    train_dataset, valid_dataset = TensorDataset(*tensor_datasets["train"]), TensorDataset(*tensor_datasets["valid"])
#     train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
#     valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if args.distributed else None
#     train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, shuffle=(not args.distributed))
#     valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=args.valid_batch_size, shuffle=False)

    logger.info("Train dataset (Batch, Candidates, Seq length): {}".format(train_dataset.tensors[0].shape)) #torch.Size([131438 utterences, 2 candidates, 282 max_lens])
    logger.info("Valid dataset (Batch, Candidates, Seq length): {}".format(valid_dataset.tensors[0].shape))
    return train_dataset, valid_dataset#, train_loader, valid_loader, train_sampler, valid_sampler


In [None]:
train_dataset, valid_dataset = get_data_loaders(args, tokenizer)

In [None]:
def train(args, train_dataset, valid_dataset, model, tokenizer):
    
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
    
    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
    
    optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)
    scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.num_train_epochs * len(train_dataloader), 0.0)])
    
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
    
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank,
                                                          find_unused_parameters=True)
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)
    
    global_step = 0
    metrics = {"nll": 10000.0,
           "accuracy": 0.0}
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    #train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
    
    for _ in range(args.num_train_epochs): #3
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) #131438
        for step, batch in enumerate(epoch_iterator): #4 utterances
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
            (lm_loss), (mc_loss), *_ = model(
                input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids,
                mc_labels=mc_labels, lm_labels=lm_labels
            )
            loss = (lm_loss * args.lm_coef + mc_loss * args.mc_coef) / args.gradient_accumulation_steps
            
            if step % 100 == 0:
                logger.info("Loss for {} is {}".format(step, loss))
                
            if args.n_gpu > 1:
                loss = loss.mean()
                
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            
            tr_loss += loss.item()
            
            if (step + 1) % args.gradient_accumulation_steps == 0 and not args.tpu:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                
                optimizer.step()
                model.zero_grad()
                
                if args.local_rank in [-1, 0] and args.logging_steps > 0 and (step+1) % args.logging_steps == 0:
                    
                    if args.local_rank == -1 and args.evaluate_during_training:
                        
                        metrics = evaluate(args, model, valid_dataset, metrics, tokenizer)
            
                global_step+=1
            
            if args.max_steps>0 and global_step > args.max_steps:
                epoch_iterator.close()
                break

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break
    
    return tr_loss/global_step, metrics

In [None]:
import numpy as np

def evaluate(args, model, valid_dataset, metrics, tokenizer, prefix=''):
    
    eval_outputs_dirs = make_logdir(args.model_checkpoint)
    
    if not os.path.exists(eval_outputs_dirs) and args.local_rank in [-1, 0]:
        os.makedirs(eval_outputs_dirs)
        
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    eval_sampler = SequentialSampler(valid_dataset) if args.local_rank == -1 else DistributedSampler(valid_dataset)
    eval_dataloader = DataLoader(valid_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
    
        # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(valid_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    
    nlls = None
    accs = None
    eval_epoch_iterator = tqdm(eval_dataloader, desc="Evaluating")
    for _, batch in enumerate(eval_epoch_iterator):
        
        model.eval()

        with torch.no_grad():
            batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
            
            lm_logits, mc_logits, *_ = model(
                input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids,
            )
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            
            x = ((lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels))
            nll = torch.nn.CrossEntropyLoss(ignore_index=-100)(x[0][0], x[1][0]).detach().cpu().numpy()#Loss(torch.nn.CrossEntropyLoss(ignore_index=-100), output_transform=lambda x: (x[0][0], x[1][0]))
            acc = torch.sum((torch.max(x[0][1], 1)[1] == x[1][1]).int()).detach().cpu().numpy().mean()#Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))
        
        if nlls is None:
            nlls = nll
            accs = acc
        else:
            nlls = np.append(nlls, nll)
            accs = np.append(accs, acc)
    
    nlls_mean = np.mean(nlls)
    accs_mean = np.mean(accs)
    
    if accs_mean>metrics['accuracy'] and nlls_mean<metrics['nll']:
        logger.info("***** New high accuracy and nll! {} {}*****".format(accs_mean, nlls_mean))
        metrics.update({'nll': nlls_mean, 'accuracy': accs_mean})
        output_eval_file = os.path.join(eval_outputs_dirs, prefix, "eval_results.txt")
    
    return metrics

In [None]:
tr_loss, eval_results = train(args, train_dataset, valid_dataset, model, tokenizer)