In [1]:
# imports
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModelForMaskedLM, TrainingArguments, Trainer
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_recall_fscore_support
from torch.utils.data import Dataset
from transformers import AutoConfig
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from transformers import EvalPrediction
from transformers import Trainer
from transformers import TrainingArguments
import evaluate
import torch
from typing import Any
from typing import Dict
from typing import Optional
from torch.utils.data import Dataset
from datasets import load_dataset
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Load and preprocess News dataset

In [2]:
data = load_dataset('masakhane/masakhanews', 'swa') 

display(data['train'][0])

def change_label(data_point):
    if data_point["label"] == 3:
        data_point["label"] = 1
    else:
        data_point["label"] = 0
    return data_point

politics_label_data = data.map(change_label)

display(politics_label_data)
display(politics_label_data["train"][:5])

{'label': 5,
 'headline': 'Tetesi za soka Ulaya Jumatatu 26.04.2021: Varane, Camara, Nagelsmann, Willock, Azpilicueta',
 'text': 'Chelsea wapo mbele ya  Manchester United na Paris St-Germain katika mbio za kutaka kumsajili beki wa Real Madrid na timu ya taifa ya Ufaransa Raphael Varane, 28. (Mundo Deportivo - in Spanish) Beki wa Guinea Ali Camara, 23, ambaye anachezea klabu ya Young Boys ya Switzerland, amezivutia klabu kadhaa za Ligi ya Primia zikiwemo  Liverpool, Arsenal, Crystal Palace, West Ham United na Norwich. (Team Talk) Bayern Munich imeanzisha mazungumzo ya kutaka kumsajili kocha wa RB Leipzig Julian Nagelsmann. (Independent) Arsenal wapo njia panda juu ya mustakabali wa kiungo wao Joe Willock, 21, ambaye yupo Newcastle kwa mkopo. Arsenal wanahitaji kuuza baadhi ya wachezaji ili kujiimarisha kifedha. (Football London) Kocha wa Atletico Madrid  Diego Simeone ana nia ya kumsajili beki raia wa Uhispania Cesar Azpilicueta, 31, kutoka  Chelsea. (El Gol Digital - in Spanish) Manche

DatasetDict({
    train: Dataset({
        features: ['label', 'headline', 'text', 'headline_text', 'url'],
        num_rows: 1658
    })
    validation: Dataset({
        features: ['label', 'headline', 'text', 'headline_text', 'url'],
        num_rows: 237
    })
    test: Dataset({
        features: ['label', 'headline', 'text', 'headline_text', 'url'],
        num_rows: 476
    })
})

{'label': [0, 0, 1, 0, 0],
 'headline': ['Tetesi za soka Ulaya Jumatatu 26.04.2021: Varane, Camara, Nagelsmann, Willock, Azpilicueta',
  'Je chanjo ya corona ni salama?',
  'Matokeo ya uchaguzi Marekani 2020: Donald Trump amfuta kazi Waziri wa Ulinzi Mark Esper',
  'Je wajua mwanamke na mwanaume hawapaswi kufanya mazoezi pamoja?',
  'Watoto waliolazimika kuwa kimya kuhusu baba zao wakutana na maaskofu jijini Paris'],
 'text': ['Chelsea wapo mbele ya  Manchester United na Paris St-Germain katika mbio za kutaka kumsajili beki wa Real Madrid na timu ya taifa ya Ufaransa Raphael Varane, 28. (Mundo Deportivo - in Spanish) Beki wa Guinea Ali Camara, 23, ambaye anachezea klabu ya Young Boys ya Switzerland, amezivutia klabu kadhaa za Ligi ya Primia zikiwemo  Liverpool, Arsenal, Crystal Palace, West Ham United na Norwich. (Team Talk) Bayern Munich imeanzisha mazungumzo ya kutaka kumsajili kocha wa RB Leipzig Julian Nagelsmann. (Independent) Arsenal wapo njia panda juu ya mustakabali wa kiungo w

## Load Pre-Trained Model
### AfriBerta

In [3]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("castorini/afriberta_base")
model = AutoModelForSequenceClassification.from_pretrained("castorini/afriberta_base")

tokenizer.model_max_length = 512

Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at castorini/afriberta_base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Tokenize Data

In [4]:

def tokenize_function(datapoints):
    return tokenizer(datapoints["headline_text"], padding="max_length", truncation=True)


tokenized_datasets = politics_label_data.map(tokenize_function, batched=True)

#small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
#small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

tokenized_datasets['train']

Dataset({
    features: ['label', 'headline', 'text', 'headline_text', 'url', 'input_ids', 'attention_mask'],
    num_rows: 1658
})

## Train Baseline on News Dataset

In [5]:

from transformers import DataCollatorWithPadding



def compute_metrics(pred: EvalPrediction) -> Dict[str, float]:
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="macro")
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}


training_args = TrainingArguments(output_dir="test_trainer", eval_strategy="epoch")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer= tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()

  0%|          | 0/624 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

{'eval_loss': 0.27488455176353455, 'eval_accuracy': 0.8987341772151899, 'eval_f1': 0.8661016949152542, 'eval_precision': 0.836869118905047, 'eval_recall': 0.9211764705882353, 'eval_runtime': 2.8363, 'eval_samples_per_second': 83.559, 'eval_steps_per_second': 10.577, 'epoch': 1.0}


  0%|          | 0/30 [00:00<?, ?it/s]

{'eval_loss': 0.16575419902801514, 'eval_accuracy': 0.9493670886075949, 'eval_f1': 0.926091476091476, 'eval_precision': 0.914996964177292, 'eval_recall': 0.9386096256684493, 'eval_runtime': 2.8574, 'eval_samples_per_second': 82.943, 'eval_steps_per_second': 10.499, 'epoch': 2.0}


Non-default generation parameters: {'max_length': 512}


{'loss': 0.193, 'grad_norm': 0.11201786994934082, 'learning_rate': 9.935897435897435e-06, 'epoch': 2.4}


Non-default generation parameters: {'max_length': 512}


  0%|          | 0/30 [00:00<?, ?it/s]

{'eval_loss': 0.1703765094280243, 'eval_accuracy': 0.9578059071729957, 'eval_f1': 0.937539531941809, 'eval_precision': 0.9314968814968815, 'eval_recall': 0.9439572192513369, 'eval_runtime': 2.8398, 'eval_samples_per_second': 83.457, 'eval_steps_per_second': 10.564, 'epoch': 3.0}
{'train_runtime': 197.5076, 'train_samples_per_second': 25.184, 'train_steps_per_second': 3.159, 'train_loss': 0.16685908345075753, 'epoch': 3.0}


TrainOutput(global_step=624, training_loss=0.16685908345075753, metrics={'train_runtime': 197.5076, 'train_samples_per_second': 25.184, 'train_steps_per_second': 3.159, 'total_flos': 875500023730176.0, 'train_loss': 0.16685908345075753, 'epoch': 3.0})

In [9]:

print(tokenized_datasets["test"])



def get_top_attended_tokens(input_ids, attention_weights, top_k=5):
    # Sum attention weights across all layers and heads
    aggregated_attentions = attention_weights.sum(dim=(0, 1)).squeeze()
    
    # Get the top-k attended token indices
    _, top_indices = torch.topk(aggregated_attentions, k=top_k)
    
    # Convert token IDs to tokens
    squezed_top_indices = top_indices.flatten()
    #print(squezed_top_indices)
    #print(input_ids[0])
    top_tokens = tokenizer.convert_ids_to_tokens([input_ids[i] for i in squezed_top_indices])
    
    return top_tokens


# Set the model to evaluation mode
model.eval()

# Disable gradient calculations for inference
with torch.no_grad():
    for example in tokenized_datasets["test"].select([0,1,2,3,4,5]):
        # Prepare inputs
        inputs = {
            'input_ids': torch.tensor(example['input_ids']).unsqueeze(0).to(device),
            'attention_mask': torch.tensor(example['attention_mask']).unsqueeze(0).to(device),
        }
        
        # Handle labels
        if 'label' in example:
            labels = torch.tensor([example['label']]).to(device)
        else:
            labels = None
        
        # Forward pass
        outputs = model(**inputs, labels=labels, output_attentions=True)
        
        # Process the outputs
        attentions = outputs.attentions
        
        # Stack all attention layers
        all_attentions = torch.stack(attentions)
        
        # Get top attended tokens
        #print(inputs['input_ids'][0])
        top_tokens = get_top_attended_tokens(inputs['input_ids'][0], all_attentions, top_k=5)
        
        print(f"Top 5 attended tokens: {top_tokens}")
        
        # If you want to see these tokens in context:
        full_text = tokenizer.decode(inputs['input_ids'][0])
        print(f"Full text: {full_text}")
        
        # You can access other features if needed for analysis
        headline = example['headline']
        text = example['text']
        headline_text = example['headline_text']
        url = example['url']
        
        # Do something with the attentions and other data
        # For example, you might want to analyze attention patterns
        # in relation to the headline or full text

Dataset({
    features: ['label', 'headline', 'text', 'headline_text', 'url', 'input_ids', 'attention_mask'],
    num_rows: 476
})
Top 5 attended tokens: ['<s>', '</s>', '.', '.', '.', '<s>', '▁', '</s>', '▁kura', '▁kisiasa', '<s>', '▁Siku', '</s>', '1', '0', '<s>', '0', '</s>', '▁rais', '▁Siku', '<s>', '0', '</s>', '▁Siku', '1', '<s>', '▁za', '</s>', '1', '▁Siku', '<s>', '▁utawala', '▁Siku', '</s>', '▁kura', '<s>', '</s>', '▁rais', '▁kisiasa', '▁utawala', '<s>', '▁utawala', '▁Ruto', '</s>', '▁rais', '<s>', '▁Ruto', '</s>', '▁utawala', '▁rais', '<s>', '▁U', '▁rais', '</s>', '▁Ruto', '<s>', 'zito', '</s>', '▁rais', '▁Ruto', '<s>', '▁kwa', '▁rais', '</s>', '▁kisiasa', '<s>', '▁rais', '</s>', 'zito', '▁kwa', '<s>', '▁rais', '▁Ruto', '▁wa', '▁kwa', '<s>', '▁rais', '▁Kenya', '</s>', '▁kwa', '<s>', '▁rais', '</s>', '▁kisiasa', '▁urais', '<s>', 'ambatanish', '▁rais', '</s>', '▁kwa', '<s>', '▁rais', '</s>', 'a', '▁ku', '<s>', '▁maneno', '▁ku', '</s>', 'a', '<s>', '▁na', '▁kisiasa', '</s>', '▁r