Skip to content

Commit

Permalink
feat: Use pipes from SDK (#2003)
Browse files Browse the repository at this point in the history
* Use create pipe from SDK

* Use pipe's showByID from SDK

* Use alter pipe from SDK

* Use drop pipe from SDK

* Remove the majority of old pipe implementation

* Remove deprecated test file

* Use show pipes from SDK

* Remove rest of the old implementation

* Remove unused functions and tests

* Fix after review
  • Loading branch information
sfc-gh-asawicki committed Aug 9, 2023
1 parent d644b63 commit 079d47d
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 560 deletions.
31 changes: 16 additions & 15 deletions pkg/datasources/pipes.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package datasources

import (
"context"
"database/sql"
"errors"
"fmt"
"log"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
)

Expand Down Expand Up @@ -63,31 +63,32 @@ func Pipes() *schema.Resource {

func ReadPipes(d *schema.ResourceData, meta interface{}) error {
db := meta.(*sql.DB)
client := sdk.NewClientFromDB(db)
ctx := context.Background()

databaseName := d.Get("database").(string)
schemaName := d.Get("schema").(string)

currentPipes, err := snowflake.ListPipes(databaseName, schemaName, db)
if errors.Is(err, sql.ErrNoRows) {
// If not found, mark resource to be removed from state file during apply or refresh
log.Printf("[DEBUG] pipes in schema (%s) not found", d.Id())
d.SetId("")
return nil
} else if err != nil {
extractedPipes, err := client.Pipes.Show(ctx, &sdk.PipeShowOptions{
In: &sdk.In{
Schema: sdk.NewSchemaIdentifier(databaseName, schemaName),
},
})
if err != nil {
log.Printf("[DEBUG] unable to parse pipes in schema (%s)", d.Id())
d.SetId("")
return nil
return err
}

pipes := []map[string]interface{}{}

for _, pipe := range currentPipes {
pipeMap := map[string]interface{}{}
pipes := make([]map[string]any, 0, len(extractedPipes))
for _, pipe := range extractedPipes {
pipeMap := map[string]any{}

pipeMap["name"] = pipe.Name
pipeMap["database"] = pipe.DatabaseName
pipeMap["schema"] = pipe.SchemaName
pipeMap["comment"] = pipe.Comment
pipeMap["integration"] = pipe.Integration.String
pipeMap["integration"] = pipe.Integration

pipes = append(pipes, pipeMap)
}
Expand Down
198 changes: 70 additions & 128 deletions pkg/resources/pipe.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package resources

import (
"bytes"
"context"
"database/sql"
"encoding/csv"
"errors"
"fmt"
"log"
"strings"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
)

Expand Down Expand Up @@ -106,129 +105,67 @@ func pipeCopyStatementDiffSuppress(_, o, n string, _ *schema.ResourceData) bool
return strings.TrimRight(o, ";\r\n") == strings.TrimRight(n, ";\r\n")
}

type pipeID struct {
DatabaseName string
SchemaName string
PipeName string
}

// String() takes in a pipeID object and returns a pipe-delimited string:
// DatabaseName|SchemaName|PipeName.
func (si *pipeID) String() (string, error) {
var buf bytes.Buffer
csvWriter := csv.NewWriter(&buf)
csvWriter.Comma = pipeIDDelimiter
dataIdentifiers := [][]string{{si.DatabaseName, si.SchemaName, si.PipeName}}
if err := csvWriter.WriteAll(dataIdentifiers); err != nil {
return "", err
}
strPipeID := strings.TrimSpace(buf.String())
return strPipeID, nil
}

// pipeIDFromString() takes in a pipe-delimited string: DatabaseName|SchemaName|PipeName
// and returns a pipeID object.
func pipeIDFromString(stringID string) (*pipeID, error) {
reader := csv.NewReader(strings.NewReader(stringID))
reader.Comma = pipeIDDelimiter
lines, err := reader.ReadAll()
if err != nil {
return nil, fmt.Errorf("not CSV compatible")
}

if len(lines) != 1 {
return nil, fmt.Errorf("1 line per pipe")
}
if len(lines[0]) != 3 {
return nil, fmt.Errorf("3 fields allowed")
}

pipeResult := &pipeID{
DatabaseName: lines[0][0],
SchemaName: lines[0][1],
PipeName: lines[0][2],
}
return pipeResult, nil
}

// CreatePipe implements schema.CreateFunc.
func CreatePipe(d *schema.ResourceData, meta interface{}) error {
db := meta.(*sql.DB)
database := d.Get("database").(string)
schema := d.Get("schema").(string)
client := sdk.NewClientFromDB(db)

databaseName := d.Get("database").(string)
schemaName := d.Get("schema").(string)
name := d.Get("name").(string)

builder := snowflake.NewPipeBuilder(name, database, schema)
ctx := context.Background()
objectIdentifier := sdk.NewSchemaObjectIdentifier(databaseName, schemaName, name)

// Set optionals
if v, ok := d.GetOk("copy_statement"); ok {
builder.WithCopyStatement(v.(string))
}
opts := &sdk.PipeCreateOptions{}

copyStatement := d.Get("copy_statement").(string)

// Set optionals
if v, ok := d.GetOk("comment"); ok {
builder.WithComment(v.(string))
opts.Comment = sdk.String(v.(string))
}

if v, ok := d.GetOk("auto_ingest"); ok && v.(bool) {
builder.WithAutoIngest()
opts.AutoIngest = sdk.Bool(true)
}

if v, ok := d.GetOk("aws_sns_topic_arn"); ok {
builder.WithAwsSnsTopicArn(v.(string))
opts.AwsSnsTopic = sdk.String(v.(string))
}

if v, ok := d.GetOk("integration"); ok {
builder.WithIntegration(v.(string))
opts.Integration = sdk.String(v.(string))
}

if v, ok := d.GetOk("error_integration"); ok {
builder.WithErrorIntegration((v.(string)))
opts.ErrorIntegration = sdk.String(v.(string))
}

q := builder.Create()

if err := snowflake.Exec(db, q); err != nil {
return fmt.Errorf("error creating pipe %v err = %w", name, err)
}

pipeID := &pipeID{
DatabaseName: database,
SchemaName: schema,
PipeName: name,
}
dataIDInput, err := pipeID.String()
err := client.Pipes.Create(ctx, objectIdentifier, copyStatement, opts)
if err != nil {
return err
}
d.SetId(dataIDInput)

d.SetId(helpers.EncodeSnowflakeID(objectIdentifier))

return ReadPipe(d, meta)
}

// ReadPipe implements schema.ReadFunc.
func ReadPipe(d *schema.ResourceData, meta interface{}) error {
db := meta.(*sql.DB)
pipeID, err := pipeIDFromString(d.Id())
if err != nil {
return err
}
client := sdk.NewClientFromDB(db)
objectIdentifier := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier)

dbName := pipeID.DatabaseName
schema := pipeID.SchemaName
name := pipeID.PipeName

sq := snowflake.NewPipeBuilder(name, dbName, schema).Show()
row := snowflake.QueryRow(db, sq)
pipe, err := snowflake.ScanPipe(row)
if errors.Is(err, sql.ErrNoRows) {
ctx := context.Background()
pipe, err := client.Pipes.ShowByID(ctx, objectIdentifier)
if err != nil {
// If not found, mark resource to be removed from state file during apply or refresh
log.Printf("[DEBUG] pipe (%s) not found", d.Id())
d.SetId("")
return nil
}
if err != nil {
return err
}

if err := d.Set("name", pipe.Name); err != nil {
return err
Expand Down Expand Up @@ -258,55 +195,67 @@ func ReadPipe(d *schema.ResourceData, meta interface{}) error {
return err
}

if err := d.Set("auto_ingest", pipe.NotificationChannel != nil); err != nil {
if err := d.Set("auto_ingest", pipe.NotificationChannel != ""); err != nil {
return err
}

if pipe.NotificationChannel != nil && strings.Contains(*pipe.NotificationChannel, "arn:aws:sns:") {
if strings.Contains(pipe.NotificationChannel, "arn:aws:sns:") {
err = d.Set("aws_sns_topic_arn", pipe.NotificationChannel)
return err
}

// The "DESCRIBE PIPE ..." command returns the string "null" for error_integration
if pipe.ErrorIntegration.String == "null" {
pipe.ErrorIntegration.Valid = false
pipe.ErrorIntegration.String = ""
if err := d.Set("error_integration", pipe.ErrorIntegration); err != nil {
return err
}
err = d.Set("error_integration", pipe.ErrorIntegration.String)
return err

return nil
}

// UpdatePipe implements schema.UpdateFunc.
func UpdatePipe(d *schema.ResourceData, meta interface{}) error {
pipeID, err := pipeIDFromString(d.Id())
if err != nil {
return err
}

dbName := pipeID.DatabaseName
schema := pipeID.SchemaName
pipe := pipeID.PipeName
db := meta.(*sql.DB)
client := sdk.NewClientFromDB(db)
objectIdentifier := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier)
ctx := context.Background()

builder := snowflake.NewPipeBuilder(pipe, dbName, schema)
pipeSet := &sdk.PipeSet{}
pipeUnset := &sdk.PipeUnset{}
var runSetStatement bool
var runUnsetStatement bool

db := meta.(*sql.DB)
if d.HasChange("comment") {
comment := d.Get("comment")
q := builder.ChangeComment(comment.(string))
if err := snowflake.Exec(db, q); err != nil {
return fmt.Errorf("error updating pipe comment on %v", d.Id())
if comment, ok := d.GetOk("comment"); ok {
runSetStatement = true
pipeSet.Comment = sdk.String(comment.(string))
} else {
runUnsetStatement = true
pipeUnset.Comment = sdk.Bool(true)
}
}

if d.HasChange("error_integration") {
var q string
if errorIntegration, ok := d.GetOk("error_integration"); ok {
q = builder.ChangeErrorIntegration(errorIntegration.(string))
runSetStatement = true
pipeSet.Comment = sdk.String(errorIntegration.(string))
} else {
q = builder.RemoveErrorIntegration()
runUnsetStatement = true
pipeUnset.Comment = sdk.Bool(true)
}
if err := snowflake.Exec(db, q); err != nil {
return fmt.Errorf("error updating pipe error_integration on %v", d.Id())
}

if runSetStatement {
options := &sdk.PipeAlterOptions{Set: pipeSet}
err := client.Pipes.Alter(ctx, objectIdentifier, options)
if err != nil {
return fmt.Errorf("error updating pipe %v: %w", objectIdentifier.Name(), err)
}
}

if runUnsetStatement {
options := &sdk.PipeAlterOptions{Unset: pipeUnset}
err := client.Pipes.Alter(ctx, objectIdentifier, options)
if err != nil {
return fmt.Errorf("error updating pipe %v: %w", objectIdentifier.Name(), err)
}
}

Expand All @@ -316,22 +265,15 @@ func UpdatePipe(d *schema.ResourceData, meta interface{}) error {
// DeletePipe implements schema.DeleteFunc.
func DeletePipe(d *schema.ResourceData, meta interface{}) error {
db := meta.(*sql.DB)
pipeID, err := pipeIDFromString(d.Id())
client := sdk.NewClientFromDB(db)
ctx := context.Background()
objectIdentifier := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier)

err := client.Pipes.Drop(ctx, objectIdentifier)
if err != nil {
return err
}

dbName := pipeID.DatabaseName
schema := pipeID.SchemaName
pipe := pipeID.PipeName

q := snowflake.NewPipeBuilder(pipe, dbName, schema).Drop()

if err := snowflake.Exec(db, q); err != nil {
return fmt.Errorf("error deleting pipe %v err = %w", d.Id(), err)
}

d.SetId("")

return nil
}

0 comments on commit 079d47d

Please sign in to comment.