Skip to content

Commit

Permalink
[BEAM-7726] Handle State Backed iterables
Browse files Browse the repository at this point in the history
  • Loading branch information
lostluck committed Jul 23, 2019
1 parent 7cd1a3d commit 0d64dbc
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 15 deletions.
86 changes: 86 additions & 0 deletions sdks/go/pkg/beam/core/runtime/exec/datasource.go
Expand Up @@ -24,6 +24,7 @@ import (
"time"

"github.com/apache/beam/sdks/go/pkg/beam/core/graph/coder"
"github.com/apache/beam/sdks/go/pkg/beam/core/util/ioutilx"
"github.com/apache/beam/sdks/go/pkg/beam/internal/errors"
"github.com/apache/beam/sdks/go/pkg/beam/log"
)
Expand All @@ -37,24 +38,29 @@ type DataSource struct {
Out Node

source DataManager
state StateReader
count int64
splitPos int64
start time.Time

mu sync.Mutex
}

// ID returns the UnitID for this node.
func (n *DataSource) ID() UnitID {
return n.UID
}

// Up initializes this datasource.
func (n *DataSource) Up(ctx context.Context) error {
return nil
}

// StartBundle initializes this datasource for the bundle.
func (n *DataSource) StartBundle(ctx context.Context, id string, data DataContext) error {
n.mu.Lock()
n.source = data.Data
n.state = data.State
n.start = time.Now()
n.count = 0
n.splitPos = math.MaxInt64
Expand Down Expand Up @@ -154,6 +160,27 @@ func (n *DataSource) makeReStream(ctx context.Context, key *FullValue, cv Elemen
if err != nil {
return nil, err
}
case chunk == -1: // State backed iterable!
chunk, err := coder.DecodeVarInt(r)
if err != nil {
return nil, err
}
token, err := ioutilx.ReadN(r, (int)(chunk))
if err != nil {
return nil, err
}
return &concatReStream{
first: &FixedReStream{Buf: buf},
next: &proxyReStream{
open: func() (Stream, error) {
r, err := n.state.OpenIterable(ctx, n.SID, token)
if err != nil {
return nil, err
}
return &elementStream{r: r, ec: cv}, nil
},
},
}, nil
default:
return nil, errors.Errorf("multi-chunk stream with invalid chunk size of %d", chunk)
}
Expand All @@ -174,6 +201,7 @@ func readStreamToBuffer(cv ElementDecoder, r io.ReadCloser, size int64, buf []Fu
return buf, nil
}

// FinishBundle resets the source and metric counters.
func (n *DataSource) FinishBundle(ctx context.Context) error {
n.mu.Lock()
defer n.mu.Unlock()
Expand All @@ -185,6 +213,7 @@ func (n *DataSource) FinishBundle(ctx context.Context) error {
return err
}

// Down resets the source.
func (n *DataSource) Down(ctx context.Context) error {
n.source = nil
return nil
Expand Down Expand Up @@ -253,3 +282,60 @@ func (n *DataSource) Split(splits []int64, frac float32) (int64, error) {
// return an error.
return 0, fmt.Errorf("failed to split at requested splits: {%v}, DataSource at index: %v", splits, c)
}

type concatReStream struct {
first, next ReStream
}

func (c *concatReStream) Open() (Stream, error) {
firstStream, err := c.first.Open()
if err != nil {
return nil, err
}
return &concatStream{first: firstStream, nextStream: c.next}, nil
}

type concatStream struct {
first Stream
nextStream ReStream
}

// Close nils the stream.
func (s *concatStream) Close() error {
if s.first == nil {
return nil
}
defer func() {
s.first = nil
s.nextStream = nil
}()
return s.first.Close()
}

func (s *concatStream) Read() (*FullValue, error) {
if s.first == nil { // When the stream is closed.
return nil, io.EOF
}
fv, err := s.first.Read()
if err == nil {
return fv, nil
}
if err == io.EOF {
if err := s.first.Close(); err != nil {
s.nextStream = nil
return nil, err
}
if s.nextStream == nil {
s.first = nil
return nil, io.EOF
}
s.first, err = s.nextStream.Open()
s.nextStream = nil
if err != nil {
return nil, err
}
fv, err := s.first.Read()
return fv, err
}
return nil, err
}
71 changes: 65 additions & 6 deletions sdks/go/pkg/beam/core/runtime/exec/datasource_test.go
@@ -1,3 +1,18 @@
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package exec

import (
Expand Down Expand Up @@ -76,14 +91,14 @@ func TestDataSource_Iterators(t *testing.T) {
name string
keys, vals []interface{}
Coder *coder.Coder
driver func(c *coder.Coder, dmw io.WriteCloser, ks, vs []interface{})
driver func(c *coder.Coder, dmw io.WriteCloser, siwFn func() io.WriteCloser, ks, vs []interface{})
}{
{
name: "beam:coder:iterable:v1-singleChunk",
keys: []interface{}{int64(42), int64(53)},
vals: []interface{}{int64(1), int64(2), int64(3), int64(4), int64(5)},
Coder: coder.NewW(coder.NewCoGBK([]*coder.Coder{coder.NewVarInt(), coder.NewVarInt()}), coder.NewGlobalWindow()),
driver: func(c *coder.Coder, dmw io.WriteCloser, ks, vs []interface{}) {
driver: func(c *coder.Coder, dmw io.WriteCloser, _ func() io.WriteCloser, ks, vs []interface{}) {
wc, kc, vc := extractCoders(c)
for _, k := range ks {
EncodeWindowedValueHeader(wc, window.SingleGlobalWindow, mtime.ZeroTimestamp, dmw)
Expand All @@ -101,13 +116,13 @@ func TestDataSource_Iterators(t *testing.T) {
keys: []interface{}{int64(42), int64(53)},
vals: []interface{}{int64(1), int64(2), int64(3), int64(4), int64(5)},
Coder: coder.NewW(coder.NewCoGBK([]*coder.Coder{coder.NewVarInt(), coder.NewVarInt()}), coder.NewGlobalWindow()),
driver: func(c *coder.Coder, dmw io.WriteCloser, ks, vs []interface{}) {
driver: func(c *coder.Coder, dmw io.WriteCloser, _ func() io.WriteCloser, ks, vs []interface{}) {
wc, kc, vc := extractCoders(c)
for _, k := range ks {
EncodeWindowedValueHeader(wc, window.SingleGlobalWindow, mtime.ZeroTimestamp, dmw)
kc.Encode(&FullValue{Elm: k}, dmw)

coder.EncodeInt32(-1, dmw) // Mark this as a multi-Chunk (though beam, runner says to use 0)
coder.EncodeInt32(-1, dmw) // Mark this as a multi-Chunk (though beam runner proto says to use 0)
for _, v := range vs {
coder.EncodeVarInt(1, dmw) // Number of elements in this chunk.
vc.Encode(&FullValue{Elm: v}, dmw)
Expand All @@ -117,6 +132,32 @@ func TestDataSource_Iterators(t *testing.T) {
dmw.Close()
},
},
{
name: "beam:coder:state_backed_iterable:v1",
keys: []interface{}{int64(42), int64(53)},
vals: []interface{}{int64(1), int64(2), int64(3), int64(4), int64(5)},
Coder: coder.NewW(coder.NewCoGBK([]*coder.Coder{coder.NewVarInt(), coder.NewVarInt()}), coder.NewGlobalWindow()),
driver: func(c *coder.Coder, dmw io.WriteCloser, swFn func() io.WriteCloser, ks, vs []interface{}) {
wc, kc, vc := extractCoders(c)
for _, k := range ks {
EncodeWindowedValueHeader(wc, window.SingleGlobalWindow, mtime.ZeroTimestamp, dmw)
kc.Encode(&FullValue{Elm: k}, dmw)
coder.EncodeInt32(-1, dmw) // Mark as multi-chunk (though beam, runner says to use 0)
coder.EncodeVarInt(-1, dmw) // Mark subsequent chunks as "state backed"

token := []byte(tokenString)
coder.EncodeVarInt(int64(len(token)), dmw) // token.
dmw.Write(token)
// Each state stream needs to be a different writer, so get a new writer.
sw := swFn()
for _, v := range vs {
vc.Encode(&FullValue{Elm: v}, sw)
}
sw.Close()
}
dmw.Close()
},
},
// TODO: Test splitting.
// TODO: Test progress.
}
Expand All @@ -132,10 +173,18 @@ func TestDataSource_Iterators(t *testing.T) {
}
dmr, dmw := io.Pipe()

go test.driver(source.Coder, dmw, test.keys, test.vals)
// Simulate individual state channels with pipes and a channel.
sRc := make(chan io.ReadCloser)
swFn := func() io.WriteCloser {
sr, sw := io.Pipe()
sRc <- sr
return sw
}
go test.driver(source.Coder, dmw, swFn, test.keys, test.vals)

constructAndExecutePlanWithContext(t, []Unit{out, source}, DataContext{
Data: &TestDataManager{R: dmr},
Data: &TestDataManager{R: dmr},
State: &TestStateReader{Rc: sRc},
})
if len(out.CapturedInputs) == 0 {
t.Fatal("did not capture source output")
Expand Down Expand Up @@ -174,6 +223,16 @@ func (dm *TestDataManager) OpenWrite(ctx context.Context, id StreamID) (io.Write
return nil, nil
}

// TestSideInputReader simulates state reads using channels.
type TestStateReader struct {
StateReader
Rc <-chan io.ReadCloser
}

func (si *TestStateReader) OpenIterable(ctx context.Context, id StreamID, key []byte) (io.ReadCloser, error) {
return <-si.Rc, nil
}

func constructAndExecutePlanWithContext(t *testing.T, us []Unit, dc DataContext) {
p, err := NewPlan("a", us)
if err != nil {
Expand Down
21 changes: 12 additions & 9 deletions sdks/go/pkg/beam/core/runtime/graphx/coder.go
Expand Up @@ -20,7 +20,7 @@ import (
"fmt"

"github.com/apache/beam/sdks/go/pkg/beam/core/graph/coder"
"github.com/apache/beam/sdks/go/pkg/beam/core/runtime/graphx/v1"
v1 "github.com/apache/beam/sdks/go/pkg/beam/core/runtime/graphx/v1"
"github.com/apache/beam/sdks/go/pkg/beam/core/typex"
"github.com/apache/beam/sdks/go/pkg/beam/core/util/protox"
"github.com/apache/beam/sdks/go/pkg/beam/internal/errors"
Expand All @@ -31,12 +31,13 @@ import (
const (
// Model constants

urnBytesCoder = "beam:coder:bytes:v1"
urnVarIntCoder = "beam:coder:varint:v1"
urnLengthPrefixCoder = "beam:coder:length_prefix:v1"
urnKVCoder = "beam:coder:kv:v1"
urnIterableCoder = "beam:coder:iterable:v1"
urnWindowedValueCoder = "beam:coder:windowed_value:v1"
urnBytesCoder = "beam:coder:bytes:v1"
urnVarIntCoder = "beam:coder:varint:v1"
urnLengthPrefixCoder = "beam:coder:length_prefix:v1"
urnKVCoder = "beam:coder:kv:v1"
urnIterableCoder = "beam:coder:iterable:v1"
urnStateBackedIterableCoder = "beam:coder:state_backed_iterable:v1"
urnWindowedValueCoder = "beam:coder:windowed_value:v1"

urnGlobalWindow = "beam:coder:global_window:v1"
urnIntervalWindow = "beam:coder:interval_window:v1"
Expand Down Expand Up @@ -175,8 +176,9 @@ func (b *CoderUnmarshaller) makeCoder(c *pb.Coder) (*coder.Coder, error) {
if err != nil {
return nil, err
}
isGBK := elm.GetSpec().GetUrn() == urnIterableCoder
if isGBK {

switch elm.GetSpec().GetUrn() {
case urnIterableCoder, urnStateBackedIterableCoder:
id = elm.GetComponentCoderIds()[0]
kind = coder.CoGBK
root = typex.CoGBKType
Expand Down Expand Up @@ -352,6 +354,7 @@ func (b *CoderMarshaller) Add(c *coder.Coder) string {
value = b.internBuiltInCoder(urnLengthPrefixCoder, union)
}

// SDKs always provide iterableCoder to runners, but can receive StateBackedIterables in return.
stream := b.internBuiltInCoder(urnIterableCoder, value)
return b.internBuiltInCoder(urnKVCoder, comp[0], stream)

Expand Down

0 comments on commit 0d64dbc

Please sign in to comment.