diff --git a/topdown/cache/cache.go b/topdown/cache/cache.go index 035f890874..c46c540abd 100644 --- a/topdown/cache/cache.go +++ b/topdown/cache/cache.go @@ -138,6 +138,9 @@ func (c *cache) unsafeInsert(k ast.Value, v InterQueryCacheValue) (dropped int) } } + // By deleting the old value, if it exists, we ensure the usage variable stays correct + c.unsafeDelete(k) + c.items[k.String()] = v c.l.PushBack(k) c.usage += size diff --git a/topdown/cache/cache_test.go b/topdown/cache/cache_test.go index a7417581b4..28fba63508 100644 --- a/topdown/cache/cache_test.go +++ b/topdown/cache/cache_test.go @@ -6,6 +6,7 @@ package cache import ( "reflect" + "sync" "testing" "github.com/open-policy-agent/opa/ast" @@ -139,6 +140,71 @@ func TestInsert(t *testing.T) { if !found { t.Fatal("Expected key \"foo5\" in cache") } + + // replacing an existing key should not affect cache size + cache = NewInterQueryCache(config) + + cacheValue6 := newInterQueryCacheValue(ast.String("bar6"), 10) + cache.Insert(ast.String("foo6"), cacheValue6) + cache.Insert(ast.String("foo6"), cacheValue6) + + cacheValue7 := newInterQueryCacheValue(ast.String("bar7"), 10) + dropped = cache.Insert(ast.StringTerm("foo7").Value, cacheValue7) + + if dropped != 0 { + t.Fatal("Expected dropped to be zero") + } +} + +func TestConcurrentInsert(t *testing.T) { + in := `{"inter_query_builtin_cache": {"max_size_bytes": 20},}` // 20 byte limit for test purposes + + config, err := ParseCachingConfig([]byte(in)) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + + cache := NewInterQueryCache(config) + + cacheValue := newInterQueryCacheValue(ast.String("bar"), 10) + cache.Insert(ast.String("foo"), cacheValue) + + wg := sync.WaitGroup{} + + for i := 0; i < 5; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + + cacheValue2 := newInterQueryCacheValue(ast.String("bar2"), 5) + cache.Insert(ast.String("foo2"), cacheValue2) + + }() + } + wg.Wait() + + cacheValue3 := newInterQueryCacheValue(ast.String("bar3"), 5) + dropped := cache.Insert(ast.String("foo3"), cacheValue3) + + if dropped != 0 { + t.Fatal("Expected dropped to be zero") + } + + _, found := cache.Get(ast.String("foo")) + if !found { + t.Fatal("Expected key \"foo\" in cache") + } + + _, found = cache.Get(ast.String("foo2")) + if !found { + t.Fatal("Expected key \"foo2\" in cache") + } + + _, found = cache.Get(ast.String("foo3")) + if !found { + t.Fatal("Expected key \"foo3\" in cache") + } } func TestDelete(t *testing.T) {