In [1]:
# Copyright (c) 2021 Graphcore Ltd. All rights reserved.
# Copyright 2020 The HuggingFace Team All rights reserved.
import collections
from tqdm.auto import tqdm
import numpy as np
import torch
from transformers import BertTokenizerFast, default_data_collator
import torch
import poptorch

In [2]:
class PadCollate:
    """
    Collate into a batch and pad the batch up to a fixed size.
    """
    def __init__(self, batch_size, padding_val_dict=None):
        self.batch_size = batch_size
        self.padding_val_dict = padding_val_dict

    def pad_tensor(self, x, val):
        pad_size = list(x.shape)
        pad_size[0] = self.batch_size - x.size(0)
        return torch.cat([x, val*torch.ones(*pad_size, dtype=x.dtype)], dim=0)

    def __call__(self, batch):
        size = len(batch)
        batch = default_data_collator(batch)
        if size < self.batch_size:
            for k in batch.keys():
                batch[k] = self.pad_tensor(batch[k], self.padding_val_dict[k])
        return batch

max_seq_length = 384
doc_stride = 128
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")


# `prepare_train_features` comes unmodified from
# https://github.com/huggingface/transformers/blob/master/examples/pytorch/question-answering/run_qa.py
def prepare_train_features(examples):
    # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    pad_on_right = tokenizer.padding_side == "right"
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_seq_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

In [3]:
from datasets import load_dataset, load_metric

datasets = load_dataset("squad")
train_dataset = datasets["train"]

train_dataset = train_dataset.map(
    prepare_train_features,
    batched=True,
    num_proc=1,
    remove_columns=train_dataset.column_names,
    load_from_cache_file=True,
)

Reusing dataset squad (/home/adamw/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)


  0%|          | 0/2 [00:00<?, ?it/s]

Loading cached processed dataset at /home/adamw/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453/cache-db2e65aee1bcfc2e.arrow


In [23]:
import torch
import poptorch
import transformers

class Wrapped(transformers.BertForQuestionAnswering):
    def __init__(self):
        super().__init__(transformers.BertConfig())
        self.bert.embeddings = poptorch.BeginBlock(self.bert.embeddings, "Embedding", ipu_id=0)

        for index, layer in enumerate(self.bert.encoder.layer):
            self.bert.encoder.layer[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=1)
       
        self.qa_outputs = poptorch.BeginBlock(self.qa_outputs, "QA Outputs", ipu_id=2)

    def forward(self, input_ids, attention_mask, token_type_ids, start_positions=None, end_positions=None):
        inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids,
            "start_positions": start_positions,
            "end_positions": end_positions
        }
        output = super().forward(**inputs)
        
        if self.training:
            final_loss = poptorch.identity_loss(output.loss, reduction="none")
            return final_loss, output.start_logits, output.end_logits
        else:
            return output.start_logits, output.end_logits

In [37]:
import popart 

opts = poptorch.Options()
opts.deviceIterations(8)
opts.autoRoundNumIPUs(True)
#opts.Training.setAutomaticLossScaling(True)
opts.anchorMode(poptorch.AnchorMode.Sum)
opts.setExecutionStrategy(
    poptorch.PipelinedExecution(
        poptorch.AutoStage.AutoIncrement
    )
)
opts.Precision.enableStochasticRounding(True)
opts.Precision.setPartialsType(torch.float16)
opts._Popart.set("disableGradAccumulationTensorStreams", True)
opts._Popart.set("subgraphCopyingStrategy", int(popart.SubgraphCopyingStrategy.JustInTime))
opts._Popart.set("outlineThreshold", 10.0)
opts._Popart.set("accumulateOuterFragmentSettings.schedule",
                 int(popart.AccumulateOuterFragmentSchedule.OverlapMemoryOptimized))
opts._Popart.set("accumulateOuterFragmentSettings.excludedVirtualGraphs", ["0"])


mem_prop = {
    f'IPU{i}': 0.2
    for i in range(5)
}
opts.setAvailableMemoryProportion(mem_prop)

<poptorch.options.Options at 0x7f7453071f98>

In [38]:
sequence_length = 384
samples_per_step = 2
num_epochs = 3

train_dataloader = poptorch.DataLoader(
    options=opts, 
    dataset=train_dataset, 
    shuffle=True, 
    batch_size=8,
    drop_last=True,
    collate_fn=PadCollate(
        samples_per_step,
        {"input_ids": 0,
         "attention_mask": 0,
         "token_type_ids": 0,
         "start_positions": sequence_length,
         "end_positions": sequence_length})
)

In [39]:
from torch import float16, float32
from transformers import get_linear_schedule_with_warmup
from poptorch import DataLoader


model_ipu = Wrapped().half()

regularized_params = []
non_regularized_params = []
for param in model_ipu.parameters():
    if param.requires_grad:
        if len(param.shape) == 1:
            non_regularized_params.append(param)
        else:
            regularized_params.append(param)
params = [
    {"params": regularized_params, "weight_decay": 0},
    {"params": non_regularized_params, "weight_decay": 0}
]
optimizer = poptorch.optim.AdamW(
    params,
    lr=5e-5,
    weight_decay=0,
    eps=1e-6,
    bias_correction=False,
    loss_scaling=1.0,
    accum_type=float16,
    first_order_momentum_accum_type=float16,
    second_order_momentum_accum_type=float32
)

num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer,
    0,
    num_epochs * len(train_dataloader)
)

In [40]:
model_ipu.train()
training_model = poptorch.trainingModel(model_ipu, opts, optimizer)

In [None]:
import time

sample_batch = next(iter(train_dataloader))
start_compile = time.perf_counter()

training_model.compile(sample_batch["input_ids"],
                               sample_batch["attention_mask"],
                               sample_batch["token_type_ids"],
                               sample_batch["start_positions"],
                               sample_batch["end_positions"])

duration_compilation = time.perf_counter() - start_compile

Graph compilation:   0%|          | 0/100 [43:12<?]
Graph compilation:  34%|███▍      | 34/100 [03:13<05:17]

In [35]:
from tqdm.auto import tqdm

progress_bar = tqdm(range(num_training_steps))


for epoch in range(num_epochs):
    for batch in train_dataloader:
        outputs = training_model(
            batch["input_ids"],
            batch["attention_mask"],
            batch["token_type_ids"],
            batch["start_positions"],
            batch["end_positions"]
        )
        
        loss = outputs.loss
        loss.backward()
        
        lr_scheduler.step()
        training_model.setOptimizer(optimizer)

        progress_bar.set_description(
            f"Epoch: {epoch}, LR={lr_scheduler.get_last_lr()[0]:.2e}, loss={loss:3.3f}"
        )
        progress_bar.update(1)

  0%|          | 0/4149 [00:00<?, ?it/s]



Graph compilation:   0%|          | 0/100 [00:00<?][A[A

Graph compilation:   3%|▎         | 3/100 [00:55<29:52][A[A

Graph compilation:   4%|▍         | 4/100 [00:57<21:08][A[A

Graph compilation:   7%|▋         | 7/100 [01:31<18:53][A[A

Graph compilation:  13%|█▎        | 13/100 [01:32<06:41][A[A

Graph compilation:  15%|█▌        | 15/100 [01:35<05:35][A[A

Graph compilation:  16%|█▌        | 16/100 [01:36<05:00][A[A

Graph compilation:  17%|█▋        | 17/100 [01:38<04:23][A[A

Graph compilation:  20%|██        | 20/100 [01:38<02:29][A[A

Graph compilation:  21%|██        | 21/100 [01:55<05:47][A[A

Graph compilation:  21%|██        | 21/100 [01:55<05:47][A[A

Graph compilation:  22%|██▏       | 22/100 [01:55<04:39][A[A

Graph compilation:  22%|██▏       | 22/100 [02:17<04:39][A[A

Graph compilation:  23%|██▎       | 23/100 [02:40<15:57][A[A

Graph compilation:  26%|██▌       | 26/100 [02:40<08:02][A[A

Graph compilation:  26%|██▌       | 26/100 [03

Error: In poptorch/python/poptorch.cpp:1220: 'popart_exception': Out of memory on tile 1472: 3631524 bytes used but tiles only have 638976 bytes of memory
Error raised in:
  [0] popart::popx::IrLowering::getExecutable()
  [1] popart::popx::Executablex::getPoplarExecutable()
  [2] popart::popx::Devicex::prepare()
  [3] popart::Session::prepareDevice(bool)
  [4] poptorch::Compiler::compileAndPrepareDevice()
  [5] popart::Session::prepareDevice: Poplar compilation
  [6] Compiler::compileAndPrepareDevice
  [7] LowerToPopart::compile


In [None]:
training_model.detachFromDevice()