Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Heads not working with adapters #680

Closed
san-deep-reddy opened this issue Apr 15, 2024 Discussed in #679 · 0 comments · Fixed by #700
Closed

Custom Heads not working with adapters #680

san-deep-reddy opened this issue Apr 15, 2024 Discussed in #679 · 0 comments · Fixed by #700
Assignees

Comments

@san-deep-reddy
Copy link

san-deep-reddy commented Apr 15, 2024

Discussed in #679

Originally posted by san-deep-reddy April 14, 2024
I have tried many models and adapter types, custom head types and configs but I always end up with the same error -

 File "testing\adapter2_testing.py", line 173, in <module>
    trainer.train()
  File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\site-packages\transformers\trainer.py", line 1537, in train
    return inner_training_loop(
  File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\site-packages\transformers\trainer.py", line 1772, in _inner_training_loop
    self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
  File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\site-packages\transformers\trainer_callback.py", line 370, in on_train_begin
    return self.call_event("on_train_begin", args, state, control)
  File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\site-packages\transformers\trainer_callback.py", line 414, in call_event
    result = getattr(callback, event)(
  File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\site-packages\transformers\integrations\integration_utils.py", line 635, in on_train_begin
    model_config_json = model.config.to_json_string()
  File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\site-packages\transformers\configuration_utils.py", line 951, in to_json_string
    return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
  File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\json\__init__.py", line 234, in dumps
    return cls(
  File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\json\encoder.py", line 201, in encode
    chunks = list(chunks)
  File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\json\encoder.py", line 431, in _iterencode
    yield from _iterencode_dict(o, _current_indent_level)
  File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\json\encoder.py", line 405, in _iterencode_dict
    yield from chunks
  File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\json\encoder.py", line 405, in _iterencode_dict
    yield from chunks
  File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\json\encoder.py", line 438, in _iterencode
    o = _default(o)
  File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\json\encoder.py", line 179, in default
    raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type type is not JSON serializable

Here is my simple code

import adapters, torch
import torch.nn as nn
import numpy as np
from adapters.heads import PredictionHead
from adapters import AutoAdapterModel, AdapterTrainer, SeqBnConfig
from transformers import AutoTokenizer, TrainingArguments, EvalPrediction
from datasets import load_dataset

model_path = "bert-base-uncased"
model = AutoAdapterModel.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

class CustomHead(PredictionHead):
    """ Same as ClassificationHead """ 
    def __init__(self, model, head_name, **config):
        super().__init__(head_name)
        self.config = config
        self.build(model=model)
    def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=False, **kwargs):
        if cls_output is None:
            cls_output = self._get_cls_output(outputs, **kwargs)
        logits = super().forward(cls_output)
        loss = None
        loss_fct = nn.CrossEntropyLoss()
        labels = kwargs.pop("labels", None)
        loss = loss_fct(logits.view(-1, self.config["num_labels"]), labels.view(-1))
        outputs = (logits,) + outputs[1:]
        if labels is not None:
            outputs = (loss,) + outputs
        return outputs

seq_config = SeqBnConfig(reduction_factor=16, use_gating=True)
model.add_adapter("adapter2", config=seq_config)
model.delete_head('default')
model.register_custom_head("my_custom_head", CustomHead)
config = {"num_labels": 2, "layers": 1, "activation_function": "tanh"}
model.add_custom_head("my_custom_head", "adapter2", **config)
model.train_adapter(['adapter2'])
model.set_active_adapters(['adapter2'])     #This line is redundant


def encode_batch(batch):
    """Encodes a batch of input data using the model tokenizer."""
    return tokenizer(batch["text"], max_length=512, padding=True, truncation=True, return_tensors="pt")

dataset = load_dataset("rotten_tomatoes")
dataset = dataset.map(encode_batch, batched=True)
dataset = dataset.rename_column(original_column_name="label", new_column_name="labels")
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=1,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    learning_rate=2e-5,
    evaluation_strategy="epoch",
    report_to="tensorboard"
)
def compute_accuracy(p: EvalPrediction):
  preds = np.argmax(p.predictions, axis=1)
  return {"acc": (preds == p.label_ids).mean()}

trainer = AdapterTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    compute_metrics=compute_accuracy,
)

trainer.train()

model.add_classification_head works fine though.

@calpt calpt linked a pull request Jun 8, 2024 that will close this issue
calpt pushed a commit that referenced this issue Jun 20, 2024
To make the model_config serializable and prevent the error mentioned in
#680 move the costum_heads dictionary out of the config and make it a
separate attribute of the model class.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants