Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement fully asynchronous Sinks and graceful pod shutdown (Java runtime) #354

Merged
merged 6 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ public List<Record> read() throws Exception {
}

Document document = foundDocuments.remove();
processed(0, 1);
return List.of(
new WebCrawlerSourceRecord(
document.content().getBytes(StandardCharsets.UTF_8), document.url()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,17 @@ public void process(List<Record> records, RecordSink recordSink) {
log.error("Error processing record: {}", record, e);
recordSink.emit(new SourceRecordAndResult(record, null, e));
} else {
log.info("Processed record {}, results {}", record, resultRecords);
processed(1, records.size());
if (log.isDebugEnabled()) {
log.debug("Processed record {}, results {}", record, resultRecords);
}
processed(0, 1);
recordSink.emit(new SourceRecordAndResult(record, resultRecords, null));
}
});
}
}

public CompletableFuture<List<Record>> processRecord(Record record) {

log.info("Processing {}", record);
if (log.isDebugEnabled()) {
log.debug("Processing {}", record);
}
Expand All @@ -119,7 +119,9 @@ public CompletableFuture<List<Record>> processRecord(Record record) {
try {
context.convertMapToStringOrBytes();
Optional<Record> recordResult = transformContextToRecord(context);
log.info("Result {}", recordResult);
if (log.isDebugEnabled()) {
log.debug("Result {}", recordResult);
}
return recordResult.map(List::of).orElseGet(List::of);
} catch (Exception e) {
log.error("Error processing record: {}", record, e);
Expand Down Expand Up @@ -369,13 +371,15 @@ public void streamAnswerChunk(
int index, String message, boolean last, TransformContext outputMessage) {
Optional<Record> record = transformContextToRecord(outputMessage);
if (record.isPresent()) {
log.info(
"index: {}, message: {}, last: {}: record {}",
index,
message,
last,
record);
topicProducer.write(List.of(record.get()));
if (log.isDebugEnabled()) {
log.debug(
"index: {}, message: {}, last: {}: record {}",
index,
message,
last,
record);
}
topicProducer.write(record.get()).join();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
import ai.langstream.api.runner.code.AbstractAgentCode;
import ai.langstream.api.runner.code.AgentSink;
import ai.langstream.api.runner.code.Record;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class VectorDBSinkAgent extends AbstractAgentCode implements AgentSink {

private VectorDatabaseWriter writer;
private CommitCallback callback;

@Override
public void init(Map<String, Object> configuration) throws Exception {
Expand All @@ -47,18 +48,9 @@ public void close() throws Exception {
}

@Override
public void write(List<Record> records) throws Exception {

public CompletableFuture<?> write(Record record) {
// naive implementation, no batching
Map<String, Object> context = Map.of();
for (Record record : records) {
writer.upsert(record, context);
callback.commit(List.of(record));
}
}

@Override
public void setCommitCallback(CommitCallback callback) {
this.callback = callback;
return writer.upsert(record, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,10 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.kafka.clients.consumer.OffsetAndMetadata;
import org.apache.kafka.common.TopicPartition;

@Slf4j
public class CassandraWriter implements VectorDatabaseWriterProvider {
Expand All @@ -58,15 +54,13 @@ public VectorDatabaseWriter createImplementation(Map<String, Object> datasourceC
private static class CassandraVectorDatabaseWriter implements VectorDatabaseWriter {

private final Map<String, Object> datasourceConfig;
private Map<TopicPartition, OffsetAndMetadata> failureOffsets;
private final AbstractSinkTask processor = new SinkTaskProcessorImpl();

public CassandraVectorDatabaseWriter(Map<String, Object> datasourceConfig) {
log.debug(
"CassandraSinkTask starting with DataSource configuration: {}",
datasourceConfig);
this.datasourceConfig = datasourceConfig;
failureOffsets = new ConcurrentHashMap<>();
}

@Override
Expand Down Expand Up @@ -140,21 +134,13 @@ public void initialise(Map<String, Object> agentConfiguration) {
new AtomicReference<>();

@Override
public void upsert(Record record, Map<String, Object> context) throws Exception {
public CompletableFuture<?> upsert(Record record, Map<String, Object> context) {
// we must handle one record at a time
// so we block until the record is processed
CompletableFuture<?> handle = new CompletableFuture();
currentRecordStatus.set(handle);
processor.put(List.of(new LangStreamSinkRecordAdapter(record)));
try {
handle.join();
} catch (CompletionException error) {
if (error.getCause() instanceof Exception e) {
throw e;
} else {
throw error;
}
}
return handle;
}

@Override
Expand Down Expand Up @@ -220,18 +206,6 @@ protected void handleFailure(

failCounter.run();
}

@Override
protected void beforeProcessingBatch() {
super.beforeProcessingBatch();
failureOffsets.clear();
}

@Override
public void start(Map<String, String> props) {
failureOffsets = new ConcurrentHashMap<>();
super.start(props);
}
}

private static class LangStreamSinkRecordAdapter implements AbstractSinkRecord {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;

Expand Down Expand Up @@ -98,73 +99,81 @@ public void initialise(Map<String, Object> agentConfiguration) {
}

@Override
public void upsert(Record record, Map<String, Object> context) {

TransformContext transformContext =
GenAIToolKitAgent.recordToTransformContext(record, true);
String id = idFunction != null ? (String) idFunction.evaluate(transformContext) : null;
String namespace =
namespaceFunction != null
? (String) namespaceFunction.evaluate(transformContext)
: null;
List<Object> vector =
vectorFunction != null
? (List<Object>) vectorFunction.evaluate(transformContext)
: null;
Map<String, Object> metadata =
metadataFunctions.entrySet().stream()
.collect(
Collectors.toMap(
Map.Entry::getKey,
e -> e.getValue().evaluate(transformContext)));
Struct metadataStruct =
Struct.newBuilder()
.putAllFields(
metadata.entrySet().stream()
.collect(
Collectors.toMap(
Map.Entry::getKey,
e ->
PineconeDataSource
.convertToValue(
e.getValue()))))
.build();

List<Float> vectorFloat = null;
if (vector != null) {
vectorFloat =
vector.stream()
.map(
n -> {
if (n instanceof String s) {
return Float.parseFloat(s);
} else if (n instanceof Number u) {
return u.floatValue();
} else {
throw new IllegalArgumentException(
"only vectors of floats are supported");
}
})
.collect(Collectors.toList());
public CompletableFuture<?> upsert(Record record, Map<String, Object> context) {
CompletableFuture<?> handle = new CompletableFuture<>();
try {
TransformContext transformContext =
GenAIToolKitAgent.recordToTransformContext(record, true);
String id =
idFunction != null ? (String) idFunction.evaluate(transformContext) : null;
String namespace =
namespaceFunction != null
? (String) namespaceFunction.evaluate(transformContext)
: null;
List<Object> vector =
vectorFunction != null
? (List<Object>) vectorFunction.evaluate(transformContext)
: null;
Map<String, Object> metadata =
metadataFunctions.entrySet().stream()
.collect(
Collectors.toMap(
Map.Entry::getKey,
e -> e.getValue().evaluate(transformContext)));
Struct metadataStruct =
Struct.newBuilder()
.putAllFields(
metadata.entrySet().stream()
.collect(
Collectors.toMap(
Map.Entry::getKey,
e ->
PineconeDataSource
.convertToValue(
e
.getValue()))))
.build();

List<Float> vectorFloat = null;
if (vector != null) {
vectorFloat =
vector.stream()
.map(
n -> {
if (n instanceof String s) {
return Float.parseFloat(s);
} else if (n instanceof Number u) {
return u.floatValue();
} else {
throw new IllegalArgumentException(
"only vectors of floats are supported");
}
})
.collect(Collectors.toList());
}

Vector v1 =
Vector.newBuilder()
.setId(id)
.addAllValues(vectorFloat)
.setMetadata(metadataStruct)
.build();

UpsertRequest.Builder builder = UpsertRequest.newBuilder().addVectors(v1);

if (namespace != null) {
builder.setNamespace(namespace);
}
UpsertRequest upsertRequest = builder.build();

UpsertResponse upsertResponse = connection.getBlockingStub().upsert(upsertRequest);

log.info("Result {}", upsertResponse);
handle.complete(null);
} catch (Exception e) {
handle.completeExceptionally(e);
}

Vector v1 =
Vector.newBuilder()
.setId(id)
.addAllValues(vectorFloat)
.setMetadata(metadataStruct)
.build();

UpsertRequest.Builder builder = UpsertRequest.newBuilder().addVectors(v1);

if (namespace != null) {
builder.setNamespace(namespace);
}
UpsertRequest upsertRequest = builder.build();

UpsertResponse upsertResponse = connection.getBlockingStub().upsert(upsertRequest);

log.info("Result {}", upsertResponse);
return handle;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
import com.datastax.oss.driver.api.core.CqlSession;
import com.datastax.oss.driver.api.core.cql.ResultSet;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.testcontainers.containers.CassandraContainer;
Expand Down Expand Up @@ -76,13 +76,12 @@ void testWrite() throws Exception {

agent.init(configuration);
agent.start();
List<Record> committed = new ArrayList<>();
agent.setCommitCallback(committed::addAll);
List<Record> committed = new CopyOnWriteArrayList<>();

Map<String, Object> value =
Map.of("id", "1", "description", "test-description", "name", "test-name");
SimpleRecord record = SimpleRecord.of(null, new ObjectMapper().writeValueAsString(value));
agent.write(List.of(record));
agent.write(record).thenRun(() -> committed.add(record)).get();

assertEquals(committed.get(0), record);
agent.close();
Expand Down Expand Up @@ -117,13 +116,12 @@ void testWriteAstra() throws Exception {

agent.init(configuration);
agent.start();
List<Record> committed = new ArrayList<>();
agent.setCommitCallback(committed::addAll);
List<Record> committed = new CopyOnWriteArrayList<>();

Map<String, Object> value =
Map.of("id", "1", "description", "test-description", "name", "test-name");
SimpleRecord record = SimpleRecord.of(null, new ObjectMapper().writeValueAsString(value));
agent.write(List.of(record));
agent.write(record).thenRun(() -> committed.add(record)).get();

assertEquals(committed.get(0), record);
agent.close();
Expand Down
Loading
Loading