# AdapterFusion for Sequence Classification

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Wicwik/peft_tutorial/blob/main/examples/adapter_fusion.ipynb)

In [None]:
%pip install -q --user transformers==4.35.2
%pip install -q --user datasets
%pip install -q --user adapters
%pip install -q --user wandb

In [None]:
import torch
import evaluate
import wandb

from datasets import load_dataset
from transformers import BertTokenizer, BertConfig, TrainingArguments, default_data_collator

from adapters import BertAdapterModel, AdapterTrainer
from adapters.composition import Fuse

In [None]:
device = "cuda"
model_name_or_path = "bert-base-uncased"
tokenizer_name_or_path = "bert-base-uncased"

max_length = 180
lr = 1e-3
num_epochs = 3
batch_size = 32 # in case of "unable to allocate" errors, decrease batch size to some lower number (e.g. 8 or 16) 

In [None]:
dataset = load_dataset("super_glue", "cb")

# test set is not labeled so we need to do custom splits
validtest = dataset["validation"].train_test_split(test_size=0.5)

dataset["validation"] = validtest["train"]
dataset["test"] = validtest["test"]

dataset["train"][0]

In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def preprocess_function(examples):
  return tokenizer(examples["premise"], examples["hypothesis"], max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")


dataset = dataset.map(preprocess_function, batched=True)
processed_datasets = dataset.map(
    preprocess_function,
    batched=True,
    num_proc=1,
    remove_columns=["premise", "hypothesis", "idx"],
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

processed_datasets = processed_datasets.rename_column("label", "labels")

train_dataset = processed_datasets["train"].shuffle()
eval_dataset = processed_datasets["validation"]
test_dataset = processed_datasets["test"]


In [None]:
id2label = {id: label for (id, label) in enumerate(processed_datasets["train"].features["labels"].names)}

config = BertConfig.from_pretrained(model_name_or_path, id2label=id2label)
model = BertAdapterModel.from_pretrained(model_name_or_path, config=config)


model.load_adapter("nli/multinli@ukp", load_as="multinli", with_head=False)
model.load_adapter("sts/qqp@ukp", with_head=False)
model.load_adapter("nli/qnli@ukp", with_head=False)

model.add_adapter_fusion(Fuse("multinli", "qqp", "qnli"))
model.set_active_adapters(Fuse("multinli", "qqp", "qnli"))

model.add_classification_head("cb", num_labels=len(id2label))

adapter_setup = Fuse("multinli", "qqp", "qnli")
model.train_adapter_fusion(adapter_setup)

print(model.adapter_summary())

model

In [None]:
metric = evaluate.load("super_glue", "cb")

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    preds = preds.argmax(axis=1)

    return metric.compute(predictions=preds, references=labels)

training_args = TrainingArguments(
    "out",
    per_device_train_batch_size=batch_size,
    learning_rate=lr,
    num_train_epochs=num_epochs,
    evaluation_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="no",
)

In [None]:
trainer = AdapterTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
    compute_metrics=compute_metrics,
)
trainer.train()

trainer.evaluate(eval_dataset=test_dataset, metric_key_prefix="test")

if wandb.run is not None:
    wandb.finish()

In [None]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"

peft_model_id = f"{model_name_or_path}_adapterfusion_seqcls"

model.save_adapter_fusion(peft_model_id, "multinli,qqp,qnli")
model.save_all_adapters(peft_model_id)

ckpt = f"{peft_model_id}/model.safetensors"
!du -h $ckpt

In [None]:
model = BertAdapterModel.from_pretrained(peft_model_id)
model.set_active_adapters(Fuse("multinli", "qqp", "qnli"))

print(model.active_adapters)

inputs = tokenizer("A pity. For myself, a great pity. But no one can say Bishop Malduin has not received latitude.", 
                   "Bishop Malduin has not received latitude", 
                   return_tensors="pt"
                   )
print(inputs)
with torch.no_grad():
    logits = model(**inputs, head="cb")[0]
    pred_class = id2label[torch.argmax(logits).item()]
    print(pred_class)