Skip to content

Commit

Permalink
lint: rowserrcheck
Browse files Browse the repository at this point in the history
It turns out `rows.Next()` can return false on error,
avoiding error detection with `rows.Scan`

See golangci/golangci-lint#945

Fix that. In particular, `QueryRowContext` with checking for `sql.ErrNowRows`
is much simpler when expecting 0 or 1 results

Mark two false positives when only wanting schema of `LIMIT 0` query
  • Loading branch information
serprex committed Jan 26, 2024
1 parent 9d4065b commit 017a7a9
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 46 deletions.
1 change: 1 addition & 0 deletions flow/.golangci.yml
Expand Up @@ -21,6 +21,7 @@ linters:
- nonamedreturns
- perfsprint
- prealloc
- rowserrcheck
- staticcheck
- stylecheck
- sqlclosecheck
Expand Down
1 change: 1 addition & 0 deletions flow/connectors/clickhouse/qrep.go
Expand Up @@ -79,6 +79,7 @@ func (c *ClickhouseConnector) createMetadataInsertStatement(
func (c *ClickhouseConnector) getTableSchema(tableName string) ([]*sql.ColumnType, error) {
//nolint:gosec
queryString := fmt.Sprintf(`SELECT * FROM %s LIMIT 0`, tableName)
//nolint:rowserrcheck
rows, err := c.database.Query(queryString)
if err != nil {
return nil, fmt.Errorf("failed to execute query: %w", err)
Expand Down
12 changes: 7 additions & 5 deletions flow/connectors/snowflake/qrep.go
Expand Up @@ -82,12 +82,9 @@ func (c *SnowflakeConnector) getTableSchema(tableName string) ([]*sql.ColumnType
}

//nolint:gosec
queryString := fmt.Sprintf(`
SELECT *
FROM %s
LIMIT 0
`, snowflakeSchemaTableNormalize(schematable))
queryString := fmt.Sprintf("SELECT * FROM %s LIMIT 0", snowflakeSchemaTableNormalize(schematable))

//nolint:rowserrcheck
rows, err := c.database.QueryContext(c.ctx, queryString)
if err != nil {
return nil, fmt.Errorf("failed to execute query: %w", err)
Expand Down Expand Up @@ -303,6 +300,11 @@ func (c *SnowflakeConnector) getColsFromTable(tableName string) ([]string, []str
colTypes = append(colTypes, colType.String)
}

err = rows.Err()
if err != nil {
return nil, nil, fmt.Errorf("failed to read rows: %w", err)
}

if len(colNames) == 0 {
return nil, nil, fmt.Errorf("cannot load schema: table %s.%s does not exist", schemaTable.Schema, schemaTable.Table)
}
Expand Down
64 changes: 23 additions & 41 deletions flow/connectors/snowflake/snowflake.go
Expand Up @@ -324,25 +324,14 @@ func (c *SnowflakeConnector) getTableSchemaForTable(tableName string) (*protos.T
}

func (c *SnowflakeConnector) GetLastOffset(jobName string) (int64, error) {
rows, err := c.database.QueryContext(c.ctx, fmt.Sprintf(getLastOffsetSQL,
c.metadataSchema, mirrorJobsTableIdentifier), jobName)
if err != nil {
return 0, fmt.Errorf("error querying Snowflake peer for last syncedID: %w", err)
}
defer func() {
err = rows.Close()
if err != nil {
c.logger.Error("error while closing rows for reading last offset", slog.Any("error", err))
}
}()

if !rows.Next() {
c.logger.Warn("No row found, returning 0")
return 0, nil
}
var result pgtype.Int8
err = rows.Scan(&result)
err := c.database.QueryRowContext(c.ctx, fmt.Sprintf(getLastOffsetSQL,
c.metadataSchema, mirrorJobsTableIdentifier), jobName).Scan(&result)
if err != nil {
if err == sql.ErrNoRows {
c.logger.Warn("No row found, returning 0")
return 0, nil
}
return 0, fmt.Errorf("error while reading result row: %w", err)
}
if result.Int64 == 0 {
Expand All @@ -362,40 +351,28 @@ func (c *SnowflakeConnector) SetLastOffset(jobName string, lastOffset int64) err
}

func (c *SnowflakeConnector) GetLastSyncBatchID(jobName string) (int64, error) {
rows, err := c.database.QueryContext(c.ctx, fmt.Sprintf(getLastSyncBatchID_SQL, c.metadataSchema,
mirrorJobsTableIdentifier), jobName)
if err != nil {
return 0, fmt.Errorf("error querying Snowflake peer for last syncBatchId: %w", err)
}
defer rows.Close()

var result pgtype.Int8
if !rows.Next() {
c.logger.Warn("No row found, returning 0")
return 0, nil
}
err = rows.Scan(&result)
err := c.database.QueryRowContext(c.ctx, fmt.Sprintf(getLastSyncBatchID_SQL, c.metadataSchema,
mirrorJobsTableIdentifier), jobName).Scan(&result)
if err != nil {
if err == sql.ErrNoRows {
c.logger.Warn("No row found, returning 0")
return 0, nil
}
return 0, fmt.Errorf("error while reading result row: %w", err)
}
return result.Int64, nil
}

func (c *SnowflakeConnector) GetLastNormalizeBatchID(jobName string) (int64, error) {
rows, err := c.database.QueryContext(c.ctx, fmt.Sprintf(getLastNormalizeBatchID_SQL, c.metadataSchema,
mirrorJobsTableIdentifier), jobName)
if err != nil {
return 0, fmt.Errorf("error querying Snowflake peer for last normalizeBatchId: %w", err)
}
defer rows.Close()

var normBatchID pgtype.Int8
if !rows.Next() {
c.logger.Warn("No row found, returning 0")
return 0, nil
}
err = rows.Scan(&normBatchID)
err := c.database.QueryRowContext(c.ctx, fmt.Sprintf(getLastNormalizeBatchID_SQL, c.metadataSchema,
mirrorJobsTableIdentifier), jobName).Scan(&normBatchID)
if err != nil {
if err == sql.ErrNoRows {
c.logger.Warn("No row found, returning 0")
return 0, nil
}
return 0, fmt.Errorf("error while reading result row: %w", err)
}
return normBatchID.Int64, nil
Expand All @@ -422,6 +399,11 @@ func (c *SnowflakeConnector) getDistinctTableNamesInBatch(flowJobName string, sy
}
destinationTableNames = append(destinationTableNames, result.String)
}

err = rows.Err()
if err != nil {
return nil, fmt.Errorf("failed to read rows: %w", err)
}
return destinationTableNames, nil
}

Expand Down

0 comments on commit 017a7a9

Please sign in to comment.