diff --git a/internal/tjson/schema.go b/internal/tjson/schema.go index 93332dc5f7ba..bac61397fa0d 100644 --- a/internal/tjson/schema.go +++ b/internal/tjson/schema.go @@ -18,7 +18,6 @@ import ( "bytes" "encoding/json" "fmt" - "reflect" "time" "github.com/FerretDB/FerretDB/internal/types" @@ -123,15 +122,70 @@ var ( ) // Equal returns true if the schemas are equal. +// For composite types schemas are equal if their types and subschemas are equal. +// For scalar types schemas are equal if their types and formats are equal. func (s *Schema) Equal(other *Schema) bool { if s == other { return true } - // TODO compare significant fields only (ignore title, description, etc.) - // TODO compare format according to type (for example, for Number, EmptyFormat == Double) - // https://github.com/FerretDB/FerretDB/issues/683 - return reflect.DeepEqual(s, other) + if s.Type != other.Type { + return false + } + + switch s.Type { + case Object: + // If `s` and `other` are objects, compare their properties. + if len(s.Properties) != len(other.Properties) { + return false + } + for k, v := range s.Properties { + vOther, ok := other.Properties[k] + if !ok { + return false + } + if eq := v.Equal(vOther); !eq { + return false + } + } + return true + case Array: + // If `s` and `other` are arrays, compare their items. + if s.Items == nil || other.Items == nil { + panic("schema.Equal: array with nil items") + } + return s.Items.Equal(other.Items) + case String, Integer, Number, Boolean: + // For scalar types, it's enough to compare their formats. + if s.Format == other.Format { + return true + } + default: + panic(fmt.Sprintf("schema.Equal: unknown type `%s`", s.Type)) + } + + // If formats don't match, normalize schemas: empty format is equal to double for numbers and int64 for integers, + // see https://docs.tigrisdata.com/overview/schema#data-types. + formatS, formatOther := s.Format, other.Format + switch s.Type { + case Number: + if s.Format == EmptyFormat { + formatS = Double + } + if other.Format == EmptyFormat { + formatOther = Double + } + case Integer: + if s.Format == EmptyFormat { + formatS = Int64 + } + if other.Format == EmptyFormat { + formatOther = Int64 + } + case Array, Boolean, Object, String: + // do nothing: these types don't have "default" format + } + return formatS == formatOther } // Marshal returns the JSON encoding of the schema. diff --git a/internal/tjson/schema_test.go b/internal/tjson/schema_test.go index 0d1a0f9b85ad..a144d911d2b2 100644 --- a/internal/tjson/schema_test.go +++ b/internal/tjson/schema_test.go @@ -67,3 +67,169 @@ func TestSchemaMarshalUnmarshal(t *testing.T) { assert.Equal(t, expected, actual) } + +func TestSchemaEqual(t *testing.T) { + t.Parallel() + + caseInt64Schema := Schema{ + Type: Integer, + Format: Int64, + } + caseIntEmptySchema := Schema{ + Type: Integer, + Format: EmptyFormat, + } + caseDoubleSchema := Schema{ + Type: Number, + Format: Double, + } + caseDoubleEmptySchema := Schema{ + Type: Number, + Format: EmptyFormat, + } + caseObjectSchema := Schema{ + Type: Object, + Properties: map[string]*Schema{ + "a": stringSchema, + "42": &caseIntEmptySchema, + }, + } + caseObjectSchemaEqual := Schema{ + Type: Object, + Properties: map[string]*Schema{ + "42": &caseIntEmptySchema, + "a": stringSchema, + }, + } + caseObjectSchemaNotEqual := Schema{ + Type: Object, + Properties: map[string]*Schema{ + "42": &caseIntEmptySchema, + "a": boolSchema, + }, + } + caseObjectSchemaKeyMissing := Schema{ + Type: Object, + Properties: map[string]*Schema{ + "42": &caseIntEmptySchema, + "b": stringSchema, + }, + } + caseObjectSchemaEmpty := Schema{ + Type: Object, + Properties: map[string]*Schema{}, + } + caseArrayDoubleSchema := Schema{ + Type: Array, + Items: &caseDoubleSchema, + } + caseArrayDoubleEmptySchema := Schema{ + Type: Array, + Items: &caseDoubleEmptySchema, + } + caseArrayObjectsSchema := Schema{ + Type: Array, + Items: &caseObjectSchema, + } + caseArrayObjectsSchemaEqual := Schema{ + Type: Array, + Items: &caseObjectSchemaEqual, + } + caseArrayObjectsSchemaNotEqual := Schema{ + Type: Array, + Items: &caseObjectSchemaNotEqual, + } + + for name, tc := range map[string]struct { + s *Schema + other *Schema + expected bool + }{ + "StringString": { + s: stringSchema, + other: stringSchema, + expected: true, + }, + "StringNumber": { + s: stringSchema, + other: doubleSchema, + expected: false, + }, + "NumberString": { + s: doubleSchema, + other: stringSchema, + expected: false, + }, + "EmptyInt64": { + s: &caseIntEmptySchema, + other: &caseInt64Schema, + expected: true, + }, + "Int64Empty": { + s: &caseInt64Schema, + other: &caseIntEmptySchema, + expected: true, + }, + "Int64Int32": { + s: &caseInt64Schema, + other: int32Schema, + expected: false, + }, + "EmptyInt32": { + s: &caseIntEmptySchema, + other: int32Schema, + expected: false, + }, + "DoubleEmpty": { + s: &caseDoubleSchema, + other: &caseDoubleEmptySchema, + expected: true, + }, + "ObjectsEqual": { + s: &caseObjectSchema, + other: &caseObjectSchemaEqual, + expected: true, + }, + "ObjectsNotEqual": { + s: &caseObjectSchemaEqual, + other: &caseObjectSchemaNotEqual, + expected: false, + }, + "ObjectsKeyMissing": { + s: &caseObjectSchema, + other: &caseObjectSchemaKeyMissing, + expected: false, + }, + "ObjectsEmpty": { + s: &caseObjectSchema, + other: &caseObjectSchemaEmpty, + expected: false, + }, + "ArrayDouble": { + s: &caseArrayDoubleSchema, + other: &caseArrayDoubleEmptySchema, + expected: true, + }, + "ArrayObjects": { + s: &caseArrayObjectsSchema, + other: &caseArrayObjectsSchemaEqual, + expected: true, + }, + "ArrayObjectsNotEqual": { + s: &caseArrayObjectsSchemaNotEqual, + other: &caseArrayObjectsSchemaEqual, + expected: false, + }, + "ArrayObjectsDouble": { + s: &caseArrayObjectsSchema, + other: &caseArrayDoubleSchema, + expected: false, + }, + } { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tc.expected, tc.s.Equal(tc.other)) + }) + } +}