diff --git a/boilingcore/templates.go b/boilingcore/templates.go index 1a43da5e9..6a346a7a8 100644 --- a/boilingcore/templates.go +++ b/boilingcore/templates.go @@ -312,9 +312,11 @@ var templateFunctions = template.FuncMap{ "whereClause": strmangle.WhereClause, // Alias and text helping - "aliasCols": func(ta TableAlias) func(string) string { return ta.Column }, - "usesPrimitives": usesPrimitives, - "isPrimitive": isPrimitive, + "aliasCols": func(ta TableAlias) func(string) string { return ta.Column }, + "usesPrimitives": usesPrimitives, + "isPrimitive": isPrimitive, + "isNullPrimitive": isNullPrimitive, + "convertNullToPrimitive": convertNullToPrimitive, "splitLines": func(a string) []string { if a == "" { return nil diff --git a/boilingcore/text_helpers.go b/boilingcore/text_helpers.go index 9f6a7da9c..7803f7063 100644 --- a/boilingcore/text_helpers.go +++ b/boilingcore/text_helpers.go @@ -151,3 +151,28 @@ func isPrimitive(typ string) bool { return false } + +func isNullPrimitive(typ string) bool { + switch typ { + // Numeric + case "null.Int", "null.Int8", "null.Int16", "null.Int32", "null.Int64": + return true + case "null.Uint", "null.Uint8", "null.Uint16", "null.Uint32", "null.Uint64": + return true + case "null.Float32", "null.Float64": + return true + case "null.Byte", "null.String": + return true + } + + return false +} + +// convertNullToPrimitive takes a type name and returns the underlying primitive type name X if it is a `null.X`, +// otherwise it returns the input value unchanged +func convertNullToPrimitive(typ string) string { + if isNullPrimitive(typ) { + return strings.ToLower(strings.Split(typ, ".")[1]) + } + return typ +} diff --git a/templates/main/00_struct.go.tpl b/templates/main/00_struct.go.tpl index aa8df661f..1e23debdf 100644 --- a/templates/main/00_struct.go.tpl +++ b/templates/main/00_struct.go.tpl @@ -63,15 +63,15 @@ func (w {{$name}}) LT(x {{.Type}}) qm.QueryMod { return qmhelper.Where(w.field, func (w {{$name}}) LTE(x {{.Type}}) qm.QueryMod { return qmhelper.Where(w.field, qmhelper.LTE, x) } func (w {{$name}}) GT(x {{.Type}}) qm.QueryMod { return qmhelper.Where(w.field, qmhelper.GT, x) } func (w {{$name}}) GTE(x {{.Type}}) qm.QueryMod { return qmhelper.Where(w.field, qmhelper.GTE, x) } - {{if or (isPrimitive .Type) (isEnumDBType .DBType) -}} -func (w {{$name}}) IN(slice []{{.Type}}) qm.QueryMod { + {{if or (isPrimitive .Type) (isNullPrimitive .Type) (isEnumDBType .DBType) -}} +func (w {{$name}}) IN(slice []{{convertNullToPrimitive .Type}}) qm.QueryMod { values := make([]interface{}, 0, len(slice)) for _, value := range slice { values = append(values, value) } return qm.WhereIn(fmt.Sprintf("%s IN ?", w.field), values...) } -func (w {{$name}}) NIN(slice []{{.Type}}) qm.QueryMod { +func (w {{$name}}) NIN(slice []{{convertNullToPrimitive .Type}}) qm.QueryMod { values := make([]interface{}, 0, len(slice)) for _, value := range slice { values = append(values, value)