Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

distributed TP model forward output's requires_grad is False #115

Open
lxuechen opened this issue Aug 10, 2023 · 5 comments
Open

distributed TP model forward output's requires_grad is False #115

lxuechen opened this issue Aug 10, 2023 · 5 comments

Comments

@lxuechen
Copy link

lxuechen commented Aug 10, 2023

Hi, thanks for the nice work!

I've been trying to optimize the performance of the TP wrapper here, and the first thing that came to mind was balancing out the compute on each rank using distributed / multiprocessing (as opposed to threading).

I've been wrapping my model like the following

tp.tensor_parallel(model, distributed=True, device_ids=device_ids)

But it seems the final model output from the forward pass doesn't require grad anymore, which makes it impossible to call loss.backward().

Code below to reproduce

# torchrun --nproc-per-node 2 --master_port=1234 training/tp/standalone.py
import os

import tensor_parallel as tp
import torch
import transformers
from torch import distributed as dist


def get_local_rank():
    return int(os.getenv("LOCAL_RANK", -1))


def get_world_size():
    return int(os.getenv("WORLD_SIZE", 1))


dist.init_process_group(backend="nccl", rank=get_local_rank(), world_size=get_world_size())
pg = dist.distributed_c10d._get_default_group()
torch.cuda.set_device(get_local_rank())

current_device = torch.device(torch.cuda.current_device())
device_ids = [current_device]
print(device_ids)

model = transformers.LlamaForCausalLM.from_pretrained(
    <some_huggingface_llama_checkpoint>, torch_dtype=torch.bfloat16
)
model, _ = tp.tensor_parallel(model, distributed=True, device_ids=device_ids)
tokenizer = transformers.AutoTokenizer.from_pretrained(
    <some_huggingface_llama_checkpoint>, use_fast=True
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tensors = tokenizer(["how are you?", "I love you and you too."], return_tensors="pt", padding=True)
tensors = {
    "input_ids": tensors["input_ids"],
    "attention_mask": tensors["attention_mask"],
    "labels": tensors["input_ids"],
}

with torch.enable_grad():
    model.train()
    tensors = {k: v.to(current_device) for k, v in tensors.items()}
    outputs = model(**tensors, return_dict=True, output_hidden_states=True)
    print(outputs.logits)
    print(outputs.logits.requires_grad) # False
    print(outputs.keys())

    hs = outputs.hidden_states
    s0 = hs[0]
    print(s0)
    print(s0.requires_grad) # False

Finally, pasting some of my system configs here

  • pip dump
tensor-parallel          2.0.0
termcolor                2.3.0
tiktoken                 0.4.0
tokenizers               0.13.3
tomli                    2.0.1
tomlkit                  0.11.8
toolz                    0.12.0
torch                    2.0.1
torchvision              0.15.2
tqdm                     4.65.0
transformers             4.30.0.dev0
  • hardware: 2 A100 gpus nvlink interconnect
@lxuechen
Copy link
Author

lxuechen commented Aug 10, 2023

I can confirm that the intermediate hidden_states within the forward pass of the submodules still has requires_grad=True, so it's likely something related to the final step (communication?). @BlackSamorez

@BlackSamorez
Copy link
Owner

outputs.logits.requires_grad=True for Bloom models. The problem must with LLaMA/LLaMA-2 models specifically. I'll look into it.

@lxuechen
Copy link
Author

Made some progress, it seems the activation doesn't require_grad anymore after going through lm-head.

@lxuechen
Copy link
Author

Ok it's caused by the last lm-head being sharded and all gather of the activations along the column dimension doesn't prop gradients.

@lxuechen
Copy link
Author

lxuechen commented Aug 11, 2023

A simple workaround for now would be to not shard the lm-head; another solution would be to run the all-gather through a custom torch.autograd.Function with backward pass implemented. I'll test out the first solution for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants