In [1]:
from google.colab import drive
drive.mount('/content/drive',force_remount=True)

Mounted at /content/drive


In [2]:
!ls /content/drive/MyDrive/GitHub/exBERT/data/splits

test_new.tsv  train_new.tsv  valid_new.tsv


In [3]:
!pip install transformers==3.1.0 seqeval==0.0.12



In [4]:
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertConfig, BertForTokenClassification

In [5]:
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
print(device)

cuda


In [6]:
data = pd.read_csv("/content/drive/MyDrive/GitHub/exBERT/data/splits/train_new.tsv", delimiter="\t", encoding='utf-8')
data.head()

Unnamed: 0,word,tag
0,Ibrutinib,O
1,and,O
2,other,O
3,targeted,O
4,inhibitors,O


In [7]:
val = pd.read_csv("/content/drive/MyDrive/GitHub/exBERT/data/splits/valid_new.tsv",  delimiter="\t", encoding='utf-8')
val.head()

Unnamed: 0,word,tag
0,INNO,B-therapy\_group
1,-,I-therapy\_group
2,406,I-therapy\_group
3,",",O
4,an,O


In [8]:
test = pd.read_csv("/content/drive/MyDrive/GitHub/exBERT/data/splits/test_new.tsv", delimiter="\t", encoding='utf-8')
test.head()

Unnamed: 0,word,tag
0,CLL,O
1,cells,O
2,treated,O
3,with,O
4,dasatinib,B-therapy\_group


In [9]:
len(data)

74939

In [10]:
len(val)

24004

In [11]:
len(test)

10141

In [12]:
train = data.append(val, ignore_index=True)

In [13]:
len(train)

98943

In [14]:
print("Number of tags: {}".format(len(test.tag.unique())))
frequencies = test.tag.value_counts()
frequencies

Number of tags: 17


O                   9888
B-therapy\_group      91
I-mutation            48
I-therapy\_group      26
B-mutation            24
B-nam                 21
I-demography          14
I-nam                  6
B-demography           5
I-PFS                  4
I-followup             4
B-OS                   3
I-OS                   2
B-followup             2
B-PFS                  1
therapy\_group         1
nam                    1
Name: tag, dtype: int64

In [15]:
test.count()

word    10141
tag     10141
dtype: int64

In [16]:
train.count() #mismatch

word    98939
tag     98943
dtype: int64

In [17]:
print("Number of tags: {}".format(len(train.tag.unique())))
frequencies = train.tag.value_counts()
frequencies

Number of tags: 26


O                   96288
B-therapy\_group      730
I-therapy\_group      502
I-mutation            278
B-mutation            203
B-nam                 140
I-nam                 123
I-demography          117
I-followup            111
B-OS                   57
I-PFS                  56
B-PFS                  54
I-OS                   43
B-followup             42
I-CI                   38
B-demography           35
I-*                    32
B-CI                   21
B-comorbidity          19
B-OR                   18
I-comorbidity          16
B-HR                    7
B-CR                    5
B-*                     5
I-CR                    2
I-OR                    1
Name: tag, dtype: int64

In [18]:
train[train['tag']=='B-*']

Unnamed: 0,word,tag
53576,MiR,B-*
67474,everolimus,B-*
78872,imatinib,B-*
87875,MGCD0103,B-*
96929,flavopiridol,B-*


In [19]:
train[train['tag']=='I-*']

Unnamed: 0,word,tag
53577,-,I-*
53578,16,I-*
53579,is,I-*
53580,indeed,I-*
53581,down-regulated,I-*
53582,(,I-*
53583,fold-change,I-*
53584,2.3,I-*
53585,",",I-*
53586,p,I-*


In [20]:
len(train)

98943

In [21]:
train = train[(train['tag']!='B-*') & (train['tag']!='I-*')]

In [22]:
len(train)

98906

In [23]:
train.count()

word    98902
tag     98906
dtype: int64

In [24]:
print("Number of tags: {}".format(len(train.tag.unique())))
frequencies = train.tag.value_counts()
frequencies

Number of tags: 24


O                   96288
B-therapy\_group      730
I-therapy\_group      502
I-mutation            278
B-mutation            203
B-nam                 140
I-nam                 123
I-demography          117
I-followup            111
B-OS                   57
I-PFS                  56
B-PFS                  54
I-OS                   43
B-followup             42
I-CI                   38
B-demography           35
B-CI                   21
B-comorbidity          19
B-OR                   18
I-comorbidity          16
B-HR                    7
B-CR                    5
I-CR                    2
I-OR                    1
Name: tag, dtype: int64

In [25]:
tags = {}
for tag, count in zip(frequencies.index, frequencies):
    if tag != "O":
        if tag[2:8] not in tags.keys():
            tags[tag[2:8]] = count
        else:
            tags[tag[2:8]] += count
    continue

print(sorted(tags.items(), key=lambda x: x[1], reverse=True))

[('therap', 1232), ('mutati', 481), ('nam', 263), ('follow', 153), ('demogr', 152), ('PFS', 110), ('OS', 100), ('CI', 59), ('comorb', 35), ('OR', 19), ('HR', 7), ('CR', 7)]


In [26]:
labels_to_ids = {k: v for v, k in enumerate(train.tag.unique())}
ids_to_labels = {v: k for v, k in enumerate(train.tag.unique())}
labels_to_ids

{'B-CI': 17,
 'B-CR': 22,
 'B-HR': 16,
 'B-OR': 4,
 'B-OS': 19,
 'B-PFS': 6,
 'B-comorbidity': 15,
 'B-demography': 10,
 'B-followup': 13,
 'B-mutation': 2,
 'B-nam': 8,
 'B-therapy\\_group': 1,
 'I-CI': 18,
 'I-CR': 23,
 'I-OR': 5,
 'I-OS': 20,
 'I-PFS': 7,
 'I-comorbidity': 21,
 'I-demography': 11,
 'I-followup': 14,
 'I-mutation': 3,
 'I-nam': 9,
 'I-therapy\\_group': 12,
 'O': 0}

In [27]:
train[train['tag'].isna()]

Unnamed: 0,word,tag


In [28]:
train[train['word'].isna()]

Unnamed: 0,word,tag
63172,,O
67541,,O
86368,,O
90435,,O


In [29]:
train = train[train['word'].notna()]
train.count()

word    98902
tag     98902
dtype: int64

In [30]:
train['tag'][11421]

'O'

In [31]:
train.head()

Unnamed: 0,word,tag
0,Ibrutinib,O
1,and,O
2,other,O
3,targeted,O
4,inhibitors,O


In [32]:
train.tail()

Unnamed: 0,word,tag
98938,use,O
98939,of,O
98940,this,O
98941,drug,O
98942,.,O


In [33]:
import re

In [34]:
words_str = ""
len(words_str)

0

In [35]:
train['word'] = train['word'].str.strip()
train['tag'] = train['tag'].str.strip()

In [36]:
train[train['word']=='']

Unnamed: 0,word,tag


In [37]:
train[train['tag']=='']

Unnamed: 0,word,tag


In [38]:
train

Unnamed: 0,word,tag
0,Ibrutinib,O
1,and,O
2,other,O
3,targeted,O
4,inhibitors,O
...,...,...
98938,use,O
98939,of,O
98940,this,O
98941,drug,O


In [39]:
words_str = ""
tag_str = ""
for word, tag in zip(train['word'], train['tag']):
  if len(words_str)==0 and len(tag_str)==0:
    words_str = word
    tag_str = tag
    #print(word, tag)
  elif len(words_str)!=0 and len(tag_str)!=0:
    words_str = words_str + ' ' + word
    tag_str = tag_str + ',' + tag
  
# words_str = words_str.strip()
# tag_str = tag_str.strip()

In [40]:
len(words_str.split())

98902

In [41]:
len(tag_str.split(','))

98902

In [42]:
# txt = "Ibrutinib and other targeted inhibitors of B-cell receptor signaling achieve impressive clinical results for patients with chronic lymphocytic leukemia ( CLL ) . A treatment-induced rise in absolute lymphocyte count ( ALC ) has emerged as a class effect of kinase inhibitors in CLL and warrants further investigation . We "
# a = re.sub("(?:.(\.\s)(?=[A-Z]))", " .\n", txt)
# a.split("\n")

In [43]:
sents = re.sub("(?:.(\.\s)(?=[A-Z]))", " .\n", words_str)
list_of_sentences = sents.split("\n")

In [44]:
list_of_sentences[:3]

['Ibrutinib and other targeted inhibitors of B-cell receptor signaling achieve impressive clinical results for patients with chronic lymphocytic leukemia ( CLL ) .',
 'A treatment-induced rise in absolute lymphocyte count ( ALC ) has emerged as a class effect of kinase inhibitors in CLL and warrants further investigation .',
 'We here report correlative studies in 64 patients with CLL treated with ibrutinib .']

In [45]:
for item in list_of_sentences[:3]:
  print(len(item.split()))

23
26
14


In [46]:
sent_len = []
for item in list_of_sentences:
  sl = len(item.split())
  sent_len.append(sl)
sent_len[:10]


[23, 26, 14, 29, 30, 27, 29, 26, 11, 9]

In [47]:
len(list_of_sentences[1].split())

26

In [48]:
len(list_of_sentences)

3370

In [49]:
list_of_tags = []
offset_1st_sent_len = 0
for i in range(0,len(list_of_sentences)):
  if i == 0:
    first_sent_start = 0
    sent_len = len(list_of_sentences[i].strip().split(' '))
    sent1 = tag_str.split(',')[0:sent_len]
    list_of_tags.append(sent1)
    offset_1st_sent_len = sent_len

  if i != 0:
    sent_start = offset_1st_sent_len
    next_sent_len = len(list_of_sentences[i].strip().split(' '))
    next_sent = tag_str.split(',')[sent_start:(sent_start + next_sent_len)]
    list_of_tags.append(next_sent)
    offset_1st_sent_len = offset_1st_sent_len + next_sent_len

In [50]:
offset_1st_sent_len

98902

In [51]:
len(list_of_tags)

3370

In [52]:
len(list_of_sentences)

3370

In [53]:
len(list_of_sentences[0].split(" "))

23

In [54]:
list_of_sentences[0]

'Ibrutinib and other targeted inhibitors of B-cell receptor signaling achieve impressive clinical results for patients with chronic lymphocytic leukemia ( CLL ) .'

In [55]:
len(list_of_tags[0])

23

In [56]:
len(list_of_sentences[3].split(" "))

29

In [57]:
len(list_of_tags[3])

29

In [58]:
list_of_tags[2]

['O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-therapy\\_group',
 'O']

In [59]:
for ele in list_of_tags[:3]:
  print(ele)

['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-therapy\\_group', 'O']


In [61]:
for ele in list_of_sentences[:3]:
  print(ele)

Ibrutinib and other targeted inhibitors of B-cell receptor signaling achieve impressive clinical results for patients with chronic lymphocytic leukemia ( CLL ) .
A treatment-induced rise in absolute lymphocyte count ( ALC ) has emerged as a class effect of kinase inhibitors in CLL and warrants further investigation .
We here report correlative studies in 64 patients with CLL treated with ibrutinib .


In [62]:
len(list_of_sentences[2].split(" "))

14

In [63]:
len(list_of_tags[2])

14

In [64]:
list_of_sentences[2]

'We here report correlative studies in 64 patients with CLL treated with ibrutinib .'

In [65]:
list_of_tags[2]

['O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-therapy\\_group',
 'O']

In [66]:
a = list_of_sentences[2]
a.split()[12]

'ibrutinib'

In [67]:
b = list_of_tags[2]
b[12]

'B-therapy\\_group'

In [68]:
list_of_str_tags = [','.join(x) for x in list_of_tags]
list_of_str_tags[:3]

['O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O',
 'O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O',
 'O,O,O,O,O,O,O,O,O,O,O,O,B-therapy\\_group,O']

In [69]:
data = pd.DataFrame(
    {'sentence': list_of_sentences,
     'word_labels': list_of_str_tags
    })
data

Unnamed: 0,sentence,word_labels
0,Ibrutinib and other targeted inhibitors of B-c...,"O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O"
1,A treatment-induced rise in absolute lymphocyt...,"O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,..."
2,We here report correlative studies in 64 patie...,"O,O,O,O,O,O,O,O,O,O,O,O,B-therapy\_group,O"
3,"We quantified tumor burden in blood , lymph no...","O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,..."
4,With just one dose of ibrutinib the average in...,"O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,..."
...,...,...
3365,Inter-individual variability is of particular ...,"O,O,O,O,O,O,O,O,O,O,O,O,O,O,O"
3366,The developing model for flavopiridol PK may u...,"O,O,O,O,B-therapy\_group,O,O,O,O,O,O,O,O,O,O,O..."
3367,"With sufficient data and validation , this mod...","O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,..."
3368,"If sufficiently robust , such a tool would be ...","O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,..."


In [78]:
a = list_of_sentences[3366]
a.split()[4]

'flavopiridol'

In [79]:
b = list_of_tags[3366]
b[4]

'B-therapy\\_group'

In [80]:
len(data)

3370

In [81]:
a = data.iloc[3366].sentence
a

'The developing model for flavopiridol PK may ultimately help to explain this variability by incorporating pharmacogenetic and other significant factors .'

In [82]:
a.split()[4]

'flavopiridol'

In [90]:
b = data.iloc[3366].word_labels
b

'O,O,O,O,B-therapy\\_group,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O'

In [94]:
# Import generic wrappers
from transformers import AutoModel, BertTokenizerFast


# Define the model repo
model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext" 

tokenizer = BertTokenizerFast.from_pretrained(model_name)    

In [95]:
MAX_LEN = 128
TRAIN_BATCH_SIZE = 8
VALID_BATCH_SIZE = 2
EPOCHS = 5
LEARNING_RATE = 1e-05
MAX_GRAD_NORM = 10

In [96]:
class dataset(Dataset):
  def __init__(self, dataframe, tokenizer, max_len):
        self.len = len(dataframe)
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len

  def __getitem__(self, index):
        # step 1: get the sentence and word labels 
        sentence = self.data.sentence[index].strip().split()  
        word_labels = self.data.word_labels[index].split(",") 

        # step 2: use tokenizer to encode sentence (includes padding/truncation up to max length)
        # BertTokenizerFast provides a handy "return_offsets_mapping" functionality for individual tokens
        encoding = self.tokenizer(sentence,
                             is_pretokenized=True, 
                             return_offsets_mapping=True, 
                             padding='max_length', 
                             truncation=True, 
                             max_length=self.max_len)
        
        # step 3: create token labels only for first word pieces of each tokenized word
        labels = [labels_to_ids[label] for label in word_labels] 
        # code based on https://huggingface.co/transformers/custom_datasets.html#tok-ner
        # create an empty array of -100 of length max_length
        encoded_labels = np.ones(len(encoding["offset_mapping"]), dtype=int) * -100
        
        # set only labels whose first offset position is 0 and the second is not 0
        i = 0
        for idx, mapping in enumerate(encoding["offset_mapping"]):
          if mapping[0] == 0 and mapping[1] != 0:
            # overwrite label
            encoded_labels[idx] = labels[i]
            i += 1

        # step 4: turn everything into PyTorch tensors
        item = {key: torch.as_tensor(val) for key, val in encoding.items()}
        item['labels'] = torch.as_tensor(encoded_labels)
        
        return item

  def __len__(self):
        return self.len

In [97]:
train_size = 0.8
train_dataset = data.sample(frac=train_size,random_state=200)
test_dataset = data.drop(train_dataset.index).reset_index(drop=True)
train_dataset = train_dataset.reset_index(drop=True)

print("FULL Dataset: {}".format(data.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("TEST Dataset: {}".format(test_dataset.shape))

training_set = dataset(train_dataset, tokenizer, MAX_LEN)
testing_set = dataset(test_dataset, tokenizer, MAX_LEN)

FULL Dataset: (3370, 2)
TRAIN Dataset: (2696, 2)
TEST Dataset: (674, 2)


In [98]:
training_set[0]

{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]),
 'input_ids': tensor([    2,  2019,  1920,  2367,  1927,  2052,  2333,    16,  4786,  2132,
          1922,  1920,  1927,  9576,  6683,  2210,  2430, 14925,  2338,  1942,
          7130, 25138, 30160,  2256,  3982,  2573,  4380,    18,     3,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     

In [99]:
for token, label in zip(tokenizer.convert_ids_to_tokens(training_set[0]["input_ids"]), training_set[0]["labels"]):
  print('{0:10}  {1}'.format(token, label))

[CLS]       -100
at          0
the         0
time        0
of          0
this        0
analysis    0
,           0
57          0
patients    0
in          0
the         0
of          0
##atum      -100
##umab      -100
group       0
had         0
crossed     0
over        0
to          0
receive     0
ibr         0
##utinib    -100
after       0
confirmed   0
disease     0
progression  0
.           0
[SEP]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100


In [100]:
train_params = {'batch_size': TRAIN_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

test_params = {'batch_size': VALID_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

training_loader = DataLoader(training_set, **train_params)
testing_loader = DataLoader(testing_set, **test_params)

In [101]:
# Download pytorch model
model = BertForTokenClassification.from_pretrained(model_name, num_labels=len(labels_to_ids))
model.to(device)

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForToken

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwis

In [102]:
inputs = training_set[2]
input_ids = inputs["input_ids"].unsqueeze(0)
attention_mask = inputs["attention_mask"].unsqueeze(0)
labels = inputs["labels"].unsqueeze(0)

input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
labels = labels.to(device)

outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
initial_loss = outputs[0]
initial_loss

tensor(3.2550, device='cuda:0', grad_fn=<NllLossBackward0>)

In [103]:
tr_logits = outputs[1]
tr_logits.shape #batch_size, sequence_length, num_labels

torch.Size([1, 128, 24])

In [104]:
optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)

In [105]:
# Defining the training function on the 80% of the dataset for tuning the bert model
def train(epoch):
    tr_loss, tr_accuracy = 0, 0
    nb_tr_examples, nb_tr_steps = 0, 0
    tr_preds, tr_labels = [], []
    # put model in training mode
    model.train()
    
    for idx, batch in enumerate(training_loader):
        
        ids = batch['input_ids'].to(device, dtype = torch.long)
        mask = batch['attention_mask'].to(device, dtype = torch.long)
        labels = batch['labels'].to(device, dtype = torch.long)

        loss, tr_logits = model(input_ids=ids, attention_mask=mask, labels=labels)
        tr_loss += loss.item()

        nb_tr_steps += 1
        nb_tr_examples += labels.size(0)
        
        if idx % 100==0:
            loss_step = tr_loss/nb_tr_steps
            print(f"Training loss per 100 training steps: {loss_step}")
           
        # compute training accuracy
        flattened_targets = labels.view(-1) # shape (batch_size * seq_len,)
        active_logits = tr_logits.view(-1, model.num_labels) # shape (batch_size * seq_len, num_labels)
        flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size * seq_len,)
        
        # only compute accuracy at active labels
        active_accuracy = labels.view(-1) != -100 # shape (batch_size, seq_len)
        #active_labels = torch.where(active_accuracy, labels.view(-1), torch.tensor(-100).type_as(labels))
        
        labels = torch.masked_select(flattened_targets, active_accuracy)
        predictions = torch.masked_select(flattened_predictions, active_accuracy)
        
        tr_labels.extend(labels)
        tr_preds.extend(predictions)

        tmp_tr_accuracy = accuracy_score(labels.cpu().numpy(), predictions.cpu().numpy())
        tr_accuracy += tmp_tr_accuracy
    
        # gradient clipping
        torch.nn.utils.clip_grad_norm_(
            parameters=model.parameters(), max_norm=MAX_GRAD_NORM
        )
        
        # backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    epoch_loss = tr_loss / nb_tr_steps
    tr_accuracy = tr_accuracy / nb_tr_steps
    print(f"Training loss epoch: {epoch_loss}")
    print(f"Training accuracy epoch: {tr_accuracy}")

In [106]:
for epoch in range(EPOCHS):
    print(f"Training epoch: {epoch + 1}")
    train(epoch)

Training epoch: 1
Training loss per 100 training steps: 3.2338485717773438
Training loss per 100 training steps: 0.7105852672369173
Training loss per 100 training steps: 0.4348260011542496
Training loss per 100 training steps: 0.35473331351736653
Training loss epoch: 0.3317185820298662
Training accuracy epoch: 0.9546292673152934
Training epoch: 2
Training loss per 100 training steps: 0.05097512528300285
Training loss per 100 training steps: 0.13179737264935923
Training loss per 100 training steps: 0.1254149898516005
Training loss per 100 training steps: 0.11966634988537263
Training loss epoch: 0.11716877435298101
Training accuracy epoch: 0.9771085842164825
Training epoch: 3
Training loss per 100 training steps: 0.03880093991756439
Training loss per 100 training steps: 0.08778549124495966
Training loss per 100 training steps: 0.08639723646561083
Training loss per 100 training steps: 0.08170880710526657
Training loss epoch: 0.0846477134722542
Training accuracy epoch: 0.9800211998613609
T

In [107]:
def valid(model, testing_loader):
    # put model in evaluation mode
    model.eval()
    
    eval_loss, eval_accuracy = 0, 0
    nb_eval_examples, nb_eval_steps = 0, 0
    eval_preds, eval_labels = [], []
    
    with torch.no_grad():
        for idx, batch in enumerate(testing_loader):
            
            ids = batch['input_ids'].to(device, dtype = torch.long)
            mask = batch['attention_mask'].to(device, dtype = torch.long)
            labels = batch['labels'].to(device, dtype = torch.long)
            
            loss, eval_logits = model(input_ids=ids, attention_mask=mask, labels=labels)
            
            eval_loss += loss.item()

            nb_eval_steps += 1
            nb_eval_examples += labels.size(0)
        
            if idx % 100==0:
                loss_step = eval_loss/nb_eval_steps
                print(f"Validation loss per 100 evaluation steps: {loss_step}")
              
            # compute evaluation accuracy
            flattened_targets = labels.view(-1) # shape (batch_size * seq_len,)
            active_logits = eval_logits.view(-1, model.num_labels) # shape (batch_size * seq_len, num_labels)
            flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size * seq_len,)
            
            # only compute accuracy at active labels
            active_accuracy = labels.view(-1) != -100 # shape (batch_size, seq_len)
        
            labels = torch.masked_select(flattened_targets, active_accuracy)
            predictions = torch.masked_select(flattened_predictions, active_accuracy)
            
            eval_labels.extend(labels)
            eval_preds.extend(predictions)
            
            tmp_eval_accuracy = accuracy_score(labels.cpu().numpy(), predictions.cpu().numpy())
            eval_accuracy += tmp_eval_accuracy

    labels = [ids_to_labels[id.item()] for id in eval_labels]
    predictions = [ids_to_labels[id.item()] for id in eval_preds]
    
    eval_loss = eval_loss / nb_eval_steps
    eval_accuracy = eval_accuracy / nb_eval_steps
    print(f"Validation Loss: {eval_loss}")
    print(f"Validation Accuracy: {eval_accuracy}")

    return labels, predictions

In [108]:
labels, predictions = valid(model, testing_loader)

Validation loss per 100 evaluation steps: 0.08540473878383636
Validation loss per 100 evaluation steps: 0.12062265152586532
Validation loss per 100 evaluation steps: 0.11604803081182531
Validation loss per 100 evaluation steps: 0.1075918509971437
Validation Loss: 0.10304265459225644
Validation Accuracy: 0.9746232359726567


In [109]:
from seqeval.metrics import classification_report

In [110]:
len(labels)

20007

In [111]:
len(predictions)

20007

In [112]:
print(classification_report(labels, predictions))

                precision    recall  f1-score   support

           PFS       0.18      0.08      0.11        25
therapy\_group       0.78      0.49      0.60       169
      mutation       0.16      0.07      0.09        45
            OS       0.33      0.04      0.07        24
           nam       0.67      0.10      0.17        21
            CI       0.00      0.00      0.00         4
      followup       0.04      0.09      0.05        11
    demography       0.00      0.00      0.00         5
            OR       0.00      0.00      0.00         5
            CR       0.00      0.00      0.00         2
   comorbidity       0.00      0.00      0.00         3
            HR       0.00      0.00      0.00         1

     micro avg       0.48      0.29      0.36       315
     macro avg       0.53      0.29      0.36       315



In [119]:
sentence = "Overall , 357 patients were randomized ( rit- uximab plus bendamustine , n=178 ; rituximab plus chlorambucil , n=179 ; intent-to-treat population ) , including 241 first-line patients (n=121 and n=120, respectively ) ; 355 patients received treatment ( n = 177 and n=178 , respectively ; safety population ) ."

In [120]:
inputs = tokenizer(sentence.split(),
                    is_pretokenized=True, 
                    return_offsets_mapping=True, 
                    padding='max_length', 
                    truncation=True, 
                    max_length=MAX_LEN,
                    return_tensors="pt")

# move to gpu
ids = inputs["input_ids"].to(device)
mask = inputs["attention_mask"].to(device)
# forward pass
outputs = model(ids, attention_mask=mask)
logits = outputs[0]

active_logits = logits.view(-1, model.num_labels) # shape (batch_size * seq_len, num_labels)
flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size*seq_len,) - predictions at the token level

tokens = tokenizer.convert_ids_to_tokens(ids.squeeze().tolist())
token_predictions = [ids_to_labels[i] for i in flattened_predictions.cpu().numpy()]
wp_preds = list(zip(tokens, token_predictions)) # list of tuples. Each tuple = (wordpiece, prediction)

prediction = []
for token_pred, mapping in zip(wp_preds, inputs["offset_mapping"].squeeze().tolist()):
  #only predictions on first word pieces are important
  if mapping[0] == 0 and mapping[1] != 0:
    prediction.append(token_pred[1])
  else:
    continue

print(sentence.split())
print(prediction)

['Overall', ',', '357', 'patients', 'were', 'randomized', '(', 'rit-', 'uximab', 'plus', 'bendamustine', ',', 'n=178', ';', 'rituximab', 'plus', 'chlorambucil', ',', 'n=179', ';', 'intent-to-treat', 'population', ')', ',', 'including', '241', 'first-line', 'patients', '(n=121', 'and', 'n=120,', 'respectively', ')', ';', '355', 'patients', 'received', 'treatment', '(', 'n', '=', '177', 'and', 'n=178', ',', 'respectively', ';', 'safety', 'population', ')', '.']
['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
