Skip to content

Commit

Permalink
Draft for snowflake avro sync method impl
Browse files Browse the repository at this point in the history
  • Loading branch information
iskakaushik committed Jun 11, 2023
1 parent dc9c444 commit e122a09
Show file tree
Hide file tree
Showing 3 changed files with 331 additions and 2 deletions.
127 changes: 125 additions & 2 deletions flow/connectors/snowflake/qrep.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
package connsnowflake

import (
"database/sql"
"fmt"
"os"
"time"

"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/encoding/protojson"
)

const qRepMetadataTableName = "_peerdb_query_replication_metadata"

func (c *SnowflakeConnector) GetQRepPartitions(config *protos.QRepConfig,
last *protos.QRepPartition,
) ([]*protos.QRepPartition, error) {
Expand All @@ -22,9 +31,123 @@ func (c *SnowflakeConnector) SyncQRepRecords(
partition *protos.QRepPartition,
records *model.QRecordBatch,
) (int, error) {
panic("not implemented")
// Ensure the destination table is available.
destTable := config.DestinationTableIdentifier

tblSchema, err := c.getTableSchema(destTable)
if err != nil {
return 0, fmt.Errorf("failed to get schema of table %s: %w", destTable, err)
}

done, err := c.isPartitionSynced(partition.PartitionId)
if err != nil {
return 0, fmt.Errorf("failed to check if partition %s is synced: %w", partition.PartitionId, err)
}

if done {
log.Infof("Partition %s has already been synced", partition.PartitionId)
return 0, nil
}

syncMode := config.SyncMode
switch syncMode {
case protos.QRepSyncMode_QREP_SYNC_MODE_MULTI_INSERT:
return 0, fmt.Errorf("multi-insert sync mode not supported for snowflake")
case protos.QRepSyncMode_QREP_SYNC_MODE_STORAGE_AVRO:
// create a temp directory for storing avro files
tmpDir, err := os.MkdirTemp("", "peerdb-avro")
if err != nil {
return 0, fmt.Errorf("failed to create temp directory: %w", err)
}
avroSync := &SnowflakeAvroSyncMethod{connector: c, localDir: tmpDir}
return avroSync.SyncQRepRecords(config.FlowJobName, destTable, partition, tblSchema, records)
default:
return 0, fmt.Errorf("unsupported sync mode: %s", syncMode)
}
}

func (c *SnowflakeConnector) createMetadataInsertStatement(
partition *protos.QRepPartition,
jobName string,
startTime time.Time,
) (string, error) {
// marshal the partition to json using protojson
pbytes, err := protojson.Marshal(partition)
if err != nil {
return "", fmt.Errorf("failed to marshal partition to json: %v", err)
}

// convert the bytes to string
partitionJSON := string(pbytes)

insertMetadataStmt := fmt.Sprintf(
`INSERT INTO "%s"."%s"
(flowJobName, partitionID, syncPartition, syncStartTime, syncFinishTime)
VALUES ('%s', '%s', PARSE_JSON('%s'), '%s'::timestamp, CURRENT_TIMESTAMP);`,
peerDBInternalSchema, qRepMetadataTableName, jobName, partition.PartitionId,
partitionJSON, startTime.Format(time.RFC3339))

return insertMetadataStmt, nil
}

func (c *SnowflakeConnector) getTableSchema(tableName string) ([]*sql.ColumnType, error) {
//nolint:gosec
queryString := fmt.Sprintf(`
SELECT *
FROM %s
LIMIT 0
`, tableName)

rows, err := c.database.Query(queryString)
if err != nil {
return nil, fmt.Errorf("failed to execute query: %w", err)
}
defer rows.Close()

columnTypes, err := rows.ColumnTypes()
if err != nil {
return nil, fmt.Errorf("failed to get column types: %w", err)
}

return columnTypes, nil
}

func (c *SnowflakeConnector) isPartitionSynced(partitionID string) (bool, error) {
//nolint:gosec
queryString := fmt.Sprintf(`
SELECT COUNT(*)
FROM _peerdb_query_replication_metadata
WHERE partitionID = '%s'
`, partitionID)

row := c.database.QueryRow(queryString)

var count int
if err := row.Scan(&count); err != nil {
return false, fmt.Errorf("failed to execute query: %w", err)
}

return count > 0, nil
}

func (c *SnowflakeConnector) SetupQRepMetadataTables(config *protos.QRepConfig) error {
panic("SetupQRepMetadataTables not implemented for snowflake connector")
// Define the schema
schemaStatement := `
CREATE TABLE IF NOT EXISTS %s.%s (
flowJobName STRING,
partitionID STRING,
syncPartition VARIANT,
syncStartTime TIMESTAMP_LTZ,
syncFinishTime TIMESTAMP_LTZ
);
`
queryString := fmt.Sprintf(schemaStatement, peerDBInternalSchema, qRepMetadataTableName)

// Execute the query
_, err := c.database.Exec(queryString)
if err != nil {
return fmt.Errorf("failed to create table %s: %w", qRepMetadataTableName, err)
}

return nil
}
205 changes: 205 additions & 0 deletions flow/connectors/snowflake/qrep_avro_sync.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
package connsnowflake

import (
"database/sql"
"encoding/json"
"fmt"
"os"
"time"

"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
"github.com/linkedin/goavro/v2"
log "github.com/sirupsen/logrus"
_ "github.com/snowflakedb/gosnowflake"
)

type SnowflakeAvroSyncMethod struct {
connector *SnowflakeConnector
localDir string
}

func NewSnowflakeAvroSyncMethod(connector *SnowflakeConnector, localDir string) *SnowflakeAvroSyncMethod {
return &SnowflakeAvroSyncMethod{
connector: connector,
localDir: localDir,
}
}

func (s *SnowflakeAvroSyncMethod) SyncQRepRecords(
flowJobName string,
dstTableName string,
partition *protos.QRepPartition,
dstTableSchema []*sql.ColumnType,
records *model.QRecordBatch) (int, error) {

startTime := time.Now()

// You will need to define your Avro schema as a string
avroSchema, err := DefineAvroSchema(dstTableName, dstTableSchema)
if err != nil {
return 0, fmt.Errorf("failed to define Avro schema: %w", err)
}

fmt.Printf("Avro schema: %s\n", avroSchema)

// Create a local file path with flowJobName and partitionID
localFilePath := fmt.Sprintf("%s/%s.avro", s.localDir, partition.PartitionId)
file, err := os.Create(localFilePath)
if err != nil {
return 0, fmt.Errorf("failed to create file: %w", err)
}
defer file.Close()

// Create OCF Writer
ocfWriter, err := goavro.NewOCFWriter(goavro.OCFConfig{
W: file,
Schema: avroSchema.Schema,
})
if err != nil {
return 0, fmt.Errorf("failed to create OCF writer: %w", err)
}

// Write each QRecord to the OCF file
for _, qRecord := range records.Records {
avroMap, err := qRecord.ToAvroCompatibleMap(&avroSchema.NullableFields, records.ColumnNames)
if err != nil {
return 0, fmt.Errorf("failed to convert QRecord to Avro compatible map: %w", err)
}

err = ocfWriter.Append([]interface{}{avroMap})
if err != nil {
return 0, fmt.Errorf("failed to write record to OCF file: %w", err)
}
}

// write this file to snowflake using COPY INTO statement
copyCmd := fmt.Sprintf("COPY INTO %s FROM @%%%s/%s FILE_FORMAT = (TYPE = AVRO)",
dstTableName, dstTableName, partition.PartitionId)

if _, err = s.connector.database.Exec(copyCmd); err != nil {
return 0, fmt.Errorf("failed to run COPY INTO command: %w", err)
}

// Insert metadata statement
insertMetadataStmt, err := s.connector.createMetadataInsertStatement(partition, flowJobName, startTime)
if err != nil {
return -1, fmt.Errorf("failed to create metadata insert statement: %v", err)
}

// Execute the metadata insert statement
if _, err = s.connector.database.Exec(insertMetadataStmt); err != nil {
return -1, fmt.Errorf("failed to execute metadata insert statement: %v", err)
}

log.Printf("pushed %d records to local file %s and loaded into Snowflake table %s",
len(records.Records), localFilePath, dstTableName)
return len(records.Records), nil
}

type AvroField struct {
Name string `json:"name"`
Type interface{} `json:"type"`
}

type AvroSchema struct {
Type string `json:"type"`
Name string `json:"name"`
Fields []AvroField `json:"fields"`
}

type AvroSchemaDefinition struct {
Schema string
NullableFields map[string]bool
}

func DefineAvroSchema(dstTableName string, dstTableSchema []*sql.ColumnType) (*AvroSchemaDefinition, error) {
avroFields := []AvroField{}
nullableFields := map[string]bool{}

for _, sqlField := range dstTableSchema {
avroType, err := GetAvroType(sqlField)
if err != nil {
return nil, err
}

// If a field is nullable, its Avro type should be ["null", actualType]
nullable, ok := sqlField.Nullable()
if !ok {
return nil, fmt.Errorf("driver does not support the following field: %s", sqlField.Name())
}

if nullable {
avroType = []interface{}{"null", avroType}
nullableFields[sqlField.Name()] = true
}

avroFields = append(avroFields, AvroField{
Name: sqlField.Name(),
Type: avroType,
})
}

avroSchema := AvroSchema{
Type: "record",
Name: dstTableName,
Fields: avroFields,
}

avroSchemaJSON, err := json.Marshal(avroSchema)
if err != nil {
return nil, fmt.Errorf("failed to marshal Avro schema to JSON: %v", err)
}

return &AvroSchemaDefinition{
Schema: string(avroSchemaJSON),
NullableFields: nullableFields,
}, nil
}

func GetAvroType(sqlField *sql.ColumnType) (interface{}, error) {
databaseType := sqlField.DatabaseTypeName()

switch databaseType {
case "VARCHAR", "CHAR", "STRING", "TEXT":
return "string", nil
case "BINARY":
return "bytes", nil
case "NUMBER":
return map[string]interface{}{
"type": "bytes",
"logicalType": "decimal",
"precision": 38,
"scale": 9,
}, nil
case "INTEGER", "BIGINT":
return "long", nil
case "FLOAT", "DOUBLE":
return "double", nil
case "BOOLEAN":
return "boolean", nil
case "DATE":
return map[string]string{
"type": "int",
"logicalType": "date",
}, nil
case "TIME":
return map[string]string{
"type": "long",
"logicalType": "time-micros",
}, nil
case "TIMESTAMP_NTZ", "TIMESTAMP_LTZ", "TIMESTAMP_TZ":
return map[string]string{
"type": "long",
"logicalType": "timestamp-millis",
}, nil
case "OBJECT", "ARRAY", "VARIANT":
// For Snowflake semi-structured types like OBJECT, ARRAY, and VARIANT, you might need to handle it
// separately based on the specific structure of the data.
// If it's a simple nested structure, you can consider mapping them to "record" types in Avro, similar to
// the bigquery.RecordFieldType case.
return nil, fmt.Errorf("Snowflake semi-structured type %s not supported yet", databaseType)
default:
return nil, fmt.Errorf("unsupported Snowflake field type: %s", databaseType)
}
}
1 change: 1 addition & 0 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ type tableNameComponents struct {
schemaIdentifier string
tableIdentifier string
}

type SnowflakeConnector struct {
ctx context.Context
database *sql.DB
Expand Down

0 comments on commit e122a09

Please sign in to comment.