diff --git a/apex/contrib/openfold_triton/layer_norm.py b/apex/contrib/openfold_triton/layer_norm.py index 7d9c6242a..881137eca 100644 --- a/apex/contrib/openfold_triton/layer_norm.py +++ b/apex/contrib/openfold_triton/layer_norm.py @@ -46,7 +46,9 @@ def forward(ctx, inputs, normalized_shape, weight, bias, eps=1e-05): x_mean = torch.empty(M, dtype=torch.float32, device=inputs.device) y = torch.empty(inputs.shape, dtype=inputs.dtype, device=inputs.device) - grid = lambda kwargs: (triton.cdiv(kwargs["M"], kwargs["M_BLOCK"]),) + def grid(kwargs): + return (triton.cdiv(kwargs["M"], kwargs["M_BLOCK"]),) + if inputs.is_contiguous(): _layer_norm_forward[grid]( x_ptr=inputs, @@ -96,7 +98,9 @@ def backward(ctx, d_y): # %% Separated kernels, similar to Inductor. # 1. dX. - grid = lambda kwargs: (triton.cdiv(kwargs["M"], kwargs["M_BLOCK"]),) + def grid(kwargs): + return (triton.cdiv(kwargs["M"], kwargs["M_BLOCK"]),) + if inputs.is_contiguous(): _layer_norm_backward_dx[grid]( dy_ptr=d_y, @@ -134,10 +138,13 @@ def backward(ctx, d_y): M_BUFSIZE = _M_BUFSIZE_CACHE.get(key, triton.cdiv(M, PARTIAL_REDUCE_MIN)) dw_partial_buf = torch.empty([N, M_BUFSIZE], dtype=torch.float32, device=d_y.device) db_partial_buf = torch.empty([N, M_BUFSIZE], dtype=torch.float32, device=d_y.device) - grid = lambda kwargs: ( - triton.cdiv(M, kwargs["M_PARTIAL_REDUCE"]), - triton.cdiv(N, kwargs["N_BLOCK"]), - ) + + def grid(kwargs): + return ( + triton.cdiv(M, kwargs["M_PARTIAL_REDUCE"]), + triton.cdiv(N, kwargs["N_BLOCK"]), + ) + if inputs.is_contiguous(): _layer_norm_backward_dw_db_partial[grid]( dy_ptr=d_y, diff --git a/apex/contrib/openfold_triton/mha.py b/apex/contrib/openfold_triton/mha.py index 9065b6ca8..e19df31e9 100644 --- a/apex/contrib/openfold_triton/mha.py +++ b/apex/contrib/openfold_triton/mha.py @@ -158,7 +158,10 @@ def forward(ctx, q, k, v, mask=None, bias=None, inf=1000000000.0, is_training=Tr o = torch.empty_like(q) Z, H, N_CTX, H_DIM = q.shape - grid = lambda META: (triton.cdiv(N_CTX, META["BLOCK_M"]), Z * H) + + def grid(META): + return (triton.cdiv(N_CTX, META["BLOCK_M"]), Z * H) + l = torch.empty( (q.shape[-4], q.shape[-3], q.shape[-2]), device=q.device, @@ -305,11 +308,14 @@ def backward(ctx, do): # BLOCK_M, BLOCK_N = 128, 64 BLOCK_M, BLOCK_N, num_warps, num_stages = schedule_triton_mha(list(q.shape), fwd=False) + # grid = lambda META: (triton.cdiv(N_CTX, META["BLOCK_N"]), Z * H) # grid = lambda META: (Z * H, triton.cdiv(N_CTX, META["BLOCK_N"])) # grid = lambda META: (triton.cdiv(N_CTX, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, # Z * H) - grid = lambda META: (Z * H,) + def grid(META): + return (Z * H,) + _bwd_kernel[grid]( q, k, diff --git a/examples/imagenet/main_amp.py b/examples/imagenet/main_amp.py index 384e5bda9..c12591281 100644 --- a/examples/imagenet/main_amp.py +++ b/examples/imagenet/main_amp.py @@ -205,7 +205,8 @@ def resume(): train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) - collate_fn = lambda b: fast_collate(b, memory_format) + def collate_fn(b): + return fast_collate(b, memory_format) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), diff --git a/pyproject.toml b/pyproject.toml index 952f4df66..c041d414b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,6 @@ build-backend = "setuptools.build_meta" line-length = 100 ignore = [ # Sorted by occurrence count (ascending) - easier to fix first - "E731", # lambda assignment (6 occurrences) "E721", # type comparison should use isinstance (8 occurrences) "E741", # ambiguous variable name (8 occurrences) "E712", # comparison to True/False (9 occurrences)