# Fine-tuning
In this file we will fine-tune Encoder models, like BERET, RoBERTa, etc., on the corpus of PAN. Then, evaluate them on the test set.

In [None]:
!pip install wandb transformers datasets

In [None]:
import os
import torch
from transformers import BertTokenizer, BertModel, DataCollatorWithPadding
from transformers import TrainingArguments
from transformers import Trainer
from transformers import AutoModelForSequenceClassification
import torchmetrics
import wandb

import json
import numpy as np
import time
from datasets import Dataset


from utilities import (read_paragraphs,
                       read_ground_truth, 
                       generate_dataset)


In [None]:

# 更新路径 Path prefixes
train_directory = './data/train_processed'
train_label_directory = './data/train_label'

# Due to the lack of the true test set. We use the validation set as our test set.
# We will split the training set into train and validation sets.
test_directory = './data/validation_processed'
test_label_directory = './data/validation_label'

checkpoint = 'bert-base-cased' #### 改这里
run_name = 'multi_author_analyse_' + checkpoint
# 读取段落数据
# Read documents
# max(end_id) = 4200
train_data = read_paragraphs(train_directory, start_id=1, end_id=4200) # {'problem-x': [sen 1, sen 2, ...], ...}
# max(end_id) = 900
test_data = read_paragraphs(test_directory, start_id=1, end_id=900)
# 读取 ground truth 数据
# Read ground truth labels
train_labels = read_ground_truth(train_label_directory, start_id=1, end_id=4200) # {'problem-x': [1, ...], ...}
test_labels  = read_ground_truth(test_label_directory, start_id=1, end_id=900)

# for doc_id, paragraphs in train_data.items():
#     print(f"{doc_id}: {paragraphs}")
#     print(train_labels[doc_id])



In [None]:
tokenizer = BertTokenizer.from_pretrained(checkpoint)
tokenizer.model_max_length

In [None]:
train_dataset = generate_dataset(train_data, train_labels, tokenizer)
test_dataset = generate_dataset(test_data, test_labels, tokenizer)

training_sets = train_dataset.train_test_split(train_size=0.8, seed=42)
# Rename the default "test" split to "validation"
training_sets["validation"] = training_sets.pop("test")
# Add the "test" set to our `DatasetDict`
training_sets["test"] = test_dataset

training_sets 

In [None]:
# doc = training_sets['train']['idx'][0]
# for sen in train_data[doc]:
#     print(sen)
# print(train_labels[doc])
# print('\n')

# print(training_sets['train']['sentence1'][0])
# print(training_sets['train']['sentence2'][0])
# print(training_sets['train']['label'][0])
# print(training_sets['train']['idx'][0])

In [None]:
def tokenize_function(example):
    return tokenizer(example["sentence1"], example["sentence2"], 
                     truncation=True)

tokenized_datasets = training_sets.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
tokenized_datasets

In [None]:
wandb.login()
wandb.init(project="Multi_author") 
# On the full dataset: with batch size as 24 and gradient_accumulation_steps=4, 
# there will be 546 training steps and 182 validation steps
training_args = TrainingArguments(
    output_dir=f"finetuned-{checkpoint}",
    evaluation_strategy = "steps",
    eval_steps=10,
    gradient_accumulation_steps=4,
    save_steps=50,
    learning_rate=2e-5,
    per_device_train_batch_size=24,
    per_device_eval_batch_size=24,
    num_train_epochs=3,
    weight_decay=0.01,
    report_to="wandb",  # enable logging to W&B
    run_name=run_name,  # name of the W&B run (optional)
    logging_steps=2,  # how often to log to W&B
)

model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

In [None]:
my_metrics = {"F1": torchmetrics.classification.F1Score(task='binary', num_classes=2, average="macro"), 
            'Accuracy': torchmetrics.classification.BinaryAccuracy()}

def compute_metrics(eval_preds):
    logits, labels = eval_preds
    labels = torch.from_numpy(labels)
    predictions = torch.from_numpy(np.argmax(logits, axis=-1))
    eval_result = {}
    for key, me in my_metrics.items():
        eval_result[key] = me(predictions, labels).item()
    return eval_result

In [None]:
trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()
wandb.finish()


In [None]:
trainer.evaluate(tokenized_datasets['validation'])

In [None]:
trainer.evaluate(tokenized_datasets['test'])