Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

thunder.jit generated fusion for layer norm backward is segmented because : unsupported post reduction normalization #2046

Closed
liqiangxl opened this issue Apr 5, 2024 · 0 comments

Comments

@liqiangxl
Copy link
Collaborator

liqiangxl commented Apr 5, 2024

The following script using thunder.jit is supposed to run layer norm backward using the innerOuter scheduler in nvFuser. But the generated fusion is rejected with msg: # Scheduler _inner_outer_persistent_ ***rejected*** because : unsupported post reduction normalization

import torch
import thunder

class MyFusion(torch.nn.Module):
    def __init__(self):
        super(MyFusion, self).__init__()
        self.layer_norm = torch.nn.LayerNorm(8192)

    def forward(self, x):
        out3 = self.layer_norm(x)
        return out3
    
inputs = [
        torch.randn(2048, 8192, device='cuda', requires_grad=True),
        ]
grads = [
        torch.randn(2048, 8192, device='cuda'),
        ]

tj_model = thunder.jit(MyFusion().cuda())
out = tj_model(*inputs)
back = out.backward(*grads)

Further check found the msg comes from SchedulerTopologyChecker::hasNonNormalizePostReductionBCast

Found unresolved broadcast after reduction: T31_l[ iS60{2048}, bS61{1 ex 8192} ] -> T40_l[ iS78{2048}, iS107{i1} ]
T40_l[ iS78{2048}, iS107{i1} ]
   = T31_l[ iS60{2048}, bS61{1 ex 8192} ]
   + T39_l[ iS76{2048}, iS106{i1} ];

T31 has a post reduction broadcast dim bS61. The check wants to ensure bS61 is same to the extent of the reduction dim by checking it is mapped with a dim goes to reduction. The related fusion ir is:

T19_l[ iS38{2048}, rS39{i1} ]
   = reduction( T18_l[ iS36{2048}, iS37{i1} ], op = add, initial value = float(0), allreduce = false )
T20_l[ iS40{2048}, bS41{1} ]
   = broadcast( T19_l[ iS38{2048}, rS39{i1} ] )
T21_l[ iS42{2048}, bS43{1} ]
   = Set( T20_l[ iS40{2048}, bS41{1} ], cache_op=Streaming )
T25_l[ iS50{2048} ]
   = squeeze( T21_l[ iS42{2048}, bS43{1} ] )
T27_l[ iS52{2048}, bS53{1} ]
   = broadcast( T25_l[ iS50{2048} ] )
T28_l[ iS54{2048}, bS55{1} ]
   = Set( T27_l[ iS52{2048}, bS53{1} ], cache_op=Streaming )
T29_l[ iS56{2048}, bS57{1} ]
   = Set( T28_l[ iS54{2048}, bS55{1} ], cache_op=Streaming )
T30_l[ iS58{2048}, bS59{1 ex 8192} ] = expand( T29_l[ iS56{2048}, bS57{1} ], {2048, 8192} )
T31_l[ iS60{2048}, bS61{1 ex 8192} ]
   = double(0.0001220703125)
   * T30_l[ iS58{2048}, bS59{1 ex 8192} ];
T40_l[ iS78{2048}, iS107{i1} ]
   = T31_l[ iS60{2048}, bS61{1 ex 8192} ]
   + T39_l[ iS76{2048}, iS106{i1} ];

It uses the backward dep chain 40 -> 31 -> 30 -> 29 -> 28 -> 27 ,and the mapped IDs are {iS107{i1}, bS61, bS61, bS57, bS55, bS53}. At T27 --> T25, no id is mapped any more and the chain breaks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
1 participant