Skip to content

Commit

Permalink
[fix] escape some strings in query builders (#231)
Browse files Browse the repository at this point in the history
Escape comments and other string fields that are not currently escaped in our query builders.

## Test Plan
<!-- detail ways in which this PR has been tested or needs to be tested -->
* [ ] acceptance tests

## References
* Fixes #127
  • Loading branch information
ryanking committed Aug 3, 2020
1 parent 0902b36 commit 7edd15b
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 19 deletions.
6 changes: 6 additions & 0 deletions pkg/snowflake/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ func TestDatabase(t *testing.T) {
c.SetBool("bam", false)
q = c.Statement()
r.Equal(`CREATE DATABASE "db1" FOO='bar' BAM=false`, q)

// test escaping
c2 := db.Create()
c2.SetString("foo", "ba'r")
q = c2.Statement()
r.Equal(`CREATE DATABASE "db1" FOO='ba\'r'`, q)
}

func TestDatabaseCreateFromShare(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions pkg/snowflake/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (sb *SchemaBuilder) Create() string {
}

if sb.comment != "" {
q.WriteString(fmt.Sprintf(` COMMENT = '%v'`, sb.comment))
q.WriteString(fmt.Sprintf(` COMMENT = '%v'`, EscapeString(sb.comment)))
}

return q.String()
Expand All @@ -120,7 +120,7 @@ func (sb *SchemaBuilder) Swap(targetSchema string) string {

// ChangeComment returns the SQL query that will update the comment on the schema.
func (sb *SchemaBuilder) ChangeComment(c string) string {
return fmt.Sprintf(`ALTER SCHEMA %v SET COMMENT = '%v'`, sb.QualifiedName(), c)
return fmt.Sprintf(`ALTER SCHEMA %v SET COMMENT = '%v'`, sb.QualifiedName(), EscapeString(c))
}

// RemoveComment returns the SQL query that will remove the comment on the schema.
Expand Down
6 changes: 3 additions & 3 deletions pkg/snowflake/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ func TestSchemaCreate(t *testing.T) {
s.WithDataRetentionDays(7)
r.Equal(s.Create(), `CREATE TRANSIENT SCHEMA "db"."test" WITH MANAGED ACCESS DATA_RETENTION_TIME_IN_DAYS = 7`)

s.WithComment("Yeehaw")
r.Equal(s.Create(), `CREATE TRANSIENT SCHEMA "db"."test" WITH MANAGED ACCESS DATA_RETENTION_TIME_IN_DAYS = 7 COMMENT = 'Yeehaw'`)
s.WithComment("Yee'haw")
r.Equal(`CREATE TRANSIENT SCHEMA "db"."test" WITH MANAGED ACCESS DATA_RETENTION_TIME_IN_DAYS = 7 COMMENT = 'Yee\'haw'`, s.Create())
}

func TestSchemaRename(t *testing.T) {
Expand All @@ -44,7 +44,7 @@ func TestSchemaSwap(t *testing.T) {
func TestSchemaChangeComment(t *testing.T) {
r := require.New(t)
s := Schema("test")
r.Equal(s.ChangeComment("worst schema ever"), `ALTER SCHEMA "test" SET COMMENT = 'worst schema ever'`)
r.Equal(`ALTER SCHEMA "test" SET COMMENT = 'worst\' schema ever'`, s.ChangeComment("worst' schema ever"))
}

func TestSchemaRemoveComment(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/snowflake/stage.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func (sb *StageBuilder) Create() string {
}

if sb.comment != "" {
q.WriteString(fmt.Sprintf(` COMMENT = '%v'`, sb.comment))
q.WriteString(fmt.Sprintf(` COMMENT = '%v'`, EscapeString(sb.comment)))
}

return q.String()
Expand Down
6 changes: 3 additions & 3 deletions pkg/snowflake/stage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ func TestStageCreate(t *testing.T) {
s.WithCopyOptions("on_error='skip_file'")
r.Equal(s.Create(), `CREATE STAGE "test_db"."test_schema"."test_stage" URL = 's3://load/encrypted_files/' CREDENTIALS = (aws_role='arn:aws:iam::001234567890:role/mysnowflakerole') ENCRYPTION = (type='AWS_SSE_KMS' kms_key_id = 'aws/key') FILE_FORMAT = (format_name=my_csv_format) COPY_OPTIONS = (on_error='skip_file')`)

s.WithComment("Yeehaw")
r.Equal(s.Create(), `CREATE STAGE "test_db"."test_schema"."test_stage" URL = 's3://load/encrypted_files/' CREDENTIALS = (aws_role='arn:aws:iam::001234567890:role/mysnowflakerole') ENCRYPTION = (type='AWS_SSE_KMS' kms_key_id = 'aws/key') FILE_FORMAT = (format_name=my_csv_format) COPY_OPTIONS = (on_error='skip_file') COMMENT = 'Yeehaw'`)
s.WithComment("Yee'haw")
r.Equal(`CREATE STAGE "test_db"."test_schema"."test_stage" URL = 's3://load/encrypted_files/' CREDENTIALS = (aws_role='arn:aws:iam::001234567890:role/mysnowflakerole') ENCRYPTION = (type='AWS_SSE_KMS' kms_key_id = 'aws/key') FILE_FORMAT = (format_name=my_csv_format) COPY_OPTIONS = (on_error='skip_file') COMMENT = 'Yee\'haw'`, s.Create())

s.WithStorageIntegration("MY_INTEGRATION")
r.Equal(s.Create(), `CREATE STAGE "test_db"."test_schema"."test_stage" URL = 's3://load/encrypted_files/' CREDENTIALS = (aws_role='arn:aws:iam::001234567890:role/mysnowflakerole') STORAGE_INTEGRATION = MY_INTEGRATION ENCRYPTION = (type='AWS_SSE_KMS' kms_key_id = 'aws/key') FILE_FORMAT = (format_name=my_csv_format) COPY_OPTIONS = (on_error='skip_file') COMMENT = 'Yeehaw'`)
r.Equal(`CREATE STAGE "test_db"."test_schema"."test_stage" URL = 's3://load/encrypted_files/' CREDENTIALS = (aws_role='arn:aws:iam::001234567890:role/mysnowflakerole') STORAGE_INTEGRATION = MY_INTEGRATION ENCRYPTION = (type='AWS_SSE_KMS' kms_key_id = 'aws/key') FILE_FORMAT = (format_name=my_csv_format) COPY_OPTIONS = (on_error='skip_file') COMMENT = 'Yee\'haw'`, s.Create())
}

func TestStageRename(t *testing.T) {
Expand Down
2 changes: 2 additions & 0 deletions pkg/snowflake/validation_test.go
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
package snowflake_test

// TODO write tests here
4 changes: 2 additions & 2 deletions pkg/snowflake/view.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (vb *ViewBuilder) Create() string {
q.WriteString(fmt.Sprintf(` VIEW %v`, vb.QualifiedName()))

if vb.comment != "" {
q.WriteString(fmt.Sprintf(" COMMENT = '%v'", vb.comment))
q.WriteString(fmt.Sprintf(" COMMENT = '%v'", EscapeString(vb.comment)))
}

q.WriteString(fmt.Sprintf(" AS %v", vb.statement))
Expand Down Expand Up @@ -139,7 +139,7 @@ func (vb *ViewBuilder) Unsecure() string {
// Note that comment is the only parameter, if more are released this should be
// abstracted as per the generic builder.
func (vb *ViewBuilder) ChangeComment(c string) string {
return fmt.Sprintf(`ALTER VIEW %v SET COMMENT = '%v'`, vb.QualifiedName(), c)
return fmt.Sprintf(`ALTER VIEW %v SET COMMENT = '%v'`, vb.QualifiedName(), EscapeString(c))
}

// RemoveComment returns the SQL query that will remove the comment on the view.
Expand Down
14 changes: 6 additions & 8 deletions pkg/snowflake/view_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,23 @@ func TestView(t *testing.T) {
v.WithSecure()
r.True(v.secure)

v.WithComment("great comment")
r.Equal("great comment", v.comment)

v.WithComment("great' comment")
v.WithStatement("SELECT * FROM DUMMY LIMIT 1")
r.Equal("SELECT * FROM DUMMY LIMIT 1", v.statement)

v.WithStatement("SELECT * FROM DUMMY WHERE blah = 'blahblah' LIMIT 1")

q := v.Create()
r.Equal(`CREATE SECURE VIEW "test" COMMENT = 'great comment' AS SELECT * FROM DUMMY WHERE blah = 'blahblah' LIMIT 1`, q)
r.Equal(`CREATE SECURE VIEW "test" COMMENT = 'great\' comment' AS SELECT * FROM DUMMY WHERE blah = 'blahblah' LIMIT 1`, q)

q = v.Secure()
r.Equal(`ALTER VIEW "test" SET SECURE`, q)

q = v.Unsecure()
r.Equal(`ALTER VIEW "test" UNSET SECURE`, q)

q = v.ChangeComment("bad comment")
r.Equal(`ALTER VIEW "test" SET COMMENT = 'bad comment'`, q)
q = v.ChangeComment("bad' comment")
r.Equal(`ALTER VIEW "test" SET COMMENT = 'bad\' comment'`, q)

q = v.RemoveComment()
r.Equal(`ALTER VIEW "test" UNSET COMMENT`, q)
Expand All @@ -49,7 +47,7 @@ func TestView(t *testing.T) {
r.Equal(v.QualifiedName(), `"mydb".."test"`)

q = v.Create()
r.Equal(`CREATE SECURE VIEW "mydb".."test" COMMENT = 'great comment' AS SELECT * FROM DUMMY WHERE blah = 'blahblah' LIMIT 1`, q)
r.Equal(`CREATE SECURE VIEW "mydb".."test" COMMENT = 'great\' comment' AS SELECT * FROM DUMMY WHERE blah = 'blahblah' LIMIT 1`, q)

q = v.Secure()
r.Equal(`ALTER VIEW "mydb".."test" SET SECURE`, q)
Expand All @@ -58,7 +56,7 @@ func TestView(t *testing.T) {
r.Equal(`SHOW VIEWS LIKE 'test' IN DATABASE "mydb"`, q)

q = v.Drop()
r.Equal(`DROP VIEW "mydb".."test"`, q)
r.Equal(`DROP VIEW "mydb".."test"`, q) // FIXME invalid query
}

func TestQualifiedName(t *testing.T) {
Expand Down

0 comments on commit 7edd15b

Please sign in to comment.