## import libraries and dataset

In [None]:
import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer, EarlyStoppingCallback
from datasets import Dataset

dataset = pd.read_csv(f"../dataset/preprocess/github-labels-top3-803k-100.0%.csv")
print(dataset.issue_label.value_counts())
print(dataset.info())

## preprocess the dataset

### process the dataset

In [None]:
dataset = dataset.rename(columns={"issue_label": "label"})
# label required to be int64, otherwise will get error during training
dataset['label'] = dataset['label'].astype('int64')

pd.set_option('future.no_silent_downcasting', True)

### tokenize the data



In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

encodings = tokenizer(
    dataset.text.values.tolist(),
    max_length = 128,
    truncation = True,
    padding="max_length",
    return_attention_mask=True,
    return_token_type_ids=True,
    return_tensors='pt'
)

print(encodings)

### split the data

In [None]:
trainer_dataset = Dataset.from_dict({
    **encodings,
    "label": dataset.label.values
})

trainer_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

print(trainer_dataset)

# how to do stratify
split_data = trainer_dataset.train_test_split(test_size=0.15)
train_set = split_data['train']
test_set = split_data['test']

## run the training

In [None]:
model = BertForSequenceClassification.from_pretrained("google-bert/bert-base-uncased", num_labels=3)

# Define training args
training_args = TrainingArguments(
    output_dir="../dataset",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    eval_strategy="epoch",
    logging_strategy="steps",
    logging_steps=10,
    num_train_epochs=50,
    save_strategy="epoch",
    load_best_model_at_end=True,
    report_to="none",
    metric_for_best_model="accuracy",
    greater_is_better=True,
)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = logits.argmax(axis=-1)

    return {
        "accuracy": accuracy_score(labels, predictions),
        "precision": precision_score(labels, predictions, average="weighted"),
        "recall": recall_score(labels, predictions, average="weighted"),
        "f1": f1_score(labels, predictions, average="weighted"),
    }

# Define trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=test_set,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

# Train the model
trainer.train()

# save model
# trainer.save_model("trained_electra_model")

# save tokenizer
# tokenizer.save_pretrained("")