In [None]:
import os
import pandas as pd
import torch
from transformers import RobertaModel, RobertaTokenizer, TrainingArguments, Trainer, DataCollatorWithPadding, RobertaForSequenceClassification
from peft import LoraConfig, get_peft_model, PeftModel
from datasets import load_dataset, Dataset, ClassLabel
import pickle
import wandb
from lion_pytorch import Lion
from transformers import get_scheduler
from torch.optim import AdamW, Adam
from datasets import logging as datasets_logging
from transformers import logging as transformers_logging
from config import structure_config

In [None]:
bs = 16 # 16 / 32 / 64 
max_step = 2500
# should be less than 6000 in current setting  
optim = 2
# 1 -> Adam lr: 2e-5~3e-5;
# 2 -> AdamW lr: 1e-4~5e-5~1e-5 weight_decay=0.01~0.05~0.1;
# 3 -> Lion: 2e-4~1e-4~5e-5 weight_decay=0.01/0.02
scheduler = "cosine"
# "cosine" / "linear"
lr = 5e-5
w_decay = 2e-2 # 1e-2 ~ 5e-2
epoch = 1

model_structure = 1
# ref: config.py

exp_name = f"s-{model_structure}-bs{bs}-optim{optim}-lr{lr}"

# Dataset

In [None]:
datasets_logging.set_verbosity_error()
base_model = 'roberta-base'
dataset = load_dataset('ag_news', split='train', cache_dir = './data')
tokenizer = RobertaTokenizer.from_pretrained('./roberta-base')
def preprocess(examples):
    tokenized = tokenizer(examples['text'], truncation=True, padding=True)
    return tokenized
tokenized_dataset = dataset.map(preprocess, batched=True,  remove_columns=["text"])
tokenized_dataset = tokenized_dataset.rename_column("label", "labels")

split_datasets = tokenized_dataset.train_test_split(test_size=24000)
train_dataset = split_datasets['train']
eval_dataset = split_datasets['test']

# Extract the number of classes and their names
num_labels = dataset.features['label'].num_classes
class_names = dataset.features["label"].names
# the labels: ['World', 'Sports', 'Business', 'Sci/Tech']

# Create an id2label mapping
# We will need this for our classifier.
id2label = {i: label for i, label in enumerate(class_names)}
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")

# Model Structure

In [None]:
transformers_logging.set_verbosity_error()
layer_configs = structure_config(model_structure)
model = RobertaForSequenceClassification.from_pretrained(
    base_model,
    id2label=id2label)
# print(model)

for config in layer_configs:
    model = RobertaForSequenceClassification.from_pretrained(
        base_model,
        id2label=id2label
    )

    target_layer = f"encoder.layer.{config['layer']}.attention.self.query"
    peft_config = LoraConfig(
        r=config["r"],
        lora_alpha=config["lora_alpha"],
        lora_dropout=config["lora_dropout"],
        bias="none",
        target_modules=config["target"],
        task_type="SEQ_CLS",
    )
    print(f"Applying LoRA to {target_layer}")
    model = get_peft_model(model, peft_config) 

peft_model = model
print(peft_model)
peft_model.print_trainable_parameters()

# Training Setting

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from transformers import EarlyStoppingCallback

if optim == 2:
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=w_decay)
elif optim == 3:
    optimizer = Lion(model.parameters(),lr=lr,weight_decay=w_decay)
else:
    optimizer = Adam(model.parameters(), lr=lr, weight_decay=w_decay)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    accuracy = accuracy_score(labels, preds)
    return {
        'accuracy': accuracy
    }

# Setup Training args
wandb.login()
output_dir = "results"
training_args = TrainingArguments(
    output_dir=output_dir,
    report_to='wandb',
    run_name = exp_name,
    eval_strategy='steps',
    logging_steps=100,
    save_steps=100,
    save_total_limit=1,
    metric_for_best_model='accuracy',
    greater_is_better=True,
    load_best_model_at_end=True,
    learning_rate=lr,
    num_train_epochs=epoch,
    max_steps= max_step,
    use_cpu=False,
    dataloader_num_workers=64,
    per_device_train_batch_size=bs, # total trained samples: batch_size*max_steps
    per_device_eval_batch_size=256,
    optim="sgd", # ignore
    gradient_checkpointing=False,
    gradient_checkpointing_kwargs={'use_reentrant':True}
)

scheduler = get_scheduler(
    name=scheduler,
    optimizer=optimizer,
    num_warmup_steps=int(0.1 * training_args.max_steps),  
    num_training_steps=training_args.max_steps
)
# normally, num_training_steps = num_epochs * dataset_size / batch_size

def get_trainer(model_):
    early_stopping = EarlyStoppingCallback(
        early_stopping_patience=10,  
        early_stopping_threshold=0.001  
    )
    return  Trainer(
      model=model_,
      args=training_args,
      compute_metrics=compute_metrics,
      train_dataset=train_dataset,
      eval_dataset=eval_dataset,
      data_collator=data_collator,
      optimizers=(optimizer, scheduler),
      callbacks=[early_stopping])

peft_lora_finetuning_trainer = get_trainer(peft_model)

# Training

In [None]:
result = peft_lora_finetuning_trainer.train()

# Local AGNEWS test

In [None]:
from torch.utils.data import DataLoader
import evaluate
from tqdm import tqdm
import importlib.util

def evaluate_model(inference_model, dataset, labelled=True, batch_size=8, data_collator=None):
    """
    Evaluate a PEFT model on a dataset.

    Args:
        inference_model: The model to evaluate.
        dataset: The dataset (Hugging Face Dataset) to run inference on.
        labelled (bool): If True, the dataset includes labels and metrics will be computed.
                         If False, only predictions will be returned.
        batch_size (int): Batch size for inference.
        data_collator: Function to collate batches. If None, the default collate_fn is used.

    Returns:
        If labelled is True, returns a tuple (metrics, predictions)
        If labelled is False, returns the predictions.
    """
    # Create the DataLoader
    eval_dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=data_collator)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    inference_model.to(device)
    inference_model.eval()

    all_predictions = []
    if labelled:
        module_path = "./metrics/accuracy.py"
        spec = importlib.util.spec_from_file_location("accuracy", module_path)
        accuracy = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(accuracy)
        metric = accuracy.Accuracy()
        # metric = evaluate.load('./metrics', module_type="metric")

    # Loop over the DataLoader
    for batch in tqdm(eval_dataloader, mininterval=0.5):
        # Move each tensor in the batch to the device
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = inference_model(**batch)
        predictions = outputs.logits.argmax(dim=-1)
        all_predictions.append(predictions.cpu())

        if labelled:
            # Expecting that labels are provided under the "labels" key.
            references = batch["labels"]
            metric.add_batch(
                predictions=predictions.cpu().numpy(),
                references=references.cpu().numpy()
            )

    # Concatenate predictions from all batches
    all_predictions = torch.cat(all_predictions, dim=0)

    if labelled:
        eval_metric = metric.compute()
        print("Evaluation Metric:", eval_metric)
        return eval_metric, all_predictions
    else:
        return all_predictions

best_model = peft_lora_finetuning_trainer.model
_, _ = evaluate_model(best_model, eval_dataset, True, 256, data_collator)

# Kaggle Testset Prediction

In [None]:
unlabelled_dataset = pd.read_pickle("test_unlabelled.pkl")
test_dataset = unlabelled_dataset.map(preprocess, batched=True, remove_columns=["text"])

# Run inference and save predictions
preds = evaluate_model(best_model, test_dataset, False, 256, data_collator)
df_output = pd.DataFrame({
    'ID': range(len(preds)),
    'Label': preds.numpy()  # or preds.tolist()
})
df_output.to_csv(os.path.join(output_dir,"inference_output.csv"), index=False)
print("Inference complete. Predictions saved to inference_output.csv")