- 
                Notifications
    
You must be signed in to change notification settings  - Fork 1.5k
 
Closed
Labels
Description
What is your question?
A toy example (I am a newbee and there might be some "brainless" atom choice):
  using ELM = cutlass::half_t;
  using bM = decltype(Int<128>{});
  using bN = decltype(Int<128>{});
  using bK = decltype(Int<16>{});
  TiledMMA tmma =
      make_tiled_mma(SM70_8x8x4_F32F16F16F32_NT{}, Layout<Shape<_2, _2, _2>>{},
                     Tile<_32, _32, _16>{});
  auto thr_mma = tmma.get_thread_slice(0);
  auto sA = make_tensor(make_smem_ptr((ELM *)(0)), Layout<Shape<bM, bK>>{});  // Let's assume A is somehow copied to this sA
  auto sB = make_tensor(make_smem_ptr((ELM *)(0)), Layout<Shape<bN, bK>>{});  // Let's assume the same as well
  Tensor tSrA = thr_mma.partition_fragment_A(sA);
  Tensor tSrB = thr_mma.partition_fragment_B(sB);
  Tensor acc = partition_fragment_C(tmma, Shape<bM, bN>{});
  auto cp_atom = Copy_Atom<SM75_U32x4_LDSM_N, ELM>{};
  auto smem_tiled_cp_A = make_tiled_copy_A(cp_atom, tmma);
  auto smem_thr_cp_A = smem_tiled_cp_A.get_thread_slice(0);
  Tensor tSsA = smem_thr_cp_A.partition_S(sA);
  auto smem_tiled_cp_B = make_tiled_copy_B(cp_atom, tmma);
  auto smem_thr_cp_B = smem_tiled_cp_B.get_thread_slice(0);
  Tensor tSsB = smem_thr_cp_B.partition_S(sB);
  Tensor tSrA_copy_view = smem_thr_cp_A.retile_D(tSrA);
  Tensor tSrB_copy_view = smem_thr_cp_A.retile_D(tSrB);
  printf("\n");
  cute::print(layout<>(tSrA));
  printf("\n");
  cute::print(layout<>(tSsA));
  printf("\n");
  cute::print(layout<>(tSrA_copy_view));
  printf("\n");
and stdout would give me this:
(_4,_8,_2):(_1,_4,_32)
(((_2,_4),_2),_4,_1):(((_1,_128),_1024),_32,_0)
((_8,_2),_4,_1):((_1,_32),_8,_0)
And then many examples will launch a pipeline iterating the K-mode like this:
  cute::copy(smem_tiled_cp_A, tSsA(_, _, _0{}), tSrA_copy_view(_, _, _0{}));
  cute::copy(smem_tiled_cp_B, tSsB(_, _, _0{}), tSrB_copy_view(_, _, _0{}));
  for (int i = 0; i < size<2>(tSrA); ++i) {
    if (i < size<2>(tSrA) - 1) {  // prefetch
      cute::copy(smem_tiled_copy_A, tSsA(_, _, i + 1), tSrA_copy_view(_, _, i + 1));
      cute::copy(smem_tiled_copy_B, tSsB(_, _, i + 1), tSrB_copy_view(_, _, i + 1));
    }
    cute::gemm(tmma, tSrA(_, _, i), tSrB(_, _, i), acc);
  }
The question is, the K-mode of tSsA and tSrA_copy_view is 1, but that of tSrA is 2. It seems a single copy from smem to register is sufficient for 2 gemms in this case, so isn't that tSsA(_, _, i + 1) and tSrA_copy_view(_, _, i + 1) will go out of bounds when i == 0?
Hope anyone could guide me through this, thanks!