Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions maputil/maputil.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
// Package maputil provides a set if functions for working with maps.
package maputil // import "github.com/teamwork/utils/v2/maputil"

import (
"errors"
"fmt"
)

// Swap the keys and values of a map.
func Swap[T comparable, V comparable](m map[T]V) map[V]T {
n := make(map[V]T)
Expand All @@ -10,3 +15,66 @@ func Swap[T comparable, V comparable](m map[T]V) map[V]T {

return n
}

var (
ErrWrongType = errors.New("wrong type")
ErrNotFound = errors.New("key not found")
)

// SetValue set value using key path
func SetValue[T any](m map[string]any, v T, keys ...string) error {
if len(keys) == 1 {
m[keys[0]] = v
return nil
}

next := m[keys[0]]
var nextMap map[string]any
if next == nil {
nextMap = map[string]any{}
m[keys[0]] = nextMap
} else {
var ok bool
nextMap, ok = next.(map[string]any)
if !ok {
return ErrWrongType
}
}

return SetValue(nextMap, v, keys[1:]...)
}

// GetValue returns the value defined as a key path
func GetValue[T any](m map[string]any, keys ...string) (T, error) {
var out T

v, err := getValue(m, keys)
if err != nil {
return out, err
}

if v == nil {
return out, ErrNotFound
}

vv, ok := v.(T)
if !ok {
return out, ErrWrongType
}

return vv, nil
}

func getValue(m map[string]any, keys []string) (any, error) {
if len(keys) == 1 {
return m[keys[0]], nil
}

a := m[keys[0]]

if m, ok := a.(map[string]any); ok {
return getValue(m, keys[1:])
}

return nil, fmt.Errorf("%w key `%s`", ErrNotFound, keys[1])
}
158 changes: 158 additions & 0 deletions maputil/maputil_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package maputil

import (
"errors"
"fmt"
"reflect"
"testing"
Expand All @@ -26,3 +27,160 @@ func TestSwap(t *testing.T) {
})
}
}

func TestGetValue(t *testing.T) {
m := map[string]any{
"a": int64(1),
"b": "2",
"c": float64(3.1),
"d": map[string]any{
"a": int64(4),
"b": "5",
"c": float64(6.1),
},
}

outInt64, err := GetValue[int64](m, "a")
if err != nil {
t.Fatal(err)
}
if outInt64 != int64(1) {
t.Fatalf("expected 1, got %v", outInt64)
}

outStr, err := GetValue[string](m, "b")
if err != nil {
t.Fatal(err)
}
if outStr != "2" {
t.Fatalf("expected \"2\", got %v", outStr)
}

outFloat, err := GetValue[float64](m, "c")
if err != nil {
t.Fatal(err)
}
if outFloat != 3.1 {
t.Fatalf("expected 3.1, got %v", outFloat)
}

outInt64, err = GetValue[int64](m, "d", "a")
if err != nil {
t.Fatal(err)
}
if outInt64 != int64(4) {
t.Fatalf("expected 4 got %v", outInt64)
}

_, err = GetValue[string](m, "a")
if !errors.Is(err, ErrWrongType) {
t.Fatalf("error wrong type expected, got %v", err)
}

_, err = GetValue[string](m, "d", "a")
if !errors.Is(err, ErrWrongType) {
t.Fatalf("error wrong type expected, got %v", err)
}

_, err = GetValue[string](m, "e")
if !errors.Is(err, ErrNotFound) {
t.Fatalf("error not found expected, got %v", err)
}

_, err = GetValue[string](m, "d", "e")
if !errors.Is(err, ErrNotFound) {
t.Fatalf("error not found expected, got %v", err)
}

_, err = GetValue[string](m, "e", "d")
if !errors.Is(err, ErrNotFound) {
t.Fatalf("error not found expected, got %v", err)
}

_, err = GetValue[string](m, "d", "c", "a")
if !errors.Is(err, ErrNotFound) {
t.Fatalf("error not found expected, got %v", err)
}

_, err = GetValue[string](m, "a", "")
if !errors.Is(err, ErrNotFound) {
t.Fatalf("error not found expected, got %v", err)
}
}

func TestSetValue(t *testing.T) {
m := map[string]any{}

// simple int set
err := SetValue(m, 10, "a")
if err != nil {
t.Fatal(err)
}
if m["a"] == nil {
t.Fatal("expected key 'a'")
}
if v, ok := m["a"].(int); !ok || v != 10 {
t.Fatalf("expected 10, got %v", v)
}

// simple string set
str := "mystring"
err = SetValue(m, str, "b")
if err != nil {
t.Fatal(err)
}
if m["b"] == nil {
t.Fatal("expected key 'b'")
}
if v, ok := m["b"].(string); !ok || v != str {
t.Fatalf("expected '%v', got %v", str, v)
}

// nested set
err = SetValue(m, 2, "c", "d")
if err != nil {
t.Fatal(err)
}
if m["c"] == nil {
t.Fatal("expected key 'b'")
}
nested, ok := m["c"].(map[string]any)
if !ok {
t.Fatalf("expected map, got %T", nested)
}
if nested["d"] == nil {
t.Fatal("expected key 'd'")
}
if v, ok := nested["d"].(int); !ok || v != 2 {
t.Fatalf("expected 2, got %v", v)
}

// replace
err = SetValue(map[string]any{"a": 10}, 10, "a")
if err != nil {
t.Fatal(err)
}
if m["a"] == nil {
t.Fatal("expected key 'a'")
}
if v, ok := m["a"].(int); !ok || v != 10 {
t.Fatalf("expected 10, got %v", v)
}

// replace tip
m = map[string]any{"a": map[string]any{"b": 1}}
err = SetValue(m, 10, "a", "b")
if err != nil {
t.Fatal(err)
}
if v, _ := m["a"].(map[string]any)["b"].(int); v != 10 {
t.Fatalf("expected 10, got %v", v)
}

// crash when mid key is not a map
m = map[string]any{"a": map[string]any{"b": 1}}
err = SetValue(m, 10, "a", "b", "c")
if err != ErrWrongType {
t.Fatal(err)
}
}