In [None]:
import os
import random
import logging
import torch
import numpy as np
import transformers
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from argparse import Namespace
from transformers import (
    MODEL_FOR_QUESTION_ANSWERING_MAPPING,
    WEIGHTS_NAME,
    AutoConfig,
    AutoModelForQuestionAnswering,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
    squad_convert_examples_to_features,
)
from transformers.data.metrics.squad_metrics import (
    compute_predictions_log_probs,
    compute_predictions_logits,
    squad_evaluate,
)
from transformers.data.processors.squad import (
    SquadResult,
    SquadV1Processor,
    SquadV2Processor,
)
from transformers.trainer_utils import is_main_process

try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter

# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)

logger = logging.getLogger(__name__)

MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


def emulate_config():
    config = {
        "seed": 42,
        "n_gpu": 1,
        "local_rank": -1,
        "threads": 8,
        # model
        "model_type": "electra",
        "config_name": "",
        "tokenizer_name": "",
        "model_name_or_path": "monologg/koelectra-base-v3-discriminator",
        "doc_stride": 128,
        "null_score_diff_threshold": 0,
        "max_seq_length": 512,
        "max_query_length": 64,
        # train and evaluation
        "do_train": True,
        "do_eval": True,
        "evaluate_during_training": True,
        "do_lower_case": False,
        "per_gpu_train_batch_size": 16,
        "per_gpu_eval_batch_size": 8,
        "learning_rate": 5e-5,
        "gradient_accumulation_steps": 1,
        "weight_decay": 0.1,  # L2 Regularization = Weight Decay in terms of SGD, but not the case for Adam.
        # cf. Adam vs. AdamW: https://hiddenbeginner.github.io/deeplearning/paperreview/2019/12/29/paper_review_AdamW.html
        # cf. LayerNorm vs. BatchNorm: https://velog.io/@glad415/Transformer-7.-%EC%9E%94%EC%B0%A8%EC%97%B0%EA%B2%B0%EA%B3%BC-%EC%B8%B5-%EC%A0%95%EA%B7%9C%ED%99%94-by-WikiDocs
        "adam_epsilon": 1e-8,
        "max_grad_norm": 1.0,
        "num_train_epochs": 3,
        "max_steps": -1,
        "warmup_steps": 0,
        "n_best_size": 20,
        "max_answer_length": 30,
        "fp16": False,
        "fp16_opt_level": "O1",  # fp16 level
        # logging
        "verbose_logging": True,  # verbose related to the data processing
        "logging_steps": 1000,
        "save_steps": 1000,
        "eval_all_checkpoints": True,
        "no_cuda": True,
        "use_mps": True,  #
        "overwrite_output_dir": False,
        "overwrite_cache": False,
        # data directory
        "output_dir": "koelectra-base-v3-korquad",
        "data_dir": "../data",
        "train_file": "KorQuAD_v1.0_train.json",
        "predict_file": "KorQuAD_v1.0_train.json",
        "cache_dir": "",
        "version_2_with_negative": False,
    }
    config = Namespace(**config)

    if config.doc_stride >= config.max_seq_length - config.max_query_length:
        logger.warning(
            "WARNING - You've set a doc stride which may be superior to the document length in some "
            "examples. This could result in errors when building features from the examples. Please reduce the doc "
            "stride or increase the maximum length to ensure the features are correctly built."
        )

    if (
        os.path.exists(config.output_dir)
        and os.listdir(config.output_dir)
        and config.do_train
        and not config.overwrite_output_dir
    ):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                config.output_dir
            )
        )

    # setup GPU & distributed learning
    if config.local_rank == -1 or config.no_cuda:
        if config.use_mps:
            device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
        else:
            device = torch.device(
                "cuda" if torch.cuda.is_available() and not config.no_cuda else "cpu"
            )
    else:
        # initializes the distributed backend which will take care of synchronizing nodes/GPUs
        # only cuda is available in this case
        if not torch.cuda.is_available():
            raise ValueError(
                "CUDA is not available on this device. Set local_rank to -1 to disable the cuda usage."
            )
        else:
            torch.cuda.set_device(config.local_rank)
            device = torch.device("cuda", config.local_rank)
            torch.distributed.init_process_group(backend="nccl")
            config.ngpu = 1

    config.device = device

    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        config.local_rank,
        device,
        config.n_gpu,
        bool(config.local_rank != -1),
        config.fp16,
    )

    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(config.local_rank):
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()

    return config


def set_seed(config):
    # refer to randomness settings in pytorch:
    # https://pytorch.org/docs/stable/notes/randomness.html

    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)
    os.environ["PYTHONHASHSEED"] = str(config.seed)

    if config.n_gpu > 0:
        torch.cuda.manual_seed_all(config.seed)


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


config = emulate_config()
set_seed(config)

# Load pretrained model and tokenizer
if config.local_rank not in [-1, 0]:
    # make sure only the first process in distributed training will download model & vocab
    torch.distributed.barrier()

config.model_type = config.model_type.lower()
model_config = AutoConfig.from_pretrained(
    config.config_name if config.config_name else config.model_name_or_path,
    cache_dir=config.cache_dir if config.cache_dir else None,
)
tokenizer = AutoTokenizer.from_pretrained(
    config.tokenizer_name if config.tokenizer_name else config.model_name_or_path,
    do_lower_case=config.do_lower_case,
    cache_dir=config.cache_dir if config.cache_dir else None,
    use_fast=False,  # SquadDataset is not compatible with Fast tokenizers which have a smarter overflow handeling
)
model = AutoModelForQuestionAnswering.from_pretrained(
    config.model_name_or_path,
    from_tf=bool(".ckpt" in config.model_name_or_path),
    config=model_config,
    cache_dir=config.cache_dir if config.cache_dir else None,
)

if config.local_rank == 0:
    # make sure only the first process in distributed training will download model & vocab
    torch.distributed.barrier()

model.to(config.device)
logger.info("Training/evaluation parameters %s", config)

# Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
# Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
# remove the need for this code, but it is still valid.
if config.fp16:
    try:
        import apex

        apex.amp.register_half_function(torch, "einsum")
    except ImportError:
        raise ImportError(
            "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
        )

In [2]:
def load_and_cache_example(config, tokenizer, evaluate=False, output_examples=False):
    if config.local_rank not in [-1, 0] and not evaluate:
        torch.distributed.barrier()

    # load data features from cache or dateset file
    input_dir = config.data_dir if config.data_dir else "."
    cached_features_file = os.path.join(
        input_dir,
        "cached_{}_{}_{}".format(
            "dev" if evaluate else "train",
            list(filter(None, config.model_name_or_path.split("/"))).pop(),
            str(config.max_seq_length),
        ),
    )

    # init features and dataset from cache if it exists
    if os.path.exists(cached_features_file) and not config.overwrite_cache:
        logger.info("Loading features from cached file %s", cached_features_file)
        features_and_dataset = torch.load(cached_features_file)
        features, dataset, examples = (
            features_and_dataset["features"],
            features_and_dataset["dataset"],
            features_and_dataset["examples"],
        )
    else:
        logger.info("Creating features from dataset file at %s", input_dir)

        if not config.data_dir and (
            (evaluate and not config.predict_file)
            or (not evaluate and not config.train_file)
        ):
            try:
                import tensorflow_datasets as tfds
            except ImportError:
                raise ImportError(
                    "If not data_dir is specified, tensorflow_datasets needs to be installed."
                )

            if config.version_2_with_negative:
                logger.warning(
                    "tensorflow_datasets does not handle version 2 of SQuAD."
                )

            tfds_examples = tfds.load("squad")
            examples = SquadV1Processor().get_examples_from_dataset(
                tfds_examples, evaluate=evaluate
            )
        else:
            processor = (
                SquadV2Processor()
                if config.version_2_with_negative
                else SquadV1Processor()
            )
            if evaluate:
                examples = processor.get_dev_examples(
                    config.data_dir, filename=config.predict_file
                )
            else:
                examples = processor.get_train_examples(
                    config.data_dir, filename=config.train_file
                )

        features, dataset = squad_convert_examples_to_features(
            examples=examples,
            tokenizer=tokenizer,
            max_seq_length=config.max_seq_length,
            doc_stride=config.doc_stride,
            max_query_length=config.max_query_length,
            is_training=not evaluate,
            return_dataset="pt",
            threads=config.threads,
        )

        if config.local_rank in [-1, 0]:
            logger.info("Saving features into cached file %s", cached_features_file)
            torch.save(
                {"features": features, "dataset": dataset, "examples": examples},
                cached_features_file,
            )

    if config.local_rank == 0 and not evaluate:
        torch.distributed.barrier()

    if output_examples:
        return dataset, examples, features
    return dataset

In [3]:
dataset, examples, features = load_and_cache_example(
    config, tokenizer, evaluate=False, output_examples=True
)

04/30/2024 03:32:42 - INFO - __main__ - Loading features from cached file ../data/cached_train_koelectra-base-v3-discriminator_512


In [4]:
batch = next(iter(dataset))

In [5]:
inputs = {
    "input_ids": batch[0].unsqueeze(0),
    "attention_mask": batch[1].unsqueeze(0),
    "token_type_ids": batch[2].unsqueeze(0),
    "start_positions": batch[3].unsqueeze(0),
    "end_positions": batch[4].unsqueeze(0),
}

for k, v in inputs.items():
    inputs[k] = v.to("mps")

In [6]:
result = model(**inputs)

In [10]:
# squad dataset analysis

start_idx, end_idx = 0, 0
for idx, type_id in enumerate(batch[2]):
    if type_id == 1:
        start_idx = idx
        break

type_ids = batch[2][start_idx:]
for idx, type_id in enumerate(type_ids):
    if type_id == 0:
        end_idx = start_idx + idx
        break

print(
    f"""
question: {examples[0].__dict__["question_text"]}
context: {examples[0].__dict__["context_text"]}
answer: {examples[0].__dict__["answer_text"]}
encoded inputs: {batch[0].tolist()}
decoded inputs: {tokenizer.decode(batch[0])}
length of tokenized tokens: {len(batch[0])}
attention masks: {batch[1].tolist()}
token_type_ids: {batch[2].tolist()}
-> {tokenizer.decode(batch[0][:start_idx])}
-> {tokenizer.decode(batch[0][start_idx:end_idx])}
-> {tokenizer.decode(batch[0][end_idx:])}
start positions: {batch[3]}
end positions: {batch[4]}
answer: {tokenizer.decode(batch[0][batch[3]: batch[4]+1 if batch[3] == batch[4] else batch[4]])}
class index: {batch[5]}
p_mask: {batch[6].long().tolist()}
is_impossible: {batch[7].long().tolist()}
"""
)


question: 바그너는 괴테의 파우스트를 읽고 무엇을 쓰고자 했는가?
context: 1839년 바그너는 괴테의 파우스트을 처음 읽고 그 내용에 마음이 끌려 이를 소재로 해서 하나의 교향곡을 쓰려는 뜻을 갖는다. 이 시기 바그너는 1838년에 빛 독촉으로 산전수전을 다 걲은 상황이라 좌절과 실망에 가득했으며 메피스토펠레스를 만나는 파우스트의 심경에 공감했다고 한다. 또한 파리에서 아브네크의 지휘로 파리 음악원 관현악단이 연주하는 베토벤의 교향곡 9번을 듣고 깊은 감명을 받았는데, 이것이 이듬해 1월에 파우스트의 서곡으로 쓰여진 이 작품에 조금이라도 영향을 끼쳤으리라는 것은 의심할 여지가 없다. 여기의 라단조 조성의 경우에도 그의 전기에 적혀 있는 것처럼 단순한 정신적 피로나 실의가 반영된 것이 아니라 베토벤의 합창교향곡 조성의 영향을 받은 것을 볼 수 있다. 그렇게 교향곡 작곡을 1839년부터 40년에 걸쳐 파리에서 착수했으나 1악장을 쓴 뒤에 중단했다. 또한 작품의 완성과 동시에 그는 이 서곡(1악장)을 파리 음악원의 연주회에서 연주할 파트보까지 준비하였으나, 실제로는 이루어지지는 않았다. 결국 초연은 4년 반이 지난 후에 드레스덴에서 연주되었고 재연도 이루어졌지만, 이후에 그대로 방치되고 말았다. 그 사이에 그는 리엔치와 방황하는 네덜란드인을 완성하고 탄호이저에도 착수하는 등 분주한 시간을 보냈는데, 그런 바쁜 생활이 이 곡을 잊게 한 것이 아닌가 하는 의견도 있다.
answer: 교향곡
encoded inputs: [2, 29064, 4034, 28889, 4234, 14623, 6334, 4110, 3244, 4219, 6570, 4292, 3063, 4219, 4195, 3771, 4034, 4070, 35, 3, 16095, 4272, 4556, 29064, 4034, 28889, 4234, 14623, 6334, 4292, 6396, 3244, 4219, 2126, 6434, 4073, 6365, 4007, 11666, 3240, 4110, 

In [11]:
del dataset, features, examples

In [13]:
no_decay = ["bias", "LayerNorm.weight"]
# Bias terms are often excluded from weight decay to prevent the model from becoming overly regularized.
# LayerNorm weights are often excluded because they are scale and shift parameters that do not benefit from weight decay.
optimizer_grouped_parameters = [
    # separate parameter groups to apply different weight decay
    {
        "params": [
            p
            for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        "weight_decay": config.weight_decay,
    },
    {
        "params": [
            p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)
        ],
        "weight_decay": 0.0,
    },
]
optimizer = torch.optim.AdamW(
    optimizer_grouped_parameters, lr=config.learning_rate, eps=config.adam_epsilon
)
# scheduler lr = initial_lr + (max_lr - initial_lr) * (1 - progress) / max(1, (num_training_steps - num_warmup_steps - step))
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=config.warmup_steps, num_training_steps=10000
)

In [21]:
def train(config, train_dataset, model, tokenizer):
    from tqdm import tqdm, trange

    if config.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    config.train_batch_size = config.per_gpu_train_batch_size * max(1, config.n_gpu)
    train_sampler = (
        RandomSampler(train_dataset)
        if config.local_rank == -1
        else DistributedSampler(train_dataset)
    )
    train_dataloader = DataLoader(
        train_dataset, batch_size=config.train_batch_size, sampler=train_sampler
    )

    if config.max_steps > 0:
        t_total = config.max_steps
        config.num_train_epochs = (
            config.max_steps
            // (len(train_dataloader) // config.gradient_accumulation_steps)
            + 1
        )
    else:
        t_total = (
            len(train_dataloader)
            // config.gradient_accumulation_steps
            * config.num_train_epochs
        )

    # prep optimizer and schedular
    no_decay = ["bias", "LayerNorm.weight"]
    # Bias terms are often excluded from weight decay to prevent the model from becoming overly regularized.
    # LayerNorm weights are often excluded because they are scale and shift parameters that do not benefit from weight decay.
    optimizer_grouped_parameters = [
        # separate parameter groups to apply different weight decay
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": config.weight_decay,
        },
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]
    optimizer = torch.optim.AdamW(
        optimizer_grouped_parameters, lr=config.learning_rate, eps=config.adam_epsilon
    )
    # scheduler lr = initial_lr + (max_lr - initial_lr) * (1 - progress) / max(1, (num_training_steps - num_warmup_steps - step))
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=config.warmup_steps, num_training_steps=t_total
    )

    # check if saved optimizer or scheduler states exists
    if os.path.isfile(
        os.path.join(config.model_name_or_path, "optimizer.pt")
    ) and os.path.isfile(os.path.join(config.model_name_or_path, "scheduler.pt")):
        # load optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(config.model_name_or_path, "optimizer.pt"))
        )
        scheduler.load_state_dict(
            torch.load(os.path.join(config.model_name_or_path, "scheduler.pt"))
        )

    if config.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please intall apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

        model, optimizer = amp.initialize()

    # multi-gpu training (should be after apex fp16 initialization)
    if config.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # distributed training (should be after apex fp16 initialization)
    if config.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[config.local_rank],
            output_device=config.local_rank,
            find_unused_parameters=True,
        )

    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", config.num_train_epochs)
    logger.info(
        "  Instantaneous batch size per GPU = %d", config.per_gpu_train_batch_size
    )
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        config.train_batch_size
        * config.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if config.local_rank != -1 else 1),
    )
    logger.info(
        "  Gradient Accumulation steps = %d", config.gradient_accumulation_steps
    )
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 1
    epochs_trained = 0
    steps_trained_in_current_epoch = 0

    # check if continuing training from a checkpoint
    if os.path.exists(config.model_name_or_path):
        try:
            # set global_step to global_step of last saved checkpoint from model path
            checkpoint_suffix = config.model_name_or_path.split("-")[-1].split("/")[0]
            global_step = int(checkpoint_suffix)
            # if gradient accumulation is used, 1bs => 1bs*gradient_accumulation_steps
            # if num_trained_epochs was trained on 1bs with k gradient_accumulation_steps,
            # it equals as k*num_trained_epoches was trained.
            epochs_trained = global_step // (
                len(train_dataloader) // config.gradient_accumulation_steps
            )
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // config.gradient_accumulation_steps
            )

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d", global_step)
            logger.info(
                "  Will skip the first %d steps in the first epoch",
                steps_trained_in_current_epoch,
            )
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()  # similar to optimizer.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(config.num_train_epochs),
        desc="Epoch",
        disable=config.local_rank not in [-1, 0],
    )

    # set random seed
    set_seed(config)

    for _ in train_iterator:
        epoch_iterator = tqdm(
            train_dataloader, desc="Iteration", disable=config.local_rank not in [-1, 0]
        )
        for step, batch in enumerate(epoch_iterator):
            # skip already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            # load input_ids, attention_mask, token_type_ids, start_positions and end_positions to gpu
            batch = tuple(t.to(config.device) for t in batch)

            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
                "start_positions": batch[3],
                "end_positions": batch[4],
            }

            if config.model_type in [
                "xlm",
                "roberta",
                "distilbert",
                "camembert",
                "bart",
                "longformer",
            ]:
                del inputs["token_type_ids"]

            if config.model_type in ["xlnet", "xlm"]:
                inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
                if config.version_2_with_negative:
                    inputs.update({"is_impossible": batch[7]})
                if hasattr(model, "config") and hasattr(model.config, "lang2id"):
                    inputs.update(
                        {
                            "langs": (
                                torch.ones(batch[0].shape, dtypes=torch.int64)
                                * config.lang_id
                            ).to(config.device)
                        }
                    )

            # forward
            outputs = model(**inputs)  # type(outputs): QuestionAnsweringModelOutput
            loss = outputs[0]  # to_tuple method of QuestionAnsweringModelOutput

            if config.n_gpu > 1:
                loss = (
                    loss.mean()
                )  # mean() to average on multi-gpu parallel (not distributed) training
            if config.gradient_accumulation_steps > 1:
                loss = loss / config.gradient_accumulation_steps

            if config.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % config.gradient_accumulation_steps == 0:
                if config.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), config.max_grad_norm
                    )
                else:
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(), config.max_grad_norm
                    )

                optimizer.step()
                scheduler.step()
                model.zero_grad()
                global_step += 1  # increase only when gradient is updated

                # log metrics
                if (
                    config.local_rank in [-1, 0]
                    and config.logging_steps > 0
                    and global_step % config.logging_steps == 0
                ):
                    # evaluate only train with single gpu otherwise metrics may not average well
                    if config.local_rank == -1 and config.evaluate_during_training:
                        results = evaluate(config, model, tokenizer)
                        for k, v in results.items():
                            tb_writer.add_scalar("eval_{}".format(k), v, global_step)
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar(
                        "loss",
                        (tr_loss - logging_loss) / config.logging_steps,
                        global_step,
                    )
                    logging_loss = tr_loss

                # save model checkpoints
                if (
                    config.local_rank in [-1, 0]
                    and config.save_steps > 0
                    and global_step % config.save_steps == 0
                ):
                    output_dir = os.path.join(
                        config.output_dir, "checkpoint-{}".format(global_step)
                    )
                    # take care of distributed/parallel training
                    model_to_save = model.module if hasattr(model, "module") else model
                    model_to_save.save_pretrained(
                        output_dir
                    )  # to use from_pretrained to load model later
                    tokenizer.save_pretrained(output_dir)

                    torch.save(config, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(
                        optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")
                    )
                    torch.save(
                        scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")
                    )
                    logger.info(
                        "Saving optimizer and scheduler states to %s", output_dir
                    )

            if config.max_steps > 0 and global_step > config.max_step:
                epoch_iterator.close()
                break

        if config.max_steps > 0 and global_step > config.max_steps:
            train_iterator.close()
            break

    if config.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step


def evaluate(config, model, tokenizer):
    pass

In [None]:
if config.do_train:
    train_dataset = load_and_cache_example(
        config, tokenizer, evaluate=False, output_examples=False
    )
    global_step, tr_loss = train(config, train_dataset, model, tokenizer)
    logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)