Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Harden string functions when NULL is passed #5

Merged
merged 2 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 86 additions & 14 deletions internal/function_bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -775,16 +775,22 @@ func bindCollate(args ...Value) (Value, error) {
}

func bindConcat(args ...Value) (Value, error) {
if len(args) < 2 {
if len(args) < 1 {
return nil, fmt.Errorf("CONCAT: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
return CONCAT(args...)
}

func bindContainsSubstr(args ...Value) (Value, error) {
if args[1] == nil {
return nil, fmt.Errorf("CONTAINS_SUBSTR: search literal must be not null")
}
if existsNull(args) {
return nil, nil
}
search, err := args[1].ToString()
if err != nil {
return nil, err
Expand All @@ -796,21 +802,24 @@ func bindEndsWith(args ...Value) (Value, error) {
if len(args) != 2 {
return nil, fmt.Errorf("ENDS_WITH: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
return ENDS_WITH(args[0], args[1])
}

func bindFormat(args ...Value) (Value, error) {
if len(args) == 0 {
return nil, fmt.Errorf("FORMAT: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
format, err := args[0].ToString()
if err != nil {
return nil, err
}
if len(args) > 1 {
if args[1] == nil {
return nil, nil
}
return FORMAT(format, args[1:]...)
}
return FORMAT(format)
Expand All @@ -820,6 +829,9 @@ func bindFromBase32(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("FROM_BASE32: invalid argument num %d", len(args))
}
if args[0] == nil {
return nil, nil
}
v, err := args[0].ToString()
if err != nil {
return nil, err
Expand All @@ -831,6 +843,9 @@ func bindFromBase64(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("FROM_BASE64: invalid argument num %d", len(args))
}
if args[0] == nil {
return nil, nil
}
v, err := args[0].ToString()
if err != nil {
return nil, err
Expand All @@ -842,6 +857,9 @@ func bindFromHex(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("FROM_HEX: invalid argument num %d", len(args))
}
if args[0] == nil {
return nil, nil
}
v, err := args[0].ToString()
if err != nil {
return nil, err
Expand All @@ -853,14 +871,11 @@ func bindInitcap(args ...Value) (Value, error) {
if len(args) != 1 && len(args) != 2 {
return nil, fmt.Errorf("INITCAP: invalid argument num %d", len(args))
}
if args[0] == nil {
if existsNull(args) {
return nil, nil
}
var delimiters []rune
if len(args) == 2 {
if args[1] == nil {
return nil, nil
}
v, err := args[1].ToString()
if err != nil {
return nil, err
Expand All @@ -881,10 +896,7 @@ func bindInstr(args ...Value) (Value, error) {
if len(args) != 2 && len(args) != 3 && len(args) != 4 {
return nil, fmt.Errorf("INSTR: invalid argument num %d", len(args))
}
if args[0] == nil {
return nil, nil
}
if args[1] == nil {
if existsNull(args) {
return nil, nil
}
var (
Expand Down Expand Up @@ -912,6 +924,9 @@ func bindLeft(args ...Value) (Value, error) {
if len(args) != 2 {
return nil, fmt.Errorf("LEFT: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
length, err := args[1].ToInt64()
if err != nil {
return nil, err
Expand All @@ -924,7 +939,7 @@ func bindLength(args ...Value) (Value, error) {
return nil, fmt.Errorf("LENGTH: invalid argument num %d", len(args))
}
if args[0] == nil {
return IntValue(0), nil
return nil, nil
}
return LENGTH(args[0])
}
Expand All @@ -933,6 +948,9 @@ func bindLpad(args ...Value) (Value, error) {
if len(args) != 2 && len(args) != 3 {
return nil, fmt.Errorf("LPAD: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
var pattern Value
if len(args) == 3 {
pattern = args[2]
Expand All @@ -948,13 +966,19 @@ func bindLower(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("LOWER: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
return LOWER(args[0])
}

func bindLtrim(args ...Value) (Value, error) {
if len(args) != 1 && len(args) != 2 {
return nil, fmt.Errorf("LTRIM: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
cutset := " "
if len(args) == 2 {
v, err := args[1].ToString()
Expand All @@ -970,6 +994,9 @@ func bindNormalize(args ...Value) (Value, error) {
if len(args) != 1 && len(args) != 2 {
return nil, fmt.Errorf("NORMALIZE: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
mode := "NFC"
if len(args) == 2 {
v, err := args[1].ToString()
Expand All @@ -989,6 +1016,9 @@ func bindNormalizeAndCasefold(args ...Value) (Value, error) {
if len(args) != 1 && len(args) != 2 {
return nil, fmt.Errorf("NORMALIZE_AND_CASEFOLD: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
mode := "NFC"
if len(args) == 2 {
v, err := args[1].ToString()
Expand Down Expand Up @@ -1020,6 +1050,9 @@ func bindRegexpContains(args ...Value) (Value, error) {
}

func bindRegexpExtract(args ...Value) (Value, error) {
if existsNull(args) {
return nil, nil
}
regexp, err := args[1].ToString()
if err != nil {
return nil, err
Expand All @@ -1044,6 +1077,9 @@ func bindRegexpExtract(args ...Value) (Value, error) {
}

func bindRegexpExtractAll(args ...Value) (Value, error) {
if existsNull(args) {
return nil, nil
}
regexp, err := args[1].ToString()
if err != nil {
return nil, err
Expand All @@ -1052,6 +1088,9 @@ func bindRegexpExtractAll(args ...Value) (Value, error) {
}

func bindRegexpInstr(args ...Value) (Value, error) {
if existsNull(args) {
return nil, nil
}
var (
pos int64 = 1
occurrence int64 = 1
Expand Down Expand Up @@ -1082,6 +1121,9 @@ func bindRegexpInstr(args ...Value) (Value, error) {
}

func bindRegexpReplace(args ...Value) (Value, error) {
if existsNull(args) {
return nil, nil
}
return REGEXP_REPLACE(args[0], args[1], args[2])
}

Expand Down Expand Up @@ -1175,7 +1217,7 @@ func bindSoundex(args ...Value) (Value, error) {

func bindSplit(args ...Value) (Value, error) {
if existsNull(args) {
return nil, nil
return &ArrayValue{}, nil
}
var delim Value
if len(args) > 1 {
Expand All @@ -1188,20 +1230,29 @@ func bindStartsWith(args ...Value) (Value, error) {
if len(args) != 2 {
return nil, fmt.Errorf("STARTS_WITH: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
return STARTS_WITH(args[0], args[1])
}

func bindStrpos(args ...Value) (Value, error) {
if len(args) != 2 {
return nil, fmt.Errorf("STRPOS: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
return STRPOS(args[0], args[1])
}

func bindSubstr(args ...Value) (Value, error) {
if len(args) != 2 && len(args) != 3 {
return nil, fmt.Errorf("SUBSTR: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
pos, err := args[1].ToInt64()
if err != nil {
return nil, err
Expand All @@ -1221,6 +1272,9 @@ func bindToBase32(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("TO_BASE32: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
b, err := args[0].ToBytes()
if err != nil {
return nil, err
Expand All @@ -1232,6 +1286,9 @@ func bindToBase64(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("TO_BASE64: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
b, err := args[0].ToBytes()
if err != nil {
return nil, err
Expand All @@ -1243,13 +1300,19 @@ func bindToCodePoints(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("TO_CODE_POINTS: invalid argument num %d", len(args))
}
if args[0] == nil {
return &ArrayValue{}, nil
}
return TO_CODE_POINTS(args[0])
}

func bindToHex(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("TO_HEX: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
b, err := args[0].ToBytes()
if err != nil {
return nil, err
Expand All @@ -1261,13 +1324,19 @@ func bindTranslate(args ...Value) (Value, error) {
if len(args) != 3 {
return nil, fmt.Errorf("TRANSLATE: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
return TRANSLATE(args[0], args[1], args[2])
}

func bindTrim(args ...Value) (Value, error) {
if len(args) != 1 && len(args) != 2 {
return nil, fmt.Errorf("TRIM: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
if len(args) == 2 {
return TRIM(args[0], args[1])
}
Expand Down Expand Up @@ -1437,6 +1506,9 @@ func bindToJson(args ...Value) (Value, error) {
if len(args) != 1 && len(args) != 2 {
return nil, fmt.Errorf("TO_JSON: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
var stringifyWideNumbers bool
if len(args) == 2 {
b, err := args[1].ToBool()
Expand Down
5 changes: 4 additions & 1 deletion internal/function_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,9 @@ func LPAD(originalValue Value, returnLength int64, pattern Value) (Value, error)
}

func LOWER(v Value) (Value, error) {
if v == nil {
return nil, nil
}
switch v.(type) {
case StringValue:
s, err := v.ToString()
Expand Down Expand Up @@ -697,7 +700,7 @@ func REGEXP_REPLACE(value, exprValue, replacementValue Value) (Value, error) {
}
return BytesValue(re.ReplaceAll(v, []byte(normalizeReplacement(string(replacement))))), nil
}
return nil, fmt.Errorf("REGEXP_REPLACE: value must be STRING or BYTES")
return nil, fmt.Errorf("REGEXP_REPLACE: value must be STRING or BYTES, %s", value)
}

func REPLACE(originalValue, fromValue, toValue Value) (Value, error) {
Expand Down
Loading
Loading