In [1]:
from datasets import set_caching_enabled
set_caching_enabled(False)

import pprint
pp = pprint.PrettyPrinter(depth=6, compact=True)
print = pp.pprint

In [None]:
from poutyne import Model

...

network = make_network()
X_train, y_train = load_data(subset="train")
X_val, y_val = load_data(subset="validation")
X_test, y_test = load_data(subset="test")

model = Model(
    network,
    "sgd",
    "cross_entropy",
    batch_metrics=["accuracy"],
    epoch_metrics=["f1"],
    device="cuda:0"
)

model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=5,
    batch_size=64
)

results = model.evaluate(X_test, y_test, batch_size=128)

In [None]:
from typing import Any, Callable, Dict, List, Tuple, Union

import torch
from transformers import default_data_collator


class TransformerCollator:
    def __init__(
        self,
        y_keys: Union[str, List[str]] = None,
        custom_collator: Callable = None,
        remove_labels: bool = False,
    ):
        self.y_keys = y_keys
        self.custom_collator = (
            custom_collator if custom_collator is not None else default_data_collator
        )
        self.remove_labels = remove_labels

    def __call__(self, inputs: Tuple[Dict]) -> Tuple[Dict, Any]:
        batch_size = len(inputs)
        batch = self.custom_collator(inputs)
        if self.y_keys is None:
            y = torch.tensor(float("nan")).repeat(batch_size)
        elif isinstance(self.y_keys, list):
            y = {
                key: batch.pop(key)
                if "labels" in key and self.remove_labels
                else batch.get(key)
                for key in self.y_keys
            }
        else:
            y = batch.pop(self.y_keys) if self.remove_labels else batch.get(self.y_keys)
        return batch, y

In [2]:
from transformers import AutoModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
model = AutoModel.from_pretrained("bert-base-cased")

inputs = tokenizer("Poutyne is inspired by Keras", return_tensors="pt")
print(model(**inputs).keys())

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


odict_keys(['last_hidden_state', 'pooler_output'])


In [None]:
from typing import Any, Dict
from torch import nn
from transformers import PreTrainedModel


class ModelWrapper(nn.Module):
    def __init__(self, transformer: PreTrainedModel):
        super().__init__()
        self.transformer = transformer

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({repr(self.transformer)})"

    def forward(self, inputs) -> Dict[str, Any]:
        return self.transformer(**inputs)

    def save_pretrained(self, *args, **kwargs) -> None:
        self.transformer.save_pretrained(*args, **kwargs)

In [None]:
class PoutyneSequenceOrderingLoss:
    def __init__(self, target_token_id):
        self.target_token_id = target_token_id

    def __call__(self, outputs, targets) -> float:
        batch_labels = targets["labels"]
        batch_logits = outputs["logits"]
        batch_input_ids = targets["input_ids"]

        # Since we have varying number of labels per instance, we need to compute the loss manually for each one.
        loss_fn = nn.MSELoss(reduction="sum")
        batch_loss = torch.tensor(0.0, dtype=torch.float64, requires_grad=True)
        for labels, logits, input_ids in zip(
            batch_labels, batch_logits, batch_input_ids
        ):
            # Firstly, we need to convert the sentence indices to regression targets.
            # To avoid exploding gradients, we norm them to be in range 0 <-> 1
            # Also we need to remove the padding entries (-100)
            true_labels = labels[labels != -100].reshape(-1)
            targets = true_labels.float()

            # Secondly, we need to get the logits from each target token in the input sequence
            target_logits = logits[input_ids == self.target_token_id].reshape(-1)

            # Sometimes we will have less target_logits than targets due to trunction of the input
            # In this case, we just consider as many targets as we have logits
            if target_logits.size(0) < targets.size(0):
                targets = targets[: target_logits.size(0)]

            # Finally we compute the loss for the current instance and add it to the batch loss
            batch_loss = batch_loss + loss_fn(targets, target_logits)

        # The final loss is obtained by averaging over the number of instances per batch
        loss = batch_loss / batch_logits.size(0)

        return loss

In [3]:
from typing import Any, Callable, Dict

class MetricWrapper:
    def __init__(self, metric: Callable, pred_key: str = "logits", y_key: str = None):
        self.metric = metric
        self.pred_key = pred_key
        self.y_key = y_key
        self._set_metric_name(metric)

    def _set_metric_name(self, metric):
        self.__name__ = metric.__name__

    def __call__(self, outputs: Dict[str, Any], y_true: Any):
        y_pred = outputs[self.pred_key]
        if self.y_key is not None:
            y_true = outputs[self.y_key]
        return self.metric(y_pred, y_true)

In [4]:
import numpy as np
from collections import defaultdict
from functools import partial
from sklearn.metrics import accuracy_score
from scipy.stats import kendalltau

def make_compute_metrics_functions(target_token_id) -> Callable:
    def compute_ranking_func(
        outputs: Dict, targets: Any, metric_key: str
    ) -> Dict[str, float]:
        batch_sent_idx = targets["labels"].detach().cpu().numpy()
        batch_input_ids = targets["input_ids"].detach().cpu().numpy()
        batch_logits = outputs.detach().cpu().numpy()

        metrics = defaultdict(list)
        for sent_idx, input_ids, logits in zip(
            batch_sent_idx, batch_input_ids, batch_logits
        ):
            sent_idx = sent_idx.reshape(-1)
            input_ids = input_ids.reshape(-1)
            logits = logits.reshape(-1)

            sent_idx = sent_idx[sent_idx != 100]
            target_logits = logits[input_ids == target_token_id]
            if sent_idx.shape[0] > target_logits.shape[0]:
                sent_idx = sent_idx[: target_logits.shape[0]]
            # Calling argsort twice on the logits gives us their ranking in ascending order
            predicted_idx = np.argsort(np.argsort(target_logits))
            tau, pvalue = kendalltau(sent_idx, predicted_idx)
            acc = accuracy_score(sent_idx, predicted_idx)
            metrics["kendalls_tau"].append(tau)
            metrics["acc"].append(acc)
            metrics["mean_logits"].append(logits.mean())
            metrics["std_logits"].append(logits.std())
        metrics = {metric: np.mean(scores) for metric, scores in metrics.items()}
        return metrics[metric_key]

    metrics = []
    for metric in ("acc", "kendalls_tau", "mean_logits", "std_logits"):
        metric_func = partial(compute_ranking_func, metric_key=metric)
        metric_func.__name__ = metric
        metrics.append(metric_func)
    return metrics

metrics = [
        MetricWrapper(func)
        for func in make_compute_metrics_functions(0)
    ]
print([metric.__name__ for metric in metrics])

['acc', 'kendalls_tau', 'mean_logits', 'std_logits']


In [None]:
import json
from poutyne.framework import experiment
from torch.optim import AdamW
from poutyne import (
    set_seeds,
    TensorBoardLogger,
    TensorBoardGradientTracker,
    Experiment,
)
from poutyne_transformers import ModelWrapper, MetricWrapper, TransformerCollator
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForTokenClassification
from datasets import load_from_disk
from poutyne_modules import (
    make_tokenization_func,
    PoutyneSequenceOrderingLoss,
    make_compute_metrics_functions,
    so_data_collator,
)


if __name__ == "__main__":
    set_seeds(42)

    MODEL_NAME_OR_PATH = "bert-base-cased"
    LEARNING_RATE = 3e-5
    TRAIN_BATCH_SIZE = 8
    VAL_BATCH_SIZE = 16
    DEVICE = 0
    N_EPOCHS = 3
    SAVE_DIR = "experiments/rocstories/bert"

    print("Loading model & tokenizer.")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH)
    transformer = AutoModelForTokenClassification.from_pretrained(
        MODEL_NAME_OR_PATH, return_dict=True, num_labels=1
    )

    print("Loading & preparing data.")
    dataset = load_from_disk("../data/rocstories/")

    if tokenizer.cls_token != "[CLS]":
        print(
            f"Model does not a have a [CLS] token. Updating the data with token {tokenizer.cls_token} ..."
        )

        def replace_cls_token(entry):
            texts = entry["text"]
            replaced_texts = []
            for text in texts:
                replaced_texts.append(text.replace("[CLS]", tokenizer.cls_token))
            entry["text"] = replaced_texts
            return entry

        dataset = dataset.map(replace_cls_token, batched=True)

    tokenization_func = make_tokenization_func(
        tokenizer=tokenizer,
        text_column="text",
        add_special_tokens=False,
        padding="max_length",
        truncation=True,
    )
    dataset = dataset.map(tokenization_func, batched=True)

    dataset = dataset.rename_column("so_targets", "labels")

    dataset = dataset.remove_columns(
        ["text", "storyid", "storytitle"] + [f"sentence{i}" for i in range(1, 6)]
    )
    dataset.set_format("torch")

    collate_fn = TransformerCollator(
        y_keys=["labels", "input_ids"],
        custom_collator=so_data_collator,
        remove_labels=True,
    )

    train_dataloader = DataLoader(
        dataset["train"], batch_size=TRAIN_BATCH_SIZE, collate_fn=collate_fn
    )
    val_dataloader = DataLoader(
        dataset["val"], batch_size=VAL_BATCH_SIZE, collate_fn=collate_fn
    )
    test_dataloader = DataLoader(
        dataset["test"], batch_size=VAL_BATCH_SIZE, collate_fn=collate_fn
    )

    print("Preparing training.")
    wrapped_transformer = ModelWrapper(transformer)
    optimizer = AdamW(wrapped_transformer.parameters(), lr=LEARNING_RATE)
    loss_fn = PoutyneSequenceOrderingLoss(target_token_id=tokenizer.cls_token_id)

    metrics = [
        MetricWrapper(func)
        for func in make_compute_metrics_functions(tokenizer.cls_token_id)
    ]

    writer = SummaryWriter("runs/roberta/1")
    tensorboard_logger = TensorBoardLogger(writer)
    gradient_logger = TensorBoardGradientTracker(writer)

    experiment = Experiment(
        directory=SAVE_DIR,
        network=wrapped_transformer,
        device=DEVICE,
        logging=True,
        optimizer=optimizer,
        loss_function=loss_fn,
        batch_metrics=metrics,
    )

    experiment.train(
        train_generator=train_dataloader,
        valid_generator=val_dataloader,
        epochs=N_EPOCHS,
        save_every_epoch=True,
    )

    test_results = experiment.test(test_generator=test_dataloader)
    with open(f"test_results_{MODEL_NAME_OR_PATH}.json", "w") as f:
        json.dump(test_results, f)