In [1]:
from torch import nn
from torch.optim import AdamW
from poutyne import Model
from poutyne_transformers import ModelWrapper, MetricWrapper, TransformerCollator
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForTokenClassification
from datasets import load_from_disk
from poutyne_modules import (
    so_data_collator,
    make_tokenization_func,
    make_rename_func,
    PoutyneSequenceOrderingLoss,
    make_compute_metrics_functions,
)

In [2]:
MODEL_NAME_OR_PATH = "bert-base-cased"
LEARNING_RATE = 3e-5
TRAIN_BATCH_SIZE = 8
VAL_BATCH_SIZE = 16
DEVICE = "cuda:0"
N_EPOCHS = 3

In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH)
transformer = AutoModelForTokenClassification.from_pretrained(
    MODEL_NAME_OR_PATH, return_dict=True, num_labels=1
)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cas

In [4]:
dataset = load_from_disk("../data/rocstories/")
# Downsampling for debugging...
# dataset = dataset.filter(lambda _, index: index < 300, with_indices=True)

Loading cached processed dataset at ../data/rocstories/train/cache-b0bad00b2d348ae1.arrow
Loading cached processed dataset at ../data/rocstories/test/cache-59718d4747a1198d.arrow
Loading cached processed dataset at ../data/rocstories/val/cache-a08ecbcb0ab07ce7.arrow


In [5]:
tokenization_func = make_tokenization_func(
    tokenizer=tokenizer,
    text_column="text",
    add_special_tokens=False,
    padding="max_length",
    truncation=True,
)
dataset = dataset.map(tokenization_func, batched=True)

Loading cached processed dataset at ../data/rocstories/train/cache-53c3e65c13787504.arrow
Loading cached processed dataset at ../data/rocstories/test/cache-dfaaa7d79a46d423.arrow
Loading cached processed dataset at ../data/rocstories/val/cache-ce1c1014001a7d07.arrow


In [6]:
rename_target_column = make_rename_func({"so_targets": "labels"}, remove_src=True)
dataset = dataset.map(rename_target_column, batched=True)

Loading cached processed dataset at ../data/rocstories/train/cache-5b833dd8c1859483.arrow
Loading cached processed dataset at ../data/rocstories/test/cache-b96a85dd76e25ba5.arrow
Loading cached processed dataset at ../data/rocstories/val/cache-4d9643d138ffa502.arrow


In [7]:
dataset

DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'input_ids', 'labels', 'sentence1', 'sentence2', 'sentence3', 'sentence4', 'sentence5', 'storyid', 'storytitle', 'text', 'token_type_ids'],
        num_rows: 300
    })
    test: Dataset({
        features: ['attention_mask', 'input_ids', 'labels', 'sentence1', 'sentence2', 'sentence3', 'sentence4', 'sentence5', 'storyid', 'storytitle', 'text', 'token_type_ids'],
        num_rows: 300
    })
    val: Dataset({
        features: ['attention_mask', 'input_ids', 'labels', 'sentence1', 'sentence2', 'sentence3', 'sentence4', 'sentence5', 'storyid', 'storytitle', 'text', 'token_type_ids'],
        num_rows: 300
    })
})

In [8]:
dataset = dataset.remove_columns(
    ["text", "storyid", "storytitle"] + [f"sentence{i}" for i in range(1, 6)]
)
dataset.set_format("torch")

In [9]:
collate_fn = TransformerCollator(
    y_keys=["labels", "input_ids"], custom_collator=so_data_collator, remove_labels=True
)

train_dataloader = DataLoader(
    dataset["train"], batch_size=TRAIN_BATCH_SIZE, collate_fn=collate_fn
)
val_dataloader = DataLoader(
    dataset["val"], batch_size=VAL_BATCH_SIZE, collate_fn=collate_fn
)
test_dataloader = DataLoader(
    dataset["test"], batch_size=VAL_BATCH_SIZE, collate_fn=collate_fn
)

In [10]:
wrapped_transformer = ModelWrapper(transformer)

optimizer = AdamW(wrapped_transformer.parameters(), lr=LEARNING_RATE)
loss_fn = PoutyneSequenceOrderingLoss(target_token_id=tokenizer.cls_token_id)

metrics = [
    MetricWrapper(func)
    for func in make_compute_metrics_functions(tokenizer.cls_token_id)
]

model = Model(
    wrapped_transformer, optimizer, loss_fn, batch_metrics=metrics, device=DEVICE
)

In [11]:
model.fit_generator(train_dataloader, val_dataloader, epochs=N_EPOCHS)

TokenClassifierOutput(loss=None, logits=tensor([[[-6.6118e-01],
         [-5.6550e-01],
         [-7.2178e-01],
         ...,
         [-6.8481e-01],
         [-6.3319e-01],
         [-6.7567e-01]],

        [[-6.0914e-01],
         [-5.1531e-01],
         [-7.3804e-01],
         ...,
         [-6.1154e-01],
         [-6.5949e-01],
         [-8.4530e-01]],

        [[-5.9103e-01],
         [-3.2422e-01],
         [-7.1828e-01],
         ...,
         [-5.8666e-01],
         [-2.6737e-01],
         [-6.8603e-01]],

        ...,

        [[-5.6462e-02],
         [-7.6171e-01],
         [-7.4970e-01],
         ...,
         [-8.2083e-01],
         [ 4.2890e-02],
         [-6.1885e-01]],

        [[-6.4896e-01],
         [-5.5668e-01],
         [-7.3285e-01],
         ...,
         [ 1.0430e-03],
         [-6.8947e-01],
         [-7.2553e-01]],

        [[-9.4805e-01],
         [-7.6723e-01],
         [-1.0490e+00],
         ...,
         [-7.1508e-01],
         [-8.2326e-01],
         [-7

AttributeError: 'str' object has no attribute 'reshape'

In [None]:
test_data = model.evaluate_generator(test_dataloader)

[35mTest steps: [36m19 [32m3.72s [35mtest_loss:[94m 1.883897[0m                                                


In [None]:
print(test_data)

1.8838973759114743