diff --git a/pathwaysutils/lru_cache.py b/pathwaysutils/lru_cache.py new file mode 100644 index 0000000..ec2b5a4 --- /dev/null +++ b/pathwaysutils/lru_cache.py @@ -0,0 +1,44 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""An LRU cache that will be cleared when JAX clears its internal cache.""" + +import functools +from typing import Any, Callable + +import jax.extend + + +def lru_cache( + maxsize: int = 4096, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """An LRU cache that will be cleared when JAX clears its internal cache. + + Args: + maxsize: The maximum number of entries to keep in the cache. When this limit + is reached, the least recently used entry will be evicted. + + Returns: + A function that can be used to decorate a function to cache its results. + """ + + def wrap(f): + cached = functools.lru_cache(maxsize=maxsize)(f) + wrapper = functools.wraps(f)(cached) + + wrapper.cache_clear = cached.cache_clear + wrapper.cache_info = cached.cache_info + jax.extend.backend.add_clear_backends_callback(wrapper.cache_clear) + return wrapper + + return wrap diff --git a/pathwaysutils/test/lru_cache_test.py b/pathwaysutils/test/lru_cache_test.py new file mode 100644 index 0000000..694498d --- /dev/null +++ b/pathwaysutils/test/lru_cache_test.py @@ -0,0 +1,85 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax.extend +from pathwaysutils import lru_cache +from absl.testing import absltest + + +class LruCacheTest(absltest.TestCase): + + def test_cache_hits(self): + x = [100] + + @lru_cache.lru_cache(maxsize=1) + def f(i): + x[i] += 1 + return x[i] + + self.assertEqual(f(0), 101) # Miss + self.assertEqual(f(0), 101) # Hit + + def test_cache_hits_and_misses_by_arguments(self): + x = [100, 200] + + @lru_cache.lru_cache(maxsize=2) + def f(i): + x[i] += 1 + return x[i] + + self.assertEqual(f(0), 101) # Miss + self.assertEqual(f(0), 101) # Hit + + self.assertEqual(f(1), 201) # Miss + self.assertEqual(f(1), 201) # Hit + + self.assertEqual(f(0), 101) # Hit + self.assertEqual(f(0), 101) # Hit + + def test_cache_lru_eviction(self): + x = [100, 200] + + @lru_cache.lru_cache(maxsize=1) + def f(i): + x[i] += 1 + return x[i] + + self.assertEqual(f(0), 101) # Miss + self.assertEqual(f(0), 101) # Hit + + self.assertEqual(f(1), 201) # Miss + self.assertEqual(f(1), 201) # Hit + + self.assertEqual(f(0), 102) # Miss + self.assertEqual(f(0), 102) # Hit + + def test_clear_cache_via_jax_clear_backend_cache(self): + x = [100] + + @lru_cache.lru_cache(maxsize=1) + def f(i): + x[i] += 1 + return x[i] + + self.assertEqual(f(0), 101) # Miss + self.assertEqual(f(0), 101) # Hit + + jax.extend.backend.clear_backends() + + self.assertEqual(f(0), 102) # Miss + self.assertEqual(f(0), 102) # Hit + + +if __name__ == "__main__": + absltest.main()