In [1]:
from torch import nn
from torch.optim import AdamW
from poutyne import Model
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForTokenClassification
from datasets import load_from_disk
from poutyne_modules import so_data_collator, make_tokenization_func, make_rename_func, TransformerPoutyneCollator, TransformerPoutyneWrapper, PoutyneSequenceOrderingLoss

In [2]:
MODEL_NAME_OR_PATH = 'bert-base-cased'
LEARNING_RATE = 3e-5
TRAIN_BATCH_SIZE = 8
VAL_BATCH_SIZE = 16

In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH)
transformer = AutoModelForTokenClassification.from_pretrained(MODEL_NAME_OR_PATH, return_dict=True, num_labels=1)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cas

In [4]:
dataset = load_from_disk('../data/rocstories/')
# Downsampling for debugging...
dataset = dataset.filter(lambda _, index: index < 300, with_indices=True)

Loading cached processed dataset at ../data/rocstories/train/cache-d64ac32a003ac6aa.arrow
Loading cached processed dataset at ../data/rocstories/test/cache-33ee7864532d589f.arrow
Loading cached processed dataset at ../data/rocstories/val/cache-d9bb8d842b3eec4b.arrow


In [5]:
tokenization_func = make_tokenization_func(tokenizer=tokenizer, text_column='text', add_special_tokens=False, padding='max_length', truncation=True)
dataset = dataset.map(tokenization_func, batched=True)

Loading cached processed dataset at ../data/rocstories/train/cache-5c74ad6a5918980a.arrow
Loading cached processed dataset at ../data/rocstories/test/cache-876e1e2f6cf44419.arrow
Loading cached processed dataset at ../data/rocstories/val/cache-f82eefa16a48f2e7.arrow


In [6]:
rename_target_column = make_rename_func({'so_targets': 'labels'}, remove_src=True)
dataset = dataset.map(rename_target_column, batched=True)

Loading cached processed dataset at ../data/rocstories/train/cache-ecebbc5fa4253ec3.arrow
Loading cached processed dataset at ../data/rocstories/test/cache-e07e983748c6664f.arrow
Loading cached processed dataset at ../data/rocstories/val/cache-02cf1c773423805c.arrow


In [7]:
dataset

DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'input_ids', 'labels', 'sentence1', 'sentence2', 'sentence3', 'sentence4', 'sentence5', 'storyid', 'storytitle', 'text', 'token_type_ids'],
        num_rows: 300
    })
    test: Dataset({
        features: ['attention_mask', 'input_ids', 'labels', 'sentence1', 'sentence2', 'sentence3', 'sentence4', 'sentence5', 'storyid', 'storytitle', 'text', 'token_type_ids'],
        num_rows: 300
    })
    val: Dataset({
        features: ['attention_mask', 'input_ids', 'labels', 'sentence1', 'sentence2', 'sentence3', 'sentence4', 'sentence5', 'storyid', 'storytitle', 'text', 'token_type_ids'],
        num_rows: 300
    })
})

In [8]:
dataset = dataset.remove_columns(['text', 'storyid', 'storytitle'] + [f'sentence{i}' for i in range(1, 6)])
dataset.set_format('torch')

In [9]:
collate_fn = TransformerPoutyneCollator(y_keys=['labels', 'input_ids'], custom_collator=so_data_collator)

train_dataloader = DataLoader(dataset['train'], batch_size=TRAIN_BATCH_SIZE, collate_fn=collate_fn)
val_dataloader = DataLoader(dataset['val'], batch_size=VAL_BATCH_SIZE, collate_fn=collate_fn)
test_dataloader = DataLoader(dataset['test'], batch_size=VAL_BATCH_SIZE, collate_fn=collate_fn)

In [10]:

wrapped_transformer = TransformerPoutyneWrapper(transformer)

optimizer = AdamW(wrapped_transformer.parameters(), lr=LEARNING_RATE)
loss_fn = PoutyneSequenceOrderingLoss(target_token_id=tokenizer.cls_token_id)


model = Model(wrapped_transformer, optimizer, loss_fn, device='cuda:0')

In [11]:
model.fit_generator(train_dataloader, val_dataloader, epochs=1)

Epoch: 1/1 Train steps: 38 Val steps: 19 15.86s loss: 10.775345 val_loss: 9.349187 


[{'epoch': 1,
  'time': 15.85634977184236,
  'loss': 10.775345441500345,
  'val_loss': 9.34918704509735}]

In [12]:
#test_loss, test_acc = model.evaluate_generator(test_loader)

In [14]:
wrapped_transformer#

TransformerPoutyneWrapper(BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((76