-
Notifications
You must be signed in to change notification settings - Fork 129
/
server.go
213 lines (185 loc) · 7.53 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
package armada
import (
"fmt"
"net"
"time"
"github.com/apache/pulsar-client-go/pulsar"
"github.com/google/uuid"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/redis/go-redis/extra/redisprometheus/v9"
"github.com/redis/go-redis/v9"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"github.com/armadaproject/armada/internal/armada/configuration"
"github.com/armadaproject/armada/internal/armada/event"
"github.com/armadaproject/armada/internal/armada/queryapi"
"github.com/armadaproject/armada/internal/armada/queue"
"github.com/armadaproject/armada/internal/armada/submit"
"github.com/armadaproject/armada/internal/common/armadacontext"
"github.com/armadaproject/armada/internal/common/auth"
"github.com/armadaproject/armada/internal/common/compress"
"github.com/armadaproject/armada/internal/common/database"
grpcCommon "github.com/armadaproject/armada/internal/common/grpc"
"github.com/armadaproject/armada/internal/common/health"
"github.com/armadaproject/armada/internal/common/pulsarutils"
"github.com/armadaproject/armada/internal/scheduler/reports"
"github.com/armadaproject/armada/internal/scheduler/schedulerobjects"
"github.com/armadaproject/armada/pkg/api"
"github.com/armadaproject/armada/pkg/client"
)
func Serve(ctx *armadacontext.Context, config *configuration.ArmadaConfig, healthChecks *health.MultiChecker) error {
log.Info("Armada server starting")
defer log.Info("Armada server shutting down")
// We call startupCompleteCheck.MarkComplete() when all services have been started.
startupCompleteCheck := health.NewStartupCompleteChecker()
healthChecks.Add(startupCompleteCheck)
// Run all services within an errgroup to propagate errors between services.
// Defer cancelling the parent context to ensure the errgroup is cancelled on return.
ctx, cancel := armadacontext.WithCancel(ctx)
defer cancel()
g, ctx := armadacontext.ErrGroup(ctx)
// List of services to run concurrently.
// Because we want to start services only once all input validation has been completed,
// we add all services to a slice and start them together at the end of this function.
var services []func() error
if err := validateSubmissionConfig(config.Submission); err != nil {
return err
}
// We support multiple simultaneous authentication services (e.g., username/password OpenId).
// For each gRPC request, we try them all until one succeeds, at which point the process is
// short-circuited.
authServices, err := auth.ConfigureAuth(config.Auth)
if err != nil {
return err
}
grpcServer := grpcCommon.CreateGrpcServer(config.Grpc.KeepaliveParams, config.Grpc.KeepaliveEnforcementPolicy, authServices, config.Grpc.Tls)
// Shut down grpcServer if the context is cancelled.
// Give the server 5 seconds to shut down gracefully.
services = append(services, func() error {
<-ctx.Done()
go func() {
time.Sleep(5 * time.Second)
grpcServer.Stop()
}()
grpcServer.GracefulStop()
return nil
})
// Create database connection. This is used for the query api, queues and for job deduplication
dbPool, err := database.OpenPgxPool(config.Postgres)
if err != nil {
return errors.WithMessage(err, "error creating postgres pool")
}
defer dbPool.Close()
queryapiServer := queryapi.New(
dbPool,
config.QueryApi.MaxQueryItems,
func() compress.Decompressor { return compress.NewZlibDecompressor() })
api.RegisterJobsServer(grpcServer, queryapiServer)
eventDb := createRedisClient(&config.EventsApiRedis)
defer func() {
if err := eventDb.Close(); err != nil {
log.WithError(err).Error("failed to close events api Redis client")
}
}()
prometheus.MustRegister(
redisprometheus.NewCollector("armada", "events_redis", eventDb))
queueRepository := queue.NewPostgresQueueRepository(dbPool)
queueCache := queue.NewCachedQueueRepository(queueRepository, config.QueueCacheRefreshPeriod)
services = append(services, func() error {
return queueCache.Run(ctx)
})
eventRepository := event.NewEventRepository(eventDb)
authorizer := auth.NewAuthorizer(
auth.NewPrincipalPermissionChecker(
config.Auth.PermissionGroupMapping,
config.Auth.PermissionScopeMapping,
config.Auth.PermissionClaimMapping,
),
)
serverId := uuid.New()
var pulsarClient pulsar.Client
// API endpoints that generate Pulsar messages.
pulsarClient, err = pulsarutils.NewPulsarClient(&config.Pulsar)
if err != nil {
return err
}
defer pulsarClient.Close()
publisher, err := pulsarutils.NewPulsarPublisher(pulsarClient, pulsar.ProducerOptions{
Name: fmt.Sprintf("armada-server-%s", serverId),
CompressionType: config.Pulsar.CompressionType,
CompressionLevel: config.Pulsar.CompressionLevel,
BatchingMaxSize: config.Pulsar.MaxAllowedMessageSize,
Topic: config.Pulsar.JobsetEventsTopic,
}, config.Pulsar.MaxAllowedMessageSize)
if err != nil {
return errors.Wrapf(err, "error creating pulsar producer")
}
defer publisher.Close()
queueServer := queue.NewServer(queueRepository, authorizer)
submitServer := submit.NewServer(
queueServer,
publisher,
queueCache,
config.Submission,
submit.NewDeduplicator(dbPool),
authorizer)
schedulerApiConnection, err := createApiConnection(config.SchedulerApiConnection)
if err != nil {
return errors.Wrapf(err, "error creating connection to scheduler api")
}
schedulerApiReportsClient := schedulerobjects.NewSchedulerReportingClient(schedulerApiConnection)
schedulingReportsServer := reports.NewProxyingSchedulingReportsServer(schedulerApiReportsClient)
eventServer := event.NewEventServer(
authorizer,
eventRepository,
queueCache,
)
api.RegisterSubmitServer(grpcServer, submitServer)
api.RegisterEventServer(grpcServer, eventServer)
api.RegisterQueueServiceServer(grpcServer, queueServer)
schedulerobjects.RegisterSchedulerReportingServer(grpcServer, schedulingReportsServer)
grpc_prometheus.Register(grpcServer)
// Cancel the errgroup if grpcServer.Serve returns an error.
log.Infof("Armada gRPC server listening on %d", config.GrpcPort)
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", config.GrpcPort))
if err != nil {
return errors.WithStack(err)
}
services = append(services, func() error {
return grpcServer.Serve(lis)
})
// Start all services and wait for the context to be cancelled,
// which occurs when the parent context is cancelled or if any of the services returns an error.
// We start all services at the end of the function to ensure all services are ready.
for _, service := range services {
g.Go(service)
}
startupCompleteCheck.MarkComplete()
return g.Wait()
}
func createRedisClient(config *redis.UniversalOptions) redis.UniversalClient {
return redis.NewUniversalClient(config)
}
func validateSubmissionConfig(config configuration.SubmissionConfig) error {
// Check that the default priority class is allowed to be submitted.
if config.DefaultPriorityClassName != "" {
if !config.AllowedPriorityClassNames[config.DefaultPriorityClassName] {
return errors.WithStack(fmt.Errorf(
"defaultPriorityClassName %s is not allowed; allowedPriorityClassNames is %v",
config.DefaultPriorityClassName, config.AllowedPriorityClassNames,
))
}
}
return nil
}
func createApiConnection(connectionDetails client.ApiConnectionDetails) (*grpc.ClientConn, error) {
grpc_prometheus.EnableClientHandlingTimeHistogram()
return client.CreateApiConnectionWithCallOptions(
&connectionDetails,
[]grpc.CallOption{},
grpc.WithChainUnaryInterceptor(grpc_prometheus.UnaryClientInterceptor),
grpc.WithChainStreamInterceptor(grpc_prometheus.StreamClientInterceptor),
)
}