Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions indexer/indexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"api.audius.co/config"
dbv1 "api.audius.co/database"
"api.audius.co/jobs"
"api.audius.co/logging"
"connectrpc.com/connect"
corev1 "github.com/OpenAudio/go-openaudio/pkg/api/core/v1"
Expand Down Expand Up @@ -66,9 +67,35 @@ func (ci *CoreIndexer) Start(ctx context.Context) error {
eg.Go(func() error {
return ci.run(ctx)
})
ci.startParityJobs(ctx)
return eg.Wait()
}

// startParityJobs schedules the periodic jobs that mirror what the legacy
// Python discovery-provider celery beat used to run. Each job's ScheduleEvery
// launches its own goroutine and exits when ctx is cancelled, so we don't
// need to add them to the errgroup — they self-manage.
//
// Intervals match apps' celery beat_schedule in src/app.py where applicable.
// update_delist_statuses isn't in apps' beat (apps invokes it externally),
// so we pick a conservative default.
func (ci *CoreIndexer) startParityJobs(ctx context.Context) {
jobs.NewHourlyPlayCountsJob(ci.Config, ci.pool).
ScheduleEvery(ctx, 30*time.Second)

jobs.NewPrunePlaysJob(ci.Config, ci.pool).
ScheduleEvery(ctx, 30*time.Second)

jobs.NewUserListeningHistoryJob(ci.Config, ci.pool).
ScheduleEvery(ctx, 5*time.Second)

jobs.NewTrendingJob(ci.Config, ci.pool).
ScheduleEvery(ctx, 10*time.Second)

jobs.NewUpdateDelistStatusesJob(ci.Config, ci.pool).
ScheduleEvery(ctx, 5*time.Minute)
}

func (ci *CoreIndexer) run(ctx context.Context) error {
go logging.SyncOnTicks(ctx, ci.logger, time.Second*10)
var height int64
Expand Down
95 changes: 95 additions & 0 deletions jobs/delegate_auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package jobs

import (
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"net/http"
"strings"
"time"

"github.com/ethereum/go-ethereum/crypto"
)

// signedHTTPGet performs an HTTP GET signed with a delegate private key.
// Mirrors apps' src/utils/auth_helpers.py:signed_get():
//
// timestamp = round(time.time() * 1000)
// signature = sign(timestamp, private_key)
// nonce = f"{timestamp}:{signature.hex()}"
// Authorization: Basic base64(nonce)
//
// where sign() does:
// digest = keccak(timestamp_text)
// signature = ETH-personal-sign(digest_hex, private_key)
//
// `delegatePrivateKey` is the 0x-prefixed (or unprefixed) hex string from
// config.Cfg.DelegatePrivateKey.
func signedHTTPGet(client *http.Client, url, delegatePrivateKey string) (*http.Response, error) {
auth, err := basicAuthNonce(delegatePrivateKey, time.Now())
if err != nil {
return nil, fmt.Errorf("build auth: %w", err)
}

req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", auth)

if client == nil {
client = http.DefaultClient
}
return client.Do(req)
}

// drainResponse closes the body of a non-2xx response and returns an error
// including the status + first 512 bytes of body for context.
func drainResponse(resp *http.Response) error {
defer resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return nil
}
body, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}

// basicAuthNonce returns the Authorization header value used by apps'
// signed_get. Exported through signedHTTPGet but extracted for testing.
func basicAuthNonce(delegatePrivateKey string, now time.Time) (string, error) {
pkHex := strings.TrimPrefix(delegatePrivateKey, "0x")
pkBytes, err := hex.DecodeString(pkHex)
if err != nil {
return "", fmt.Errorf("decode private key: %w", err)
}
pk, err := crypto.ToECDSA(pkBytes)
if err != nil {
return "", fmt.Errorf("parse private key: %w", err)
}

// Python: timestamp = round(time.time() * 1000)
tsStr := fmt.Sprintf("%d", now.UnixMilli())

// Python: digest = Web3.keccak(text=timestamp_str).hex()
digestBytes := crypto.Keccak256([]byte(tsStr))

// Python: encode_defunct(hexstr=digest_hex) -> EIP-191 personal-sign envelope
// over the BYTES that digest_hex decodes to (i.e. the raw keccak digest).
// We replicate that here.
prefix := fmt.Sprintf("\x19Ethereum Signed Message:\n%d", len(digestBytes))
personalDigest := crypto.Keccak256(append([]byte(prefix), digestBytes...))

sig, err := crypto.Sign(personalDigest, pk)
if err != nil {
return "", fmt.Errorf("sign: %w", err)
}
// go-ethereum's crypto.Sign returns recovery id 0 or 1; web3.py returns
// 27 or 28. Normalize to match apps.
if sig[64] < 27 {
sig[64] += 27
}

nonce := tsStr + ":" + hex.EncodeToString(sig)
return "Basic " + base64.StdEncoding.EncodeToString([]byte(nonce)), nil
}
40 changes: 40 additions & 0 deletions jobs/delegate_auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package jobs

import (
"encoding/base64"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// TestBasicAuthNonce_Shape verifies the Authorization header has the
// expected "Basic base64(<ts>:<sig_hex>)" structure and decodes cleanly.
// Cross-format verification against Python apps is left to integration
// testing — the signing primitives (keccak + eth personal-sign) are
// covered by go-ethereum's own tests.
func TestBasicAuthNonce_Shape(t *testing.T) {
// Test key from api/'s default dev config.
key := "13422b9affd75ff80f94f1ea394e6a6097830cb58cda2d3542f37464ecaee7df"
now := time.UnixMilli(1700000000000)

auth, err := basicAuthNonce(key, now)
require.NoError(t, err)

require.True(t, strings.HasPrefix(auth, "Basic "))
decoded, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(auth, "Basic "))
require.NoError(t, err)

parts := strings.SplitN(string(decoded), ":", 2)
require.Len(t, parts, 2)
assert.Equal(t, "1700000000000", parts[0])
// signature is 65 bytes = 130 hex chars
assert.Len(t, parts[1], 130)
}

func TestBasicAuthNonce_BadKey(t *testing.T) {
_, err := basicAuthNonce("not-hex", time.Now())
assert.Error(t, err)
}
152 changes: 152 additions & 0 deletions jobs/index_hourly_play_counts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package jobs

import (
"context"
"errors"
"fmt"
"sync"
"time"

"api.audius.co/config"
"api.audius.co/database"
"api.audius.co/logging"
"github.com/jackc/pgx/v5"
"go.uber.org/zap"
)

// HourlyPlayCountsJob rolls up plays into the hourly_play_counts table.
// Mirrors apps/packages/discovery-provider/src/tasks/index_hourly_play_counts.py.
//
// On each run:
// - reads the last processed plays.id checkpoint from indexing_checkpoints,
// - buckets newer plays by date_trunc('hour', created_at),
// - upserts into hourly_play_counts with sum-on-conflict,
// - advances the checkpoint to max(plays.id) seen.
//
// hourly_play_counts powers GET /v1/metrics/plays. If this job stops running
// the endpoint returns an empty array (it doesn't error).
type HourlyPlayCountsJob struct {
pool database.DbPool
logger *zap.Logger

mutex sync.Mutex
isRunning bool
}

// HourlyPlayCountsCheckpoint is the indexing_checkpoints.tablename used to
// remember the highest plays.id already rolled up.
const HourlyPlayCountsCheckpoint = "hourly_play_counts"

func NewHourlyPlayCountsJob(cfg config.Config, pool database.DbPool) *HourlyPlayCountsJob {
return &HourlyPlayCountsJob{
pool: pool,
logger: logging.NewZapLogger(cfg).Named("HourlyPlayCountsJob"),
}
}

// ScheduleEvery runs the job every `interval` until the context is cancelled.
func (j *HourlyPlayCountsJob) ScheduleEvery(ctx context.Context, interval time.Duration) *HourlyPlayCountsJob {
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
j.Run(ctx)
case <-ctx.Done():
j.logger.Info("Job shutting down")
return
}
}
}()
return j
}

// Run executes the job once.
func (j *HourlyPlayCountsJob) Run(ctx context.Context) {
if err := j.run(ctx); err != nil {
j.logger.Error("Job run failed", zap.Error(err))
}
}

func (j *HourlyPlayCountsJob) run(ctx context.Context) error {
j.mutex.Lock()
if j.isRunning {
j.mutex.Unlock()
return fmt.Errorf("job is already running")
}
j.isRunning = true
j.mutex.Unlock()
defer func() {
j.mutex.Lock()
j.isRunning = false
j.mutex.Unlock()
}()

prev, err := getCheckpoint(ctx, j.pool, HourlyPlayCountsCheckpoint)
if err != nil {
return fmt.Errorf("read checkpoint: %w", err)
}

var newMax *int64
err = j.pool.QueryRow(ctx, "SELECT MAX(id) FROM plays").Scan(&newMax)
if err != nil {
return fmt.Errorf("read max(plays.id): %w", err)
}
if newMax == nil || *newMax == prev {
j.logger.Debug("No new plays since last run")
return nil
}

// Bucket and upsert in a single statement: the SUM-on-conflict means we
// can safely process the same row twice (idempotent on full batches),
// but the checkpoint advances per run so the steady-state cost stays
// bounded.
res, err := j.pool.Exec(ctx, `
INSERT INTO hourly_play_counts (hourly_timestamp, play_count)
SELECT date_trunc('hour', created_at) AS hourly_timestamp,
COUNT(*) AS play_count
FROM plays
WHERE id > @prev AND id <= @new_max
GROUP BY date_trunc('hour', created_at)
ON CONFLICT (hourly_timestamp)
DO UPDATE SET play_count = hourly_play_counts.play_count + EXCLUDED.play_count
`, pgx.NamedArgs{
"prev": prev,
"new_max": *newMax,
})
if err != nil {
return fmt.Errorf("upsert hourly buckets: %w", err)
}

if err := saveCheckpoint(ctx, j.pool, HourlyPlayCountsCheckpoint, *newMax); err != nil {
return fmt.Errorf("save checkpoint: %w", err)
}

j.logger.Info("Hourly play counts updated",
zap.Int64("prev_checkpoint", prev),
zap.Int64("new_checkpoint", *newMax),
zap.Int64("hours_touched", res.RowsAffected()))
return nil
}

// getCheckpoint reads the named checkpoint value (last_checkpoint column),
// returning 0 when the row does not yet exist.
func getCheckpoint(ctx context.Context, pool database.DbPool, name string) (int64, error) {
var v int64
err := pool.QueryRow(ctx, "SELECT last_checkpoint FROM indexing_checkpoints WHERE tablename = $1", name).Scan(&v)
if errors.Is(err, pgx.ErrNoRows) {
return 0, nil
}
return v, err
}

// saveCheckpoint upserts the named checkpoint.
func saveCheckpoint(ctx context.Context, pool database.DbPool, name string, value int64) error {
_, err := pool.Exec(ctx, `
INSERT INTO indexing_checkpoints (tablename, last_checkpoint)
VALUES ($1, $2)
ON CONFLICT (tablename) DO UPDATE SET last_checkpoint = EXCLUDED.last_checkpoint
`, name, value)
return err
}
Loading
Loading