We all know that BERT model and it's variants (DISTILBERT and ROBERTa) can process a maximum sequence length of 512. The sequences which are longer than the max sequence length are truncated and the ones less than 512 can be padded up to a fixed length to maintain sequence length uniformity throughout the dataset.

We'll look how to pass long text sequences to BERT and it's variants, without needing to truncate them, as some sequences may have relevant information after the 512 token length limit.

We'll use the DistilBERT model as it is much lighter and has less trainable parameters, which will use less computation resources and space.

In [None]:
### Install libraries
!pip install transformers
!pip install scikit-learn

We'll be using the 20NewsGroup dataset available on sklearn.dataset class

In [1]:
from sklearn.datasets import fetch_20newsgroups
newsgroups_train = fetch_20newsgroups(subset='train')

In [76]:
### Create a dataframe of columns text and labels
import re
import pandas as pd
from sklearn.model_selection import train_test_split
df = pd.DataFrame(list(zip(newsgroups_train.data,newsgroups_train.target)),columns=['text','label'])
df = df.sample(frac=1).head(5000)
print(df)

                                                    text  label
4028   From: mlee@eng.sdsu.edu (Mike Lee)\nSubject: M...      1
4686   From: rclark@nyx.cs.du.edu\nSubject: Re: Is th...     11
10440  From: kkeller@mail.sas.upenn.edu (Keith Keller...     10
7219   From: bobbe@vice.ICO.TEK.COM (Robert Beauchain...      0
6618   From: emm@tamarack202.cray.com (Mike McConnell...      1
...                                                  ...    ...
10654  From: mantolov@golum.riv.csu.edu.au (Michael A...      4
10417  From: sforsblo@vipunen.hut.fi (Svante Forsblom...     10
10032  From: steveg@cadkey.com (Steve Gallichio)\nSub...     10
9322   From: James_Jim_Frazier@cup.portal.com\nSubjec...      4
4642   From: (Joseph D. Barrus)\nSubject: Utility to ...      2

[5000 rows x 2 columns]


In [77]:
label2ids = {label:i for i,label in enumerate(newsgroups_train.target_names)}
ids2label = {id:label for label,id in label2ids.items()}
print(label2ids)
print(ids2label)

{'alt.atheism': 0, 'comp.graphics': 1, 'comp.os.ms-windows.misc': 2, 'comp.sys.ibm.pc.hardware': 3, 'comp.sys.mac.hardware': 4, 'comp.windows.x': 5, 'misc.forsale': 6, 'rec.autos': 7, 'rec.motorcycles': 8, 'rec.sport.baseball': 9, 'rec.sport.hockey': 10, 'sci.crypt': 11, 'sci.electronics': 12, 'sci.med': 13, 'sci.space': 14, 'soc.religion.christian': 15, 'talk.politics.guns': 16, 'talk.politics.mideast': 17, 'talk.politics.misc': 18, 'talk.religion.misc': 19}
{0: 'alt.atheism', 1: 'comp.graphics', 2: 'comp.os.ms-windows.misc', 3: 'comp.sys.ibm.pc.hardware', 4: 'comp.sys.mac.hardware', 5: 'comp.windows.x', 6: 'misc.forsale', 7: 'rec.autos', 8: 'rec.motorcycles', 9: 'rec.sport.baseball', 10: 'rec.sport.hockey', 11: 'sci.crypt', 12: 'sci.electronics', 13: 'sci.med', 14: 'sci.space', 15: 'soc.religion.christian', 16: 'talk.politics.guns', 17: 'talk.politics.mideast', 18: 'talk.politics.misc', 19: 'talk.religion.misc'}


In [78]:
df['len'] = df.text.apply(lambda t: len(t.split()))
df.len.describe()

count     5000.000000
mean       291.394800
std        548.426855
min         15.000000
25%        109.000000
50%        180.000000
75%        295.000000
max      11263.000000
Name: len, dtype: float64

In [79]:
### Basic Preprocessing
def text_preprocessing(text):
  text = text.lower()
  text = text.replace('\n',' ').replace('\r',' ').replace('\t',' ')
  text = re.sub(r" {2,}"," ",text)
  return text

df['text'] = df.text.apply(lambda t: text_preprocessing(t))

In [80]:
df['encoded_label'] = df['label'].astype('category').cat.codes
NUM_LABELS = len(df['encoded_label'].value_counts().index)
print(NUM_LABELS)

20


In [81]:
# label2ids = dict(zip(df.label,df.encoded_label))
# print(label2ids)

In [82]:
train_texts,val_texts,train_labels,val_labels = train_test_split(df.text.tolist(),df.encoded_label.tolist(),test_size=0.2,random_state=42)
print(type(train_texts),type(train_labels))
print(len(train_texts),len(val_texts),len(train_labels),len(val_labels))

<class 'list'> <class 'list'>
4000 1000 4000 1000


In [83]:
from transformers import AutoTokenizer,AutoModelForSequenceClassification

In [117]:
import sys
import torch

In [84]:
model_checkpoint = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint,num_labels=NUM_LABELS)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [105]:
train_tokens = tokenizer(train_texts)
test_tokens = tokenizer(val_texts)
train_encodings = [{"input_ids":ids, "attention_mask":mask} for ids,mask in zip(train_tokens['input_ids'],train_tokens['attention_mask'])]
test_encodings = [{"input_ids":ids, "attention_mask":mask} for ids,mask in zip(test_tokens['input_ids'],test_tokens['attention_mask'])]

Sliding Window Class will attend to all the data without truncating. The input_id sequences more than 512 sequence will be splitted by length of window size. Each subsequent windows will have the overlap of the tokens from the previous token.
Arguments:
* window_size, default 256, length of each window slide. Max value should be 2 less than the no of max length defined in the transformer model architecture.
* overlap, default 128, length of the overlap of tokens from the previous slide.
* tokenizer, default 'distilbert', required to set the START END bit token values. Other available tokenizer = ['roberta']
* pad_sequence_flag, default True, sets padding to the windows
* MAX_LEN, default 512, BERT based models can attend maximum 512 sequence lengths, other architecture can attend as defined in their paper.

In [100]:
class SlidingWindow:

  def __init__(self, window_size=256, overlap=128, tokenizer='distilbert', pad_sequence_flag=True, MAX_LEN=512):
    self.window_size = window_size
    self.overlap = overlap
    self.tokenizer = tokenizer
    self.pad_sequence_flag = pad_sequence_flag
    self.pad_sequence_length = MAX_LEN

  def pad_sequences(self,sequence,padding_type='POST',sequence_type='input_ids',tokenizer='distilbert'):
    if self.tokenizer == 'distilbert':
      if sequence_type == 'input_ids':
        padding_data = [0]
      elif sequence_type == 'attention_mask':
        padding_data = [0]
    elif self.tokenizer == 'roberta':
      if sequence_type == 'input_ids':
        padding_data = [1]
      elif sequence_type == 'attention_mask':
        padding_data = [0]

    if padding_type == 'POST':
      sequence.extend(padding_data*(self.pad_sequence_length-len(sequence)))
      return sequence
    elif padding_type == 'PRE':
      temp_data = [padding_data]*(self.pad_sequence_length-len(sequence))
      return temp_data

  def create_token_windows(self,tokenizer_sequence,labels,label_ids_flag=False,label_ids_list=[]):
    if label_ids_flag == True and (len(label_ids_flag) != len(tokenizer_sequence['input_ids'])):
      print("Either flag of label_ids_flag is not set to True or label_ids_list is empty or length not equal to the length of tokenizer_sequence")
      sys.exit(1)

    input_ids_list = []
    attention_mask_list = []
    labels_list = []

    if self.tokenizer == "roberta":
      start_token = [0]
      end_token = [2]
    elif self.tokenizer == "distilbert":
      start_token = [101]
      end_token = [102]
    else:
      print("[ERROR] No Tokenizer Defined - ", self.tokenizer)
      sys.exit()

    for index,input_sequence in enumerate(tokenizer_sequence):
      sequence = input_sequence['input_ids'][1:-1]
      attention_sequence = input_sequence['attention_mask'][1:-1]
      label = labels[index]

      for i in range(0,len(sequence),self.window_size-self.overlap):
        window_input_ids = start_token + sequence[i:i+self.window_size] + end_token
        window_input_ids = self.pad_sequences(window_input_ids,sequence_type="input_ids",tokenizer=self.tokenizer)
        input_ids_list.append(window_input_ids)

        attention_mask = [1]+attention_sequence[i:i+self.window_size]+[1]
        attention_mask = self.pad_sequences(attention_mask,sequence_type="attention_mask",tokenizer=self.tokenizer)
        attention_mask_list.append(attention_mask)

        if label_ids_flag == True:
          labels_list.append(label_ids_list[label])
        else:
          labels_list.append(label)

    return input_ids_list,attention_mask_list,labels_list


In [101]:
sw = SlidingWindow(window_size=510, overlap=128, tokenizer='distilbert')

In [110]:
train_input_ids,train_masks,train_labels = sw.create_token_windows(train_encodings,train_labels)

In [115]:
test_input_ids,test_masks,test_labels = sw.create_token_windows(test_encodings,val_labels)

In [118]:
train_input_ids = torch.tensor(train_input_ids,dtype=torch.long)
train_masks = torch.tensor(train_masks,dtype=torch.long)
train_labels = torch.tensor(train_labels,dtype=torch.long)

In [119]:
test_input_ids = torch.tensor(test_input_ids,dtype=torch.long)
test_masks = torch.tensor(test_masks,dtype=torch.long)
test_labels = torch.tensor(test_labels,dtype=torch.long)