In [None]:
import pandas, torch, csv, random

In [None]:

class Gpt2ClassificationCollator(object):

    def __init__(self, use_tokenizer, labels_encoder, max_sequence_len=None):

        # Tokenizer to be used inside the class.
        self.use_tokenizer = use_tokenizer
        # Check max sequence length.
        self.max_sequence_len = use_tokenizer.model_max_length if max_sequence_len is None else max_sequence_len
        # Label encoder used inside the class.
        self.labels_encoder = labels_encoder

        return

    def __call__(self, sequences):
      
        # print(sequences)
        # Get all texts from sequences list.
        texts = [sequence['text'] for sequence in sequences]
#         print(texts)
        # Get all labels from sequences list.
        labels = [int(sequence['label']) for sequence in sequences]
#         print(labels)
        # Encode all labels using label encoder.
        labels = [self.labels_encoder[label] for label in labels]
#         print(labels)
        # Call tokenizer on all texts to convert into tensors of numbers with 
        # appropriate padding.
        inputs = self.use_tokenizer(text=texts, return_tensors="pt", padding=True, truncation=True,  max_length=self.max_sequence_len)
        # Update the inputs with the associated encoded labels as tensor.
        inputs.update({'labels':torch.tensor(labels)})

        return inputs