<a href="https://colab.research.google.com/github/Taaniya/exploring-gpt2-language-model/blob/main/Text_classification_with_gpt2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
! pip install transformers datasets accelerate

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.29.2-py3-none-any.whl (7.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m64.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-2.12.0-py3-none-any.whl (474 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m53.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate
  Downloading accelerate-0.19.0-py3-none-any.whl (219 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m219.1/219.1 kB[0m [31m29.9 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)
  Downloading huggingface_hub-0.15.1-py3-none-any.whl (236 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m236.8/236.8 kB[0m [31m33.9 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1

In [4]:
import re
from transformers import pipeline
from transformers.pipelines.pt_utils import KeyDataset
from datasets import load_dataset
from transformers import AutoTokenizer
import pandas as pd
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay

In [5]:
model_save_path = "/model_dir/distilgpt2-classifier"

train_dataset_path = "./train.csv"
test_dataset_path = "./test.csv"

In [None]:
! python finetune_gpt2_clm_pytorch.py \
--dataset_path "./train.csv" \
--batch_size 16 \
--epoch 4 \
--model_name gpt2 \
--model_version 1 \
--model_dir "./"

In [6]:
def parse_label_preds(text):
  """ Parses generated text by the model to extracts the label.
      Regex pattern is defined based on the prompt patterns used in data preparation
      for model finetuning.
  """
  try:
    span = re.search("(###Text:)[a-zA-Z]+", text).span()
    label_text = text[span[0]: span[1]]
    label_text = re.sub("###Text:", "", label_text)

  except Exception as e:
    print(f"Can't parse: {text}")
    label_text = "None"

  return label_text

In [7]:
def prep_test_data(row):
  """ Preprocessing test data example the same way during inference as is done 
  during data prep for finetuning.
  """
  row['text'] = "###Text: " + row['text'] + " ###Label:"
  return row

In [None]:
# Load model with text generation NLP task for inference on GPU
ft_gpt2_classifier = pipeline("text-generation", model=model_save_path, device=0)

# load test datasets
test_data = load_dataset("csv", data_files={"train": [test_dataset_path]}, 
                         download_mode='force_redownload')

test_df = pd.read_csv(test_dataset_path)

In [None]:
updated_test_dataset = test_data.map(prep_test_data)

In [None]:
prompt = "###Text: asmsalc ###Label:"
output = ft_gpt2_classifier(prompt, return_full_text=True, no_repeat_ngram_size=2, 
                            num_beams=2)
print(output[0]['generated_text'])

In [None]:
preds = []
# Run inference on entire test dataset. Run on each example separately rather 
# than in batches for GPT2 with text generation task while using KeyDataset

for output in ft_gpt2_classifier(KeyDataset(test_data['test'], 'text'), 
                                 return_full_text=True, no_repeat_ngram_size=2, 
                                  num_beams=2):
  pred_label = parse_label_preds(output[0]'generated_text')
  preds.append(pred_label)

In [None]:
test_df['preds'] = preds

In [None]:
print(test_df[["label", "preds"]].head(20))

In [None]:
print(classification_report(test_df['label'], test_df['preds']))

In [None]:
# Check misclassified examples
test_df[test_df['preds'] != test_df['label']]

In [None]:
# Save misclassified cases into excel for further RCA
test_df[test_df['preds'] != test_df['label']].to_excel("./misclassified.xlsx", index=False)