Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions connectableobservable.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ func (c *connectableObservable) TakeWhile(apply Predicate) Observable {
return c.observable.TakeWhile(apply)
}

func (c *connectableObservable) Timeout(duration Duration) Observable {
return c.observable.Timeout(duration)
func (c *connectableObservable) Timeout(observable Observable) Observable {
return c.observable.Timeout(observable)
}

func (c *connectableObservable) ToChannel(opts ...options.Option) Channel {
Expand Down
32 changes: 18 additions & 14 deletions observable.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ type Observable interface {
TakeLast(nth uint) Observable
TakeUntil(apply Predicate) Observable
TakeWhile(apply Predicate) Observable
Timeout(duration Duration) Observable
Timeout(observable Observable) Observable
ToChannel(opts ...options.Option) Channel
ToMap(keySelector Function) Single
ToMapWithValueSelector(keySelector, valueSelector Function) Single
Expand Down Expand Up @@ -1669,22 +1669,26 @@ func (o *observable) TakeWhile(apply Predicate) Observable {
return newColdObservableFromFunction(f)
}

func (o *observable) Timeout(duration Duration) Observable {
func (o *observable) Timeout(observable Observable) Observable {
f := func(out chan interface{}) {
it := o.Iterator(context.Background())
// TODO Handle cancel
ctx, _ := context.WithTimeout(context.Background(), duration.duration())
for {
if item, err := it.Next(ctx); err == nil {
out <- item
} else {
out <- err
break
ctx, cancel := context.WithCancel(context.Background())
go func() {
it := o.Iterator(ctx)
for {
if item, err := it.Next(ctx); err == nil {
out <- item
} else {
out <- err
break
}
}
}
close(out)
}()
go func() {
it := observable.Iterator(context.Background())
it.Next(context.Background())
cancel()
}()
}

return newColdObservableFromFunction(f)
}

Expand Down
210 changes: 210 additions & 0 deletions observable_mock_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
package rxgo

import (
"bufio"
"context"
"strconv"
"strings"
"testing"

"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

const signalCh = byte(0)

var mockError = errors.New("")

type mockIterable struct {
iterator Iterator
}

type mockIterator struct {
mock.Mock
}

type task struct {
observable int
item int
error error
close bool
}

func (s *mockIterable) Iterator(ctx context.Context) Iterator {
return s.iterator
}

func newMockObservable(iterator Iterator) Observable {
return &observable{
observableType: cold,
iterable: &mockIterable{
iterator: iterator,
},
}
}

func countTab(line string) int {
i := 0
for _, runeValue := range line {
if runeValue == '\t' {
i++
} else {
break
}
}
return i
}

// TODO Causality with more than two observables
func mockObservables(t *testing.T, in string) []Observable {
scanner := bufio.NewScanner(strings.NewReader(in))
m := make(map[int]int)
tasks := make([]task, 0)
count := 0
for scanner.Scan() {
s := scanner.Text()
if s == "" {
continue
}
observable := countTab(s)
v := strings.TrimSpace(s)
switch v {
case "x":
tasks = append(tasks, task{
observable: observable,
close: true,
})
case "e":
tasks = append(tasks, task{
observable: observable,
error: mockError,
})
default:
n, err := strconv.Atoi(v)
if err != nil {
assert.FailNow(t, err.Error())
}
tasks = append(tasks, task{
observable: observable,
item: n,
})
}
if _, contains := m[observable]; !contains {
m[observable] = count
count++
}
}

iterators := make([]*mockIterator, 0, len(m))
calls := make([]*mock.Call, len(m))
for i := 0; i < len(m); i++ {
iterators = append(iterators, new(mockIterator))
}

item, err := args(tasks[0])
call := iterators[0].On("Next", mock.Anything).Once().Return(item, err)
calls[0] = call

var lastCh chan struct{}
lastObservableType := tasks[0].observable
for i := 1; i < len(tasks); i++ {
t := tasks[i]
index := m[t.observable]
obs := iterators[index]
item, err := args(t)
if lastObservableType == t.observable {
if calls[index] == nil {
calls[index] = obs.On("Next", mock.Anything).Once().Return(item, err)
} else {
calls[index].On("Next", mock.Anything).Once().Return(item, err)
}
} else {
lastObservableType = t.observable
if lastCh == nil {
ch := make(chan struct{})
lastCh = ch
if calls[index] == nil {
calls[index] = obs.On("Next", mock.Anything).Once().Return(item, err).
Run(func(args mock.Arguments) {
run(args, ch, nil)
})
} else {
calls[index].On("Next", mock.Anything).Once().Return(item, err).
Run(func(args mock.Arguments) {
run(args, ch, nil)
})
}
} else {
var ch chan struct{}
// If this is the latest task we do not set any wait channel
if i != len(tasks)-1 {
ch = make(chan struct{})
}
previous := lastCh
if calls[index] == nil {
calls[index] = obs.On("Next", mock.Anything).Once().Return(item, err).
Run(func(args mock.Arguments) {
run(args, ch, previous)
})
} else {
calls[index].On("Next", mock.Anything).Once().Return(item, err).
Run(func(args mock.Arguments) {
run(args, ch, previous)
})
}
lastCh = ch
}
}
}

observables := make([]Observable, 0, len(iterators))
for _, iterator := range iterators {
observables = append(observables, newMockObservable(iterator))
}
return observables
}

func args(t task) (interface{}, error) {
if t.close {
return nil, &NoSuchElementError{}
}
if t.error != nil {
return t.error, nil
}
return t.item, nil
}

func run(args mock.Arguments, wait chan struct{}, send chan struct{}) {
if send != nil {
send <- struct{}{}
}
if wait == nil {
return
}
if len(args) == 1 {
if ctx, ok := args[0].(context.Context); ok {
if sig, ok := ctx.Value(signalCh).(chan struct{}); ok {
select {
case <-wait:
case <-ctx.Done():
sig <- struct{}{}
}
return
}
}
}
<-wait
}

func (m *mockIterator) Next(ctx context.Context) (interface{}, error) {
sig := make(chan struct{}, 1)
defer close(sig)
outputs := m.Called(context.WithValue(ctx, signalCh, sig))
select {
case <-sig:
return nil, &CancelledIteratorError{}
default:
return outputs.Get(0), outputs.Error(1)
}
}
83 changes: 50 additions & 33 deletions observable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,7 @@ func TestBufferWithTimeWithMockedTime(t *testing.T) {
timeshift.AssertNotCalled(t, "duration")
}

func TestBufferWithTimeWithMinorMockedTime(t *testing.T) {
func TestBufferWithTime_MinorMockedTime(t *testing.T) {
ch := make(chan interface{})
from := FromIterator(newIteratorFromChannel(ch))

Expand All @@ -1294,7 +1294,6 @@ func TestBufferWithTimeWithMinorMockedTime(t *testing.T) {

obs := from.BufferWithTime(timespan, timeshift)

time.Sleep(10 * time.Millisecond)
ch <- 1
close(ch)

Expand Down Expand Up @@ -1642,14 +1641,28 @@ func TestSample(t *testing.T) {
}

func TestSample_NotRepeatedItems(t *testing.T) {
ch := make(chan interface{})
obs := FromChannel(ch).Sample(Interval(make(chan struct{}), 50*time.Millisecond))
go func() {
ch <- 1
time.Sleep(200 * time.Millisecond)
close(ch)
}()
AssertObservable(t, obs, HasItems(1))
observables := mockObservables(t, `
1
2
0
3
4
5
0
6
0
7
8
0
0
9
0
x
x
`)
obs := observables[0].Sample(observables[1])

AssertObservable(t, obs, HasItems(2, 5, 6, 8, 9))
}

func TestSample_SourceObsClosedBeforeIntervalFired(t *testing.T) {
Expand Down Expand Up @@ -1754,26 +1767,30 @@ func TestStartWithObservable_Empty2(t *testing.T) {
AssertObservable(t, obs, HasItems(1, 2, 3))
}

//var _ = Describe("Timeout operator", func() {
// FIXME
//Context("when creating an observable with timeout operator", func() {
// ch := make(chan interface{}, 10)
// duration := WithDuration(pollingInterval)
// o := FromChannel(ch).Timeout(duration)
// Context("after a given period without items", func() {
// outNext, outErr, _ := subscribe(o)
//
// ch <- 1
// ch <- 2
// ch <- 3
// time.Sleep(time.Second)
// ch <- 4
// It("should receive the elements before the timeout", func() {
// Expect(pollItems(outNext, timeout)).Should(Equal([]interface{}{1, 2, 3}))
// })
// It("should receive a TimeoutError", func() {
// Expect(pollItem(outErr, timeout)).Should(Equal(&TimeoutError{}))
// })
// })
//})
//})
func TestTimeout(t *testing.T) {
observables := mockObservables(t, `
1
2
3
0
4
5
x
x
`)
obs := observables[0].Timeout(observables[1])
AssertObservable(t, obs, HasItems(1, 2, 3))
}

func TestTimeout_ClosedChannel(t *testing.T) {
observables := mockObservables(t, `
1
2
3
x
0
x
`)
obs := observables[0].Timeout(observables[1])
AssertObservable(t, obs, HasItems(1, 2, 3))
}
5 changes: 2 additions & 3 deletions observablecreate.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,7 @@ func Start(f Supplier, fs ...Supplier) Observable {
return newColdObservableFromChannel(out)
}

// Timer returns an Observable that emits the zeroed value of a float64 after a
// specified delay, and then completes.
// Timer returns an Observable that emits an empty structure after a specified delay, and then completes.
func Timer(d Duration) Observable {
out := make(chan interface{})
go func() {
Expand All @@ -419,7 +418,7 @@ func Timer(d Duration) Observable {
} else {
time.Sleep(d.duration())
}
out <- 0.
out <- struct{}{}
close(out)
}()
return newColdObservableFromChannel(out)
Expand Down
Loading