In [1]:
import os
import poptorch
import torch

from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_scheduler, AdamW
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

import numpy as np

In [2]:
raw_datasets = load_dataset("imdb")
tokenizer = AutoTokenizer.from_pretrained("google/electra-base-generator")

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = raw_datasets.map(tokenize_function, remove_columns=['text'])

train_dataset = tokenized_datasets["train"].shuffle(seed=42)
eval_dataset = tokenized_datasets["test"].shuffle(seed=42)

train_dataset = train_dataset.rename_column(original_column_name='label', new_column_name='labels')
eval_dataset = eval_dataset.rename_column(original_column_name='label', new_column_name='labels')

train_dataset.set_format(type='torch')
eval_dataset.set_format(type='torch')

opts = poptorch.Options().deviceIterations(8)
train_dataloader = poptorch.DataLoader(
    options=opts, 
    dataset=train_dataset, 
    shuffle=True, 
    batch_size=4, 
    drop_last=True
)

val_opts = poptorch.Options().deviceIterations(8).anchorMode(poptorch.AnchorMode.All)
eval_dataloader = poptorch.DataLoader(
    options=val_opts, 
    dataset=eval_dataset, 
    batch_size=4, 
    drop_last=True
)

Reusing dataset imdb (/home/kamilp/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a)
Loading cached processed dataset at /home/kamilp/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-57f371c49ae71cc3.arrow
Loading cached processed dataset at /home/kamilp/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-fe6a5da606659f22.arrow
Loading cached processed dataset at /home/kamilp/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-9284272e82823a4e.arrow
Loading cached shuffled indices for dataset at /home/kamilp/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-c61fe55def0a4b5c.arrow
Loading cached shuffled indices for dataset at /home/kamilp/.cache

In [3]:
class Wrapped(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model       

    def forward(self, input_ids, token_type_ids, attention_mask, labels):
        loss, logits = self.model.forward(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        
        if self.model.training:
            final_loss = poptorch.identity_loss(loss, reduction="none")
            return final_loss, logits
        
        return loss, logits
    
model = AutoModelForSequenceClassification.from_pretrained("google/electra-base-generator", num_labels=2, return_dict=False)
optimizer = AdamW(model.parameters(), lr=1e-5)

Some weights of the model checkpoint at google/electra-base-generator were not used when initializing ElectraForSequenceClassification: ['generator_predictions.LayerNorm.bias', 'generator_predictions.dense.bias', 'generator_predictions.LayerNorm.weight', 'generator_lm_head.bias', 'generator_lm_head.weight', 'generator_predictions.dense.weight']
- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-base-generator and are newly initialized

In [4]:
model.electra.embeddings = poptorch.BeginBlock(model.electra.embeddings, "Embedding", ipu_id=0)

layer_ipu = [0,0,0,1,1,1,2,2,2,3,3,3]
for index, layer in enumerate(model.electra.encoder.layer):
    ipu = layer_ipu[index]
    model.electra.encoder.layer[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu)
    
model.classifier = poptorch.BeginBlock(model.classifier, "Classifier", ipu_id=3)

In [5]:
trainingModel = poptorch.trainingModel(Wrapped(model), options=opts)
inferenceModel = poptorch.inferenceModel(Wrapped(model), options=val_opts)

In [6]:
trainingModel.compile(**next(iter(train_dataloader)))

Graph compilation: 100%|██████████| 100/100 [06:54<00:00]


In [7]:
inferenceModel.compile(**next(iter(eval_dataloader)))

Graph compilation: 100%|██████████| 100/100 [01:47<00:00]


In [8]:
y_pred, y_true, losses = [], [], []
for batch in eval_dataloader:
    y_true.extend(batch['labels'].tolist())

    with torch.no_grad():
        loss, logits = inferenceModel(**batch)

    losses.extend(loss.tolist())
    y_pred.extend(logits.argmax(dim=1).tolist())

acc = accuracy_score(y_true, y_pred)
val_loss = np.mean(losses)

print(acc, val_loss)

0.49991997439180536 0.6932223301080369


In [9]:
train_loss, val_loss, acc = 0,0,0
progress_bar = tqdm(range(2), total=2)

for epoch in progress_bar:
    for batch in train_dataloader:
        loss, logits = trainingModel(**batch)
        train_loss = loss.item()
        
        progress_bar.set_postfix({'epoch':epoch, 'train_loss': train_loss, 'val_loss':val_loss, 'acc': acc})
        
    trainingModel.copyWeightsToHost()
    inferenceModel.copyWeightsToDevice()

    y_pred, y_true, losses = [], [], []
    for batch in eval_dataloader:
        y_true.extend(batch['labels'].tolist())
        
        with torch.no_grad():
            loss, logits = inferenceModel(**batch)
            
        losses.extend(loss.tolist())
        y_pred.extend(logits.argmax(dim=1).tolist())
        
    acc = accuracy_score(y_true, y_pred)
    val_loss = np.mean(losses)
    
    print(acc, val_loss)
    
    progress_bar.set_postfix({'epoch':epoch, 'train_loss': train_loss, 'val_loss':val_loss, 'acc': acc})

  0%|          | 0/2 [00:00<?, ?it/s]

0.8753201024327785 0.30010253097385015
0.9041293213828425 0.2394368511843155


In [10]:
train_dataloader.terminate()
eval_dataloader.terminate()