# PyTorch / Huggingface Transformers experimentation

Finetune some bert variant using transformers & pytorch
This is a sequence-tagging problem. Give each token an IOB tag

Remember that BERT uses subword tokens so we will need to transform tokens (and their labels) between the per-word REAL tokens and the subword BERT tokens.

This was mostly an exercise in using the more complex tools for cases where we want to have that fine-grain control.

I wouldn't suggest going down this route straight away as it demands more engineering overhead than simple spacy pipelines w.r.t configuration of training & deployment. This is deep R&D work for eeking out additional percentage points of performance, which will matter but not at the prototype v1 stage.

**NOTE: This notebook is a rough WIP, so isn't in a presentable state!**

In [4]:
from pathlib import Path
import re

def read_conll(file_path, contains_pos_tags=False):
    file_path = Path(file_path)

    raw_text = file_path.read_text(encoding='utf-8').strip()
    raw_docs = re.split(r'\n\s?\n', raw_text)
    
    token_docs = []
    tag_docs = []
    
    for doc in raw_docs:
        tokens = []
        tags = []
        for line in doc.split('\n'):
            
            if contains_pos_tags:
                token, _, tag = line.split()
            else:
                token, tag = line.split()
                
            tokens.append(token)
            tags.append(tag)
            
        token_docs.append(tokens)
        tag_docs.append(tags)

    return token_docs, tag_docs

In [5]:
docs, tags = read_conll('../data/processed/PLOD_IOB_tagged.conll')

In [6]:
from sklearn.model_selection import train_test_split

train_texts, val_texts, train_tags, val_tags = train_test_split(docs, tags, test_size=.25)
val_texts, test_texts, val_tags, test_tags = train_test_split(val_texts, val_tags, test_size=.5)


In [7]:
unique_tags = set(tag for doc in tags for tag in doc)
tag2id = {tag: id for id, tag in enumerate(unique_tags)}
id2tag = {id: tag for tag, id in tag2id.items()}

In [8]:
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers import DataCollatorForTokenClassification

tokenizer = AutoTokenizer.from_pretrained('distilroberta-base', add_prefix_space=True)

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
len(train_texts)

120699

In [10]:
num_samples = 1000

In [11]:
train_encodings = tokenizer(train_texts[:num_samples], is_split_into_words=True, padding=True, truncation=True)
val_encodings = tokenizer(val_texts[:num_samples], is_split_into_words=True, padding=True, truncation=True)

In [12]:
def tokenize_and_align_labels(tags, encodings, label_all_tokens = True):

    tags = labels = [[tag2id[tag] for tag in doc] for doc in tags]
    
    labels = []
    for i, label in enumerate(tags):
        word_ids = encodings.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx

        labels.append(label_ids)
        
    return labels

In [14]:
train_labels = tokenize_and_align_labels(train_tags[:num_samples], train_encodings)
val_labels = tokenize_and_align_labels(val_tags[:num_samples], val_encodings)

In [15]:
train_encodings[4].tokens

['<s>',
 'ĠBrief',
 'ly',
 'Ġ,',
 'Ġthe',
 'Ġseed',
 'lings',
 'Ġwere',
 'Ġquickly',
 'Ġimmersed',
 'Ġin',
 'Ġis',
 'op',
 'rop',
 'anol',
 'Ġpre',
 'he',
 'ated',
 'Ġto',
 'Ġ75',
 'ĠÂ°',
 'ĠC',
 'Ġwith',
 'Ġ0',
 '.',
 '01',
 'Ġ%',
 'Ġbut',
 'yl',
 'ated',
 'Ġhydro',
 'xy',
 't',
 'ol',
 'u',
 'ene',
 'Ġ(',
 'ĠB',
 'HT',
 'Ġ)',
 'Ġfor',
 'Ġ15',
 'Ġminutes',
 'Ġ.',
 '</s>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<p

In [16]:
len(train_labels[4])

512

In [17]:
len(train_tags[4])

26

In [18]:
# train_encodings.pop("offset_mapping")
# val_encodings.pop("offset_mapping")

In [19]:
import torch

class PLODDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [20]:
train_dataset = PLODDataset(train_encodings, train_labels)
val_dataset = PLODDataset(val_encodings, val_labels)

In [21]:
from torch.utils.data import DataLoader
from transformers import AutoModelForTokenClassification

from torch.optim import AdamW 

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = AutoModelForTokenClassification.from_pretrained('distilroberta-base', num_labels = len(unique_tags))

Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaForTokenClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForTokenClassification were not initialized from the model checkpoint at distilroberta-base and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream tas

In [22]:
model.to(device)
model.train()

RobertaForTokenClassification(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm

In [23]:
train_loader = DataLoader(train_dataset, batch_size=16)

optim = AdamW(model.parameters(), lr=5e-5)

for epoch in range(2):
    for batch in train_loader:
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
        loss.backward()
        optim.step()

model.eval()

RobertaForTokenClassification(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm

In [63]:
# input_str = "Patients with extensive-stage small cell lung cancer (ED-SCLC) have a very short survival time even if they receive standard cytotoxic chemotherapy with etoposide and platinum (EP)."
input_str = 'In contrast, these cells are absent in the lamina propria (LP) of the mouse colonic mucosa, where most IL-10-secreting lymphocytes express Foxp3 [16].'

In [64]:
inputs = tokenizer(input_str, padding=True, truncation=True, return_tensors='pt').to(device)

In [65]:
inputs

{'input_ids': tensor([[    0,    96,  5709,     6,   209,  4590,    32, 11640,    11,     5,
           784, 41568,  8462,  6374,    36, 21992,    43,     9,     5, 18292,
         17735,   636, 38791,  5166,     6,   147,   144, 11935,    12,   698,
            12,  8584,   241,  2577, 23496, 44601,  5486,  2063,   642,   246,
           646,  1549,  8174,     2]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='cuda:0')}

In [66]:
outputs = model(**inputs)

In [67]:
outputs

TokenClassifierOutput(loss=None, logits=tensor([[[-0.6138, -0.6680, -0.0871,  1.2316],
         [-1.5102, -1.8666, -2.4535,  5.1474],
         [-1.9305, -1.8036, -2.3607,  4.8776],
         [-1.7232, -1.4494, -2.8566,  5.4164],
         [-1.4039, -2.2569, -2.7008,  5.1325],
         [-0.8144, -1.9781, -3.0788,  4.8919],
         [-2.0613, -1.3120, -2.8724,  5.3820],
         [-2.4469, -1.5748, -2.3598,  5.4010],
         [-2.1997, -1.3873, -2.4605,  5.4236],
         [-2.3879, -1.5284, -1.8941,  4.6757],
         [-2.2350, -1.4874,  2.2968,  0.4937],
         [-2.5363, -0.6757,  1.6134,  0.4842],
         [-2.8732,  2.5288, -1.8849,  0.9659],
         [-2.8488,  2.0024, -1.8762,  0.6059],
         [ 1.4684, -2.5377, -3.1271,  3.1645],
         [ 3.8848, -3.2437, -2.4698,  1.2805],
         [ 0.9271, -1.8850, -2.6658,  3.0638],
         [-2.0416, -1.0433, -2.6415,  4.7327],
         [-2.1956, -1.3959, -2.8112,  5.2242],
         [-2.1493, -1.9524, -1.8941,  4.3037],
         [-2.7451, -

In [68]:
predicted_token_class_ids = outputs.logits.argmax(-1)

In [69]:
pred_tags = [id2tag[t.item()] for t in predicted_token_class_ids[0]]

In [70]:
input_tokenized = tokenizer.tokenize(input_str, padding=True, truncation=True)

In [71]:
len(pred_tags)

44

In [72]:
len(input_tokenized)

42

In [73]:
inputs_with_preds = list(zip(input_tokenized, pred_tags[1:-1]))

In [74]:
inputs_with_preds

[('ĠIn', 'O'),
 ('Ġcontrast', 'O'),
 (',', 'O'),
 ('Ġthese', 'O'),
 ('Ġcells', 'O'),
 ('Ġare', 'O'),
 ('Ġabsent', 'O'),
 ('Ġin', 'O'),
 ('Ġthe', 'O'),
 ('Ġl', 'B-LF'),
 ('amina', 'B-LF'),
 ('Ġprop', 'I-LF'),
 ('ria', 'I-LF'),
 ('Ġ(', 'O'),
 ('LP', 'B-SF'),
 (')', 'O'),
 ('Ġof', 'O'),
 ('Ġthe', 'O'),
 ('Ġmouse', 'O'),
 ('Ġcolon', 'O'),
 ('ic', 'O'),
 ('Ġmuc', 'O'),
 ('osa', 'O'),
 (',', 'O'),
 ('Ġwhere', 'O'),
 ('Ġmost', 'O'),
 ('ĠIL', 'B-SF'),
 ('-', 'B-SF'),
 ('10', 'O'),
 ('-', 'O'),
 ('sec', 'O'),
 ('re', 'O'),
 ('ting', 'O'),
 ('Ġlymph', 'O'),
 ('ocytes', 'O'),
 ('Ġexpress', 'O'),
 ('ĠFox', 'B-SF'),
 ('p', 'B-SF'),
 ('3', 'B-SF'),
 ('Ġ[', 'O'),
 ('16', 'O'),
 ('].', 'O')]