Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] 23_ampere_gemm_operand_reduction_fusion wrong results when LayoutOutput=RowMajor #674

Closed
danthe3rd opened this issue Oct 24, 2022 · 3 comments
Labels
bug Something isn't working

Comments

@danthe3rd
Copy link
Contributor

danthe3rd commented Oct 24, 2022

Repro
(1) Change this line in example 23:

using LayoutOutput = cutlass::layout::ColumnMajor;

with:

using LayoutOutput = cutlass::layout::RowMajor;

(2) Test with

$ make 23_ampere_gemm_operand_reduction_fusion && ./examples/23_ampere_gemm_operand_reduction_fusion/23_ampere_gemm_operand_reduction_fusion --perf-check --ref-check --m=2736 --n=1536 --k=9456
[100%] Building CUDA object examples/23_ampere_gemm_operand_reduction_fusion/CMakeFiles/23_ampere_gemm_operand_reduction_fusion.dir/ampere_gemm_operand_reduction_fusion.cu.o
[100%] Linking CUDA executable 23_ampere_gemm_operand_reduction_fusion
[100%] Built target 23_ampere_gemm_operand_reduction_fusion
ERROR - results miscompared.
ID,M,N,K,SplitK-Slices,Parallel-SplitK,Runtime
gemm_1,2736,1536,9456,1,0,0.371661

More info
There are some typos I think in the templates, like here, where ReduceKForA_ is passed as the PartitionsK argument

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaWithReductionTensorOp<
WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
ElementC, LayoutC, Operator, ReduceKForA_, WarpCount::kK>::Type;

template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename WarpShape_,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape_,
/// Data type of A elements
typename ElementA,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA,
/// Data type of B elements
typename ElementB,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB,
/// Element type of C matrix
typename ElementC,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC,
/// Operator describing the tensor operation
typename Operator_ = arch::OpMultiplyAdd,
/// Number of partitions along K dimension
int PartitionsK = 1,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor = false>
struct DefaultMmaWithReductionTensorOp {

From my runs, it looks like it's actually reducing the operand B instead of A
setup
I'm on 4db6a61

@danthe3rd danthe3rd added ? - Needs Triage bug Something isn't working labels Oct 24, 2022
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Oct 28, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3
    op=xops.SwiGLUPackedFusedOp)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Oct 28, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3
    op=xops.SwiGLUPackedFusedOp)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Oct 28, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3
    op=xops.SwiGLUPackedFusedOp)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Oct 28, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3
    op=xops.SwiGLUPackedFusedOp)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Oct 28, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3
    op=xops.SwiGLUPackedFusedOp)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Oct 28, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3
    op=xops.SwiGLUPackedFusedOp)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Oct 28, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3
    op=xops.SwiGLUPackedFusedOp)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Oct 28, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3
    op=xops.SwiGLUPackedFusedOp)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Oct 28, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3
    op=xops.SwiGLUPackedFusedOp)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Oct 28, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3
    op=xops.SwiGLUPackedFusedOp)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Oct 31, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3
    op=xops.SwiGLUPackedFusedOp)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Oct 31, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3
    op=xops.SwiGLUPackedFusedOp)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Oct 31, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Oct 31, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
@hwu36
Copy link
Collaborator

hwu36 commented Oct 31, 2022

please check #682

danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Nov 3, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Nov 3, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Nov 4, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Nov 4, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Nov 4, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Nov 4, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Nov 4, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Nov 4, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
@mnicely
Copy link
Collaborator

mnicely commented Nov 6, 2022

@danthe3rd did you have any luck resolving your issues?

@danthe3rd
Copy link
Contributor Author

It's working now with @hwu36 's PR! Thanks a lot - closing the issue

danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Nov 7, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Nov 7, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Nov 10, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
danthe3rd pushed a commit to facebookresearch/xformers that referenced this issue Nov 10, 2022
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants