In [None]:
from datasets import load_dataset

news_datasets = load_dataset("indic_glue","bbca.hi")

In [None]:
news_datasets

In [None]:
news_datasets["train"].set_format("pandas")

In [None]:
# get label counts for both classes
label_counts = news_datasets["train"]["label"].value_counts()
num_labels = (len(label_counts.keys()))

In [None]:
label_counts

In [None]:
news_datasets["train"].reset_format()

In [None]:
from transformers import set_seed

# set_seed(30)
set_seed(42)

In [None]:
from transformers import PreTrainedTokenizerFast, AutoModelForSequenceClassification

tokenizer = PreTrainedTokenizerFast.from_pretrained("../Hindi Pretraining/models/unigram/bert-base-pretrained-hindi")

In [None]:
model =  AutoModelForSequenceClassification.from_pretrained("../Hindi Pretraining/models/unigram/bert-base-pretrained-hindi", num_labels=num_labels)

In [None]:
# num_added_tokens = tokenizer.add_tokens(["5","7","8","9"])

In [None]:
# Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.
# model.resize_token_embeddings(len(tokenizer))

In [None]:
def tokenize_function(example):
    #return tokenizer(example["text"], truncation=True)
    return tokenizer(example["text"], truncation=True, max_length=128)

In [None]:
from transformers import DataCollatorWithPadding

tokenized_datasets = news_datasets.map(tokenize_function, batched=True, remove_columns=["text"])
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
tokenized_datasets

In [None]:
for sample in temp["train"]:
    print(tokenizer.decode(sample["input_ids"]))

In [None]:
def assign_label(example):
    mapping = {
        "india": 0,
        "international": 1,
        "entertainment": 2,
        "sport":3,
        "news": 4,
        "science": 5,
        "business": 6,
        "pakistan": 7,
        "southasia":8,
        "institutional":  9,
        "social":10,
        "china": 11,
        "multimedia":  12,
        "learningenglish": 13
    }
    example['label'] = mapping[example['label']]
    return example

In [None]:
tokenized_datasets = tokenized_datasets.map(assign_label)
tokenized_datasets.set_format("torch")
tokenized_datasets.column_names

In [None]:
samples = [tokenized_datasets["train"][i] for i in range(5)]
samples

for chunk in data_collator(samples)["input_ids"]:
    print(f"\n'>>> {tokenizer.decode(chunk)}'")

In [None]:
news_datasets["train"][:5]

In [None]:
from torch.utils.data import DataLoader
# batch_size = 16
batch_size = 32

train_dataloader = DataLoader(
    tokenized_datasets["train"], shuffle=True, batch_size=batch_size, collate_fn=data_collator
)
eval_dataloader = DataLoader(
    tokenized_datasets["test"], batch_size=batch_size, collate_fn=data_collator
)

In [None]:
for batch in train_dataloader:
    break
{k: v.shape for k, v in batch.items()}

In [None]:
import torch
with torch.no_grad():
    outputs = model(**batch)
    print(outputs.loss, outputs.logits.shape)

In [None]:
import numpy as np
import evaluate

metric_fun = evaluate.load("accuracy")

def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    metric_result = metric_fun.compute(references=labels, predictions=predictions)
    return {
        "accuracy": metric_result["accuracy"],
    }

In [None]:
#  disable weights and biases logging
import os
os.environ["WANDB_DISABLED"] = "true"

In [None]:
from transformers import TrainingArguments

# batch_size = 16
batch_size = 32
# Show the training loss with every epoch
logging_steps = len(tokenized_datasets["train"]) // batch_size


training_args = TrainingArguments(
    report_to = None,
    output_dir="models/bbc-classifier",
    overwrite_output_dir=True,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    #learning_rate=4e-5,
    learning_rate=3e-5,
    weight_decay=0.01,
    #weight_decay=0.02,
    warmup_ratio = 0.05,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    #num_train_epochs=6,
    num_train_epochs=4,
    #push_to_hub=True,
    fp16=True,
    logging_steps=logging_steps,
)

In [None]:
# from datasets import concatenate_datasets

# entire_train = concatenate_datasets([tokenized_datasets["train"], tokenized_datasets["validation"]]) 

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    #train_dataset=entire_train,
    eval_dataset=tokenized_datasets["test"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
trainer.evaluate(tokenized_datasets["test"])

In [None]:
trainer.save_model()

In [None]:
y_preds, y_true, _ = trainer.predict(tokenized_datasets["test"])

In [None]:
y_preds = np.argmax(y_preds, axis=-1)

In [None]:
from sklearn.metrics import classification_report
target_names = ["india",
        "international",
        "entertainment",
        "sport",
        "news",
        "science",
        "business",
        "pakistan",
        "southasia",
        "institutional",
        "social",
        "china",
        "multimedia",
        "learningenglish"]

print(classification_report(y_true, y_preds,target_names=target_names))

In [None]:
import matplotlib.pyplot as plt
from seaborn import heatmap
from sklearn.metrics import confusion_matrix

#plot heatmap of confusion matrix
mat = confusion_matrix(y_true, y_preds)
heatmap(mat, cmap="Pastel1_r", fmt="d", xticklabels=target_names, yticklabels=target_names, annot=True)

#add overall title to plot
plt.title('Confusion matrix for AC', fontsize = 12) # title with fontsize 20

In [None]:
misclassified = [i for i in range(len(y_preds)) if ((y_preds[i] != y_true[i]) and (y_true[i]==4) and (y_preds[i]==1))]

In [None]:
misclassified = news_datasets['test'].select(misclassified)

In [None]:
misclassified[:]

In [None]:
model = AutoModelForSequenceClassification.from_pretrained("models/bert-unigram-bengali-classifier")
model.to("cuda")

In [None]:
trainer.evaluate(tokenized_datasets["validation"])