diff --git a/query/query.go b/query/query.go index 6b6fd98..2ae3a9e 100644 --- a/query/query.go +++ b/query/query.go @@ -171,8 +171,8 @@ type Where struct { } // EQ adds an equality condition to the WHERE clause. -func (w *Where) EQ(fieldName string, value string) *Where { - if isEmpty(value) || isEmpty(fieldName) { +func (w *Where) EQ(fieldName string, value any) *Where { + if isEmpty(fieldName) { return w } @@ -181,8 +181,8 @@ func (w *Where) EQ(fieldName string, value string) *Where { } // LIKE adds a LIKE condition to the WHERE clause. -func (w *Where) LIKE(fieldName string, value string) *Where { - if isEmpty(value) || isEmpty(fieldName) { +func (w *Where) LIKE(fieldName string, value any) *Where { + if isEmpty(fieldName) { return w } w.chunks = append(w.chunks, LIKE(fieldName, value)) @@ -207,6 +207,15 @@ func (w *Where) AND(args ...string) *Where { return w } +func (w *Where) IS(fieldName string, value any) *Where { + if isEmpty(fieldName) { + return w + } + + w.chunks = append(w.chunks, IS(fieldName, value)) + return w +} + // AND creates an AND condition string from the provided arguments. func AND(args ...string) string { if len(args) < 2 { @@ -216,13 +225,36 @@ func AND(args ...string) string { } // EQ creates an equality condition string. -func EQ(fieldName string, value string) string { - return fmt.Sprintf("%s='%s'", fieldName, value) +func EQ(fieldName string, value any) string { + return buildComparison(fieldName, value, "=") } // LIKE creates a LIKE condition string. -func LIKE(fieldName string, value string) string { - return fmt.Sprintf("%s LIKE '%s'", fieldName, value) +func LIKE(fieldName string, value any) string { + return buildComparison(fieldName, value, "LIKE") +} + +func buildComparison(fieldName string, value any, comparison string) string { + if isEmpty(fieldName) { + panic("fieldName cannot be empty") + } + switch value.(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return fmt.Sprintf("%s %s %d", fieldName, comparison, value) + case float32, float64: + return fmt.Sprintf("%s %s %v", fieldName, comparison, value) + case string: + if value == "NULL" || value == "null" { + return fmt.Sprintf("%s %s NULL", fieldName, comparison) + } + return fmt.Sprintf("%s %s '%s'", fieldName, comparison, value) + case bool: + return fmt.Sprintf("%s %s %t", fieldName, comparison, value) + case nil: + return fmt.Sprintf("%s %s NULL", fieldName, comparison) + default: + panic("not supported type. Supports only numbers, string, bool") + } } // OR creates an OR condition string from the provided arguments. @@ -233,6 +265,11 @@ func OR(args ...string) string { return fmt.Sprintf("(%s)", strings.Join(args, " OR ")) } +// IS creates an IS condition string from the provided arguments. +func IS(fieldName string, value any) string { + return buildComparison(fieldName, value, "IS") +} + // GroupBy adds a GROUP BY clause to the query func (q *Query) GroupBy(columns ...string) *Query { if len(columns) == 0 { diff --git a/query/query_test.go b/query/query_test.go index 7970250..194f310 100644 --- a/query/query_test.go +++ b/query/query_test.go @@ -67,7 +67,7 @@ var tests = []testItem{ query.Where().EQ("name", "testname") return query.String() }, - expected: "SELECT * FROM test_table WHERE name='testname'", + expected: "SELECT * FROM test_table WHERE name = 'testname'", }, { name: "where LIKE", @@ -85,7 +85,7 @@ var tests = []testItem{ query.Where().OR(EQ("name", "testname"), EQ("age", "12")) return query.String() }, - expected: "SELECT * FROM test_table WHERE (name='testname' OR age='12')", + expected: "SELECT * FROM test_table WHERE (name = 'testname' OR age = '12')", }, { name: "where OR AND", @@ -94,7 +94,7 @@ var tests = []testItem{ query.Where().OR(EQ("name", "testname"), EQ("age", "12")).AND(EQ("id", "123"), EQ("email", "test@mail.com")) return query.String() }, - expected: "SELECT * FROM test_table WHERE (name='testname' OR age='12') AND (id='123' AND email='test@mail.com')", + expected: "SELECT * FROM test_table WHERE (name = 'testname' OR age = '12') AND (id = '123' AND email = 'test@mail.com')", }, { name: "not valid ORDER", @@ -205,7 +205,7 @@ var tests = []testItem{ query.Where().OR(EQ("name", "testname"), AND(EQ("age", "10"), EQ("city", "New York"))) return query.String() }, - expected: "SELECT * FROM test_table WHERE (name='testname' OR (age='10' AND city='New York'))", + expected: "SELECT * FROM test_table WHERE (name = 'testname' OR (age = '10' AND city = 'New York'))", }, { name: "nested AND OR", @@ -214,7 +214,7 @@ var tests = []testItem{ query.Where().AND(EQ("name", "testname"), OR(EQ("age", "10"), EQ("city", "New York"))) return query.String() }, - expected: "SELECT * FROM test_table WHERE (name='testname' AND (age='10' OR city='New York'))", + expected: "SELECT * FROM test_table WHERE (name = 'testname' AND (age = '10' OR city = 'New York'))", }, { name: "inner join", @@ -336,7 +336,7 @@ var tests = []testItem{ query.Where().EQ("name", "John") return query.String() }, - expected: "SELECT * FROM users WHERE name='John' GROUP BY name HAVING COUNT(id) > 10", + expected: "SELECT * FROM users WHERE name = 'John' GROUP BY name HAVING COUNT(id) > 10", }, { // HAVING after GROUP BY @@ -355,6 +355,51 @@ var tests = []testItem{ }, expected: "SELECT * FROM users", }, + { + name: "using IS NULL", + callback: func(t *testing.T) string { + query := New("SELECT * FROM users") + query.Where().IS("name", nil) + return query.String() + }, + expected: "SELECT * FROM users WHERE name IS NULL", + }, + { + name: "using IS NULL and OR", + callback: func(t *testing.T) string { + query := New("SELECT * FROM users") + query.Where().OR(IS("name", nil), EQ("name", "John")) + return query.String() + }, + expected: "SELECT * FROM users WHERE (name IS NULL OR name = 'John')", + }, + { + name: "using EQ with int type", + callback: func(t *testing.T) string { + query := New("SELECT * FROM users") + query.Where().EQ("age", 12) + return query.String() + }, + expected: "SELECT * FROM users WHERE age = 12", + }, + { + name: "using EQ with float64 type", + callback: func(t *testing.T) string { + query := New("SELECT * FROM users") + query.Where().EQ("age", 12.5) + return query.String() + }, + expected: "SELECT * FROM users WHERE age = 12.5", + }, + { + name: "using EQ with bool type", + callback: func(t *testing.T) string { + query := New("SELECT * FROM users") + query.Where().EQ("blocked", true) + return query.String() + }, + expected: "SELECT * FROM users WHERE blocked = true", + }, } func TestQuery(t *testing.T) { @@ -375,5 +420,5 @@ func ExampleQuery() { AND(EQ("id", "123"), EQ("email", "test@mail.com")). LIKE("name", "%testname") fmt.Println(query.String()) - // Output: SELECT * FROM test_table LEFT JOIN test_posts ON test_table.id=test_posts.user_id WHERE (name='testname' OR age='12') AND (id='123' AND email='test@mail.com') AND name LIKE '%testname' ORDER BY name DESC LIMIT 5 OFFSET 1 + // Output: SELECT * FROM test_table LEFT JOIN test_posts ON test_table.id=test_posts.user_id WHERE (name = 'testname' OR age = '12') AND (id = '123' AND email = 'test@mail.com') AND name LIKE '%testname' ORDER BY name DESC LIMIT 5 OFFSET 1 }