Skip to content

Commit

Permalink
Add embedding node and embedding buffer for datanode
Browse files Browse the repository at this point in the history
Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
  • Loading branch information
aoiasd committed Jul 11, 2024
1 parent e0b39d8 commit 81bb6ef
Show file tree
Hide file tree
Showing 16 changed files with 740 additions and 24 deletions.
101 changes: 101 additions & 0 deletions internal/datanode/pipeline/flow_graph_embedding_node.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package pipeline

import (
"fmt"

"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/datanode/writebuffer"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/vectorizer"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/samber/lo"
"go.uber.org/zap"
)

// TODO Support String and move type to proto
type EmbeddingType int32

const (
Hash = 0
)

type embeddingNode struct {
BaseNode

schema *schemapb.CollectionSchema
pkField *schemapb.FieldSchema
channelName string
embeddingType EmbeddingType

vectorizer vectorizer.Vectorizer
}

func (eNode *embeddingNode) Name() string {
return fmt.Sprintf("embeddingNode-%s-%s", eNode.embeddingType, eNode.channelName)
}

func (eNode *embeddingNode) prepareInsert(insertMsgs []*msgstream.InsertMsg) ([]*writebuffer.InsertData, error) {
groups := lo.GroupBy(insertMsgs, func(msg *msgstream.InsertMsg) int64 { return msg.SegmentID })
segmentPartition := lo.SliceToMap(insertMsgs, func(msg *msgstream.InsertMsg) (int64, int64) { return msg.GetSegmentID(), msg.GetPartitionID() })

result := make([]*writebuffer.InsertData, 0, len(groups))
for segment, msgs := range groups {
inData := writebuffer.NewInsertData(segment, segmentPartition[segment], len(msgs), eNode.pkField.GetDataType())

for _, msg := range msgs {
data, err := storage.InsertMsgToInsertData(msg, eNode.schema)
if err != nil {
log.Warn("failed to transfer insert msg to insert data", zap.Error(err))
return nil, err
}

pkFieldData, err := storage.GetPkFromInsertData(eNode.schema, data)
if err != nil {
return nil, err
}
if pkFieldData.RowNum() != data.GetRowNum() {
return nil, merr.WrapErrServiceInternal("pk column row num not match")
}

tsFieldData, err := storage.GetTimestampFromInsertData(data)
if err != nil {
return nil, err
}
if tsFieldData.RowNum() != data.GetRowNum() {
return nil, merr.WrapErrServiceInternal("timestamp column row num not match")
}

emData, err := eNode.vectorizer.Vectorize(data)
if err != nil {
log.Warn("failed to embedding insert data", zap.Error(err))
return nil, err
}
inData.Append(data, emData, pkFieldData, tsFieldData)
}
result = append(result, inData)
}
return result, nil
}

func (eNode *embeddingNode) Opearte(in []Msg) []Msg {
fgMsg := in[0].(*FlowGraphMsg)
return []Msg{fgMsg}
}
4 changes: 4 additions & 0 deletions internal/datanode/pipeline/flow_graph_message.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package pipeline
import (
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/datanode/util"
"github.com/milvus-io/milvus/internal/datanode/writebuffer"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/flowgraph"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
Expand Down Expand Up @@ -46,7 +47,10 @@ type (
type FlowGraphMsg struct {
BaseMsg
InsertMessages []*msgstream.InsertMsg
InsertData []*writebuffer.InsertData

DeleteMessages []*msgstream.DeleteMsg

TimeRange util.TimeRange
StartPositions []*msgpb.MsgPosition
EndPositions []*msgpb.MsgPosition
Expand Down
12 changes: 10 additions & 2 deletions internal/datanode/syncmgr/serializer.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ type SyncPack struct {
metacache metacache.MetaCache
metawriter MetaWriter
// data
insertData []*storage.InsertData
deltaData *storage.DeleteData
insertData []*storage.InsertData
embeddingData []*storage.EmbeddingData
deltaData *storage.DeleteData
// statistics
tsFrom typeutil.Timestamp
tsTo typeutil.Timestamp
Expand All @@ -64,6 +65,13 @@ func (p *SyncPack) WithInsertData(insertData []*storage.InsertData) *SyncPack {
return p
}

func (p *SyncPack) WithEmbeddingData(embeddingData []*storage.EmbeddingData) *SyncPack {
p.embeddingData = lo.Filter(embeddingData, func(emData *storage.EmbeddingData, _ int) bool {
return emData != nil
})
return p
}

func (p *SyncPack) WithDeleteData(deltaData *storage.DeleteData) *SyncPack {
p.deltaData = deltaData
return p
Expand Down
24 changes: 24 additions & 0 deletions internal/datanode/syncmgr/storage_serializer.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type storageV1Serializer struct {
pkField *schemapb.FieldSchema

inCodec *storage.InsertCodec
emCodec *storage.EmbeddingCodec
delCodec *storage.DeleteCodec

allocator allocator.Interface
Expand Down Expand Up @@ -101,6 +102,9 @@ func (s *storageV1Serializer) EncodeBuffer(ctx context.Context, pack *SyncPack)
}
task.binlogBlobs = binlogBlobs

embeddingBlobs, err := s.serializeEmbedding(ctx, pack)
task.embeddingBlobs = embeddingBlobs

singlePKStats, batchStatsBlob, err := s.serializeStatslog(pack)
if err != nil {
log.Warn("failed to serialized statslog", zap.Error(err))
Expand Down Expand Up @@ -183,6 +187,26 @@ func (s *storageV1Serializer) serializeBinlog(ctx context.Context, pack *SyncPac
return result, nil
}

func (s *storageV1Serializer) serializeEmbedding(ctx context.Context, pack *SyncPack) (map[int64]*storage.Blob, error) {
log := log.Ctx(ctx)
blobs, err := s.emCodec.Serialize(pack.collectionID, pack.partitionID, pack.segmentID, pack.embeddingData...)
if err != nil {
return nil, err
}

result := make(map[int64]*storage.Blob)
for _, blob := range blobs {
fieldID, err := strconv.ParseInt(blob.GetKey(), 10, 64)
if err != nil {
log.Error("serialize buffer failed ... cannot parse string to fieldID ..", zap.Error(err))
return nil, err
}

result[fieldID] = blob
}
return result, nil
}

func (s *storageV1Serializer) serializeStatslog(pack *SyncPack) (*storage.PrimaryKeyStats, *storage.Blob, error) {
var rowNum int64
var pkFieldData []storage.FieldData
Expand Down
1 change: 1 addition & 0 deletions internal/datanode/syncmgr/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ type SyncTask struct {

binlogBlobs map[int64]*storage.Blob // fieldID => blob
binlogMemsize map[int64]int64 // memory size
embeddingBlobs map[int64]*storage.Blob
batchStatsBlob *storage.Blob
mergedStatsBlob *storage.Blob
deltaBlob *storage.Blob
Expand Down
2 changes: 1 addition & 1 deletion internal/datanode/writebuffer/bf_write_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func NewBFWriteBuffer(channel string, metacache metacache.MetaCache, storageV2Ca
}, nil
}

func (wb *bfWriteBuffer) dispatchDeleteMsgs(groups []*inData, deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) {
func (wb *bfWriteBuffer) dispatchDeleteMsgs(groups []*InsertData, deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) {
batchSize := paramtable.Get().CommonCfg.BloomFilterApplyBatchSize.GetAsInt()

split := func(pks []storage.PrimaryKey, pkTss []uint64, segments []*metacache.SegmentInfo) {
Expand Down
25 changes: 20 additions & 5 deletions internal/datanode/writebuffer/insert_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ type InsertBuffer struct {
collSchema *schemapb.CollectionSchema

buffers []*storage.InsertData

hasEmbedding bool
embeddings []*storage.EmbeddingData // embedding data for varchar field which enable match
}

func NewInsertBuffer(sch *schemapb.CollectionSchema) (*InsertBuffer, error) {
Expand Down Expand Up @@ -103,25 +106,37 @@ func NewInsertBuffer(sch *schemapb.CollectionSchema) (*InsertBuffer, error) {
return ib, nil
}

func (ib *InsertBuffer) buffer(inData *storage.InsertData, tr TimeRange, startPos, endPos *msgpb.MsgPosition) {
func (ib *InsertBuffer) buffer(inData *storage.InsertData, emData *storage.EmbeddingData, tr TimeRange, startPos, endPos *msgpb.MsgPosition) {
// buffer := ib.currentBuffer()
// storage.MergeInsertData(buffer.buffer, inData)
ib.buffers = append(ib.buffers, inData)
ib.embeddings = append(ib.embeddings, emData)
}

func (ib *InsertBuffer) Yield() []*storage.InsertData {
func (ib *InsertBuffer) Yield() ([]*storage.InsertData, []*storage.EmbeddingData) {
result := ib.buffers
// set buffer nil to so that fragmented buffer could get GCed
ib.buffers = nil
return result
if !ib.hasEmbedding {
return result, nil
}

embeddings := ib.embeddings
ib.embeddings = nil
return result, embeddings
}

func (ib *InsertBuffer) Buffer(inData *inData, startPos, endPos *msgpb.MsgPosition) int64 {
func (ib *InsertBuffer) Buffer(inData *InsertData, startPos, endPos *msgpb.MsgPosition) int64 {
bufferedSize := int64(0)
for idx, data := range inData.data {
tsData := inData.tsField[idx]
var emData *storage.EmbeddingData
if ib.hasEmbedding {
emData = inData.embeddings[idx]
}

tr := ib.getTimestampRange(tsData)
ib.buffer(data, tr, startPos, endPos)
ib.buffer(data, emData, tr, startPos, endPos)

// update buffer size
ib.UpdateStatistics(int64(data.GetRowNum()), int64(data.GetMemorySize()), tr, startPos, endPos)
Expand Down
6 changes: 4 additions & 2 deletions internal/datanode/writebuffer/segment_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ func (buf *segmentBuffer) IsFull() bool {
return buf.insertBuffer.IsFull() || buf.deltaBuffer.IsFull()
}

func (buf *segmentBuffer) Yield() (insert []*storage.InsertData, delete *storage.DeleteData) {
return buf.insertBuffer.Yield(), buf.deltaBuffer.Yield()
func (buf *segmentBuffer) Yield() (insert []*storage.InsertData, embedding []*storage.EmbeddingData, delete *storage.DeleteData) {
insert, embedding = buf.insertBuffer.Yield()
delete = buf.deltaBuffer.Yield()
return
}

func (buf *segmentBuffer) MinTimestamp() typeutil.Timestamp {
Expand Down
Loading

0 comments on commit 81bb6ef

Please sign in to comment.