Skip to content

seems typo here? #1833

@hanlinxuy

Description

@hanlinxuy

Describe the bug

(https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/grouped_linear.py#L245)

ctx.num_gemms should be a int variable which can not be iterable?

for i in ctx.num_gemms:

if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO
for i in ctx.num_gemms:
w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
w.main_grad = main_grads[i]
weights[i] = w

Steps/Code to reproduce bug

Please list minimal steps or code snippet for us to be able to reproduce the bug.

A helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports.

Expected behavior

A clear and concise description of what you expected to happen.

Environment overview (please complete the following information)

  • Environment location: [Bare-metal, Docker, Cloud(specify cloud provider - AWS, Azure, GCP, Collab)]
  • Method of Transformer Engine install: [pip install or from source]. Please specify exact commands you used to install.
  • If method of install is [Docker], provide docker pull & docker run commands used

Environment details

If NVIDIA docker image is used you don't need to specify these.
Otherwise, please provide:

  • OS version
  • PyTorch version
  • Python version
  • Transformer Engine version
  • CUDA version
  • CUDNN version

Device details

  • GPU model

Additional context

Add any other context about the problem here.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions