In [1]:
from datasets import load_dataset

ds = load_dataset("fancyzhx/dbpedia_14")
CLASS_LABELS = ds['train'].features['label'].names
CLASS_LABELS

['Company',
 'EducationalInstitution',
 'Artist',
 'Athlete',
 'OfficeHolder',
 'MeanOfTransportation',
 'Building',
 'NaturalPlace',
 'Village',
 'Animal',
 'Plant',
 'Album',
 'Film',
 'WrittenWork']

# Obtain first n samples from each class

In [2]:
import numpy as np

In [3]:
def get_n_samples_per_class(dataset, n, shuffle = False):
    """
        Given a test dataset, select n samples from each class
        and return a smaller dataset containing all the samples.

        Args:
            dataset (Dataset): The test dataset to sample.
            n (int): How many samples from each class to extract.
            shuffle (bool): Whether to sort the final result by class or randomly. NOTE: Dataset.shuffle() hangs indefinitely on Nix.
        
        Returns:
            sample (Dataset): The sampled dataset.
    """
    ds_sorted = dataset.sort('label')
    _, class_indices = np.unique(ds_sorted['label'], return_index=True)

    
    class_indices = np.array([list(range(index, index + n)) for index in class_indices])
    class_indices = class_indices.flatten()

    if shuffle:
        sample = dataset.shuffle().sort('label').select(class_indices) # Dataset.shuffle() hangs indefinitely on Nix - No idea why.
    else:
        sample = dataset.sort('label').select(class_indices)
    
    if shuffle: sample = sample.shuffle() # Dataset.shuffle() hangs indefinitely on Nix - No idea why.
    return sample

In [4]:
small_dataset = get_n_samples_per_class(ds['test'], 3, shuffle=False)

In [5]:
small_dataset[0]

{'label': 0,
 'title': 'TY KU',
 'content': " TY KU /taɪkuː/ is an American alcoholic beverage company that specializes in sake and other spirits. The privately-held company was founded in 2004 and is headquartered in New York City New York. While based in New York TY KU's beverages are made in Japan through a joint venture with two sake breweries. Since 2011 TY KU's growth has extended its products into all 50 states."}

# Prompt to classify articles

In [15]:
PROMPT = """You are an expert in classifying articles into categories.
Your task is to read an article, decide which category it belongs into, and then return the number of that category.
There are 14 categories you may choose from, but you can only decide one category.

CATEGORIES:
1. Company
2. Educational Institution
3. Artist
4. Athlete
5. Office Holder
6. Method Of Transportation
7. Building
8. Natural Place
9. Village
10. Animal
11. Plant
12. Album
13. Film
14. Written Work

Read the following article and return only the number of its category. Do NOT return any text.
"""

In [16]:
def get_classification_prompt(article):
    """
        For a given article in the Dataset,
        return a LLM prompt in chat template form
        to get its category.

        Args:
            article (Dictionary): Any item in the dataset.

        Returns:
            prompt (Dictionary): The prompt as a [Chat Template](https://huggingface.co/docs/transformers/main/en/chat_templating).
    """
    return [
      {"role": "system", "content": PROMPT},
      {"role": "user", "content": article["content"].strip()},
    ]

# Load LLM

To access the LLM (C4AI Command R7B), you will need to accept a license agreement on Hugging Face.

STEPS:
1. Log into hugging face
2. Accept the license agreement [here](https://huggingface.co/CohereForAI/c4ai-command-r7b-12-2024)
3. Replace variable ``YOUR_HF_TOKEN`` with your Hugging Face token.

In [6]:
HF_TOKEN = ""

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "CohereForAI/aya-expanse-8b"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(model_id, token=HF_TOKEN)

tokenizer_config.json:   0%|          | 0.00/8.64k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/12.8M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/439 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/634 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/21.0k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

In [None]:

# Format message with the c4ai-command-r7b-12-2024 chat template
messages = get_classification_prompt(small_dataset[0])
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")

gen_tokens = model.generate(
    input_ids,
    max_new_tokens=10,
    do_sample=True,
    temperature=0.3,
)

gen_text = tokenizer.decode(gen_tokens[0], skip_special_tokens=True)
print(gen_text)