# Load DBPedia dataset

In [None]:
from datasets import load_dataset

ds = load_dataset("fancyzhx/dbpedia_14")
CLASS_LABELS = ds['train'].features['label'].names
CLASS_LABELS

# Obtain first n samples from each class

In [None]:
import numpy as np

In [None]:
def get_n_samples_per_class(dataset, n, shuffle = False):
    """
        Given a test dataset, select n samples from each class
        and return a smaller dataset containing all the samples.

        Args:
            dataset (Dataset): The test dataset to sample.
            n (int): How many samples from each class to extract.
            shuffle (bool): Whether to sort the final result by class or randomly. NOTE: Dataset.shuffle() hangs indefinitely on Nix.

        Returns:
            sample (Dataset): The sampled dataset.
    """
    ds_sorted = dataset.sort('label')
    _, class_indices = np.unique(ds_sorted['label'], return_index=True)


    class_indices = np.array([list(range(index, index + n)) for index in class_indices])
    class_indices = class_indices.flatten()

    if shuffle:
        sample = dataset.shuffle().sort('label').select(class_indices) # Dataset.shuffle() hangs indefinitely on Nix - No idea why.
    else:
        sample = dataset.sort('label').select(class_indices)

    if shuffle: sample = sample.shuffle() # Dataset.shuffle() hangs indefinitely on Nix - No idea why.
    return sample

In [None]:
small_dataset = get_n_samples_per_class(ds['test'], 3, shuffle=False)

# Prompt to classify articles

In [None]:
PROMPT = """You are an expert in classifying articles into categories.
Your task is to read an article, decide which category it belongs into, and then return the number of that category.
There are 14 categories you may choose from, but you can only decide one category.

CATEGORIES:
0. Company
1. Educational Institution
2. Artist
3. Athlete
4. Office Holder
5. Method Of Transportation
6. Building
7. Natural Place
8. Village
9. Animal
10. Plant
11. Album
12. Film
13. Written Work

Read the following article and return the most suitable category as a number ("0"), NOT as text ("Company").
"""

In [None]:
def get_classification_prompt(article):
    """
        For a given article in the Dataset,
        return a LLM prompt in chat template form
        to get its category.

        Args:
            article (Dictionary): Any item in the dataset.

        Returns:
            prompt (Dictionary): The prompt as a [Chat Template](https://huggingface.co/docs/transformers/main/en/chat_templating).
    """
    return [
      {"role": "system", "content": PROMPT},
      {"role": "user", "content": article["content"].strip()},
    ]

# Load LLM

To access the LLM (C4AI Command R7B), you will need to accept a license agreement on Hugging Face.

STEPS:
1. Log into hugging face
2. Accept the license agreement [here](https://huggingface.co/CohereForAI/c4ai-command-r7b-12-2024)
3. Replace variable ``YOUR_HF_TOKEN`` with your Hugging Face token.

In [None]:
HF_TOKEN = ""

In [None]:
from accelerate.test_utils.testing import get_backend

DEVICE, _, _ = get_backend()

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "CohereForAI/c4ai-command-r7b-12-2024"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN, device_map="auto")
model = AutoModelForCausalLM.from_pretrained(model_id, token=HF_TOKEN, device_map="auto")

In [None]:
import transformers, torch

pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    token=HF_TOKEN,
    torch_dtype=torch.float16,
    device_map="auto",
)

# Test the model

In [None]:
import re

def get_category_label(article):
  """
  For a given article in the DBPedia dataset, predict its category label.

  Args:
    article (str): Article contents as raw text.

  Returns:
    label (int): The category of the article.
  """
  input = get_classification_prompt(article)

  chat_history = pipeline(
      input,
      do_sample=True,
      #top_k=10,
      #num_return_sequences=1,
      eos_token_id=tokenizer.eos_token_id,
      max_new_tokens=10,
      temperature=0.001
      #continue_final_message=continue_final_message
  )

  response = chat_history[0]["generated_text"][-1]['content']

  response_number=re.findall(r"\d+",response)
  if response_number is not None:
    response_number = int(response_number[0])
    return response_number
  else:
    raise Exception(f"No number found in LLM response: {response}")

In [None]:
from tqdm import tqdm

def predict_classes(dataset):
  """
    For a given RFPedia dataset, use the contents of each article to predict its label.

    Args:
      dataset (Dataset): The dataset to sample.

    Returns:
      results (tuple<list, list>): Two lists: ``y_pred`` (predicted labels) and ``y_true`` (actual labels).
  """
  y_pred = []
  y = []

  for item in tqdm(dataset):
    y_pred.append( get_category_label(item) )
    y.append( item['label' ])

  return y_pred, y_true

In [None]:
# Predict all article categories in the dataset

y_pred, y_true = predict_classes(small_dataset['test'])

# Evaluate the model

## Get precision, recall, and F1 score for all classes

In [None]:
from sklearn.metrics import classification_report

classif_report = classification_report(y_true, y_pred)

## Get confusion matrix

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from matplotlib import pyplot as plt

cm = confusion_matrix(y_true, y_pred, labels=CLASS_LABELS)
disp = ConfusionMatrixDisplay(
    confusion_matrix=cm,
    display_labels = CLASS_LABELS,
    cmap = plt.cm.Blues,
    normalize=True)

disp.plot()
plt.show()