Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup float cosine vectors, use FMA where fast and available to reduce error #12731

Merged
merged 3 commits into from
Oct 30, 2023

Conversation

rmuir
Copy link
Member

@rmuir rmuir commented Oct 28, 2023

The intel fma is nice, and its easier to reason about when looking at assembly. We basically reduce the error for free where its available. Along with another change (reducing the unrolling for cosine, since it has 3 fma ops already), we can speed up cosine from 6 -> 8 uops/us.

On the arm the fma leads to slight slowdowns, so we don't use it. Its not much, just something like 10%, but seems like the wrong tradeoff.

If you run the code with -XX-UseFMA there's no slowdown, but no speedup either. And obviously, no changes for ARM here.

Skylake AVX-256

Main:
Benchmark                                  (size)   Mode  Cnt   Score   Error   Units
VectorUtilBenchmark.floatCosineScalar        1024  thrpt    5   0.624 ± 0.041  ops/us
VectorUtilBenchmark.floatCosineVector        1024  thrpt    5   5.988 ± 0.111  ops/us
VectorUtilBenchmark.floatDotProductScalar    1024  thrpt    5   1.959 ± 0.032  ops/us
VectorUtilBenchmark.floatDotProductVector    1024  thrpt    5  12.058 ± 0.920  ops/us
VectorUtilBenchmark.floatSquareScalar        1024  thrpt    5   1.422 ± 0.018  ops/us
VectorUtilBenchmark.floatSquareVector        1024  thrpt    5   9.837 ± 0.154  ops/us

Patch:
Benchmark                                  (size)   Mode  Cnt   Score   Error   Units
VectorUtilBenchmark.floatCosineScalar        1024  thrpt    5   0.638 ± 0.006  ops/us
VectorUtilBenchmark.floatCosineVector        1024  thrpt    5   8.164 ± 0.084  ops/us
VectorUtilBenchmark.floatDotProductScalar    1024  thrpt    5   1.997 ± 0.027  ops/us
VectorUtilBenchmark.floatDotProductVector    1024  thrpt    5  12.486 ± 0.163  ops/us
VectorUtilBenchmark.floatSquareScalar        1024  thrpt    5   1.445 ± 0.014  ops/us
VectorUtilBenchmark.floatSquareVector        1024  thrpt    5  11.682 ± 0.129  ops/us

Patch (with -jvmArgsAppend '-XX:-UseFMA'):
Benchmark                                  (size)   Mode  Cnt   Score   Error   Units
VectorUtilBenchmark.floatCosineScalar        1024  thrpt    5   0.641 ± 0.005  ops/us
VectorUtilBenchmark.floatCosineVector        1024  thrpt    5   6.102 ± 0.053  ops/us
VectorUtilBenchmark.floatDotProductScalar    1024  thrpt    5   1.997 ± 0.007  ops/us
VectorUtilBenchmark.floatDotProductVector    1024  thrpt    5  12.177 ± 0.170  ops/us
VectorUtilBenchmark.floatSquareScalar        1024  thrpt    5   1.450 ± 0.027  ops/us
VectorUtilBenchmark.floatSquareVector        1024  thrpt    5  10.464 ± 0.154  ops/us

Copy link
Contributor

@uschindler uschindler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small changes to logging requested. Posisbly also change the comment for module system, also at the ramusage class where it was borrowed from.

final Class<?> beanClazz = Class.forName(HOTSPOT_BEAN_CLASS);
// we use reflection for this, because the management factory is not part
// of Java 8's compact profile:
final Object hotSpotBean =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, I know this code from the RamUsage code parts. The comment should possibly be updated at both places to mention module system and that it is optional there.

This module is declared optional in our module-info:

requires static jdk.management; // this is optional but explicit declaration is recommended

So basically this code is fine, we do not want to hardcode the module (as it is not part of the JDK platform standard). Maybe we should add the "FMA enabled" also to the logger message. Should be easy by making the flag pkg private and refer to it from the initialization code where the log message is printed:

log.info(
String.format(
Locale.ENGLISH,
"Java vector incubator API enabled; uses preferredBitSize=%d%s",
PanamaVectorUtilSupport.VECTOR_BITSIZE,
PanamaVectorUtilSupport.HAS_FAST_INTEGER_VECTORS
? ""
: "; floating-point vectors only"));

Let's add the same with PanamaVectorUtilSupport.HAS_FAST_FMA ? "; FMA enabled" : ""

We should maybe add this code to some common class in utils package (like Constants#getVMOption(String name)). We can create a separate PR for that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it works with the module system at least: i tested it. If we want to move this code around i am fine, as long as i have static final constant.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

surely we can leave comments about the module system to another issue. It was somehow ok for ramusageestimator to do this, but not ok for vectors code?

Honestly, i havent a clue about the module system (nor a care) and no idea how it works or what 'requires static' means or any of that. To me, it looks like more overengineered java garbage (sorry). So I'm ill-equipped to be updating comments inside RAMUsageEstimator. I just want to try to improve the vectorization here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, thats fine thank for confirming that it works. The "require static" allows access to the module (if available). So it looks like the code works fine.

I think my only change I'd suggested is to add the FMA enablement to the logging message as stated above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add the same with PanamaVectorUtilSupport.HAS_FAST_FMA ? "; FMA enabled" : ""

++

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed the logging change.

@ChrisHegarty
Copy link
Contributor

ha! So just removing the overly aggressive unrolling in cosine improves things. The check on FMA is nice - I had similar thoughts ( you just beat me to it! ), and it inlines nicely. I also agree, we don't wanna use FMA on ARM, it performs 10-15% worse on my M2.

Sanity results from my Rocket Lake:

main:

VectorUtilBenchmark.floatCosineScalar        1024  thrpt    5   0.845 ± 0.001  ops/us
VectorUtilBenchmark.floatCosineVector        1024  thrpt    5   8.885 ± 0.005  ops/us
VectorUtilBenchmark.floatDotProductScalar    1024  thrpt    5   3.406 ± 0.018  ops/us
VectorUtilBenchmark.floatDotProductVector    1024  thrpt    5  26.168 ± 0.009  ops/us
VectorUtilBenchmark.floatSquareScalar        1024  thrpt    5   2.549 ± 0.005  ops/us
VectorUtilBenchmark.floatSquareVector        1024  thrpt    5  19.283 ± 0.001  ops/us

Robert's branch:

VectorUtilBenchmark.floatCosineScalar        1024  thrpt    5   0.845 ± 0.003  ops/us
VectorUtilBenchmark.floatCosineVector        1024  thrpt    5  14.636 ± 0.016  ops/us
VectorUtilBenchmark.floatDotProductScalar    1024  thrpt    5   3.400 ± 0.083  ops/us
VectorUtilBenchmark.floatDotProductVector    1024  thrpt    5  27.265 ± 0.065  ops/us
VectorUtilBenchmark.floatSquareScalar        1024  thrpt    5   2.548 ± 0.012  ops/us
VectorUtilBenchmark.floatSquareVector        1024  thrpt    5  25.529 ± 0.207  ops/us

@ChrisHegarty
Copy link
Contributor

.. and yes (I've not forgotten), we need something like a java.lang.Architecture/Platform, that is queryable for such low-level support (rather than resorting to beans - which actually works kinda ok, but is not ideal)

@uschindler uschindler added this to the 9.9.0 milestone Oct 29, 2023
@rmuir
Copy link
Member Author

rmuir commented Oct 30, 2023

ha! So just removing the overly aggressive unrolling in cosine improves things.

well, only in combination with switch to FMA. seems then its able to keep cpu busy multiplying.

@rmuir
Copy link
Member Author

rmuir commented Oct 30, 2023

.. and yes (I've not forgotten), we need something like a java.lang.Architecture/Platform, that is queryable for such low-level support (rather than resorting to beans - which actually works kinda ok, but is not ideal)

and compiler should be be fixed to unroll basic loops to take advantage of the fact you can do 4 of these things in parallel on modern cpus.

or failing that, if i'm gonna have to unroll loops myself, then at least give me some basic info (e.g. cpu model) so i can do it properly.

Currently it is the worst of both worlds.

@rmuir
Copy link
Member Author

rmuir commented Oct 30, 2023

Last time i tried to figure out WTF was happening here, I think i determined that floating point reproducibility was still preventing this from happening? That there isn't like a "bail out" from this on the vector API, instead just some clever wording in the javadocs of reduceLanes

Which is really sad, how is the vector API supposed to be usable if everyone has to unroll their own loops in order to use 100% of the hardware instead of 25%.

@uschindler
Copy link
Contributor

Last time i tried to figure out WTF was happening here, I think i determined that floating point reproducibility was still preventing this from happening? That there isn't like a "bail out" from this on the vector API, instead just some clever wording in the javadocs of reduceLanes

Which is really sad, how is the vector API supposed to be usable if everyone has to unroll their own loops in order to use 100% of the hardware instead of 25%.

The float use case is problematic becaue order of multiplications/sums changes the result. So you can't easily rewrite the stuff to run in parallel as the result would be different. This is also the reason why the auto-vectorizer can't do anything

I think the Panama API should allow the user to figure out how many parallel units are available to somehow dynamically split work correctly.

@rmuir
Copy link
Member Author

rmuir commented Oct 30, 2023

I think the Panama API should allow the user to figure out how many parallel units are available to somehow dynamically split work correctly.

I'm not even sure openjdk/hotspot knows this or even attempts to approximate it? It never deals with -ffast-math-style optimizations that would make use of it, due to its floating point restrictions, right?

but knowing the CPU info/model would help. then at least folks can at least do it themselves.

@asfgit asfgit merged commit e292a5f into apache:main Oct 30, 2023
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants