Skip to content

Commit

Permalink
Simplify Job Deduplication (#3579)
Browse files Browse the repository at this point in the history
* use proper database table

Signed-off-by: Chris Martin <chris@cmartinit.co.uk>

* implement cleanup

Signed-off-by: Chris Martin <chris@cmartinit.co.uk>

* disallow long client ids

Signed-off-by: Chris Martin <chris@cmartinit.co.uk>

* improved logging

Signed-off-by: Chris Martin <chris@cmartinit.co.uk>

* lint

Signed-off-by: Chris Martin <chris@cmartinit.co.uk>

---------

Signed-off-by: Chris Martin <chris@cmartinit.co.uk>
Co-authored-by: Chris Martin <chris@cmartinit.co.uk>
  • Loading branch information
d80tb7 and Chris Martin committed May 15, 2024
1 parent fb47775 commit cb0e300
Show file tree
Hide file tree
Showing 14 changed files with 186 additions and 357 deletions.
8 changes: 7 additions & 1 deletion cmd/lookoutv2/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,13 @@ func prune(ctx *armadacontext.Context, config configuration.LookoutV2Config) {

ctxTimeout, cancel := armadacontext.WithTimeout(ctx, config.PrunerConfig.Timeout)
defer cancel()
err = pruner.PruneDb(ctxTimeout, db, config.PrunerConfig.ExpireAfter, config.PrunerConfig.BatchSize, clock.RealClock{})
err = pruner.PruneDb(
ctxTimeout,
db,
config.PrunerConfig.ExpireAfter,
config.PrunerConfig.DeduplicationExpireAfter,
config.PrunerConfig.BatchSize,
clock.RealClock{})
if err != nil {
panic(err)
}
Expand Down
1 change: 1 addition & 0 deletions config/lookoutv2/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ postgres:
sslmode: disable
prunerConfig:
expireAfter: 1008h # 42 days, 6 weeks
deduplicationExpireAfter: 168 # 7 days
timeout: 1h
batchSize: 1000
uiConfig:
Expand Down
12 changes: 1 addition & 11 deletions internal/armada/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"github.com/armadaproject/armada/internal/common/database"
grpcCommon "github.com/armadaproject/armada/internal/common/grpc"
"github.com/armadaproject/armada/internal/common/health"
"github.com/armadaproject/armada/internal/common/pgkeyvalue"
"github.com/armadaproject/armada/internal/common/pulsarutils"
"github.com/armadaproject/armada/internal/scheduler/reports"
"github.com/armadaproject/armada/internal/scheduler/schedulerobjects"
Expand Down Expand Up @@ -157,21 +156,12 @@ func Serve(ctx *armadacontext.Context, config *configuration.ArmadaConfig, healt
}
defer publisher.Close()

// KV store where we Automatically clean up keys after two weeks.
store, err := pgkeyvalue.New(ctx, dbPool, config.Pulsar.DedupTable)
if err != nil {
return err
}
services = append(services, func() error {
return store.PeriodicCleanup(ctx, time.Hour, 14*24*time.Hour)
})

submitServer := submit.NewServer(
publisher,
queueRepository,
queueCache,
config.Submission,
submit.NewDeduplicator(store),
submit.NewDeduplicator(dbPool),
authorizer)

schedulerApiConnection, err := createApiConnection(config.SchedulerApiConnection)
Expand Down
46 changes: 16 additions & 30 deletions internal/armada/submit/deduplication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import (
"testing"
"time"

"github.com/jackc/pgx/v5/pgxpool"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"

"github.com/armadaproject/armada/internal/common/armadacontext"
"github.com/armadaproject/armada/internal/common/database/lookout"
"github.com/armadaproject/armada/pkg/api"
)

Expand All @@ -17,26 +18,6 @@ type deduplicationIdsWithQueue struct {
kvs map[string]string
}

type InMemoryKeyValueStore struct {
kvs map[string][]byte
}

func (m *InMemoryKeyValueStore) Store(_ *armadacontext.Context, kvs map[string][]byte) error {
maps.Copy(m.kvs, kvs)
return nil
}

func (m *InMemoryKeyValueStore) Load(_ *armadacontext.Context, keys []string) (map[string][]byte, error) {
result := make(map[string][]byte, len(keys))
for _, k := range keys {
v, ok := m.kvs[k]
if ok {
result[k] = v
}
}
return result, nil
}

func TestDeduplicator(t *testing.T) {
tests := map[string]struct {
initialKeys []deduplicationIdsWithQueue
Expand Down Expand Up @@ -110,19 +91,24 @@ func TestDeduplicator(t *testing.T) {
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second)
deduplicator := NewDeduplicator(&InMemoryKeyValueStore{kvs: map[string][]byte{}})
err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error {
deduplicator := NewDeduplicator(db)

// Store
for _, keys := range tc.initialKeys {
err := deduplicator.StoreOriginalJobIds(ctx, keys.queue, keys.kvs)
require.NoError(t, err)
}

// Store
for _, keys := range tc.initialKeys {
err := deduplicator.StoreOriginalJobIds(ctx, keys.queue, keys.kvs)
// Fetch
keys, err := deduplicator.GetOriginalJobIds(ctx, tc.queueToFetch, tc.jobsToFetch)
require.NoError(t, err)
}

// Fetch
keys, err := deduplicator.GetOriginalJobIds(ctx, tc.queueToFetch, tc.jobsToFetch)
require.NoError(t, err)
assert.Equal(t, tc.expectedKeys, keys)

assert.Equal(t, tc.expectedKeys, keys)
return nil
})
assert.NoError(t, err)
cancel()
})
}
Expand Down
85 changes: 67 additions & 18 deletions internal/armada/submit/deduplicaton.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package submit

import (
"crypto/sha1"
"fmt"

"github.com/jackc/pgx/v5/pgxpool"
"golang.org/x/exp/maps"

"github.com/armadaproject/armada/internal/common/armadacontext"
"github.com/armadaproject/armada/internal/common/pgkeyvalue"
"github.com/armadaproject/armada/pkg/api"
)

Expand All @@ -19,55 +18,105 @@ type Deduplicator interface {

// PostgresDeduplicator is an implementation of a Deduplicator that uses a pgkeyvalue.KeyValueStore as its state store
type PostgresDeduplicator struct {
kvStore pgkeyvalue.KeyValueStore
db *pgxpool.Pool
}

func NewDeduplicator(kvStore pgkeyvalue.KeyValueStore) *PostgresDeduplicator {
return &PostgresDeduplicator{kvStore: kvStore}
func NewDeduplicator(db *pgxpool.Pool) *PostgresDeduplicator {
return &PostgresDeduplicator{db: db}
}

func (s *PostgresDeduplicator) GetOriginalJobIds(ctx *armadacontext.Context, queue string, jobRequests []*api.JobSubmitRequestItem) (map[string]string, error) {
// Armada checks for duplicate job submissions if a ClientId (i.e. a deduplication id) is provided.
// Deduplication is based on storing the combined hash of the ClientId and queue. For storage efficiency,
// we store hashes instead of user-provided strings.
kvs := make(map[string][]byte, len(jobRequests))
kvs := make(map[string]string, len(jobRequests))
for _, req := range jobRequests {
if req.ClientId != "" {
kvs[s.jobKey(queue, req.ClientId)] = []byte(req.ClientId)
kvs[s.jobKey(queue, req.ClientId)] = req.ClientId
}
}

duplicates := make(map[string]string)
// If we have any client Ids, retrieve their job ids
if len(kvs) > 0 {
keys := maps.Keys(kvs)
existingKvs, err := s.kvStore.Load(ctx, keys)
existingKvs, err := s.loadMappings(ctx, keys)
if err != nil {
return nil, err
}
for k, v := range kvs {
originalJobId, ok := existingKvs[k]
if ok {
duplicates[string(v)] = string(originalJobId)
duplicates[v] = originalJobId
}
}
}
return duplicates, nil
}

func (s *PostgresDeduplicator) StoreOriginalJobIds(ctx *armadacontext.Context, queue string, mappings map[string]string) error {
if s.kvStore == nil || len(mappings) == 0 {
if len(mappings) == 0 {
return nil
}
kvs := make(map[string][]byte, len(mappings))
kvs := make(map[string]string, len(mappings))
for k, v := range mappings {
kvs[s.jobKey(queue, k)] = []byte(v)
kvs[s.jobKey(queue, k)] = v
}
return s.kvStore.Store(ctx, kvs)
return s.storeMappings(ctx, kvs)
}

func (s *PostgresDeduplicator) jobKey(queue, clientId string) string {
combined := fmt.Sprintf("%s:%s", queue, clientId)
h := sha1.Sum([]byte(combined))
return fmt.Sprintf("%x", h)
return fmt.Sprintf("%s:%s", queue, clientId)
}

func (s *PostgresDeduplicator) storeMappings(ctx *armadacontext.Context, mappings map[string]string) error {
deduplicationIDs := make([]string, 0, len(mappings))
jobIDs := make([]string, 0, len(mappings))

for deduplicationID, jobID := range mappings {
deduplicationIDs = append(deduplicationIDs, deduplicationID)
jobIDs = append(jobIDs, jobID)
}

sql := `
INSERT INTO job_deduplication (deduplication_id, job_id)
SELECT unnest($1::text[]), unnest($2::text[])
ON CONFLICT (deduplication_id) DO NOTHING
`
_, err := s.db.Exec(ctx, sql, deduplicationIDs, jobIDs)
if err != nil {
return err
}

return nil
}

func (s *PostgresDeduplicator) loadMappings(ctx *armadacontext.Context, keys []string) (map[string]string, error) {
// Prepare the output map
result := make(map[string]string)

sql := `
SELECT deduplication_id, job_id
FROM job_deduplication
WHERE deduplication_id = ANY($1)
`

rows, err := s.db.Query(ctx, sql, keys)
if err != nil {
return nil, err
}
defer rows.Close()

// Iterate through the result rows
for rows.Next() {
var deduplicationID, jobID string
if err := rows.Scan(&deduplicationID, &jobID); err != nil {
return nil, err
}
result[deduplicationID] = jobID
}

if err := rows.Err(); err != nil {
return nil, err
}

return result, nil
}
12 changes: 11 additions & 1 deletion internal/armada/submit/validation/submit_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ var (
validateTerminationGracePeriod,
validateIngresses,
validatePorts,
validateClientId,
}
)

// ValidateSubmitRequest ensures that the incoming api.JobSubmitRequest is well-formed. It achieves this
// by applying a series of validators that each check a single aspect of the request. Validators may
// chose to validate the whole obSubmitRequest or just a single JobSubmitRequestItem.
// choose to validate the whole obSubmitRequest or just a single JobSubmitRequestItem.
// This function will return the error from the first validator that fails, or nil if all validators pass.
func ValidateSubmitRequest(req *api.JobSubmitRequest, config configuration.SubmissionConfig) error {
for _, validationFunc := range requestValidators {
Expand Down Expand Up @@ -179,6 +180,15 @@ func validateAffinity(j *api.JobSubmitRequestItem, _ configuration.SubmissionCon
return nil
}

// Ensures that if a request specifies a ClientId, that clientID is not too long
func validateClientId(j *api.JobSubmitRequestItem, _ configuration.SubmissionConfig) error {
const maxClientIdChars = 100
if len(j.GetClientId()) > maxClientIdChars {
return fmt.Errorf("client id of length %d is greater than max allowed length of %d", len(j.ClientId), maxClientIdChars)
}
return nil
}

// Ensures that if a request specifies a PriorityClass, that priority class is supported by Armada.
func validatePriorityClasses(j *api.JobSubmitRequestItem, config configuration.SubmissionConfig) error {
spec := j.GetMainPodSpec()
Expand Down
35 changes: 35 additions & 0 deletions internal/armada/submit/validation/submit_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package validation

import (
"strconv"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -727,6 +728,40 @@ func TestValidatePriorityClasses(t *testing.T) {
}
}

func TestValidateClientId(t *testing.T) {
tests := map[string]struct {
req *api.JobSubmitRequestItem
expectSuccess bool
}{
"no client id": {
req: &api.JobSubmitRequestItem{},
expectSuccess: true,
},
"client id of 100 chars is fine": {
req: &api.JobSubmitRequestItem{
ClientId: strings.Repeat("a", 100),
},
expectSuccess: true,
},
"client id over 100 chars is forbidden": {
req: &api.JobSubmitRequestItem{
ClientId: strings.Repeat("a", 101),
},
expectSuccess: false,
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
err := validateClientId(tc.req, configuration.SubmissionConfig{})
if tc.expectSuccess {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
})
}
}

func TestValidateQueue(t *testing.T) {
tests := map[string]struct {
req *api.JobSubmitRequest
Expand Down
Loading

0 comments on commit cb0e300

Please sign in to comment.