In [41]:
from transformers import BartTokenizer, BartForConditionalGeneration
from datasets import load_dataset
import torch
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# Load pre-trained BART model and tokenizer
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")

# Define label options (21 categories)
label_options = [
    "POLITICS", "INTERNATIONAL RELATIONS", "EUROPEAN UNION", "LAW", "ECONOMICS",
    "TRADE", "FINANCE", "SOCIAL QUESTIONS", "EDUCATION AND COMMUNICATIONS", "SCIENCE",
    "BUSINESS AND COMPETITION", "EMPLOYMENT AND WORKING CONDITIONS", "TRANSPORT",
    "ENVIRONMENT", "AGRICULTURE, FORESTRY AND FISHERIES", "AGRI-FOODSTUFFS",
    "PRODUCTION, TECHNOLOGY AND RESEARCH", "ENERGY", "INDUSTRY", "GEOGRAPHY",
    "INTERNATIONAL ORGANISATIONS"
]

In [38]:
# Load the Multi-EURLEX dataset (assuming English subset)
dataset = load_dataset('multi_eurlex', 'all_languages', split='test')

In [39]:
# Function to prompt BART with text and label options
def classify_text_with_bart(text, label_options):
    # Prepare the input for BART
    prompt = f"={text}\n\n DO NOT REPEAT THE TEXT. Select the most relevant categories from the list below. Only list the categories that best match the content of the text. DO NOT REPEAT THE ENTIRE LIST, just give a concise answer:\n\nCategories: {', '.join(label_options)}"
    
    # Tokenize input
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True)

    # Generate output from BART
    outputs = model.generate(inputs['input_ids'], max_length=200, num_beams=5, num_return_sequences=1)

    # Decode the generated output
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return generated_text

# Extract relevant labels from BART's output
def extract_labels_from_generated_text(generated_text, label_options):
    relevant_labels = []
    for label in label_options:
        if label.lower() in generated_text.lower():
            relevant_labels.append(label)
    return relevant_labels

# Function to map label names back to numbers
def map_labels_to_indices(label_names, label_options):
    label_indices = [label_options.index(label) for label in label_names if label in label_options]
    return label_indices

# Function to preprocess the dataset and map true labels to category numbers
def preprocess_dataset(dataset, label_options):
    preprocessed_data = []

    for item in dataset:
        text = item['text']['en']  # Extract English text
        labels = item['labels']    # True label numbers

        preprocessed_data.append({"text": text, "labels": labels})

    return preprocessed_data

# Evaluate the model on the entire dataset
def evaluate_bart_on_dataset(dataset, label_options):
    all_true_labels = []
    all_predicted_labels = []

    for entry in dataset:
        text = entry['text']
        true_labels = entry['labels']

        # Get BART's generated labels
        generated_text = classify_text_with_bart(text, label_options)
        predicted_label_names = extract_labels_from_generated_text(generated_text, label_options)
        predicted_labels = map_labels_to_indices(predicted_label_names, label_options)

        # Store true and predicted labels for later comparison
        all_true_labels.append(true_labels)
        all_predicted_labels.append(predicted_labels)

        print(f"\nText: {generated_text}")
        print(f"Generated labels: {predicted_label_names}")
        print(f"True labels: {true_labels}, Predicted labels: {predicted_labels}")

    return all_true_labels, all_predicted_labels

In [40]:
# Preprocess the dataset
preprocessed_data = preprocess_dataset(dataset, label_options)

# Run the evaluation on the entire preprocessed dataset
true_labels, predicted_labels = evaluate_bart_on_dataset(preprocessed_data, label_options)

# Now you can compare true_labels and predicted_labels to compute accuracy, precision, etc.
# Flatten the lists of true and predicted labels for evaluation
flattened_true = [label for sublist in true_labels for label in sublist]
flattened_pred = [label for sublist in predicted_labels for label in sublist]

precision, recall, f1, _ = precision_recall_fscore_support(flattened_true, flattened_pred, average='macro')

print(f"\nPrecision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


ValueError: Input length of input_ids is 902, but `max_length` is set to 200. This can lead to unexpected behavior. You should consider increasing `max_length` or, better yet, setting `max_new_tokens`.