Current MoEBlock initial impl follows the same A2Av flow in maxtext where for FSDP, weights are AG then quantized before groupedGEMM happens. This is wasteful as the quantization kernel has to work with fsdp times the amount of data needed, which makes it about fsdp times slower. Need to move it to happen before AG.
Current MoEBlock initial impl follows the same A2Av flow in maxtext where for FSDP, weights are AG then quantized before groupedGEMM happens. This is wasteful as the quantization kernel has to work with fsdp times the amount of data needed, which makes it about fsdp times slower. Need to move it to happen before AG.