/
schema.go
175 lines (145 loc) · 5.25 KB
/
schema.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
package v1
import (
"context"
"errors"
"strings"
v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/validator"
"github.com/rs/zerolog/log"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/authzed/spicedb/internal/datastore"
"github.com/authzed/spicedb/internal/namespace"
"github.com/authzed/spicedb/internal/services/serviceerrors"
"github.com/authzed/spicedb/internal/services/shared"
"github.com/authzed/spicedb/internal/sharederrors"
"github.com/authzed/spicedb/pkg/schemadsl/compiler"
"github.com/authzed/spicedb/pkg/schemadsl/generator"
"github.com/authzed/spicedb/pkg/schemadsl/input"
)
// NewSchemaServer creates a SchemaServiceServer instance.
func NewSchemaServer(ds datastore.Datastore) v1.SchemaServiceServer {
return &schemaServer{
ds: ds,
WithServiceSpecificInterceptors: shared.WithServiceSpecificInterceptors{
Unary: grpcvalidate.UnaryServerInterceptor(),
Stream: grpcvalidate.StreamServerInterceptor(),
},
}
}
type schemaServer struct {
v1.UnimplementedSchemaServiceServer
shared.WithServiceSpecificInterceptors
ds datastore.Datastore
}
func (ss *schemaServer) ReadSchema(ctx context.Context, in *v1.ReadSchemaRequest) (*v1.ReadSchemaResponse, error) {
readRevision, err := ss.ds.HeadRevision(ctx)
if err != nil {
return nil, rewritePermissionsError(ctx, err)
}
nsDefs, err := ss.ds.ListNamespaces(ctx, readRevision)
if err != nil {
return nil, rewriteSchemaError(ctx, err)
}
if len(nsDefs) == 0 {
return nil, status.Errorf(codes.NotFound, "No schema has been defined; please call WriteSchema to start")
}
var objectDefs []string
for _, nsDef := range nsDefs {
objectDef, _ := generator.GenerateSource(nsDef)
objectDefs = append(objectDefs, objectDef)
}
return &v1.ReadSchemaResponse{
SchemaText: strings.Join(objectDefs, "\n\n"),
}, nil
}
func (ss *schemaServer) WriteSchema(ctx context.Context, in *v1.WriteSchemaRequest) (*v1.WriteSchemaResponse, error) {
log.Ctx(ctx).Trace().Str("schema", in.GetSchema()).Msg("requested Schema to be written")
readRevision, err := ss.ds.HeadRevision(ctx)
if err != nil {
return nil, rewritePermissionsError(ctx, err)
}
inputSchema := compiler.InputSchema{
Source: input.InputSource("schema"),
SchemaString: in.GetSchema(),
}
// Build a map of existing definitions to determine those being removed, if any.
existingDefs, err := ss.ds.ListNamespaces(ctx, readRevision)
if err != nil {
return nil, rewriteSchemaError(ctx, err)
}
existingDefMap := map[string]bool{}
for _, existingDef := range existingDefs {
existingDefMap[existingDef.Name] = true
}
// Compile the schema into the namespace definitions.
emptyDefaultPrefix := ""
nsdefs, err := compiler.Compile([]compiler.InputSchema{inputSchema}, &emptyDefaultPrefix)
if err != nil {
return nil, rewriteSchemaError(ctx, err)
}
log.Ctx(ctx).Trace().Interface("namespace definitions", nsdefs).Msg("compiled namespace definitions")
// For each definition, perform a diff and ensure the changes will not result in any
// relationships left without associated schema.
for _, nsdef := range nsdefs {
ts, err := namespace.BuildNamespaceTypeSystemForDefs(nsdef, nsdefs)
if err != nil {
return nil, rewriteSchemaError(ctx, err)
}
if err := ts.Validate(ctx); err != nil {
return nil, rewriteSchemaError(ctx, err)
}
if err := shared.SanityCheckExistingRelationships(ctx, ss.ds, nsdef, readRevision); err != nil {
return nil, rewriteSchemaError(ctx, err)
}
existingDefMap[nsdef.Name] = false
}
log.Ctx(ctx).Trace().Interface("namespace definitions", nsdefs).Msg("validated namespace definitions")
// Ensure that deleting namespaces will not result in any relationships left without associated
// schema.
for nsdefName, removed := range existingDefMap {
if !removed {
continue
}
err := shared.EnsureNoRelationshipsExist(ctx, ss.ds, nsdefName)
if err != nil {
return nil, rewriteSchemaError(ctx, err)
}
}
// Write the new namespaces.
var names []string
for _, nsdef := range nsdefs {
if _, err := ss.ds.WriteNamespace(ctx, nsdef); err != nil {
return nil, rewriteSchemaError(ctx, err)
}
names = append(names, nsdef.Name)
}
// Delete the removed namespaces.
var removedNames []string
for nsdefName, removed := range existingDefMap {
if !removed {
continue
}
if _, err := ss.ds.DeleteNamespace(ctx, nsdefName); err != nil {
return nil, rewriteSchemaError(ctx, err)
}
removedNames = append(removedNames, nsdefName)
}
log.Ctx(ctx).Trace().Interface("namespace definitions", nsdefs).Strs("added/changed", names).Strs("removed", removedNames).Msg("wrote namespace definitions")
return &v1.WriteSchemaResponse{}, nil
}
func rewriteSchemaError(ctx context.Context, err error) error {
var nsNotFoundError sharederrors.UnknownNamespaceError
var errWithContext compiler.ErrorWithContext
switch {
case errors.As(err, &nsNotFoundError):
return status.Errorf(codes.NotFound, "Object Definition `%s` not found", nsNotFoundError.NotFoundNamespaceName())
case errors.As(err, &errWithContext):
return status.Errorf(codes.InvalidArgument, "%s", err)
case errors.As(err, &datastore.ErrReadOnly{}):
return serviceerrors.ErrServiceReadOnly
default:
log.Ctx(ctx).Err(err)
return err
}
}