Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: mat1 and mat2 shapes cannot be multiplied #162

Closed
chenmingjiong opened this issue Feb 24, 2023 · 8 comments
Closed

RuntimeError: mat1 and mat2 shapes cannot be multiplied #162

chenmingjiong opened this issue Feb 24, 2023 · 8 comments

Comments

@chenmingjiong
Copy link

chenmingjiong commented Feb 24, 2023

I got this error when finetuning "EleutherAI/gpt-j-6B" using load_in_8bit and LoRA on 8×2080ti:
RuntimeError: mat1 and mat2 shapes cannot be multiplied
I'm using data parallelism and not using model parallelism.
The code runs normally when training with 1 gpu, but failed when using accelerate on 8 gpu.

Reproduce steps:
clone this repo: https://github.com/CarperAI/trlx
modify the script: examples/summarize_rlhf/sft/train_gptj_summarize.py

import random
import os
import evaluate
import numpy as np
import torch
import torch.nn as nn
from peft import LoraConfig, get_peft_model 
from summarize_dataset import TLDRDataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    default_data_collator,
)


def set_seed(seed_val=42):
    random.seed(seed_val)
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)


if __name__ == "__main__":
    output_dir = "gptj-supervised-summarize-checkpoint"
    train_batch_size = 4
    gradient_accumulation_steps = 1
    learning_rate = 1e-5
    eval_batch_size = 1
    eval_steps = 500
    max_input_length = 550
    save_steps = 1000
    num_train_epochs = 5
    random.seed(42)
    os.environ["WANDB_DISABLED"] = "true"

    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
    model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", use_cache=False, load_in_8bit=True, device_map={'':torch.cuda.current_device()})
    tokenizer.pad_token = tokenizer.eos_token
    model.resize_token_embeddings(len(tokenizer))
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model.config.end_token_id = tokenizer.eos_token_id
    model.config.pad_token_id = model.config.eos_token_id

    for param in model.parameters():
        param.requires_grad = False  # freeze the model - train adapters later
        if param.ndim == 1:
            # cast the small parameters (e.g. layernorm) to fp32 for stability
            param.data = param.data.to(torch.float32)
    model.gradient_checkpointing_enable()
    model.enable_input_require_grads()

    class CastOutputToFloat(nn.Sequential):
        def forward(self, x): return super().forward(x).to(torch.float32)
    model.lm_head = CastOutputToFloat(model.lm_head)

    config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    )

    model = get_peft_model(model, config)

    # Set up the datasets
    data_path = "CarperAI/openai_summarize_tldr"
    train_dataset = TLDRDataset(
        data_path,
        tokenizer,
        "train",
        max_length=max_input_length,
    )
    dev_dataset = TLDRDataset(
        data_path,
        tokenizer,
        "valid",
        max_length=max_input_length,
    )

    # Set up the metric
    rouge = evaluate.load("rouge")

    def compute_metrics(eval_preds):
        labels_ids = eval_preds.label_ids
        pred_ids = eval_preds.predictions
        pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
        label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
        result = rouge.compute(predictions=pred_str, references=label_str)
        return result

    # Create a preprocessing function to extract out the proper logits from the model output
    def preprocess_logits_for_metrics(logits, labels):
        if isinstance(logits, tuple):
            logits = logits[0]
        return logits.argmax(dim=-1)

    # Prepare the trainer and start training
    training_args = TrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="steps",
        eval_accumulation_steps=1,
        learning_rate=learning_rate,
        per_device_train_batch_size=train_batch_size,
        per_device_eval_batch_size=eval_batch_size,
        gradient_checkpointing=True,
        half_precision_backend="auto",
        fp16=True,
        adam_beta1=0.9,
        adam_beta2=0.95,
        gradient_accumulation_steps=gradient_accumulation_steps,
        num_train_epochs=num_train_epochs,
        warmup_steps=100,
        eval_steps=eval_steps,
        save_steps=save_steps,
        load_best_model_at_end=True,
        logging_steps=50,
        # deepspeed="examples/summarize_rlhf/sft/ds_config_gptj.json",
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=dev_dataset,
        compute_metrics=compute_metrics,
        data_collator=default_data_collator,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    )
    trainer.train()
    trainer.save_model(output_dir)

and run:
accelerate launch --num_processes 8 examples/summarize_rlhf/sft/train_gptj_summarize.py

Full error logs:

RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/chenmingrui/miniconda3/envs/petals/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/home/chenmingrui/miniconda3/envs/petals/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/data/peft/src/peft/peft_model.py", line 519, in forward
    return self.base_model(
  File "/home/chenmingrui/miniconda3/envs/petals/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/chenmingrui/miniconda3/envs/petals/lib/python3.10/site-packages/accelerate/hooks.py", line 156, in new_forward
    output = old_forward(*args, **kwargs)
  File "/data/transformers/src/transformers/models/gptj/modeling_gptj.py", line 842, in forward
    transformer_outputs = self.transformer(
  File "/home/chenmingrui/miniconda3/envs/petals/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/chenmingrui/miniconda3/envs/petals/lib/python3.10/site-packages/accelerate/hooks.py", line 156, in new_forward
    output = old_forward(*args, **kwargs)
  File "/data/transformers/src/transformers/models/gptj/modeling_gptj.py", line 668, in forward
    outputs = torch.utils.checkpoint.checkpoint(
  File "/home/chenmingrui/miniconda3/envs/petals/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 249, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/home/chenmingrui/miniconda3/envs/petals/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 107, in forward
    outputs = run_function(*args)
  File "/data/transformers/src/transformers/models/gptj/modeling_gptj.py", line 664, in custom_forward
    return module(*inputs, use_cache, output_attentions)
  File "/home/chenmingrui/miniconda3/envs/petals/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/chenmingrui/miniconda3/envs/petals/lib/python3.10/site-packages/accelerate/hooks.py", line 156, in new_forward
    output = old_forward(*args, **kwargs)
  File "/data/transformers/src/transformers/models/gptj/modeling_gptj.py", line 301, in forward
    attn_outputs = self.attn(
  File "/home/chenmingrui/miniconda3/envs/petals/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/chenmingrui/miniconda3/envs/petals/lib/python3.10/site-packages/accelerate/hooks.py", line 156, in new_forward
    output = old_forward(*args, **kwargs)
  File "/data/transformers/src/transformers/models/gptj/modeling_gptj.py", line 202, in forward
    query = self.q_proj(hidden_states)
  File "/home/chenmingrui/miniconda3/envs/petals/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/data/peft/src/peft/tuners/lora.py", line 496, in forward
    result = super().forward(x)
  File "/home/chenmingrui/miniconda3/envs/petals/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 242, in forward
    out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
  File "/home/chenmingrui/miniconda3/envs/petals/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 488, in matmul
    return MatMul8bitLt.apply(A, B, out, bias, state)
  File "/home/chenmingrui/miniconda3/envs/petals/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 397, in forward
    output += torch.matmul(subA, state.subB)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (550x11 and 10x4096)
@peakji
Copy link

peakji commented Feb 25, 2023

Exact same error while fine-tuning BLOOMZ: hyperonym/basaran#5

@AttentionAllUNeed
Copy link

Same error while fine-tuning on https://github.com/lvwerra/trl/tree/main/examples/sentiment/scripts/gpt-neox-20b_peft step1.SFT with LoRA. Have you guys ever resolved it? @chenmingjiong @peakji

@peakji
Copy link

peakji commented Apr 7, 2023

I haven't found a solution yet. Now I'm using half-precision instead of int8.

@pseudotensor
Copy link

Same error when doing inference when there are multiple threads involved. h2oai/h2ogpt#104

@andersonbcdefg
Copy link

Anyone solve this problem? I'm having this same issue when loading a LoRA'd checkpoint for Falcon 7B.

@minlik
Copy link

minlik commented Jul 11, 2023

same issue with llama 33b inference, any update here?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@siddhantwaghjale
Copy link

Was anyone able to solve this issue, getting this issue when fine-tuning with LORA and BitsandBytesConfig

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants