In [None]:
# STEP 1: Install Required Libraries
!pip install transformers datasets evaluate -q

In [None]:
import pandas as pd
import numpy as np
import re
import torch
import torch.nn as nn
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import StratifiedKFold
from imblearn.over_sampling import RandomOverSampler
import evaluate
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback,
)

In [None]:
# Load dataset
df = pd.read_csv("/content/drive/MyDrive/ANLP Project/new dataset 3.csv")

# Clean and normalize text
amharic_punctuations = ['፣', '፤', '፡',  '።','፥', '፦', '፧', '፨']
amharic_punct_pattern = "[" + re.escape("".join(amharic_punctuations)) + "]"

def clean_text(text):
    text = re.sub(r"http\S+|www\S+|https\S+", "", text)
    text = re.sub(r"[a-zA-Z0-9]+", "", text)
    text = re.sub(amharic_punct_pattern, "", text)
    text = re.sub(r"[^\u1200-\u137F\s]", "", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text

normalization_map = {
    "ሃ": "ሀ", "ኅ": "ሀ", "ኁ": "ሁ", "ኂ": "ሂ", "ኃ": "ሀ", "ኄ": "ሄ", "ኅ": "ህ", "ኆ": "ሆ",
    "ሐ": "ሀ", "ሓ": "ሀ", "ኻ": "ሀ", "ሑ": "ሁ", "ሒ": "ሂ", "ሔ": "ሄ", "ሕ": "ህ", "ሖ": "ሆ",
    "ሠ": "ሰ", "ሡ": "ሱ", "ሢ": "ሲ", "ሣ": "ሳ", "ሤ": "ሴ", "ሥ": "ስ", "ሦ": "ሶ",
    "ኣ": "አ", "ዐ": "አ", "ዑ": "ኡ", "ዒ": "ኢ", "ዓ": "ኣ", "ዔ": "ኤ", "ዕ": "እ", "ዖ": "ኦ",
    "ጸ": "ፀ", "ጹ": "ፁ", "ጺ": "ፂ", "ጻ": "ፃ", "ጼ": "ፄ", "ጽ": "ፅ", "ጾ": "ፆ"
}

def normalize_amharic(text):
    for char, norm in normalization_map.items():
        text = text.replace(char, norm)
    return text

amharic_stopwords = ["እንዴት", "ያለ", "እውነት", "ስለዚህ", "በጣም", "አይደለም", "እና", "ሁሉ", "ነው", "አንድ", "ላይ", "ወደ", "በመካከል"]

def remove_stopwords(text):
    return ' '.join([word for word in text.split() if word not in amharic_stopwords])

df["cleaned"] = df["cleaned_text"].apply(clean_text)
df["normalized"] = df["cleaned"].apply(normalize_amharic)
df["no_stopwords"] = df["normalized"].apply(remove_stopwords)

# Encode labels
label_encoder = LabelEncoder()
df["label"] = label_encoder.fit_transform(df["sentiment"])
num_labels = df["label"].nunique()

In [None]:
# Class weights
class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(df["label"]), y=df["label"])
class_weights = torch.tensor(class_weights, dtype=torch.float)

# Tokenizer and model checkpoint
model_checkpoint = "Davlan/afro-xlmr-base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

# Metric functions
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")

def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    return {
        "accuracy": accuracy.compute(predictions=preds, references=p.label_ids)["accuracy"],
        "f1": f1.compute(predictions=preds, references=p.label_ids, average="weighted")["f1"]
    }

# Tokenization function
def tokenize_function(example):
    return tokenizer(example["cleaned_text"], truncation=True, padding="max_length", max_length=128)

# Custom Trainer with class weights
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        loss_fct = nn.CrossEntropyLoss(weight=class_weights.to(model.device))
        loss = loss_fct(logits, labels)
        return (loss, outputs) if return_outputs else loss

In [None]:
def plot_confusion_matrix(true_labels, pred_labels, classes, fold=None):
    cm = confusion_matrix(true_labels, pred_labels, labels=range(len(classes)))
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=classes, yticklabels=classes)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    title = "Confusion Matrix"
    if fold is not None:
        title += f" - Fold {fold+1}"
    plt.title(title)
    plt.show()

In [None]:
kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
fold_results = []
all_true_labels = []
all_pred_labels = []

for fold, (train_idx, val_idx) in enumerate(kfold.split(df["no_stopwords"], df["label"])):
    print(f"\n========== Fold {fold + 1} ==========")

    train_df = df.iloc[train_idx].copy()
    val_df = df.iloc[val_idx].copy()

    # Random oversampling
    ros = RandomOverSampler(random_state=42)
    X_train_resampled, y_train_resampled = ros.fit_resample(
        train_df[["no_stopwords"]], train_df["label"]
    )
    train_df_resampled = pd.DataFrame({
        "no_stopwords": X_train_resampled["no_stopwords"].values,
        "label": y_train_resampled
    })

    # Prepare HF Datasets
    train_dataset = Dataset.from_pandas(train_df_resampled.rename(columns={"no_stopwords": "cleaned_text"}))
    val_dataset = Dataset.from_pandas(val_df[["no_stopwords", "label"]].rename(columns={"no_stopwords": "cleaned_text"}))

    train_dataset = train_dataset.map(tokenize_function, batched=True)
    val_dataset = val_dataset.map(tokenize_function, batched=True)

    model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

    training_args = TrainingArguments(
        output_dir=f"./results/fold_{fold+1}",
        eval_strategy="epoch",
        save_strategy="epoch",
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=10,
        logging_dir=f"./logs/fold_{fold+1}",
        logging_steps=10,
        #learning_rate=2e-5,
        save_total_limit=1,
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        greater_is_better=True,
    )

    early_stopping = EarlyStoppingCallback(early_stopping_patience=2)

    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
        callbacks=[early_stopping],
    )

    trainer.train()

    metrics = trainer.evaluate()
    fold_results.append(metrics)
    print(f"Fold {fold+1} - Accuracy: {metrics['eval_accuracy']:.4f}, F1: {metrics['eval_f1']:.4f}")

    # Plot training curves
    log_history = trainer.state.log_history
    epochs = []
    train_losses = []

    eval_epochs = []
    eval_losses = []
    eval_accuracies = []

    for log in log_history:
        if "epoch" in log:
            if "loss" in log:
                epochs.append(log["epoch"])
                train_losses.append(log["loss"])
            if "eval_loss" in log:
                eval_epochs.append(log["epoch"])
                eval_losses.append(log["eval_loss"])
            if "eval_accuracy" in log:
                eval_accuracies.append(log["eval_accuracy"])

    # Plotting
    plt.figure(figsize=(12, 4))

    # Loss Curve
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label="Train Loss")
    if eval_losses:
        plt.plot(eval_epochs, eval_losses, label="Eval Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(f"Loss Curve - Fold {fold+1}")
    plt.legend()

    # Accuracy Curve
    plt.subplot(1, 2, 2)
    if eval_accuracies:
        plt.plot(eval_epochs, eval_accuracies, label="Eval Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title(f"Accuracy Curve - Fold {fold+1}")
    plt.legend()

    plt.show()

    # Confusion matrix for validation
    val_preds_output = trainer.predict(val_dataset)
    val_preds = np.argmax(val_preds_output.predictions, axis=1)
    val_labels = val_preds_output.label_ids

    plot_confusion_matrix(val_labels, val_preds, classes=label_encoder.classes_, fold=fold)

    all_true_labels.extend(val_labels)
    all_pred_labels.extend(val_preds)

In [None]:
import nbformat

notebook_path = "your_notebook.ipynb"
nb = nbformat.read(notebook_path, as_version=4)

# Remove broken widget metadata
for cell in nb.cells:
    if "widgets" in cell.get("metadata", {}):
        cell["metadata"].pop("widgets", None)

nbformat.write(nb, notebook_path)
print("Notebook cleaned from broken widget metadata.")
