Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Job Deduplication #3579

Merged
merged 8 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion cmd/lookoutv2/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,13 @@ func prune(ctx *armadacontext.Context, config configuration.LookoutV2Config) {

ctxTimeout, cancel := armadacontext.WithTimeout(ctx, config.PrunerConfig.Timeout)
defer cancel()
err = pruner.PruneDb(ctxTimeout, db, config.PrunerConfig.ExpireAfter, config.PrunerConfig.BatchSize, clock.RealClock{})
err = pruner.PruneDb(
ctxTimeout,
db,
config.PrunerConfig.ExpireAfter,
config.PrunerConfig.DeduplicationExpireAfter,
config.PrunerConfig.BatchSize,
clock.RealClock{})
if err != nil {
panic(err)
}
Expand Down
1 change: 1 addition & 0 deletions config/lookoutv2/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ postgres:
sslmode: disable
prunerConfig:
expireAfter: 1008h # 42 days, 6 weeks
deduplicationExpireAfter: 168 # 7 days
timeout: 1h
batchSize: 1000
uiConfig:
Expand Down
12 changes: 1 addition & 11 deletions internal/armada/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"github.com/armadaproject/armada/internal/common/database"
grpcCommon "github.com/armadaproject/armada/internal/common/grpc"
"github.com/armadaproject/armada/internal/common/health"
"github.com/armadaproject/armada/internal/common/pgkeyvalue"
"github.com/armadaproject/armada/internal/common/pulsarutils"
"github.com/armadaproject/armada/internal/scheduler/reports"
"github.com/armadaproject/armada/internal/scheduler/schedulerobjects"
Expand Down Expand Up @@ -157,21 +156,12 @@ func Serve(ctx *armadacontext.Context, config *configuration.ArmadaConfig, healt
}
defer publisher.Close()

// KV store where we Automatically clean up keys after two weeks.
store, err := pgkeyvalue.New(ctx, dbPool, config.Pulsar.DedupTable)
if err != nil {
return err
}
services = append(services, func() error {
return store.PeriodicCleanup(ctx, time.Hour, 14*24*time.Hour)
})

submitServer := submit.NewServer(
publisher,
queueRepository,
queueCache,
config.Submission,
submit.NewDeduplicator(store),
submit.NewDeduplicator(dbPool),
authorizer)

schedulerApiConnection, err := createApiConnection(config.SchedulerApiConnection)
Expand Down
46 changes: 16 additions & 30 deletions internal/armada/submit/deduplication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import (
"testing"
"time"

"github.com/jackc/pgx/v5/pgxpool"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"

"github.com/armadaproject/armada/internal/common/armadacontext"
"github.com/armadaproject/armada/internal/common/database/lookout"
"github.com/armadaproject/armada/pkg/api"
)

Expand All @@ -17,26 +18,6 @@ type deduplicationIdsWithQueue struct {
kvs map[string]string
}

type InMemoryKeyValueStore struct {
kvs map[string][]byte
}

func (m *InMemoryKeyValueStore) Store(_ *armadacontext.Context, kvs map[string][]byte) error {
maps.Copy(m.kvs, kvs)
return nil
}

func (m *InMemoryKeyValueStore) Load(_ *armadacontext.Context, keys []string) (map[string][]byte, error) {
result := make(map[string][]byte, len(keys))
for _, k := range keys {
v, ok := m.kvs[k]
if ok {
result[k] = v
}
}
return result, nil
}

func TestDeduplicator(t *testing.T) {
tests := map[string]struct {
initialKeys []deduplicationIdsWithQueue
Expand Down Expand Up @@ -110,19 +91,24 @@ func TestDeduplicator(t *testing.T) {
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second)
deduplicator := NewDeduplicator(&InMemoryKeyValueStore{kvs: map[string][]byte{}})
err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error {
deduplicator := NewDeduplicator(db)

// Store
for _, keys := range tc.initialKeys {
err := deduplicator.StoreOriginalJobIds(ctx, keys.queue, keys.kvs)
require.NoError(t, err)
}

// Store
for _, keys := range tc.initialKeys {
err := deduplicator.StoreOriginalJobIds(ctx, keys.queue, keys.kvs)
// Fetch
keys, err := deduplicator.GetOriginalJobIds(ctx, tc.queueToFetch, tc.jobsToFetch)
require.NoError(t, err)
}

// Fetch
keys, err := deduplicator.GetOriginalJobIds(ctx, tc.queueToFetch, tc.jobsToFetch)
require.NoError(t, err)
assert.Equal(t, tc.expectedKeys, keys)

assert.Equal(t, tc.expectedKeys, keys)
return nil
})
assert.NoError(t, err)
cancel()
})
}
Expand Down
85 changes: 67 additions & 18 deletions internal/armada/submit/deduplicaton.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package submit

import (
"crypto/sha1"
"fmt"

"github.com/jackc/pgx/v5/pgxpool"
"golang.org/x/exp/maps"

"github.com/armadaproject/armada/internal/common/armadacontext"
"github.com/armadaproject/armada/internal/common/pgkeyvalue"
"github.com/armadaproject/armada/pkg/api"
)

Expand All @@ -19,55 +18,105 @@ type Deduplicator interface {

// PostgresDeduplicator is an implementation of a Deduplicator that uses a pgkeyvalue.KeyValueStore as its state store
type PostgresDeduplicator struct {
kvStore pgkeyvalue.KeyValueStore
db *pgxpool.Pool
}

func NewDeduplicator(kvStore pgkeyvalue.KeyValueStore) *PostgresDeduplicator {
return &PostgresDeduplicator{kvStore: kvStore}
func NewDeduplicator(db *pgxpool.Pool) *PostgresDeduplicator {
return &PostgresDeduplicator{db: db}
}

func (s *PostgresDeduplicator) GetOriginalJobIds(ctx *armadacontext.Context, queue string, jobRequests []*api.JobSubmitRequestItem) (map[string]string, error) {
// Armada checks for duplicate job submissions if a ClientId (i.e. a deduplication id) is provided.
// Deduplication is based on storing the combined hash of the ClientId and queue. For storage efficiency,
// we store hashes instead of user-provided strings.
kvs := make(map[string][]byte, len(jobRequests))
kvs := make(map[string]string, len(jobRequests))
for _, req := range jobRequests {
if req.ClientId != "" {
kvs[s.jobKey(queue, req.ClientId)] = []byte(req.ClientId)
kvs[s.jobKey(queue, req.ClientId)] = req.ClientId
}
}

duplicates := make(map[string]string)
// If we have any client Ids, retrieve their job ids
if len(kvs) > 0 {
keys := maps.Keys(kvs)
existingKvs, err := s.kvStore.Load(ctx, keys)
existingKvs, err := s.loadMappings(ctx, keys)
if err != nil {
return nil, err
}
for k, v := range kvs {
originalJobId, ok := existingKvs[k]
if ok {
duplicates[string(v)] = string(originalJobId)
duplicates[v] = originalJobId
}
}
}
return duplicates, nil
}

func (s *PostgresDeduplicator) StoreOriginalJobIds(ctx *armadacontext.Context, queue string, mappings map[string]string) error {
if s.kvStore == nil || len(mappings) == 0 {
if len(mappings) == 0 {
return nil
}
kvs := make(map[string][]byte, len(mappings))
kvs := make(map[string]string, len(mappings))
for k, v := range mappings {
kvs[s.jobKey(queue, k)] = []byte(v)
kvs[s.jobKey(queue, k)] = v
}
return s.kvStore.Store(ctx, kvs)
return s.storeMappings(ctx, kvs)
}

func (s *PostgresDeduplicator) jobKey(queue, clientId string) string {
combined := fmt.Sprintf("%s:%s", queue, clientId)
h := sha1.Sum([]byte(combined))
return fmt.Sprintf("%x", h)
return fmt.Sprintf("%s:%s", queue, clientId)
}

func (s *PostgresDeduplicator) storeMappings(ctx *armadacontext.Context, mappings map[string]string) error {
deduplicationIDs := make([]string, 0, len(mappings))
jobIDs := make([]string, 0, len(mappings))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most minor point - elsewhere we use jobIds rather than jobIDs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes very true- I'll fix this in a follow up


for deduplicationID, jobID := range mappings {
deduplicationIDs = append(deduplicationIDs, deduplicationID)
jobIDs = append(jobIDs, jobID)
}

sql := `
INSERT INTO job_deduplication (deduplication_id, job_id)
SELECT unnest($1::text[]), unnest($2::text[])
ON CONFLICT (deduplication_id) DO NOTHING
`
_, err := s.db.Exec(ctx, sql, deduplicationIDs, jobIDs)
if err != nil {
return err
}

return nil
}

func (s *PostgresDeduplicator) loadMappings(ctx *armadacontext.Context, keys []string) (map[string]string, error) {
// Prepare the output map
result := make(map[string]string)

sql := `
SELECT deduplication_id, job_id
FROM job_deduplication
WHERE deduplication_id = ANY($1)
`

rows, err := s.db.Query(ctx, sql, keys)
if err != nil {
return nil, err
}
defer rows.Close()

// Iterate through the result rows
for rows.Next() {
var deduplicationID, jobID string
if err := rows.Scan(&deduplicationID, &jobID); err != nil {
return nil, err
}
result[deduplicationID] = jobID
}

if err := rows.Err(); err != nil {
return nil, err
}

return result, nil
}
12 changes: 11 additions & 1 deletion internal/armada/submit/validation/submit_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ var (
validateTerminationGracePeriod,
validateIngresses,
validatePorts,
validateClientId,
}
)

// ValidateSubmitRequest ensures that the incoming api.JobSubmitRequest is well-formed. It achieves this
// by applying a series of validators that each check a single aspect of the request. Validators may
// chose to validate the whole obSubmitRequest or just a single JobSubmitRequestItem.
// choose to validate the whole obSubmitRequest or just a single JobSubmitRequestItem.
// This function will return the error from the first validator that fails, or nil if all validators pass.
func ValidateSubmitRequest(req *api.JobSubmitRequest, config configuration.SubmissionConfig) error {
for _, validationFunc := range requestValidators {
Expand Down Expand Up @@ -179,6 +180,15 @@ func validateAffinity(j *api.JobSubmitRequestItem, _ configuration.SubmissionCon
return nil
}

// Ensures that if a request specifies a ClientId, that clientID is not too long
func validateClientId(j *api.JobSubmitRequestItem, _ configuration.SubmissionConfig) error {
const maxClientIdChars = 100
if len(j.GetClientId()) > maxClientIdChars {
return fmt.Errorf("client id of length %d is greater than max allowed length of %d", len(j.ClientId), maxClientIdChars)
}
return nil
}

// Ensures that if a request specifies a PriorityClass, that priority class is supported by Armada.
func validatePriorityClasses(j *api.JobSubmitRequestItem, config configuration.SubmissionConfig) error {
spec := j.GetMainPodSpec()
Expand Down
35 changes: 35 additions & 0 deletions internal/armada/submit/validation/submit_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package validation

import (
"strconv"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -727,6 +728,40 @@ func TestValidatePriorityClasses(t *testing.T) {
}
}

func TestValidateClientId(t *testing.T) {
tests := map[string]struct {
req *api.JobSubmitRequestItem
expectSuccess bool
}{
"no client id": {
req: &api.JobSubmitRequestItem{},
expectSuccess: true,
},
"client id of 100 chars is fine": {
req: &api.JobSubmitRequestItem{
ClientId: strings.Repeat("a", 100),
},
expectSuccess: true,
},
"client id over 100 chars is forbidden": {
req: &api.JobSubmitRequestItem{
ClientId: strings.Repeat("a", 101),
},
expectSuccess: false,
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
err := validateClientId(tc.req, configuration.SubmissionConfig{})
if tc.expectSuccess {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
})
}
}

func TestValidateQueue(t *testing.T) {
tests := map[string]struct {
req *api.JobSubmitRequest
Expand Down