Skip to content

Commit

Permalink
Merge pull request #4 from 243826/main
Browse files Browse the repository at this point in the history
Added support for locking mechanism so that parallel tests can be run.
  • Loading branch information
243826 committed Sep 9, 2023
2 parents 6a359b0 + 1553405 commit d62e96d
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 19 deletions.
134 changes: 115 additions & 19 deletions env.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package envygo
import (
"fmt"
"reflect"
"sync"
"unsafe"
)

Expand Down Expand Up @@ -41,13 +42,15 @@ func MockMany(envs ...interface{}) func() {
}
}

return Unmock(funcs...)
return func() {
Unmock(funcs...)
}
}

// Unmock is handy to invoke the return values of Mock family of methods
func Unmock(funcs ...func()) func() {
return func() {
for _, function := range funcs {
func Unmock(funcs ...func()) {
for _, function := range funcs {
if function != nil {
defer function()
}
}
Expand All @@ -63,6 +66,10 @@ type field struct {
exported bool
}

// Locker function locks the interface old if lockUnlock value is true.
// It unlocks the locked interface old if the lockUnlock value is false
type Locker func(old interface{}, lockUnlock bool)

func toPairs(new interface{}, includeZeros bool) []field {
var fields []field = nil

Expand Down Expand Up @@ -91,6 +98,84 @@ func toPairs(new interface{}, includeZeros bool) []field {
return fields
}

// execute examines old to see if any of the fields defined
// in old can be used to lock the interface. If a qualified
// mutex is found in old, then it's locked before function
// is invoked. If the function panics, the mutex is unlocked
// right away. In all other cases it's caller's responsibility
// to invoke locker to unlock the mutex if one is there.
func execute(old interface{}, function func()) Locker {
locker := getMutex(reflect.ValueOf(old).Elem(), reflect.TypeOf(old).Elem())

if locker == nil {
function()
} else {
_panic := true
if locker != nil {
locker(old, true)
defer func() {
if _panic {
locker(old, false)
}
}()
}
function()
_panic = false
}

return locker
}

func getMutex(valueOf reflect.Value, typeOf reflect.Type) (locker Locker) {
defer func() {
if locker == nil {
locker = func(interface{}, bool) {}
}
}()

for i := typeOf.NumField(); i > 0; {
i--

oldField := typeOf.Field(i)
tag := oldField.Tag.Get("env")
if tag == "mutex" {
field := valueOf.Field(i)
if !oldField.IsExported() {
field = getNonExportedField(field)
}

switch oldField.Type.Kind() {
case reflect.Struct:
mutex := (*sync.Mutex)(unsafe.Pointer(field.UnsafeAddr()))
return func(old interface{}, lockUnlock bool) {
if lockUnlock {
mutex.Lock()
} else {
mutex.Unlock()
}
}
case reflect.Pointer:
mutex := field.Interface().(*sync.Mutex)
if mutex != nil {
return func(old interface{}, lock bool) {
if lock {
mutex.Lock()
} else {
mutex.Unlock()
}
}
}
case reflect.Func:
return field.Interface().(Locker)
default:
panic(`field marked "mutex" can either be a pointer to sync.Mutex or a "Locker" function`)
}
}
}

return nil
}

// Mock mocks the old environment using the values in the new environment
// If the type for old environment is identical to the type of the new environment
// then any attribute with value identical to default value for its type is
Expand All @@ -102,8 +187,10 @@ func Mock(old interface{}, new interface{}) func() {
return func() {}
}

array = mockHelper(old, array)
locker := execute(old, func() { array = mockHelper(old, array) })

return func() {
defer locker(old, false)
mockHelper(old, array)
}
}
Expand Down Expand Up @@ -147,10 +234,15 @@ func MockField(old interface{}, name string, value any) func() {
typeOf := reflect.TypeOf(old).Elem()
if typeField, found := typeOf.FieldByName(name); found {
exported := typeField.IsExported()
valueOfStruct := reflect.ValueOf(old).Elem()
old := mockField(valueOfStruct, name, value, exported)
valueOfOld := reflect.ValueOf(old).Elem()

locker := execute(old, func() {
value = mockField(valueOfOld, name, value, exported)
})

return func() {
mockField(valueOfStruct, name, old, exported)
defer locker(old, false)
mockField(valueOfOld, name, value, exported)
}
}

Expand All @@ -162,29 +254,33 @@ type Fields map[string]interface{}

// MockFields mocks many fields of the struct pointed to by old
func MockFields(old interface{}, fields Fields) func() {
typeOf := reflect.TypeOf(old).Elem()
valueOfStruct := reflect.ValueOf(old).Elem()

var array []field
for name, value := range fields {
if typeField, found := typeOf.FieldByName(name); found {
exported := typeField.IsExported()
old := mockField(valueOfStruct, name, value, exported)
array = append(array, field{name, old, exported})

locker := execute(old, func() {
typeOf := reflect.TypeOf(old).Elem()
valueOfOld := reflect.ValueOf(old).Elem()

for name, value := range fields {
if typeField, found := typeOf.FieldByName(name); found {
exported := typeField.IsExported()
old := mockField(valueOfOld, name, value, exported)
array = append(array, field{name, old, exported})
}
}
}
})

if array == nil {
return func() {}
}

return func() {
defer locker(old, false)
mockHelper(old, array)
}
}

func mockField(valueOfStruct reflect.Value, name string, new any, exported bool) interface{} {
field := valueOfStruct.FieldByName(name)
func mockField(valueOfOld reflect.Value, name string, new any, exported bool) interface{} {
field := valueOfOld.FieldByName(name)
if !exported {
field = getNonExportedField(field)
}
Expand Down
63 changes: 63 additions & 0 deletions env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"os"
"sync"
"testing"
)

Expand Down Expand Up @@ -170,3 +171,65 @@ func TestResetMockFields(t *testing.T) {
t.Fail()
}
}

type mutexEnv struct {
mutex sync.Mutex `env:"mutex"`
Name string
}

func TestMutex(t *testing.T) {
var env = &mutexEnv{Name: "original"}
go func() { defer Unmock(MockField(env, "Name", "mocked 1")) }()
go func() { defer Unmock(MockField(env, "Name", "mocked 2")) }()
}

type mutexPtrEnv struct {
mutex *sync.Mutex `env:"mutex"`
Name string
}

func TestMutexPtr(t *testing.T) {
var env = &mutexPtrEnv{mutex: &sync.Mutex{}, Name: "original"}

var latch sync.WaitGroup
latch.Add(2)

go func() {
defer Unmock(MockField(env, "Name", "mocked1"))
if env.Name != "mocked1" {
t.Fail()
}
latch.Done()
}()

go func() {
defer Unmock(MockField(env, "Name", "mocked2"))
if env.Name != "mocked2" {
t.Fail()
}
latch.Done()
}()

latch.Wait()
}

type mutexFuncEnv struct {
mutex Locker `env:"mutex"`
Name string
}

func TestMutexFunc(t *testing.T) {
var lockError bool
mutex := sync.Mutex{}
var env = &mutexFuncEnv{mutex: func(old interface{}, lockUnlock bool) {
if lockUnlock {
if !mutex.TryLock() {
lockError = true
}
} else if !lockError {
mutex.Unlock()
}
}, Name: "original"}
defer Unmock(MockField(env, "Name", "mocked 1"))
defer Unmock(MockField(env, "Name", "mocked 2"))
}

0 comments on commit d62e96d

Please sign in to comment.