In [1]:
batch_size = 8
device_iterations = 8
num_epochs = 3

In [2]:
from datasets import load_dataset
from transformers import AutoTokenizer

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

raw_datasets = load_dataset("imdb")

tokenizer = AutoTokenizer.from_pretrained(
    "google/electra-small-discriminator", 
    return_dict=False, 
    strict=False
)
def tokenize_function(examples):
    return tokenizer(examples['text'], truncation=True, padding = "max_length", max_length=128)

tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["text"])

tokenized_datasets.set_format("torch")

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(100))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(100))

Reusing dataset imdb (/home/adamw/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/adamw/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-887b23ecc981d6fc.arrow
Loading cached processed dataset at /home/adamw/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-75cc5b06e4079d09.arrow
Loading cached processed dataset at /home/adamw/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-0a04fa3fd453fef7.arrow
Loading cached shuffled indices for dataset at /home/adamw/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-1712cca7dfe0d81e.arrow
Loading cached shuffled indices for dataset at /home/adamw/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-e444912a4fb84160.arrow


In [19]:
from transformers import AutoModelForSequenceClassification, get_scheduler
import poptorch
import torch

model = AutoModelForSequenceClassification.from_pretrained(
    "google/electra-base-discriminator", 
    num_labels=2,
    return_dict=False
)


model.electra.embeddings.position_embeddings = poptorch.BeginBlock(
    layer_to_call=model.electra.embeddings.position_embeddings,
    ipu_id=1
)

layer_ipu = [0,0,0,1,1,1,2,2,2,3,3,3]

for index, layer in enumerate(model.electra.encoder.layer):
    ipu = layer_ipu[index]
    model.electra.encoder.layer[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu)

Some weights of the model checkpoint at google/electra-base-discriminator were not used when initializing ElectraForSequenceClassification: ['discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense.bias', 'discriminator_predictions.dense_prediction.weight']
- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-base-discriminator and are newly initialized: ['classifier.d

In [20]:
opts = poptorch.Options().deviceIterations(device_iterations)

In [21]:
class WrappedModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    
    def forward(self, input_ids, token_type_ids, attention_mask, labels, position_ids=None, 
                head_mask=None, inputs_embeds=None):
        outputs, something = self.model.forward(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
            labels=labels,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds
        )
        print(something)
        if self.model.training:
            final_loss = poptorch.identity_loss(outputs.loss, reduction="none")
            return final_loss, outputs
        return outputs

    def __getattr__(self, attr):
        try:
            return torch.nn.Module.__getattr__(self, attr)
        except AttributeError:
            return getattr(self.model, attr)

In [22]:
wm = WrappedModel(model)
wm.train()

from poptorch.optim import AdamW
from torch import float16, float32
from poptorch import DataLoader


optimizer = poptorch.optim.AdamW(
    params=wm.parameters(),
    lr=5e-5,
    weight_decay=0,
    eps=1e-6,
    bias_correction=False,
    loss_scaling=1.0,
    accum_type=float16,
    first_order_momentum_accum_type=float16,
    second_order_momentum_accum_type=float32
)

tm = poptorch.trainingModel(wm, opts, optimizer)

In [23]:
from poptorch import DataLoader
train_dataloader = DataLoader(options=opts, dataset=small_train_dataset, shuffle=True, batch_size=batch_size, drop_last=True)
test_dataloader = DataLoader(options=opts, dataset=small_eval_dataset, batch_size=batch_size, drop_last=True)

num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

In [24]:
from tqdm.auto import tqdm

progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch["labels"]=batch["label"]
        del batch["label"]
        lr_scheduler.step()
        outputs = tm(**batch)
        progress_bar.update(1)

  0%|          | 0/3 [00:00<?, ?it/s]



tensor([[-0.1509,  0.0389],
        [-0.0984, -0.0087],
        [-0.0605, -0.0780],
        [-0.1464,  0.0123],
        [-0.0714, -0.0025],
        [-0.0938, -0.0438],
        [-0.1306, -0.0215],
        [-0.0930, -0.0462]], grad_fn=<AddmmBackward>)


AttributeError: 'Tensor' object has no attribute 'loss'