Skip to content

Update augment branch to use kernels for zero-fill before layernorm and rmsnorm#1

Merged
cbcase merged 5 commits intov0.13-augmentfrom
v0.13-augment-zero-out
Oct 17, 2024
Merged

Update augment branch to use kernels for zero-fill before layernorm and rmsnorm#1
cbcase merged 5 commits intov0.13-augmentfrom
v0.13-augment-zero-out

Conversation

@cbcase
Copy link
Collaborator

@cbcase cbcase commented Oct 16, 2024

This PR updates the TE code to use a fill kernel (not memset) to zero-out buffers before calling the fast fp8 layernorm and rmsnorm kernels. We are making this change because the memset ops introduce gaps in cuda graph execution.

Copy link

@MarkusRabe MarkusRabe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only comment might be that it might be helpful to keep the existing code runnable, just hidden by a flag.

params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
}

const char *envval = std::getenv("NVTE_FORCE_MEMSET");

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only nit is that it would be good to document the flag. This is a fall back to the original behavior of library.

@cbcase cbcase merged commit 6bdede8 into v0.13-augment Oct 17, 2024
@cbcase cbcase deleted the v0.13-augment-zero-out branch October 17, 2024 20:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants