Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[inductor] share more cse cache during swap buffer (pytorch#124921)
`swap_buffer` will make the `cse_cache` cannot be shared inside/outside of the lambda function scope. For example, ``` auto tmp8 = -std::numeric_limits<float>::infinity(); auto tmp9 = [&] { auto tmp12 = -std::numeric_limits<float>::infinity(); return tmp12; } ``` `tmp12` should not be created since it is same with `tmp8`. We make the `cse_cache` as a read only cache inside the scope (because it is unsafe to expose cache inside the scope,the outside scope cannot use it.) **Test Plan** ``` python test/inductor/test_torchinductor.py -k test_AllenaiLongformerBase_repro_cpu ``` the `static_cast<int>(256)` will only occur once after this PR since the inside scope can share the cse buffer outside the scope. Before this PR, ``` cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], ''' #include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h" extern "C" void kernel(const float* in_ptr0, float* out_ptr1) { #pragma omp parallel num_threads(128) { int tid = omp_get_thread_num(); { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L)) { #pragma GCC ivdep for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L)) { for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L)) { auto tmp0 = c10::convert<int>(x1); auto tmp1 = static_cast<int>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int>(x3); auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1); auto tmp6 = static_cast<int>(257); auto tmp7 = at::vec::Vectorized<int>(tmp6); auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7); auto tmp10 = at::vec::VecMask<float,1>::from(tmp2); auto tmp11 = tmp8 & tmp10; auto tmp9 = [&] { auto tmp12 = -std::numeric_limits<float>::infinity(); return tmp12; } ; auto tmp13 = [&] { if (tmp11.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>()); } } () ; auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp15 = static_cast<int>(3); auto tmp16 = tmp14 < tmp15; auto tmp18 = tmp16 & tmp2; auto tmp17 = [&] { auto tmp19 = c10::convert<int>(x3); auto tmp20 = at::vec::Vectorized<int>::arange(tmp19, 1); auto tmp21 = static_cast<int>(256); auto tmp22 = at::vec::Vectorized<int>(tmp21); auto tmp23 = at::vec::VecMask<int,1>(tmp20 >= tmp22); auto tmp25 = at::vec::VecMask<float,1>::from(tmp18); auto tmp26 = tmp23 & tmp25; auto tmp24 = [&] { auto tmp27 = tmp26.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp27; } ; auto tmp28 = [&] { if (tmp26.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp24())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp24(), tmp26.template cast<float,1>()); } } () ; auto tmp29 = static_cast<float>(0.0); auto tmp30 = at::vec::Vectorized<float>(tmp29); auto tmp31 = decltype(tmp28)::blendv(tmp30, tmp28, tmp23.template cast<float,1>()); return tmp31; } ; auto tmp32 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp33 = static_cast<float>(0.0); auto tmp34 = at::vec::VecMask<float,1>::from(tmp16); auto tmp35 = at::vec::Vectorized<float>(tmp33); auto tmp36 = decltype(tmp32)::blendv(tmp35, tmp32, tmp34.template cast<float,1>()); auto tmp37 = decltype(tmp13)::blendv(tmp36, tmp13, tmp8.template cast<float,1>()); return tmp37; } ; auto tmp38 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp39 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp40 = static_cast<int>(3); auto tmp41 = tmp39 < tmp40; auto tmp42 = [&] { auto tmp43 = c10::convert<int>(x3); auto tmp44 = at::vec::Vectorized<int>::arange(tmp43, 1); auto tmp45 = static_cast<int>(256); auto tmp46 = at::vec::Vectorized<int>(tmp45); auto tmp47 = at::vec::VecMask<int,1>(tmp44 >= tmp46); auto tmp49 = at::vec::VecMask<float,1>::from(tmp41); auto tmp50 = tmp47 & tmp49; auto tmp48 = [&] { auto tmp51 = tmp50.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp51; } ; auto tmp52 = [&] { if (tmp50.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp48())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp48(), tmp50.template cast<float,1>()); } } () ; auto tmp53 = static_cast<float>(0.0); auto tmp54 = at::vec::Vectorized<float>(tmp53); auto tmp55 = decltype(tmp52)::blendv(tmp54, tmp52, tmp47.template cast<float,1>()); return tmp55; } ; auto tmp56 = tmp41 ? tmp42() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp57 = static_cast<float>(0.0); auto tmp58 = at::vec::VecMask<float,1>::from(tmp41); auto tmp59 = at::vec::Vectorized<float>(tmp57); auto tmp60 = decltype(tmp56)::blendv(tmp59, tmp56, tmp58.template cast<float,1>()); auto tmp61 = at::vec::VecMask<float,1>::from(tmp2); auto tmp62 = decltype(tmp38)::blendv(tmp60, tmp38, tmp61.template cast<float,1>()); tmp62.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))); } #pragma omp simd simdlen(8) for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L)) { auto tmp0 = c10::convert<int64_t>(x1); auto tmp1 = static_cast<int64_t>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int64_t>(x3); auto tmp5 = static_cast<int64_t>(257); auto tmp6 = tmp4 < tmp5; auto tmp7 = [&] { auto tmp8 = -std::numeric_limits<float>::infinity(); return tmp8; } ; auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0); auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp11 = static_cast<int64_t>(3); auto tmp12 = tmp10 < tmp11; auto tmp13 = [&] { auto tmp14 = c10::convert<int64_t>(x3); auto tmp15 = static_cast<int64_t>(256); auto tmp16 = tmp14 >= tmp15; auto tmp17 = [&] { auto tmp18 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp18; } ; auto tmp19 = tmp16 ? tmp17() : static_cast<decltype(tmp17())>(0.0); auto tmp20 = static_cast<float>(0.0); auto tmp21 = tmp16 ? tmp19 : tmp20; return tmp21; } ; auto tmp22 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0); auto tmp23 = static_cast<float>(0.0); auto tmp24 = tmp12 ? tmp22 : tmp23; auto tmp25 = tmp6 ? tmp9 : tmp24; return tmp25; } ; auto tmp26 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0); auto tmp27 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp28 = static_cast<int64_t>(3); auto tmp29 = tmp27 < tmp28; auto tmp30 = [&] { auto tmp31 = c10::convert<int64_t>(x3); auto tmp32 = static_cast<int64_t>(256); auto tmp33 = tmp31 >= tmp32; auto tmp34 = [&] { auto tmp35 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp35; } ; auto tmp36 = tmp33 ? tmp34() : static_cast<decltype(tmp34())>(0.0); auto tmp37 = static_cast<float>(0.0); auto tmp38 = tmp33 ? tmp36 : tmp37; return tmp38; } ; auto tmp39 = tmp29 ? tmp30() : static_cast<decltype(tmp30())>(0.0); auto tmp40 = static_cast<float>(0.0); auto tmp41 = tmp29 ? tmp39 : tmp40; auto tmp42 = tmp2 ? tmp26 : tmp41; out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp42; } } } } } } } ''') ``` After this PR, ``` cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], ''' #include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h" extern "C" void kernel(const float* in_ptr0, float* out_ptr1) { #pragma omp parallel num_threads(128) { int tid = omp_get_thread_num(); { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L)) { #pragma GCC ivdep for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L)) { for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L)) { auto tmp0 = c10::convert<int>(x1); auto tmp1 = static_cast<int>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int>(x3); auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1); auto tmp6 = static_cast<int>(257); auto tmp7 = at::vec::Vectorized<int>(tmp6); auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7); auto tmp10 = at::vec::VecMask<float,1>::from(tmp2); auto tmp11 = tmp8 & tmp10; auto tmp9 = [&] { auto tmp12 = -std::numeric_limits<float>::infinity(); return tmp12; } ; auto tmp13 = [&] { if (tmp11.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>()); } } () ; auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp15 = static_cast<int>(3); auto tmp16 = tmp14 < tmp15; auto tmp18 = tmp16 & tmp2; auto tmp17 = [&] { auto tmp19 = at::vec::Vectorized<int>(tmp1); auto tmp20 = at::vec::VecMask<int,1>(tmp5 >= tmp19); auto tmp22 = at::vec::VecMask<float,1>::from(tmp18); auto tmp23 = tmp20 & tmp22; auto tmp21 = [&] { auto tmp24 = tmp23.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp24; } ; auto tmp25 = [&] { if (tmp23.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp21())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp21(), tmp23.template cast<float,1>()); } } () ; auto tmp26 = static_cast<float>(0.0); auto tmp27 = at::vec::Vectorized<float>(tmp26); auto tmp28 = decltype(tmp25)::blendv(tmp27, tmp25, tmp20.template cast<float,1>()); return tmp28; } ; auto tmp29 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp30 = static_cast<float>(0.0); auto tmp31 = at::vec::VecMask<float,1>::from(tmp16); auto tmp32 = at::vec::Vectorized<float>(tmp30); auto tmp33 = decltype(tmp29)::blendv(tmp32, tmp29, tmp31.template cast<float,1>()); auto tmp34 = decltype(tmp13)::blendv(tmp33, tmp13, tmp8.template cast<float,1>()); return tmp34; } ; auto tmp35 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp36 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp37 = static_cast<int>(3); auto tmp38 = tmp36 < tmp37; auto tmp39 = [&] { auto tmp40 = c10::convert<int>(x3); auto tmp41 = at::vec::Vectorized<int>::arange(tmp40, 1); auto tmp42 = at::vec::Vectorized<int>(tmp1); auto tmp43 = at::vec::VecMask<int,1>(tmp41 >= tmp42); auto tmp45 = at::vec::VecMask<float,1>::from(tmp38); auto tmp46 = tmp43 & tmp45; auto tmp44 = [&] { auto tmp47 = tmp46.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp47; } ; auto tmp48 = [&] { if (tmp46.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp44())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp44(), tmp46.template cast<float,1>()); } } () ; auto tmp49 = static_cast<float>(0.0); auto tmp50 = at::vec::Vectorized<float>(tmp49); auto tmp51 = decltype(tmp48)::blendv(tmp50, tmp48, tmp43.template cast<float,1>()); return tmp51; } ; auto tmp52 = tmp38 ? tmp39() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp53 = static_cast<float>(0.0); auto tmp54 = at::vec::VecMask<float,1>::from(tmp38); auto tmp55 = at::vec::Vectorized<float>(tmp53); auto tmp56 = decltype(tmp52)::blendv(tmp55, tmp52, tmp54.template cast<float,1>()); auto tmp57 = at::vec::VecMask<float,1>::from(tmp2); auto tmp58 = decltype(tmp35)::blendv(tmp56, tmp35, tmp57.template cast<float,1>()); tmp58.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))); } #pragma omp simd simdlen(8) for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L)) { auto tmp0 = c10::convert<int64_t>(x1); auto tmp1 = static_cast<int64_t>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int64_t>(x3); auto tmp5 = static_cast<int64_t>(257); auto tmp6 = tmp4 < tmp5; auto tmp7 = [&] { auto tmp8 = -std::numeric_limits<float>::infinity(); return tmp8; } ; auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0); auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp11 = static_cast<int64_t>(3); auto tmp12 = tmp10 < tmp11; auto tmp13 = [&] { auto tmp14 = tmp4 >= tmp1; auto tmp15 = [&] { auto tmp16 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp16; } ; auto tmp17 = tmp14 ? tmp15() : static_cast<decltype(tmp15())>(0.0); auto tmp18 = static_cast<float>(0.0); auto tmp19 = tmp14 ? tmp17 : tmp18; return tmp19; } ; auto tmp20 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0); auto tmp21 = static_cast<float>(0.0); auto tmp22 = tmp12 ? tmp20 : tmp21; auto tmp23 = tmp6 ? tmp9 : tmp22; return tmp23; } ; auto tmp24 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0); auto tmp25 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp26 = static_cast<int64_t>(3); auto tmp27 = tmp25 < tmp26; auto tmp28 = [&] { auto tmp29 = c10::convert<int64_t>(x3); auto tmp30 = tmp29 >= tmp1; auto tmp31 = [&] { auto tmp32 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp32; } ; auto tmp33 = tmp30 ? tmp31() : static_cast<decltype(tmp31())>(0.0); auto tmp34 = static_cast<float>(0.0); auto tmp35 = tmp30 ? tmp33 : tmp34; return tmp35; } ; auto tmp36 = tmp27 ? tmp28() : static_cast<decltype(tmp28())>(0.0); auto tmp37 = static_cast<float>(0.0); auto tmp38 = tmp27 ? tmp36 : tmp37; auto tmp39 = tmp2 ? tmp24 : tmp38; out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp39; } } } } } } } ''') ``` Pull Request resolved: pytorch#124921 Approved by: https://github.com/jgong5, https://github.com/jansel ghstack dependencies: pytorch#124597
- Loading branch information