From 81bb6efc1644bca7da3d4cd2fb6072427e531e2a Mon Sep 17 00:00:00 2001 From: aoiasd Date: Thu, 11 Jul 2024 15:45:50 +0800 Subject: [PATCH] Add embedding node and embedding buffer for datanode Signed-off-by: aoiasd --- .../pipeline/flow_graph_embedding_node.go | 101 ++++++++++++ .../datanode/pipeline/flow_graph_message.go | 4 + internal/datanode/syncmgr/serializer.go | 12 +- .../datanode/syncmgr/storage_serializer.go | 24 +++ internal/datanode/syncmgr/task.go | 1 + .../datanode/writebuffer/bf_write_buffer.go | 2 +- .../datanode/writebuffer/insert_buffer.go | 25 ++- .../datanode/writebuffer/segment_buffer.go | 6 +- internal/datanode/writebuffer/write_buffer.go | 81 ++++++++-- internal/storage/data_codec.go | 43 ++++++ internal/storage/match_embedding.go | 114 ++++++++++++++ .../util/tokenizerapi/mocks/TokenStream.go | 146 ++++++++++++++++++ internal/util/tokenizerapi/mocks/Tokenizer.go | 111 +++++++++++++ internal/util/tokenizerapi/token_stream.go | 8 + internal/util/tokenizerapi/tokenizer.go | 7 + internal/util/vectorizer/vectorizer.go | 79 ++++++++++ 16 files changed, 740 insertions(+), 24 deletions(-) create mode 100644 internal/datanode/pipeline/flow_graph_embedding_node.go create mode 100644 internal/storage/match_embedding.go create mode 100644 internal/util/tokenizerapi/mocks/TokenStream.go create mode 100644 internal/util/tokenizerapi/mocks/Tokenizer.go create mode 100644 internal/util/tokenizerapi/token_stream.go create mode 100644 internal/util/tokenizerapi/tokenizer.go create mode 100644 internal/util/vectorizer/vectorizer.go diff --git a/internal/datanode/pipeline/flow_graph_embedding_node.go b/internal/datanode/pipeline/flow_graph_embedding_node.go new file mode 100644 index 000000000000..49d2bb61b9b5 --- /dev/null +++ b/internal/datanode/pipeline/flow_graph_embedding_node.go @@ -0,0 +1,101 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pipeline + +import ( + "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/datanode/writebuffer" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/vectorizer" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/samber/lo" + "go.uber.org/zap" +) + +// TODO Support String and move type to proto +type EmbeddingType int32 + +const ( + Hash = 0 +) + +type embeddingNode struct { + BaseNode + + schema *schemapb.CollectionSchema + pkField *schemapb.FieldSchema + channelName string + embeddingType EmbeddingType + + vectorizer vectorizer.Vectorizer +} + +func (eNode *embeddingNode) Name() string { + return fmt.Sprintf("embeddingNode-%s-%s", eNode.embeddingType, eNode.channelName) +} + +func (eNode *embeddingNode) prepareInsert(insertMsgs []*msgstream.InsertMsg) ([]*writebuffer.InsertData, error) { + groups := lo.GroupBy(insertMsgs, func(msg *msgstream.InsertMsg) int64 { return msg.SegmentID }) + segmentPartition := lo.SliceToMap(insertMsgs, func(msg *msgstream.InsertMsg) (int64, int64) { return msg.GetSegmentID(), msg.GetPartitionID() }) + + result := make([]*writebuffer.InsertData, 0, len(groups)) + for segment, msgs := range groups { + inData := writebuffer.NewInsertData(segment, segmentPartition[segment], len(msgs), eNode.pkField.GetDataType()) + + for _, msg := range msgs { + data, err := storage.InsertMsgToInsertData(msg, eNode.schema) + if err != nil { + log.Warn("failed to transfer insert msg to insert data", zap.Error(err)) + return nil, err + } + + pkFieldData, err := storage.GetPkFromInsertData(eNode.schema, data) + if err != nil { + return nil, err + } + if pkFieldData.RowNum() != data.GetRowNum() { + return nil, merr.WrapErrServiceInternal("pk column row num not match") + } + + tsFieldData, err := storage.GetTimestampFromInsertData(data) + if err != nil { + return nil, err + } + if tsFieldData.RowNum() != data.GetRowNum() { + return nil, merr.WrapErrServiceInternal("timestamp column row num not match") + } + + emData, err := eNode.vectorizer.Vectorize(data) + if err != nil { + log.Warn("failed to embedding insert data", zap.Error(err)) + return nil, err + } + inData.Append(data, emData, pkFieldData, tsFieldData) + } + result = append(result, inData) + } + return result, nil +} + +func (eNode *embeddingNode) Opearte(in []Msg) []Msg { + fgMsg := in[0].(*FlowGraphMsg) + return []Msg{fgMsg} +} diff --git a/internal/datanode/pipeline/flow_graph_message.go b/internal/datanode/pipeline/flow_graph_message.go index ca2b72765e4c..a00ac6751268 100644 --- a/internal/datanode/pipeline/flow_graph_message.go +++ b/internal/datanode/pipeline/flow_graph_message.go @@ -19,6 +19,7 @@ package pipeline import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/datanode/util" + "github.com/milvus-io/milvus/internal/datanode/writebuffer" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/flowgraph" "github.com/milvus-io/milvus/pkg/mq/msgstream" @@ -46,7 +47,10 @@ type ( type FlowGraphMsg struct { BaseMsg InsertMessages []*msgstream.InsertMsg + InsertData []*writebuffer.InsertData + DeleteMessages []*msgstream.DeleteMsg + TimeRange util.TimeRange StartPositions []*msgpb.MsgPosition EndPositions []*msgpb.MsgPosition diff --git a/internal/datanode/syncmgr/serializer.go b/internal/datanode/syncmgr/serializer.go index cd7be9d06208..aef2f056c6fd 100644 --- a/internal/datanode/syncmgr/serializer.go +++ b/internal/datanode/syncmgr/serializer.go @@ -39,8 +39,9 @@ type SyncPack struct { metacache metacache.MetaCache metawriter MetaWriter // data - insertData []*storage.InsertData - deltaData *storage.DeleteData + insertData []*storage.InsertData + embeddingData []*storage.EmbeddingData + deltaData *storage.DeleteData // statistics tsFrom typeutil.Timestamp tsTo typeutil.Timestamp @@ -64,6 +65,13 @@ func (p *SyncPack) WithInsertData(insertData []*storage.InsertData) *SyncPack { return p } +func (p *SyncPack) WithEmbeddingData(embeddingData []*storage.EmbeddingData) *SyncPack { + p.embeddingData = lo.Filter(embeddingData, func(emData *storage.EmbeddingData, _ int) bool { + return emData != nil + }) + return p +} + func (p *SyncPack) WithDeleteData(deltaData *storage.DeleteData) *SyncPack { p.deltaData = deltaData return p diff --git a/internal/datanode/syncmgr/storage_serializer.go b/internal/datanode/syncmgr/storage_serializer.go index 475d4f3446da..1d68eaa7556d 100644 --- a/internal/datanode/syncmgr/storage_serializer.go +++ b/internal/datanode/syncmgr/storage_serializer.go @@ -43,6 +43,7 @@ type storageV1Serializer struct { pkField *schemapb.FieldSchema inCodec *storage.InsertCodec + emCodec *storage.EmbeddingCodec delCodec *storage.DeleteCodec allocator allocator.Interface @@ -101,6 +102,9 @@ func (s *storageV1Serializer) EncodeBuffer(ctx context.Context, pack *SyncPack) } task.binlogBlobs = binlogBlobs + embeddingBlobs, err := s.serializeEmbedding(ctx, pack) + task.embeddingBlobs = embeddingBlobs + singlePKStats, batchStatsBlob, err := s.serializeStatslog(pack) if err != nil { log.Warn("failed to serialized statslog", zap.Error(err)) @@ -183,6 +187,26 @@ func (s *storageV1Serializer) serializeBinlog(ctx context.Context, pack *SyncPac return result, nil } +func (s *storageV1Serializer) serializeEmbedding(ctx context.Context, pack *SyncPack) (map[int64]*storage.Blob, error) { + log := log.Ctx(ctx) + blobs, err := s.emCodec.Serialize(pack.collectionID, pack.partitionID, pack.segmentID, pack.embeddingData...) + if err != nil { + return nil, err + } + + result := make(map[int64]*storage.Blob) + for _, blob := range blobs { + fieldID, err := strconv.ParseInt(blob.GetKey(), 10, 64) + if err != nil { + log.Error("serialize buffer failed ... cannot parse string to fieldID ..", zap.Error(err)) + return nil, err + } + + result[fieldID] = blob + } + return result, nil +} + func (s *storageV1Serializer) serializeStatslog(pack *SyncPack) (*storage.PrimaryKeyStats, *storage.Blob, error) { var rowNum int64 var pkFieldData []storage.FieldData diff --git a/internal/datanode/syncmgr/task.go b/internal/datanode/syncmgr/task.go index b6c07a781bce..a493529c4c94 100644 --- a/internal/datanode/syncmgr/task.go +++ b/internal/datanode/syncmgr/task.go @@ -75,6 +75,7 @@ type SyncTask struct { binlogBlobs map[int64]*storage.Blob // fieldID => blob binlogMemsize map[int64]int64 // memory size + embeddingBlobs map[int64]*storage.Blob batchStatsBlob *storage.Blob mergedStatsBlob *storage.Blob deltaBlob *storage.Blob diff --git a/internal/datanode/writebuffer/bf_write_buffer.go b/internal/datanode/writebuffer/bf_write_buffer.go index 808b4038609e..7a3b483a7673 100644 --- a/internal/datanode/writebuffer/bf_write_buffer.go +++ b/internal/datanode/writebuffer/bf_write_buffer.go @@ -30,7 +30,7 @@ func NewBFWriteBuffer(channel string, metacache metacache.MetaCache, storageV2Ca }, nil } -func (wb *bfWriteBuffer) dispatchDeleteMsgs(groups []*inData, deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) { +func (wb *bfWriteBuffer) dispatchDeleteMsgs(groups []*InsertData, deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) { batchSize := paramtable.Get().CommonCfg.BloomFilterApplyBatchSize.GetAsInt() split := func(pks []storage.PrimaryKey, pkTss []uint64, segments []*metacache.SegmentInfo) { diff --git a/internal/datanode/writebuffer/insert_buffer.go b/internal/datanode/writebuffer/insert_buffer.go index b7f496e83ada..91ebe4d880a9 100644 --- a/internal/datanode/writebuffer/insert_buffer.go +++ b/internal/datanode/writebuffer/insert_buffer.go @@ -75,6 +75,9 @@ type InsertBuffer struct { collSchema *schemapb.CollectionSchema buffers []*storage.InsertData + + hasEmbedding bool + embeddings []*storage.EmbeddingData // embedding data for varchar field which enable match } func NewInsertBuffer(sch *schemapb.CollectionSchema) (*InsertBuffer, error) { @@ -103,25 +106,37 @@ func NewInsertBuffer(sch *schemapb.CollectionSchema) (*InsertBuffer, error) { return ib, nil } -func (ib *InsertBuffer) buffer(inData *storage.InsertData, tr TimeRange, startPos, endPos *msgpb.MsgPosition) { +func (ib *InsertBuffer) buffer(inData *storage.InsertData, emData *storage.EmbeddingData, tr TimeRange, startPos, endPos *msgpb.MsgPosition) { // buffer := ib.currentBuffer() // storage.MergeInsertData(buffer.buffer, inData) ib.buffers = append(ib.buffers, inData) + ib.embeddings = append(ib.embeddings, emData) } -func (ib *InsertBuffer) Yield() []*storage.InsertData { +func (ib *InsertBuffer) Yield() ([]*storage.InsertData, []*storage.EmbeddingData) { result := ib.buffers // set buffer nil to so that fragmented buffer could get GCed ib.buffers = nil - return result + if !ib.hasEmbedding { + return result, nil + } + + embeddings := ib.embeddings + ib.embeddings = nil + return result, embeddings } -func (ib *InsertBuffer) Buffer(inData *inData, startPos, endPos *msgpb.MsgPosition) int64 { +func (ib *InsertBuffer) Buffer(inData *InsertData, startPos, endPos *msgpb.MsgPosition) int64 { bufferedSize := int64(0) for idx, data := range inData.data { tsData := inData.tsField[idx] + var emData *storage.EmbeddingData + if ib.hasEmbedding { + emData = inData.embeddings[idx] + } + tr := ib.getTimestampRange(tsData) - ib.buffer(data, tr, startPos, endPos) + ib.buffer(data, emData, tr, startPos, endPos) // update buffer size ib.UpdateStatistics(int64(data.GetRowNum()), int64(data.GetMemorySize()), tr, startPos, endPos) diff --git a/internal/datanode/writebuffer/segment_buffer.go b/internal/datanode/writebuffer/segment_buffer.go index 6afd64fff7fa..983211619872 100644 --- a/internal/datanode/writebuffer/segment_buffer.go +++ b/internal/datanode/writebuffer/segment_buffer.go @@ -32,8 +32,10 @@ func (buf *segmentBuffer) IsFull() bool { return buf.insertBuffer.IsFull() || buf.deltaBuffer.IsFull() } -func (buf *segmentBuffer) Yield() (insert []*storage.InsertData, delete *storage.DeleteData) { - return buf.insertBuffer.Yield(), buf.deltaBuffer.Yield() +func (buf *segmentBuffer) Yield() (insert []*storage.InsertData, embedding []*storage.EmbeddingData, delete *storage.DeleteData) { + insert, embedding = buf.insertBuffer.Yield() + delete = buf.deltaBuffer.Yield() + return } func (buf *segmentBuffer) MinTimestamp() typeutil.Timestamp { diff --git a/internal/datanode/writebuffer/write_buffer.go b/internal/datanode/writebuffer/write_buffer.go index 7f28c288c259..28ef35232cfd 100644 --- a/internal/datanode/writebuffer/write_buffer.go +++ b/internal/datanode/writebuffer/write_buffer.go @@ -394,34 +394,86 @@ func (wb *writeBufferBase) getOrCreateBuffer(segmentID int64) *segmentBuffer { return buffer } -func (wb *writeBufferBase) yieldBuffer(segmentID int64) ([]*storage.InsertData, *storage.DeleteData, *TimeRange, *msgpb.MsgPosition) { +func (wb *writeBufferBase) yieldBuffer(segmentID int64) ([]*storage.InsertData, []*storage.EmbeddingData, *storage.DeleteData, *TimeRange, *msgpb.MsgPosition) { buffer, ok := wb.buffers[segmentID] if !ok { - return nil, nil, nil, nil + return nil, nil, nil, nil, nil } // remove buffer and move it to sync manager delete(wb.buffers, segmentID) start := buffer.EarliestPosition() timeRange := buffer.GetTimeRange() - insert, delta := buffer.Yield() + insert, embedding, delta := buffer.Yield() - return insert, delta, timeRange, start + return insert, embedding, delta, timeRange, start } -type inData struct { +type InsertData struct { segmentID int64 partitionID int64 data []*storage.InsertData - pkField []storage.FieldData - tsField []*storage.Int64FieldData - rowNum int64 + embeddings []*storage.EmbeddingData + + pkField []storage.FieldData + pkType schemapb.DataType + + tsField []*storage.Int64FieldData + rowNum int64 intPKTs map[int64]int64 strPKTs map[string]int64 } -func (id *inData) pkExists(pk storage.PrimaryKey, ts uint64) bool { +func NewInsertData(segmentID, partitionID int64, cap int, pkType schemapb.DataType) *InsertData { + data := &InsertData{ + segmentID: segmentID, + partitionID: partitionID, + data: make([]*storage.InsertData, 0, cap), + embeddings: make([]*storage.EmbeddingData, 0, cap), + pkField: make([]storage.FieldData, 0, cap), + pkType: pkType, + } + + switch pkType { + case schemapb.DataType_Int64: + data.intPKTs = make(map[int64]int64) + case schemapb.DataType_VarChar: + data.strPKTs = make(map[string]int64) + } + + return data +} + +func (id *InsertData) Append(data *storage.InsertData, emData *storage.EmbeddingData, pkFieldData storage.FieldData, tsFieldData *storage.Int64FieldData) { + id.data = append(id.data, data) + id.pkField = append(id.pkField, pkFieldData) + id.tsField = append(id.tsField, tsFieldData) + id.embeddings = append(id.embeddings, emData) + id.rowNum += int64(data.GetRowNum()) + + timestamps := tsFieldData.GetRows().([]int64) + switch id.pkType { + case schemapb.DataType_Int64: + pks := pkFieldData.GetRows().([]int64) + for idx, pk := range pks { + ts, ok := id.intPKTs[pk] + if !ok || timestamps[idx] < ts { + id.intPKTs[pk] = timestamps[idx] + } + } + case schemapb.DataType_VarChar: + pks := pkFieldData.GetRows().([]string) + for idx, pk := range pks { + ts, ok := id.strPKTs[pk] + if !ok || timestamps[idx] < ts { + id.strPKTs[pk] = timestamps[idx] + } + } + } +} + +func (id *InsertData) pkExists(pk storage.PrimaryKey, ts uint64) bool { var ok bool var minTs int64 switch pk.Type() { @@ -434,7 +486,7 @@ func (id *inData) pkExists(pk storage.PrimaryKey, ts uint64) bool { return ok && ts > uint64(minTs) } -func (id *inData) batchPkExists(pks []storage.PrimaryKey, tss []uint64, hits []bool) []bool { +func (id *InsertData) batchPkExists(pks []storage.PrimaryKey, tss []uint64, hits []bool) []bool { if len(pks) == 0 { return nil } @@ -466,9 +518,9 @@ func (wb *writeBufferBase) prepareInsert(insertMsgs []*msgstream.InsertMsg) ([]* groups := lo.GroupBy(insertMsgs, func(msg *msgstream.InsertMsg) int64 { return msg.SegmentID }) segmentPartition := lo.SliceToMap(insertMsgs, func(msg *msgstream.InsertMsg) (int64, int64) { return msg.GetSegmentID(), msg.GetPartitionID() }) - result := make([]*inData, 0, len(groups)) + result := make([]*InsertData, 0, len(groups)) for segment, msgs := range groups { - inData := &inData{ + inData := &InsertData{ segmentID: segment, partitionID: segmentPartition[segment], data: make([]*storage.InsertData, 0, len(msgs)), @@ -537,7 +589,7 @@ func (wb *writeBufferBase) prepareInsert(insertMsgs []*msgstream.InsertMsg) ([]* } // bufferInsert transform InsertMsg into bufferred InsertData and returns primary key field data for future usage. -func (wb *writeBufferBase) bufferInsert(inData *inData, startPos, endPos *msgpb.MsgPosition) error { +func (wb *writeBufferBase) bufferInsert(inData *InsertData, startPos, endPos *msgpb.MsgPosition) error { _, ok := wb.metaCache.GetSegmentByID(inData.segmentID) // new segment if !ok { @@ -585,7 +637,7 @@ func (wb *writeBufferBase) getSyncTask(ctx context.Context, segmentID int64) (sy var totalMemSize float64 = 0 var tsFrom, tsTo uint64 - insert, delta, timeRange, startPos := wb.yieldBuffer(segmentID) + insert, embedding, delta, timeRange, startPos := wb.yieldBuffer(segmentID) if timeRange != nil { tsFrom, tsTo = timeRange.timestampMin, timeRange.timestampMax } @@ -610,6 +662,7 @@ func (wb *writeBufferBase) getSyncTask(ctx context.Context, segmentID int64) (sy pack := &syncmgr.SyncPack{} pack.WithInsertData(insert). + WithEmbeddingData(embedding). WithDeleteData(delta). WithCollectionID(wb.collectionID). WithPartitionID(segmentInfo.PartitionID()). diff --git a/internal/storage/data_codec.go b/internal/storage/data_codec.go index 549fbe932d12..0c0e0b471253 100644 --- a/internal/storage/data_codec.go +++ b/internal/storage/data_codec.go @@ -24,11 +24,13 @@ import ( "sort" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus-storage/go/common/log" "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metautil" "github.com/milvus-io/milvus/pkg/util/typeutil" + "go.uber.org/zap" ) const ( @@ -1014,3 +1016,44 @@ func (dataDefinitionCodec *DataDefinitionCodec) Deserialize(blobs []*Blob) (ts [ return resultTs, requestsStrings, nil } + +type EmbeddingCodec struct { + Schema *etcdpb.CollectionMeta +} + +func (embeddingCodec *EmbeddingCodec) Serialize(collectionID UniqueID, partitionID UniqueID, segmentID UniqueID, data ...*EmbeddingData) ([]*Blob, error) { + var blobs []*Blob + for _, field := range embeddingCodec.Schema.Schema.Fields { + // if !field.IsMatch(){ + // continue + // } + rowNum := 0 + writer := NewEmbeddingWriter(schemapb.DataType_SparseFloatVector) + for _, block := range data { + fieldData := block.Data[field.FieldID] + rowNum += fieldData.RowNum() + _, err := writer.Write(fieldData) + if err != nil { + log.Info("Searialize embedding data failed", + zap.Int64("collectionID", collectionID), zap.Int64("partitionID", partitionID), zap.Int64("segmentID", segmentID), zap.Error(err)) + return nil, err + } + } + + writer.Finish() + + buffer, err := writer.GetBuffer() + if err != nil { + return nil, err + } + + blobKey := fmt.Sprintf("%d", field.FieldID) + blobs = append(blobs, &Blob{ + Key: blobKey, + Value: buffer, + RowNum: int64(rowNum), + MemorySize: 0, // TODO Memory Size + }) + } + return blobs, nil +} diff --git a/internal/storage/match_embedding.go b/internal/storage/match_embedding.go new file mode 100644 index 000000000000..c9ff8df5307d --- /dev/null +++ b/internal/storage/match_embedding.go @@ -0,0 +1,114 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "bytes" + "encoding/binary" + + "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +type EmbeddingData struct { + Data map[FieldID]FieldData // ONLY SUPPORT SPARSE VECTOR + +} + +func NewEmbeddingData(schema *schemapb.CollectionSchema) (*EmbeddingData, error) { + return NewEmbeddingDataWithCap(schema, 0) +} + +func NewEmbeddingDataWithCap(schema *schemapb.CollectionSchema, cap int) (*EmbeddingData, error) { + if schema == nil { + return nil, merr.WrapErrParameterMissing("collection schema") + } + + edata := &EmbeddingData{ + Data: make(map[FieldID]FieldData), + } + + for _, field := range schema.GetFields() { + // TODO + // if !field.IsMatch() { + // continue + // } + fieldData, err := NewFieldData(schemapb.DataType_SparseFloatVector, field, cap) + if err != nil { + return nil, err + } + edata.Data[field.FieldID] = fieldData + } + + return edata, nil +} + +type EmbeddingWriter struct { + offset int + output *bytes.Buffer + // TODO Embedding Type? + embeddingType schemapb.DataType + + // size of single pair of indice and values + pairSize int + + isFinished bool +} + +func NewEmbeddingWriter(embeddingType schemapb.DataType) *EmbeddingWriter { + return &EmbeddingWriter{ + embeddingType: embeddingType, + } +} + +func (writer *EmbeddingWriter) WriteRow(content []byte) { + writer.output.Write(binary.LittleEndian.AppendUint32(make([]byte, 4), uint32(len(content)/writer.pairSize))) + writer.output.Write(content) +} + +func (writer *EmbeddingWriter) Write(data FieldData) (int, error) { + for i := 0; i < data.RowNum(); i++ { + row := data.GetRow(i) + vectorBytes, ok := row.([]byte) + if !ok { + return 0, merr.WrapErrParameterInvalid("sparse vector", row, "Wrong row type") + } + writer.WriteRow(vectorBytes) + } + return data.RowNum(), nil +} + +func (writer *EmbeddingWriter) GetBuffer() ([]byte, error) { + if !writer.isFinished { + return nil, errors.New("please close embedding writer before get buffer") + } + + data := writer.output.Bytes() + if len(data) == 0 { + return nil, errors.New("empty buffer") + } + + return data, nil +} + +func (writer *EmbeddingWriter) Finish() { + writer.isFinished = true +} + +type EmebddingReader struct { +} diff --git a/internal/util/tokenizerapi/mocks/TokenStream.go b/internal/util/tokenizerapi/mocks/TokenStream.go new file mode 100644 index 000000000000..ae556b619abb --- /dev/null +++ b/internal/util/tokenizerapi/mocks/TokenStream.go @@ -0,0 +1,146 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// TokenStream is an autogenerated mock type for the TokenStream type +type TokenStream struct { + mock.Mock +} + +type TokenStream_Expecter struct { + mock *mock.Mock +} + +func (_m *TokenStream) EXPECT() *TokenStream_Expecter { + return &TokenStream_Expecter{mock: &_m.Mock} +} + +// Advance provides a mock function with given fields: +func (_m *TokenStream) Advance() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// TokenStream_Advance_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Advance' +type TokenStream_Advance_Call struct { + *mock.Call +} + +// Advance is a helper method to define mock.On call +func (_e *TokenStream_Expecter) Advance() *TokenStream_Advance_Call { + return &TokenStream_Advance_Call{Call: _e.mock.On("Advance")} +} + +func (_c *TokenStream_Advance_Call) Run(run func()) *TokenStream_Advance_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *TokenStream_Advance_Call) Return(_a0 bool) *TokenStream_Advance_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *TokenStream_Advance_Call) RunAndReturn(run func() bool) *TokenStream_Advance_Call { + _c.Call.Return(run) + return _c +} + +// Destroy provides a mock function with given fields: +func (_m *TokenStream) Destroy() { + _m.Called() +} + +// TokenStream_Destroy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Destroy' +type TokenStream_Destroy_Call struct { + *mock.Call +} + +// Destroy is a helper method to define mock.On call +func (_e *TokenStream_Expecter) Destroy() *TokenStream_Destroy_Call { + return &TokenStream_Destroy_Call{Call: _e.mock.On("Destroy")} +} + +func (_c *TokenStream_Destroy_Call) Run(run func()) *TokenStream_Destroy_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *TokenStream_Destroy_Call) Return() *TokenStream_Destroy_Call { + _c.Call.Return() + return _c +} + +func (_c *TokenStream_Destroy_Call) RunAndReturn(run func()) *TokenStream_Destroy_Call { + _c.Call.Return(run) + return _c +} + +// Token provides a mock function with given fields: +func (_m *TokenStream) Token() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// TokenStream_Token_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Token' +type TokenStream_Token_Call struct { + *mock.Call +} + +// Token is a helper method to define mock.On call +func (_e *TokenStream_Expecter) Token() *TokenStream_Token_Call { + return &TokenStream_Token_Call{Call: _e.mock.On("Token")} +} + +func (_c *TokenStream_Token_Call) Run(run func()) *TokenStream_Token_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *TokenStream_Token_Call) Return(_a0 string) *TokenStream_Token_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *TokenStream_Token_Call) RunAndReturn(run func() string) *TokenStream_Token_Call { + _c.Call.Return(run) + return _c +} + +// NewTokenStream creates a new instance of TokenStream. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewTokenStream(t interface { + mock.TestingT + Cleanup(func()) +}) *TokenStream { + mock := &TokenStream{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/util/tokenizerapi/mocks/Tokenizer.go b/internal/util/tokenizerapi/mocks/Tokenizer.go new file mode 100644 index 000000000000..e0dad6c19dda --- /dev/null +++ b/internal/util/tokenizerapi/mocks/Tokenizer.go @@ -0,0 +1,111 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mocks + +import ( + tokenizerapi "github.com/milvus-io/milvus/internal/util/tokenizerapi" + mock "github.com/stretchr/testify/mock" +) + +// Tokenizer is an autogenerated mock type for the Tokenizer type +type Tokenizer struct { + mock.Mock +} + +type Tokenizer_Expecter struct { + mock *mock.Mock +} + +func (_m *Tokenizer) EXPECT() *Tokenizer_Expecter { + return &Tokenizer_Expecter{mock: &_m.Mock} +} + +// Destroy provides a mock function with given fields: +func (_m *Tokenizer) Destroy() { + _m.Called() +} + +// Tokenizer_Destroy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Destroy' +type Tokenizer_Destroy_Call struct { + *mock.Call +} + +// Destroy is a helper method to define mock.On call +func (_e *Tokenizer_Expecter) Destroy() *Tokenizer_Destroy_Call { + return &Tokenizer_Destroy_Call{Call: _e.mock.On("Destroy")} +} + +func (_c *Tokenizer_Destroy_Call) Run(run func()) *Tokenizer_Destroy_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Tokenizer_Destroy_Call) Return() *Tokenizer_Destroy_Call { + _c.Call.Return() + return _c +} + +func (_c *Tokenizer_Destroy_Call) RunAndReturn(run func()) *Tokenizer_Destroy_Call { + _c.Call.Return(run) + return _c +} + +// NewTokenStream provides a mock function with given fields: text +func (_m *Tokenizer) NewTokenStream(text string) tokenizerapi.TokenStream { + ret := _m.Called(text) + + var r0 tokenizerapi.TokenStream + if rf, ok := ret.Get(0).(func(string) tokenizerapi.TokenStream); ok { + r0 = rf(text) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(tokenizerapi.TokenStream) + } + } + + return r0 +} + +// Tokenizer_NewTokenStream_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NewTokenStream' +type Tokenizer_NewTokenStream_Call struct { + *mock.Call +} + +// NewTokenStream is a helper method to define mock.On call +// - text string +func (_e *Tokenizer_Expecter) NewTokenStream(text interface{}) *Tokenizer_NewTokenStream_Call { + return &Tokenizer_NewTokenStream_Call{Call: _e.mock.On("NewTokenStream", text)} +} + +func (_c *Tokenizer_NewTokenStream_Call) Run(run func(text string)) *Tokenizer_NewTokenStream_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *Tokenizer_NewTokenStream_Call) Return(_a0 tokenizerapi.TokenStream) *Tokenizer_NewTokenStream_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Tokenizer_NewTokenStream_Call) RunAndReturn(run func(string) tokenizerapi.TokenStream) *Tokenizer_NewTokenStream_Call { + _c.Call.Return(run) + return _c +} + +// NewTokenizer creates a new instance of Tokenizer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewTokenizer(t interface { + mock.TestingT + Cleanup(func()) +}) *Tokenizer { + mock := &Tokenizer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/util/tokenizerapi/token_stream.go b/internal/util/tokenizerapi/token_stream.go new file mode 100644 index 000000000000..2df0f0202e4e --- /dev/null +++ b/internal/util/tokenizerapi/token_stream.go @@ -0,0 +1,8 @@ +package tokenizerapi + +//go:generate mockery --name=TokenStream --with-expecter +type TokenStream interface { + Advance() bool + Token() string + Destroy() +} diff --git a/internal/util/tokenizerapi/tokenizer.go b/internal/util/tokenizerapi/tokenizer.go new file mode 100644 index 000000000000..2b6debbec71f --- /dev/null +++ b/internal/util/tokenizerapi/tokenizer.go @@ -0,0 +1,7 @@ +package tokenizerapi + +//go:generate mockery --name=Tokenizer --with-expecter +type Tokenizer interface { + NewTokenStream(text string) TokenStream + Destroy() +} diff --git a/internal/util/vectorizer/vectorizer.go b/internal/util/vectorizer/vectorizer.go new file mode 100644 index 000000000000..7b223fdb4ad1 --- /dev/null +++ b/internal/util/vectorizer/vectorizer.go @@ -0,0 +1,79 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package vectorizer + +import ( + "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/tokenizerapi" + "github.com/milvus-io/milvus/pkg/util/typeutil" + "github.com/samber/lo" + "go.uber.org/zap" +) + +type Vectorizer interface { + Vectorize(data *storage.InsertData) (*storage.EmbeddingData, error) +} + +type HashVectorizer struct { + schema *schemapb.CollectionSchema + tokenizer tokenizerapi.Tokenizer + embedType schemapb.DataType +} + +func (v *HashVectorizer) Vectorize(data *storage.InsertData) (*storage.EmbeddingData, error) { + result := &storage.EmbeddingData{ + Data: make(map[int64]storage.FieldData), + } + + for _, field := range v.schema.Fields { + // TODO if field not embed + if field.DataType != schemapb.DataType_VarChar { + continue + } + + fieldData := data.Data[field.FieldID] + embedData, err := storage.NewFieldData(v.embedType, nil, fieldData.RowNum()) + if err != nil { + return nil, fmt.Errorf("create field data failed", zap.String("dataType", v.embedType.String())) + } + + for i := 0; i < fieldData.RowNum(); i++ { + rowData, ok := fieldData.GetRow(i).(string) + if !ok { + // TODO + return nil, fmt.Errorf("") + } + + embeddingMap := map[uint32]float32{} + tokenStream := v.tokenizer.NewTokenStream(rowData) + for tokenStream.Advance() { + token := tokenStream.Token() + // TODO More Hash Option + hash := typeutil.HashString2Uint32(token) + embeddingMap[hash] += 1 + } + embedData.AppendRow(typeutil.CreateSparseFloatRow(lo.Keys(embeddingMap), lo.Values(embeddingMap))) + } + result.Data[field.FieldID] = embedData + } + return result, nil +}