diff --git a/io/common_test.go b/io/common_test.go index 6f8cb68..cd04707 100644 --- a/io/common_test.go +++ b/io/common_test.go @@ -2,9 +2,11 @@ package io_test import ( "errors" + "log" "testing" "github.com/primetalk/goio/io" + "github.com/primetalk/goio/slice" "github.com/stretchr/testify/assert" ) @@ -14,6 +16,10 @@ var errExpected = errors.New(errorMessage) var failure = io.Fail[string](errExpected) +func inc(i int) int { + return i + 1 +} + func UnsafeIO[A any](t *testing.T, ioa io.IO[A]) A { res, err1 := io.UnsafeRunSync(ioa) assert.NoError(t, err1) @@ -26,3 +32,12 @@ func UnsafeIOExpectError[A any](t *testing.T, expected error, ioa io.IO[A]) { assert.Equal(t, expected, err1) } } + +func Nats(count int) (ios []io.IO[int]) { + return slice.Map(slice.Range(0, count), func(i int) io.IO[int] { + return io.Pure(func() int { + log.Printf("executing %v\n", i) + return i + }) + }) +} diff --git a/io/io.go b/io/io.go index 88db9dc..8c1178c 100644 --- a/io/io.go +++ b/io/io.go @@ -157,6 +157,13 @@ func Lift[A any](a A) IO[A] { return LiftPair(a, nil) } +// LiftFunc wraps the result of function into IO. +func LiftFunc[A any, B any](f func(A) B) func(A) IO[B] { + return func(a A) IO[B] { + return Lift(f(a)) + } +} + // Fail[A] constructs an IO[A] that fails with the given error. func Fail[A any](err error) IO[A] { var a A diff --git a/io/io_test.go b/io/io_test.go index c11d13c..1b3d820 100644 --- a/io/io_test.go +++ b/io/io_test.go @@ -1,7 +1,6 @@ package io_test import ( - "log" "testing" "github.com/primetalk/goio/fun" @@ -17,19 +16,22 @@ func TestIO(t *testing.T) { return i + j, nil }) }) - res, err := io.UnsafeRunSync(io30) - assert.Equal(t, err, nil) - assert.Equal(t, res, 30) + res := UnsafeIO(t, io30) + assert.Equal(t, 30, res) +} + +func TestLiftFunc(t *testing.T) { + f := io.LiftFunc(inc) + assert.Equal(t, 11, UnsafeIO(t, f(10))) } func TestErr(t *testing.T) { var ptr *string = nil ptrio := io.Lift(ptr) uptr := io.FlatMap(ptrio, io.Unptr[string]) - _, err := io.UnsafeRunSync(uptr) - assert.Equal(t, io.ErrorNPE, err) + UnsafeIOExpectError(t, io.ErrorNPE, uptr) wrappedUptr := io.Wrapf(uptr, "my message %d", 10) - _, err = io.UnsafeRunSync(wrappedUptr) + _, err := io.UnsafeRunSync(wrappedUptr) assert.Equal(t, "my message 10: nil pointer", err.Error()) } @@ -40,28 +42,15 @@ func TestFinally(t *testing.T) { oe := io.OnError(fin, func(err error) io.IO[fun.Unit] { return io.FromPureEffect(func() { onErrorExecuted = true }) }) - _, err := io.UnsafeRunSync(oe) - assert.Error(t, err, errorMessage) + UnsafeIOExpectError(t, errExpected, oe) assert.True(t, finalizerExecuted) assert.True(t, onErrorExecuted) } -func Nats(count int) (ios []io.IO[int]) { - for i := 0; i < count; i += 1 { - j := i - ios = append(ios, io.Pure(func() int { - log.Printf("executing %v\n", j) - return j - })) - } - return -} - func TestSequence(t *testing.T) { ios := Nats(10) for i, io1 := range ios { - res, err := io.UnsafeRunSync(io1) - assert.NoError(t, err) + res := UnsafeIO(t, io1) assert.Equal(t, i, res) } ioseq := io.Sequence(ios)