In [None]:
from torch.utils.data import Dataset,DataLoader,random_split
import torch
import re
import pandas as pd

In [None]:
BASE_FOLDER = 'TEST/'
INDEX_PATH = os.path.join(BASE_FOLDER,"index.tsv")
CLEAN_LEVEL = 4
CACHE_DIR = os.path.join(os.getcwd(), 'transformers-cache')

In [None]:
class WikipediaDataset(Dataset):
    def __init__(
        self,
        tsv_file,
        root_dir,
        clean_level=0,
        delete_sections=[
            "References",
            "External links",
            "See also",
            "Further reading",
            "Notes",
            "Bibliography",
            "Sources",
        ],
    ):
        """
        Args:
            tsv_file (string): Path to the tsv file with annotations.
            root_dir (string): Database Base Directory.
            clean_level (callable, optional): Optional cleaning to be applied on a sample.
            delete_sections (list, optional): Optional list of sections to be deleted if clean level is configured to delete sections.
        """
        self.wiki_frame = pd.read_csv(tsv_file, sep="\t")
        self.root_dir = root_dir
        self.clean_level = clean_level
        self.RE_HEADINGS = re.compile(r"==.*?==+", re.MULTILINE)
        self.delete_sections = list(map(self._headingClean, delete_sections))

    def _headingClean(self,x,args=None):
        """
            internal function to clean headings and make them lowercase so that comparisons can be performed
        """
        try:
            return x.replace(" ", "").lower()
        except Exception as e:
            print("ERROR at clean_heading function:", e)
            return x

    def cleanHeadings(self, x, DEL_HEADINGS=False):
        """
            Function to remove unwanted characters from headings
        """
        if DEL_HEADINGS:
            return self.RE_HEADINGS.sub("", x)
        else:
            return (
                x.replace("==== ", "")
                .replace("=== ", "")
                .replace("== ", "")
                .replace(" ====", "")
                .replace(" ===", "")
                .replace(" ==", "")
            )

    def removeSections(self, x):
        """
            Function to remove unwanted sections from the text
        """
        r = self.RE_HEADINGS.finditer(x)
        sections = [(m.start(0), m.end(0)) for m in r]
        s = []
        for i, sec in enumerate(sections):
            secname = x[sec[0] : sec[1]].replace("=", "").replace(" ", "").lower()
            if secname in self.delete_sections:
                sb = sec[0]
                try:
                    se = sections[i + 1][0]
                except IndexError:
                    se = len(x)
                s.append(x[sb:se])
        for sec in s:
            x = x.replace(sec, "")
        return x

    def clean(self, x):
        """
            Function to clean the text
            CLEANING LEVELS:
                0: No cleaning
                1: Clean All Headings
                2: Delete All Headings
                3: Delete Sections only
                4: Delete Selected Sections and Clean Headings
                5: Delete Selected Sections and Delete All Headings
            RECOMMENDED:
                0: No Cleaning
                1: Light Cleaning
                4: Heavy Cleaning
        """
        if self.clean_level == 0:
            return x
        elif self.clean_level == 1:
            return self.cleanHeadings(x)
        elif self.clean_level == 2:
            return self.cleanHeadings(x, DEL_HEADINGS=True)
        elif self.clean_level == 3:
            return self.removeSections(x)
        elif self.clean_level == 4:
            return self.cleanHeadings(self.removeSections(x))
        elif self.clean_level == 5:
            return self.cleanHeadings(self.removeSections(x), DEL_HEADINGS=True)
        else:
            raise Exception(
                "Invalid clean_level configured. Please reinitialize dataloader."
            )

    def __len__(self):
        return len(self.wiki_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        x = self.wiki_frame.iloc[idx]
        cat, file_name = x["category"], x["filename"]
        txt = open(os.path.join(self.root_dir, cat, file_name), "r").read()
        txt = self.clean(txt)
        sample = {"text": txt, "label": cat}
        return sample


In [None]:
dataset = WikipediaDataset(INDEX_PATH, BASE_FOLDER, clean_level=CLEAN_LEVEL)
length = len(dataset)
print(f"Dataset length: {length}")
test_l = int(length * 0.15)
train_l = length - test_l
valid_l = int(train_l * 0.1)
train_l -= valid_l
train,test,validation = random_split(range(len(dataset)), [train_l,test_l,valid_l], generator=torch.Generator().manual_seed(42))
print(f"Train: {len(train)}, Test: {len(test)}, Validation: {len(validation)}")
# Dictionary of labels and their id - this will be used to convert.
# String labels to number ids.
labels_ids = {'B': 0, 'C': 1, 'FA': 2, 'GA': 3, 'Start': 4, 'Stub': 5}

# How many labels are we using in training.
# This is used to decide size of classification head.
n_labels = len(labels_ids)

In [None]:


class Gpt2ClassificationCollator(object):
    r"""
    Data Collator used for GPT2 in a classificaiton task. 
    
    It uses a given tokenizer and label encoder to convert any text and labels to numbers that 
    can go straight into a GPT2 model.

    This class is built with reusability in mind: it can be used as is as long
    as the `dataloader` outputs a batch in dictionary format that can be passed 
    straight into the model - `model(**batch)`.

    Arguments:

      use_tokenizer (:obj:`transformers.tokenization_?`):
          Transformer type tokenizer used to process raw text into numbers.

      labels_ids (:obj:`dict`):
          Dictionary to encode any labels names into numbers. Keys map to 
          labels names and Values map to number associated to those labels.

      max_sequence_len (:obj:`int`, `optional`)
          Value to indicate the maximum desired sequence to truncate or pad text
          sequences. If no value is passed it will used maximum sequence size
          supported by the tokenizer and model.

    """

    def __init__(self, use_tokenizer, labels_encoder, max_sequence_len=None):

        # Tokenizer to be used inside the class.
        self.use_tokenizer = use_tokenizer
        # Check max sequence length.
        self.max_sequence_len = use_tokenizer.model_max_length if max_sequence_len is None else max_sequence_len
        # Label encoder used inside the class.
        self.labels_encoder = labels_encoder

        return

    def __call__(self, sequences):
        r"""
        This function allowes the class objesct to be used as a function call.
        Sine the PyTorch DataLoader needs a collator function, I can use this 
        class as a function.

        Arguments:

          item (:obj:`list`):
              List of texts and labels.

        Returns:
          :obj:`Dict[str, object]`: Dictionary of inputs that feed into the model.
          It holddes the statement `model(**Returned Dictionary)`.
        """

        # Get all texts from sequences list.
        texts = [sequence['text'] for sequence in sequences]
        # Get all labels from sequences list.
        labels = [sequence['label'] for sequence in sequences]
        # Encode all labels using label encoder.
        labels = [self.labels_encoder[label] for label in labels]
        # Call tokenizer on all texts to convert into tensors of numbers with 
        # appropriate padding.
        inputs = self.use_tokenizer(text=texts, return_tensors="pt", padding=True, truncation=True,  max_length=self.max_sequence_len)
        # Update the inputs with the associated encoded labels as tensor.
        inputs.update({'labels':torch.tensor(labels)})

        return inputs



In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import set_seed, GPT2Tokenizer

# Set seed for reproducibility.
set_seed(42)

# Number of training epochs.
epochs = 4

# Number of batches - depending on the max sequence length and GPU memory.
# For 512 sequence length batch of 10 works without cuda memory issues.
# For small sequence length can try batch of 32 or higher.
batch_size = 32

# Pad or truncate text sequences to a specific length
# if `None` it will use maximum sequence of word piece tokens allowed by model.
max_length = 60

# Look for gpu to use. Will use `cpu` by default if no gpu found.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Name of transformers model - will use already pretrained model.
# Path of transformer model - will load your own model from local disk.
model_name_or_path = 'gpt2'


In [None]:

# Get model's tokenizer.
print('Loading tokenizer...')
tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path=model_name_or_path, cache_dir=CACHE_DIR)
# default to left padding
tokenizer.padding_side = "left"
# Define PAD Token = EOS Token = 50256
tokenizer.pad_token = tokenizer.eos_token



In [None]:
# Create data collator to encode text and labels into numbers.
gpt2_classificaiton_collator = Gpt2ClassificationCollator(use_tokenizer=tokenizer, 
                                                          labels_encoder=labels_ids, 
                                                          max_sequence_len=max_length)


print('Dealing with Train...')
# Move pytorch dataset into dataloader.
train_dataloader = DataLoader(train, batch_size=batch_size, shuffle=True, collate_fn=gpt2_classificaiton_collator)
print('Created `train_dataloader` with %d batches!'%len(train_dataloader))

print('Dealing with Test...')
# Move pytorch dataset into dataloader.
test_dataloader = DataLoader(test, batch_size=batch_size, shuffle=False, collate_fn=gpt2_classificaiton_collator)
print('Created `test_dataloader` with %d batches!'%len(test_dataloader))

print('Dealing with Validation...')
# Move pytorch dataset into dataloader.
valid_dataloader = DataLoader(validation, batch_size=batch_size, shuffle=False, collate_fn=gpt2_classificaiton_collator)
print('Created `eval_dataloader` with %d batches!'%len(valid_dataloader))

https://colab.research.google.com/github/gmihaila/ml_things/blob/master/notebooks/pytorch/gpt2_finetune_classification.ipynb#scrollTo=EDEubgJIt23C