diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/DSLConfigKeys.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/DSLConfigKeys.java index 76cbf2902..3319c752d 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/DSLConfigKeys.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/DSLConfigKeys.java @@ -142,4 +142,34 @@ public class DSLConfigKeys implements Serializable { .key("geaflow.dsl.source.parallelism") .noDefaultValue() .description("Set source parallelism"); + + public static final ConfigKey GEAFLOW_DSL_GCN_HOPS = ConfigKeys + .key("geaflow.dsl.gcn.hops") + .defaultValue(2) + .description("The hop count for built-in gcn."); + + public static final ConfigKey GEAFLOW_DSL_GCN_FANOUT = ConfigKeys + .key("geaflow.dsl.gcn.fanout") + .defaultValue(-1) + .description("The max sampled neighbors per visited vertex for built-in gcn."); + + public static final ConfigKey GEAFLOW_DSL_GCN_EDGE_DIRECTION = ConfigKeys + .key("geaflow.dsl.gcn.edge.direction") + .defaultValue("BOTH") + .description("The edge direction for built-in gcn. Optional values: IN, OUT, BOTH."); + + public static final ConfigKey GEAFLOW_DSL_GCN_VERTEX_FEATURE_FIELDS = ConfigKeys + .key("geaflow.dsl.gcn.vertex.feature.fields") + .noDefaultValue() + .description("Comma-separated vertex feature field names for built-in gcn."); + + public static final ConfigKey GEAFLOW_DSL_GCN_BATCH_SIZE = ConfigKeys + .key("geaflow.dsl.gcn.batch.size") + .defaultValue(64) + .description("The max infer batch size for built-in gcn."); + + public static final ConfigKey GEAFLOW_DSL_GCN_EDGE_WEIGHT_FIELD = ConfigKeys + .key("geaflow.dsl.gcn.edge.weight.field") + .noDefaultValue() + .description("Optional edge weight field name for built-in gcn."); } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/FrameworkConfigKeys.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/FrameworkConfigKeys.java index 441370ab5..e09fe6b3e 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/FrameworkConfigKeys.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/FrameworkConfigKeys.java @@ -118,6 +118,16 @@ public class FrameworkConfigKeys implements Serializable { .defaultValue(true) .description("infer env suppress log enable, default is true"); + public static final ConfigKey INFER_CONTEXT_POOL_MAX_SIZE = ConfigKeys + .key("geaflow.infer.context.pool.max.size") + .defaultValue(8) + .description("max infer context count for the same config key, default is 8"); + + public static final ConfigKey INFER_CONTEXT_POOL_BORROW_TIMEOUT_SEC = ConfigKeys + .key("geaflow.infer.context.pool.borrow.timeout.sec") + .defaultValue(30) + .description("max wait time for borrowing infer context, default is 30 seconds"); + public static final ConfigKey INFER_USER_DEFINE_LIB_PATH = ConfigKeys .key("geaflow.infer.user.define.lib.path") .noDefaultValue() @@ -169,4 +179,3 @@ public class FrameworkConfigKeys implements Serializable { .description("in dynmic graph, whether udf function materialize graph in finish"); } - diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/PortUtil.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/PortUtil.java index b97b45e7c..3e8e53b4b 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/PortUtil.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/PortUtil.java @@ -43,7 +43,13 @@ public static int getPort(int minPort, int maxPort) { num++; } } - throw new RuntimeException(String.format("no available port in [%d,%d]", minPort, maxPort)); + // Fallback to an ephemeral port chosen by OS when the configured range is unavailable + // (for example, in constrained CI/network environments). + try (ServerSocket serverSocket = new ServerSocket(0)) { + return serverSocket.getLocalPort(); + } catch (Exception e) { + throw new RuntimeException(String.format("no available port in [%d,%d]", minPort, maxPort)); + } } public static int getPort(int port) { diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/AlgorithmModelRuntimeContext.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/AlgorithmModelRuntimeContext.java new file mode 100644 index 000000000..6d6bd9ce8 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/AlgorithmModelRuntimeContext.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.common.algo; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.apache.geaflow.dsl.common.data.Row; +import org.apache.geaflow.dsl.common.data.RowEdge; +import org.apache.geaflow.model.graph.edge.EdgeDirection; + +/** + * Runtime context for model-backed algorithms that need inference and access to + * the active vertex's dynamic value and edges in the current batch. + * + * @param The type of vertex IDs. + * @param The type of messages that can be sent between vertices. + */ +public interface AlgorithmModelRuntimeContext extends AlgorithmRuntimeContext { + + Object infer(Map payload); + + default List inferBatch(List> payloads) { + List results = new ArrayList<>(payloads.size()); + for (Map payload : payloads) { + results.add(infer(payload)); + } + return results; + } + + Row loadDynamicVertexValue(Object vertexId); + + List loadDynamicEdges(Object vertexId, EdgeDirection direction); +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/BatchAlgorithmUserFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/BatchAlgorithmUserFunction.java new file mode 100644 index 000000000..b239a3d56 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/BatchAlgorithmUserFunction.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.common.algo; + +import java.util.List; +import java.util.Optional; +import org.apache.geaflow.dsl.common.data.Row; +import org.apache.geaflow.dsl.common.data.RowVertex; + +/** + * Optional extension for algorithms that can finalize results in batches. + */ +public interface BatchAlgorithmUserFunction extends AlgorithmUserFunction { + + void finishBatch(List graphVertices, List> updatedValues); +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/util/TypeCastUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/util/TypeCastUtil.java index f5631f335..8eaf8fea2 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/util/TypeCastUtil.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/util/TypeCastUtil.java @@ -91,6 +91,9 @@ public static ITypeCast getTypeCast(Class sourceType, Class t if (sourceType == targetType) { return identityCast; } + if (sourceType.isArray() && targetType == String.class) { + return (ITypeCast) new Array2String(); + } if (sourceType.isArray() && targetType.isArray()) { ITypeCast componentTypeCast = getTypeCast(sourceType.getComponentType(), targetType.getComponentType()); return (ITypeCast) new ArrayCast(componentTypeCast, targetType.getComponentType()); @@ -179,6 +182,30 @@ public Object castTo(Object objects) { } } + private static class Array2String implements ITypeCast { + + @Override + public String castTo(Object objects) { + if (objects == null) { + return null; + } + if (!objects.getClass().isArray()) { + return String.valueOf(objects); + } + int length = Array.getLength(objects); + StringBuilder builder = new StringBuilder(); + builder.append('['); + for (int i = 0; i < length; i++) { + if (i > 0) { + builder.append(", "); + } + builder.append(String.valueOf(Array.get(objects, i))); + } + builder.append(']'); + return builder.toString(); + } + } + private static class Int2Long implements ITypeCast { @Override diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java index 47addc84a..724960e07 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java @@ -39,6 +39,7 @@ import org.apache.geaflow.dsl.udf.graph.ClusterCoefficient; import org.apache.geaflow.dsl.udf.graph.CommonNeighbors; import org.apache.geaflow.dsl.udf.graph.ConnectedComponents; +import org.apache.geaflow.dsl.udf.graph.GCN; import org.apache.geaflow.dsl.udf.graph.IncKHopAlgorithm; import org.apache.geaflow.dsl.udf.graph.IncMinimumSpanningTree; import org.apache.geaflow.dsl.udf.graph.IncWeakConnectedComponents; @@ -226,6 +227,7 @@ public class BuildInSqlFunctionTable extends ListSqlOperatorTable { .add(GeaFlowFunction.of(SingleSourceShortestPath.class)) .add(GeaFlowFunction.of(AllSourceShortestPath.class)) .add(GeaFlowFunction.of(PageRank.class)) + .add(GeaFlowFunction.of(GCN.class)) .add(GeaFlowFunction.of(KHop.class)) .add(GeaFlowFunction.of(KCore.class)) .add(GeaFlowFunction.of(IncrementalKCore.class)) diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/GCN.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/GCN.java new file mode 100644 index 000000000..f40469365 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/GCN.java @@ -0,0 +1,467 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import org.apache.geaflow.common.config.ConfigHelper; +import org.apache.geaflow.common.config.keys.DSLConfigKeys; +import org.apache.geaflow.common.config.keys.FrameworkConfigKeys; +import org.apache.geaflow.common.type.IType; +import org.apache.geaflow.common.type.primitive.ByteType; +import org.apache.geaflow.common.type.primitive.DecimalType; +import org.apache.geaflow.common.type.primitive.DoubleType; +import org.apache.geaflow.common.type.primitive.FloatType; +import org.apache.geaflow.common.type.primitive.IntegerType; +import org.apache.geaflow.common.type.primitive.LongType; +import org.apache.geaflow.common.type.primitive.ShortType; +import org.apache.geaflow.dsl.common.algo.AlgorithmModelRuntimeContext; +import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; +import org.apache.geaflow.dsl.common.algo.AlgorithmUserFunction; +import org.apache.geaflow.dsl.common.algo.BatchAlgorithmUserFunction; +import org.apache.geaflow.dsl.common.algo.IncrementalAlgorithmUserFunction; +import org.apache.geaflow.dsl.common.data.Row; +import org.apache.geaflow.dsl.common.data.RowEdge; +import org.apache.geaflow.dsl.common.data.RowVertex; +import org.apache.geaflow.dsl.common.data.impl.ObjectRow; +import org.apache.geaflow.dsl.common.function.Description; +import org.apache.geaflow.dsl.common.types.ArrayType; +import org.apache.geaflow.dsl.common.types.EdgeType; +import org.apache.geaflow.dsl.common.types.GraphSchema; +import org.apache.geaflow.dsl.common.types.ObjectType; +import org.apache.geaflow.dsl.common.types.StructType; +import org.apache.geaflow.dsl.common.types.TableField; +import org.apache.geaflow.dsl.common.types.VertexType; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNEdgeRecord; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNExpandMessage; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNFragmentMessage; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNPayloadAssembler; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNResultDecoder; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNState; +import org.apache.geaflow.dsl.udf.graph.gcn.MergedNeighborhoodCollector; +import org.apache.geaflow.model.graph.edge.EdgeDirection; + +@Description(name = "gcn", description = "built-in udga for GCN node inference") +public class GCN implements AlgorithmUserFunction, + BatchAlgorithmUserFunction, IncrementalAlgorithmUserFunction { + + private AlgorithmRuntimeContext context; + private AlgorithmModelRuntimeContext modelContext; + private final MergedNeighborhoodCollector neighborhoodCollector = new MergedNeighborhoodCollector(); + private final GCNPayloadAssembler payloadAssembler = new GCNPayloadAssembler(); + private final GCNResultDecoder resultDecoder = new GCNResultDecoder(); + + private int hops; + private int fanout; + private int batchSize; + private EdgeDirection edgeDirection; + private Map featureIndexesByLabel; + private Map edgeWeightIndexesByLabel; + private boolean inferEnabled; + private boolean edgeWeightEnabled; + + @Override + public void init(AlgorithmRuntimeContext context, Object[] parameters) { + if (parameters.length > 0) { + throw new IllegalArgumentException("CALL gcn() does not accept arguments"); + } + this.context = context; + if (!(context instanceof AlgorithmModelRuntimeContext)) { + throw new IllegalArgumentException("GCN requires model runtime context"); + } + this.modelContext = (AlgorithmModelRuntimeContext) context; + this.inferEnabled = ConfigHelper.getBooleanOrDefault(context.getConfig().getConfigMap(), + FrameworkConfigKeys.INFER_ENV_ENABLE.getKey(), false); + this.hops = ConfigHelper.getIntegerOrDefault(context.getConfig().getConfigMap(), + DSLConfigKeys.GEAFLOW_DSL_GCN_HOPS.getKey(), + (Integer) DSLConfigKeys.GEAFLOW_DSL_GCN_HOPS.getDefaultValue()); + if (hops < 1) { + throw new IllegalArgumentException("geaflow.dsl.gcn.hops must be >= 1"); + } + this.fanout = ConfigHelper.getIntegerOrDefault(context.getConfig().getConfigMap(), + DSLConfigKeys.GEAFLOW_DSL_GCN_FANOUT.getKey(), + (Integer) DSLConfigKeys.GEAFLOW_DSL_GCN_FANOUT.getDefaultValue()); + if (fanout == 0 || fanout < -1) { + throw new IllegalArgumentException("geaflow.dsl.gcn.fanout must be -1 or > 0"); + } + this.batchSize = ConfigHelper.getIntegerOrDefault(context.getConfig().getConfigMap(), + DSLConfigKeys.GEAFLOW_DSL_GCN_BATCH_SIZE.getKey(), + (Integer) DSLConfigKeys.GEAFLOW_DSL_GCN_BATCH_SIZE.getDefaultValue()); + if (batchSize < 1) { + throw new IllegalArgumentException("geaflow.dsl.gcn.batch.size must be >= 1"); + } + String direction = ConfigHelper.getStringOrDefault(context.getConfig().getConfigMap(), + DSLConfigKeys.GEAFLOW_DSL_GCN_EDGE_DIRECTION.getKey(), + String.valueOf(DSLConfigKeys.GEAFLOW_DSL_GCN_EDGE_DIRECTION.getDefaultValue())); + this.edgeDirection = parseEdgeDirection(direction); + + String featureFieldConfig = context.getConfig() + .getString(DSLConfigKeys.GEAFLOW_DSL_GCN_VERTEX_FEATURE_FIELDS.getKey()); + if (featureFieldConfig == null || featureFieldConfig.trim().isEmpty()) { + throw new IllegalArgumentException("geaflow.dsl.gcn.vertex.feature.fields must be configured"); + } + List featureFields = parseFields(featureFieldConfig, "GCN feature fields"); + this.featureIndexesByLabel = resolveVertexFeatureIndexes(context.getGraphSchema(), featureFields); + + String edgeWeightField = context.getConfig() + .getString(DSLConfigKeys.GEAFLOW_DSL_GCN_EDGE_WEIGHT_FIELD.getKey()); + this.edgeWeightEnabled = edgeWeightField != null && !edgeWeightField.trim().isEmpty(); + this.edgeWeightIndexesByLabel = resolveEdgeWeightIndexes(context.getGraphSchema(), + edgeWeightEnabled ? edgeWeightField.trim() : null); + } + + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + if (vertex == null) { + return; + } + if (context.getCurrentIterationId() == 1L) { + initializeWindowState(vertex); + return; + } + + Object vertexId = vertex.getId(); + GCNState state = loadState(updatedValues); + List fragmentMessages = new ArrayList<>(); + Map expandMessages = new LinkedHashMap<>(); + while (messages.hasNext()) { + Object message = messages.next(); + if (message instanceof GCNFragmentMessage) { + fragmentMessages.add((GCNFragmentMessage) message); + } else if (message instanceof GCNExpandMessage) { + GCNExpandMessage expandMessage = (GCNExpandMessage) message; + GCNExpandMessage existing = expandMessages.get(expandMessage.getRootId()); + if (existing == null || expandMessage.getDepth() < existing.getDepth()) { + expandMessages.put(expandMessage.getRootId(), expandMessage); + } + } + } + + boolean stateChanged = false; + if (state.getAccumulator() != null) { + for (GCNFragmentMessage fragmentMessage : fragmentMessages) { + if (vertexId.equals(fragmentMessage.getRootId())) { + state.getAccumulator().addFragment(fragmentMessage); + if (fragmentMessage.isDynamicExecution()) { + state.setDynamicExecution(true); + } + stateChanged = true; + } + } + } + + for (GCNExpandMessage expandMessage : expandMessages.values()) { + if (!state.shouldExpand(expandMessage.getRootId(), expandMessage.getDepth())) { + continue; + } + GCNFragmentMessage localFragment = buildLocalFragment(expandMessage.getRootId(), vertex); + context.sendMessage(expandMessage.getRootId(), localFragment); + sendExpandMessages(expandMessage.getRootId(), expandMessage.getDepth(), vertexId, + localFragment.getEdges()); + stateChanged = true; + } + + if (stateChanged) { + context.updateVertexValue(ObjectRow.create(state)); + } + } + + @Override + public void finish(RowVertex graphVertex, Optional updatedValues) { + finishBatch(java.util.Collections.singletonList(graphVertex), + java.util.Collections.singletonList(updatedValues)); + } + + @Override + public void finishBatch(List graphVertices, List> updatedValues) { + if (!inferEnabled || graphVertices.isEmpty()) { + if (!inferEnabled) { + throw new IllegalStateException("GCN requires geaflow.infer.env.enable=true"); + } + return; + } + List batchVertices = new ArrayList<>(batchSize); + List> payloads = new ArrayList<>(batchSize); + for (int i = 0; i < graphVertices.size(); i++) { + RowVertex graphVertex = graphVertices.get(i); + GCNState state = loadState(updatedValues.get(i)); + if (graphVertex == null || state.getAccumulator() == null) { + continue; + } + batchVertices.add(graphVertex); + payloads.add(payloadAssembler.assemble(graphVertex.getId(), state.getAccumulator())); + if (payloads.size() >= batchSize) { + emitBatch(batchVertices, payloads); + } + } + emitBatch(batchVertices, payloads); + } + + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType( + new TableField("node_id", graphSchema.getIdType(), false), + new TableField("embedding", new ArrayType(DoubleType.INSTANCE)), + new TableField("prediction", IntegerType.INSTANCE), + new TableField("confidence", DoubleType.INSTANCE) + ); + } + + @Override + public void finish() { + } + + private void initializeWindowState(RowVertex vertex) { + GCNState state = new GCNState(); + Object vertexId = vertex.getId(); + state.initAccumulator(vertexId); + state.shouldExpand(vertexId, 0); + GCNFragmentMessage localFragment = buildLocalFragment(vertexId, vertex); + state.getAccumulator().addFragment(localFragment); + state.setDynamicExecution(localFragment.isDynamicExecution()); + context.updateVertexValue(ObjectRow.create(state)); + sendExpandMessages(vertexId, 0, vertexId, localFragment.getEdges()); + } + + private void emitBatch(List graphVertices, List> payloads) { + if (payloads.isEmpty()) { + return; + } + List inferResults = modelContext.inferBatch(payloads); + if (inferResults.size() != payloads.size()) { + throw new IllegalArgumentException(String.format( + "GCN infer batch result size mismatch, payloadSize=%s, resultSize=%s", + payloads.size(), inferResults.size())); + } + for (int i = 0; i < inferResults.size(); i++) { + RowVertex vertex = graphVertices.get(i); + Row resultRow = resultDecoder.decode(vertex.getId(), inferResults.get(i)); + context.take(resultRow); + } + graphVertices.clear(); + payloads.clear(); + } + + private List parseFields(String fieldConfig, String description) { + List fields = new ArrayList<>(); + for (String token : fieldConfig.split(",")) { + String field = token.trim(); + if (!field.isEmpty()) { + fields.add(field); + } + } + if (fields.isEmpty()) { + throw new IllegalArgumentException("No valid " + description + " configured"); + } + return fields; + } + + private Map resolveVertexFeatureIndexes(GraphSchema graphSchema, List fields) { + Map indexesByLabel = new LinkedHashMap<>(); + String resolvedSignature = null; + for (TableField graphField : graphSchema.getFields()) { + if (!(graphField.getType() instanceof VertexType)) { + continue; + } + VertexType vertexType = (VertexType) graphField.getType(); + StructType valueType = new StructType(vertexType.getValueFields()); + int[] indexes = new int[fields.size()]; + StringBuilder signature = new StringBuilder(); + for (int i = 0; i < fields.size(); i++) { + int index = valueType.indexOf(fields.get(i)); + if (index < 0) { + throw new IllegalArgumentException(String.format( + "GCN feature field '%s' missing from vertex label '%s'", + fields.get(i), graphField.getName())); + } + TableField field = valueType.getField(index); + validateNumericField(field, "GCN feature field"); + indexes[i] = index; + if (signature.length() > 0) { + signature.append(';'); + } + signature.append(field.getName()).append(':') + .append(field.getType().getClass().getName()); + } + String candidateSignature = signature.toString(); + if (resolvedSignature == null) { + resolvedSignature = candidateSignature; + } else if (!resolvedSignature.equals(candidateSignature)) { + throw new IllegalArgumentException( + "GCN feature fields are inconsistent across vertex value schemas"); + } + indexesByLabel.put(graphField.getName(), indexes); + } + if (indexesByLabel.isEmpty()) { + throw new IllegalArgumentException("GCN requires at least one vertex schema"); + } + return indexesByLabel; + } + + private Map resolveEdgeWeightIndexes(GraphSchema graphSchema, String weightField) { + Map indexesByLabel = new LinkedHashMap<>(); + if (weightField == null) { + return indexesByLabel; + } + for (TableField graphField : graphSchema.getFields()) { + if (!(graphField.getType() instanceof EdgeType)) { + continue; + } + EdgeType edgeType = (EdgeType) graphField.getType(); + StructType valueType = new StructType(edgeType.getValueFields()); + int index = valueType.indexOf(weightField); + if (index < 0) { + throw new IllegalArgumentException(String.format( + "GCN edge weight field '%s' missing from edge label '%s'", + weightField, graphField.getName())); + } + TableField field = valueType.getField(index); + validateNumericField(field, "GCN edge weight field"); + indexesByLabel.put(graphField.getName(), index); + } + return indexesByLabel; + } + + private EdgeDirection parseEdgeDirection(String direction) { + try { + return EdgeDirection.valueOf(direction.trim().toUpperCase(Locale.ROOT)); + } catch (Exception e) { + throw new IllegalArgumentException( + "geaflow.dsl.gcn.edge.direction must be one of IN, OUT, BOTH"); + } + } + + private void validateNumericField(TableField field, String fieldType) { + if (!isNumericScalarType(field.getType())) { + throw new IllegalArgumentException(fieldType + " must be numeric scalar: " + + field.getName()); + } + } + + private boolean isNumericScalarType(IType type) { + return type instanceof ByteType + || type instanceof ShortType + || type instanceof IntegerType + || type instanceof LongType + || type instanceof FloatType + || type instanceof DoubleType + || type instanceof DecimalType; + } + + private GCNState loadState(Optional updatedValues) { + if (!updatedValues.isPresent()) { + return new GCNState(); + } + Object state = updatedValues.get().getField(0, ObjectType.INSTANCE); + if (state instanceof GCNState) { + return (GCNState) state; + } + return new GCNState(); + } + + private GCNFragmentMessage buildLocalFragment(Object rootId, RowVertex vertex) { + Object vertexId = vertex.getId(); + List mergedEdges = neighborhoodCollector.collectMergedEdges(vertexId, + context.loadStaticEdges(edgeDirection), modelContext, edgeDirection, fanout); + Row mergedValue = neighborhoodCollector.resolveMergedVertexValue(vertexId, vertex, modelContext); + boolean dynamicOverlay = modelContext.loadDynamicVertexValue(vertexId) != null + || !modelContext.loadDynamicEdges(vertexId, edgeDirection).isEmpty(); + List features = extractFeatures(vertex, mergedValue); + List edgeRecords = new ArrayList<>(mergedEdges.size()); + for (RowEdge edge : mergedEdges) { + edgeRecords.add(new GCNEdgeRecord(edge.getSrcId(), edge.getTargetId(), edge.getLabel(), + edge.getDirect(), extractEdgeWeight(edge))); + } + return new GCNFragmentMessage(rootId, vertexId, features, edgeRecords, dynamicOverlay); + } + + private List extractFeatures(RowVertex vertex, Row row) { + int[] featureIndexes = featureIndexesByLabel.get(vertex.getLabel()); + if (featureIndexes == null) { + throw new IllegalArgumentException("GCN vertex label is not supported: " + vertex.getLabel()); + } + List features = new ArrayList<>(featureIndexes.length); + if (row == null) { + for (int i = 0; i < featureIndexes.length; i++) { + features.add(0.0D); + } + return features; + } + for (int index : featureIndexes) { + Object field = row.getField(index, ObjectType.INSTANCE); + if (field == null) { + features.add(0.0D); + } else if (field instanceof Number) { + features.add(((Number) field).doubleValue()); + } else { + throw new IllegalArgumentException("GCN feature field value must be numeric: " + field); + } + } + return features; + } + + private double extractEdgeWeight(RowEdge edge) { + if (!edgeWeightEnabled) { + return 1.0D; + } + Integer index = edgeWeightIndexesByLabel.get(edge.getLabel()); + if (index == null) { + throw new IllegalArgumentException("GCN edge label is not supported: " + edge.getLabel()); + } + Row value = edge.getValue(); + if (value == null) { + return 0.0D; + } + Object field = value.getField(index, ObjectType.INSTANCE); + if (field == null) { + return 0.0D; + } + if (field instanceof Number) { + return ((Number) field).doubleValue(); + } + throw new IllegalArgumentException("GCN edge weight field value must be numeric: " + field); + } + + private void sendExpandMessages(Object rootId, int currentDepth, Object currentVertexId, + List edgeRecords) { + if (currentDepth >= hops) { + return; + } + Set nextHop = new LinkedHashSet<>(); + for (GCNEdgeRecord edgeRecord : edgeRecords) { + Object neighborId = currentVertexId.equals(edgeRecord.getSrcId()) + ? edgeRecord.getTargetId() : edgeRecord.getSrcId(); + if (neighborId == null || currentVertexId.equals(neighborId)) { + continue; + } + nextHop.add(neighborId); + } + for (Object neighborId : nextHop) { + context.sendMessage(neighborId, new GCNExpandMessage(rootId, currentDepth + 1)); + } + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNAccumulator.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNAccumulator.java new file mode 100644 index 000000000..d4272b776 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNAccumulator.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph.gcn; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +public class GCNAccumulator implements Serializable { + + private final Object rootId; + private final Map> nodeFeatures; + private final Map edgeRecords; + + public GCNAccumulator(Object rootId) { + this.rootId = rootId; + this.nodeFeatures = new LinkedHashMap<>(); + this.edgeRecords = new LinkedHashMap<>(); + } + + public Object getRootId() { + return rootId; + } + + public Map> getNodeFeatures() { + return nodeFeatures; + } + + public List getEdgeRecords() { + return new ArrayList<>(edgeRecords.values()); + } + + public void addFragment(GCNFragmentMessage fragmentMessage) { + nodeFeatures.put(fragmentMessage.getVertexId(), + new ArrayList<>(fragmentMessage.getFeatures())); + for (GCNEdgeRecord edgeRecord : fragmentMessage.getEdges()) { + String identity = edgeRecord.identity(); + GCNEdgeRecord existing = edgeRecords.get(identity); + if (existing == null) { + edgeRecords.put(identity, edgeRecord); + } else { + edgeRecords.put(identity, existing.mergeWeight(edgeRecord.getWeight())); + } + } + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNEdgeRecord.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNEdgeRecord.java new file mode 100644 index 000000000..3cdbfae2c --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNEdgeRecord.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph.gcn; + +import java.io.Serializable; +import java.util.Objects; +import org.apache.geaflow.model.graph.edge.EdgeDirection; + +public class GCNEdgeRecord implements Serializable { + + private final Object srcId; + private final Object targetId; + private final String label; + private final EdgeDirection direction; + private final double weight; + + public GCNEdgeRecord(Object srcId, Object targetId, String label, EdgeDirection direction, + double weight) { + this.srcId = srcId; + this.targetId = targetId; + this.label = label; + this.direction = direction; + this.weight = weight; + } + + public Object getSrcId() { + return srcId; + } + + public Object getTargetId() { + return targetId; + } + + public String getLabel() { + return label; + } + + public EdgeDirection getDirection() { + return direction; + } + + public double getWeight() { + return weight; + } + + public String identity() { + return String.valueOf(srcId) + "->" + targetId + "@" + label + "#" + direction; + } + + public GCNEdgeRecord mergeWeight(double delta) { + return new GCNEdgeRecord(srcId, targetId, label, direction, weight + delta); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof GCNEdgeRecord)) { + return false; + } + GCNEdgeRecord that = (GCNEdgeRecord) o; + return Objects.equals(srcId, that.srcId) && Objects.equals(targetId, that.targetId) + && Objects.equals(label, that.label) && direction == that.direction + && Double.compare(weight, that.weight) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(srcId, targetId, label, direction, weight); + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNExpandMessage.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNExpandMessage.java new file mode 100644 index 000000000..608c71c8e --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNExpandMessage.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph.gcn; + +import java.io.Serializable; +import java.util.Objects; + +public class GCNExpandMessage implements Serializable { + + private final Object rootId; + private final int depth; + + public GCNExpandMessage(Object rootId, int depth) { + this.rootId = rootId; + this.depth = depth; + } + + public Object getRootId() { + return rootId; + } + + public int getDepth() { + return depth; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof GCNExpandMessage)) { + return false; + } + GCNExpandMessage that = (GCNExpandMessage) o; + return depth == that.depth && Objects.equals(rootId, that.rootId); + } + + @Override + public int hashCode() { + return Objects.hash(rootId, depth); + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNFragmentMessage.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNFragmentMessage.java new file mode 100644 index 000000000..cac6d9be5 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNFragmentMessage.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph.gcn; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class GCNFragmentMessage implements Serializable { + + private final Object rootId; + private final Object vertexId; + private final List features; + private final List edges; + private final boolean dynamicExecution; + + public GCNFragmentMessage(Object rootId, Object vertexId, List features, + List edges, boolean dynamicExecution) { + this.rootId = rootId; + this.vertexId = vertexId; + this.features = new ArrayList<>(features); + this.edges = new ArrayList<>(edges); + this.dynamicExecution = dynamicExecution; + } + + public Object getRootId() { + return rootId; + } + + public Object getVertexId() { + return vertexId; + } + + public List getFeatures() { + return Collections.unmodifiableList(features); + } + + public List getEdges() { + return Collections.unmodifiableList(edges); + } + + public boolean isDynamicExecution() { + return dynamicExecution; + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNPayloadAssembler.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNPayloadAssembler.java new file mode 100644 index 000000000..7dd32d403 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNPayloadAssembler.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph.gcn; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +public class GCNPayloadAssembler { + + /** + * Assemble payload for Python GCN transformer. + * + *
+     * {
+     *   "center_node_id": int|long,
+     *   "sampled_nodes": List[int|long],
+     *   "node_features": List[List[float]],
+     *   "edge_index": List[List[int]],   // 2 x E local indices
+     *   "edge_weight": List[float]
+     * }
+     * 
+ */ + public Map assemble(Object centerNodeId, GCNAccumulator accumulator) { + Map vertexIndex = new LinkedHashMap<>(); + List vertexIds = new ArrayList<>(); + List> features = new ArrayList<>(); + + for (Map.Entry> entry : accumulator.getNodeFeatures().entrySet()) { + vertexIndex.put(entry.getKey(), vertexIds.size()); + vertexIds.add(entry.getKey()); + features.add(new ArrayList<>(entry.getValue())); + } + + Map directedEdges = new LinkedHashMap<>(); + Map degreeMap = new LinkedHashMap<>(); + for (GCNEdgeRecord edgeRecord : accumulator.getEdgeRecords()) { + Integer srcIndex = vertexIndex.get(edgeRecord.getSrcId()); + Integer targetIndex = vertexIndex.get(edgeRecord.getTargetId()); + if (srcIndex == null || targetIndex == null) { + continue; + } + addEdge(directedEdges, degreeMap, srcIndex, targetIndex, edgeRecord.getWeight()); + } + for (int i = 0; i < vertexIds.size(); i++) { + addEdge(directedEdges, degreeMap, i, i, 1.0D); + } + + List row = new ArrayList<>(); + List col = new ArrayList<>(); + List edgeWeight = new ArrayList<>(); + for (double[] edge : directedEdges.values()) { + int src = (int) edge[0]; + int target = (int) edge[1]; + double rawWeight = edge[2]; + row.add(src); + col.add(target); + double srcDegree = degreeMap.getOrDefault(src, 1.0D); + double targetDegree = degreeMap.getOrDefault(target, 1.0D); + edgeWeight.add(rawWeight / Math.sqrt(srcDegree * targetDegree)); + } + + Map payload = new LinkedHashMap<>(); + payload.put("center_node_id", centerNodeId); + payload.put("sampled_nodes", vertexIds); + payload.put("node_features", features); + payload.put("edge_index", Arrays.asList(row, col)); + payload.put("edge_weight", edgeWeight); + return payload; + } + + private void addEdge(Map directedEdges, Map degreeMap, int src, + int target, double weight) { + String key = src + "->" + target; + double[] edge = directedEdges.get(key); + if (edge == null) { + edge = new double[]{src, target, 0.0D}; + directedEdges.put(key, edge); + } + edge[2] += weight; + degreeMap.put(src, degreeMap.getOrDefault(src, 0.0D) + weight); + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNResultDecoder.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNResultDecoder.java new file mode 100644 index 000000000..b907e693d --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNResultDecoder.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph.gcn; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.apache.geaflow.dsl.common.data.Row; +import org.apache.geaflow.dsl.common.data.impl.ObjectRow; + +public class GCNResultDecoder { + + public Row decode(Object vertexId, Object inferResult) { + if (!(inferResult instanceof Map)) { + throw new IllegalArgumentException( + "GCN infer result must be a map: " + (inferResult == null ? "null" : inferResult.getClass())); + } + Map resultMap = (Map) inferResult; + // node_id is required by protocol but we treat Java vertexId as authoritative. + // If it exists and is numeric, do a best-effort sanity check. + Object nodeId = resultMap.get("node_id"); + if (nodeId != null && vertexId instanceof Number && nodeId instanceof Number) { + long expected = ((Number) vertexId).longValue(); + long actual = ((Number) nodeId).longValue(); + if (expected != actual) { + throw new IllegalArgumentException( + "GCN infer result node_id mismatch, expected=" + vertexId + ", actual=" + nodeId); + } + } + + Integer prediction = asInteger(resultMap.get("prediction")); + Double confidence = asDouble(resultMap.get("confidence")); + Double[] embedding = asEmbedding(resultMap.get("embedding")); + return ObjectRow.create(vertexId, embedding, prediction, confidence); + } + + private Integer asInteger(Object value) { + if (value == null) { + return null; + } + if (value instanceof Number) { + return ((Number) value).intValue(); + } + try { + return Integer.parseInt(String.valueOf(value)); + } catch (Exception e) { + throw new IllegalArgumentException("Invalid prediction value: " + value, e); + } + } + + private Double asDouble(Object value) { + if (value == null) { + return null; + } + if (value instanceof Number) { + return ((Number) value).doubleValue(); + } + try { + return Double.parseDouble(String.valueOf(value)); + } catch (Exception e) { + throw new IllegalArgumentException("Invalid confidence or embedding value: " + value, e); + } + } + + private Double[] asEmbedding(Object value) { + if (value == null) { + return null; + } + if (value instanceof Double[]) { + return (Double[]) value; + } + List embeddingValues = new ArrayList<>(); + if (value instanceof List) { + for (Object item : (List) value) { + embeddingValues.add(asDouble(item)); + } + } else if (value instanceof Object[]) { + for (Object item : (Object[]) value) { + embeddingValues.add(asDouble(item)); + } + } else { + return null; + } + return embeddingValues.toArray(new Double[0]); + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNState.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNState.java new file mode 100644 index 000000000..7dac16af8 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNState.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph.gcn; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +public class GCNState implements Serializable { + + private GCNAccumulator accumulator; + private final Map minDepthByRoot; + private boolean dynamicExecution; + + public GCNState() { + this.minDepthByRoot = new HashMap<>(); + } + + public GCNAccumulator getAccumulator() { + return accumulator; + } + + public void initAccumulator(Object rootId) { + if (accumulator == null) { + accumulator = new GCNAccumulator(rootId); + } + } + + public Map getMinDepthByRoot() { + return minDepthByRoot; + } + + public boolean shouldExpand(Object rootId, int depth) { + Integer knownDepth = minDepthByRoot.get(rootId); + if (knownDepth != null && knownDepth <= depth) { + return false; + } + minDepthByRoot.put(rootId, depth); + return true; + } + + public boolean isDynamicExecution() { + return dynamicExecution; + } + + public void setDynamicExecution(boolean dynamicExecution) { + this.dynamicExecution = dynamicExecution; + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/MergedNeighborhoodCollector.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/MergedNeighborhoodCollector.java new file mode 100644 index 000000000..f3e6d12b3 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/MergedNeighborhoodCollector.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph.gcn; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; +import org.apache.geaflow.dsl.common.algo.AlgorithmModelRuntimeContext; +import org.apache.geaflow.dsl.common.data.Row; +import org.apache.geaflow.dsl.common.data.RowEdge; +import org.apache.geaflow.dsl.common.data.RowVertex; +import org.apache.geaflow.model.graph.edge.EdgeDirection; + +public class MergedNeighborhoodCollector { + + public Row resolveMergedVertexValue(Object vertexId, RowVertex historyVertex, + AlgorithmModelRuntimeContext context) { + Row dynamicValue = context.loadDynamicVertexValue(vertexId); + if (dynamicValue != null) { + return dynamicValue; + } + return historyVertex == null ? null : historyVertex.getValue(); + } + + public List collectMergedEdges(Object vertexId, + List staticEdges, + AlgorithmModelRuntimeContext context, + EdgeDirection direction, + int fanout) { + List mergedEdges = new ArrayList<>(staticEdges.size() + + context.loadDynamicEdges(vertexId, direction).size()); + mergedEdges.addAll(staticEdges); + mergedEdges.addAll(context.loadDynamicEdges(vertexId, direction)); + if (fanout < 0 || mergedEdges.size() <= fanout) { + return mergedEdges; + } + mergedEdges.sort(Comparator + .comparingInt((RowEdge edge) -> sampleScore(vertexId, direction, edge)) + .thenComparing(this::edgeIdentity)); + return new ArrayList<>(mergedEdges.subList(0, fanout)); + } + + public Object getNeighborId(RowEdge edge, Object currentVertexId) { + if (!Objects.equals(edge.getSrcId(), currentVertexId)) { + return edge.getSrcId(); + } + return edge.getTargetId(); + } + + private String edgeIdentity(RowEdge edge) { + return String.valueOf(edge.getSrcId()) + "->" + edge.getTargetId() + "@" + + edge.getLabel(); + } + + private int sampleScore(Object vertexId, EdgeDirection direction, RowEdge edge) { + return Objects.hash(vertexId, direction, edgeIdentity(edge)); + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateGraphAlgorithmTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateGraphAlgorithmTest.java index a1de8505a..8af7beb34 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateGraphAlgorithmTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateGraphAlgorithmTest.java @@ -40,5 +40,14 @@ public void testGraphAlgorithm() { .validate() .expectValidateType( "RecordType(BIGINT vid, BIGINT distance)"); + + String script3 = "CALL GCN() YIELD (node_id, embedding, prediction, confidence)\n" + + "RETURN CAST(node_id AS INT) AS node_id, prediction, confidence"; + + PlanTester.build() + .gql(script3) + .validate() + .expectValidateType( + "RecordType(INTEGER node_id, INTEGER prediction, DOUBLE confidence)"); } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/graph/GCNHelperTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/graph/GCNHelperTest.java new file mode 100644 index 000000000..573f1b193 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/graph/GCNHelperTest.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertThrows; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.geaflow.common.type.primitive.DoubleType; +import org.apache.geaflow.common.type.primitive.IntegerType; +import org.apache.geaflow.common.type.primitive.LongType; +import org.apache.geaflow.dsl.common.data.Row; +import org.apache.geaflow.dsl.common.types.ArrayType; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNAccumulator; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNEdgeRecord; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNFragmentMessage; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNPayloadAssembler; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNResultDecoder; +import org.apache.geaflow.model.graph.edge.EdgeDirection; +import org.testng.annotations.Test; + +public class GCNHelperTest { + + @Test + public void testPayloadAssemblerWithBidirectionalAndSelfLoop() { + GCNAccumulator accumulator = new GCNAccumulator(1L); + accumulator.addFragment(new GCNFragmentMessage(1L, 1L, + Arrays.asList(1.0, 2.0), Arrays.asList( + new GCNEdgeRecord(1L, 2L, "knows", EdgeDirection.OUT, 1.0D)), false)); + accumulator.addFragment(new GCNFragmentMessage(1L, 2L, + Arrays.asList(3.0, 4.0), Arrays.asList( + new GCNEdgeRecord(2L, 1L, "knows", EdgeDirection.OUT, 1.0D)), false)); + + GCNPayloadAssembler assembler = new GCNPayloadAssembler(); + Map payload = assembler.assemble(1L, accumulator); + + assertEquals(payload.get("center_node_id"), 1L); + assertEquals(((List) payload.get("sampled_nodes")).size(), 2); + assertEquals(((List) payload.get("node_features")).size(), 2); + + List edgeIndex = (List) payload.get("edge_index"); + List row = (List) edgeIndex.get(0); + List col = (List) edgeIndex.get(1); + assertEquals(row.size(), 4); + assertEquals(col.size(), 4); + assertEquals(row.get(0), 0); + assertEquals(col.get(0), 1); + assertEquals(row.get(1), 1); + assertEquals(col.get(1), 0); + assertEquals(row.get(2), 0); + assertEquals(col.get(2), 0); + assertEquals(row.get(3), 1); + assertEquals(col.get(3), 1); + + List edgeWeight = (List) payload.get("edge_weight"); + assertEquals(edgeWeight.size(), 4); + assertEquals(edgeWeight.get(0), 0.5D); + assertEquals(edgeWeight.get(1), 0.5D); + assertEquals(edgeWeight.get(2), 0.5D); + assertEquals(edgeWeight.get(3), 0.5D); + } + + @Test + public void testAccumulatorOverwritesLatestFragmentPerVertex() { + GCNAccumulator accumulator = new GCNAccumulator(1L); + accumulator.addFragment(new GCNFragmentMessage(1L, 2L, + Arrays.asList(1.0, 2.0), Arrays.asList( + new GCNEdgeRecord(2L, 3L, "knows", EdgeDirection.OUT, 1.0D)), false)); + accumulator.addFragment(new GCNFragmentMessage(1L, 2L, + Arrays.asList(9.0, 8.0), Arrays.asList( + new GCNEdgeRecord(2L, 4L, "knows", EdgeDirection.OUT, 1.0D)), false)); + + Map> nodeFeatures = accumulator.getNodeFeatures(); + assertEquals(nodeFeatures.get(2L), Arrays.asList(9.0, 8.0)); + assertEquals(accumulator.getEdgeRecords().size(), 2); + } + + @Test + public void testPayloadAssemblerForDirectedSingleEdgeDoesNotInjectReverseEdge() { + GCNAccumulator accumulator = new GCNAccumulator(1L); + accumulator.addFragment(new GCNFragmentMessage(1L, 1L, + Arrays.asList(1.0D), Arrays.asList( + new GCNEdgeRecord(1L, 2L, "knows", EdgeDirection.OUT, 2.0D)), false)); + accumulator.addFragment(new GCNFragmentMessage(1L, 2L, + Arrays.asList(2.0D), Collections.emptyList(), false)); + + Map payload = new GCNPayloadAssembler().assemble(1L, accumulator); + + List edgeIndex = (List) payload.get("edge_index"); + assertEquals(edgeIndex.get(0), Arrays.asList(0, 0, 1)); + assertEquals(edgeIndex.get(1), Arrays.asList(1, 0, 1)); + } + + @Test + public void testPayloadAssemblerAggregatesDuplicateDirectedEdges() { + GCNAccumulator accumulator = new GCNAccumulator(1L); + accumulator.addFragment(new GCNFragmentMessage(1L, 1L, + Arrays.asList(1.0D), Arrays.asList( + new GCNEdgeRecord(1L, 2L, "knows", EdgeDirection.OUT, 2.0D), + new GCNEdgeRecord(1L, 2L, "knows", EdgeDirection.OUT, 3.0D)), false)); + accumulator.addFragment(new GCNFragmentMessage(1L, 2L, + Arrays.asList(2.0D), Collections.emptyList(), false)); + + Map payload = new GCNPayloadAssembler().assemble(1L, accumulator); + + List edgeIndex = (List) payload.get("edge_index"); + assertEquals(edgeIndex.get(0), Arrays.asList(0, 0, 1)); + assertEquals(edgeIndex.get(1), Arrays.asList(1, 0, 1)); + List edgeWeight = (List) payload.get("edge_weight"); + assertEquals(edgeWeight.get(0), 5.0D / Math.sqrt(6.0D)); + } + + @Test + public void testResultDecoderWithMap() { + Map inferResult = new HashMap<>(); + inferResult.put("node_id", 1L); + inferResult.put("prediction", 7); + inferResult.put("confidence", 0.9D); + inferResult.put("embedding", Arrays.asList(0.1D, 0.2D)); + + Row row = new GCNResultDecoder().decode(1L, inferResult); + assertEquals(row.getField(0, LongType.INSTANCE), 1L); + assertNotNull(row.getField(1, new ArrayType(DoubleType.INSTANCE))); + Object[] embedding = (Object[]) row.getField(1, new ArrayType(DoubleType.INSTANCE)); + assertEquals(embedding.length, 2); + assertEquals(row.getField(2, IntegerType.INSTANCE), 7); + assertEquals(row.getField(3, DoubleType.INSTANCE), 0.9D); + } + + @Test + public void testResultDecoderUsesVertexIdWhenNodeIdMissingAndParsesNumericStrings() { + Map inferResult = new HashMap<>(); + inferResult.put("prediction", "7"); + inferResult.put("confidence", "0.9"); + inferResult.put("embedding", Arrays.asList("0.1", 0.2D)); + + Row row = new GCNResultDecoder().decode(5L, inferResult); + + assertEquals(row.getField(0, LongType.INSTANCE), 5L); + assertEquals(row.getField(2, IntegerType.INSTANCE), Integer.valueOf(7)); + assertEquals(row.getField(3, DoubleType.INSTANCE), Double.valueOf(0.9D)); + } + + @Test + public void testResultDecoderRejectsNodeIdMismatch() { + Map inferResult = new HashMap<>(); + inferResult.put("node_id", 6L); + + assertThrows(IllegalArgumentException.class, + () -> new GCNResultDecoder().decode(5L, inferResult)); + } + + @Test + public void testResultDecoderReturnsNullEmbeddingForUnsupportedType() { + Map inferResult = new HashMap<>(); + inferResult.put("embedding", "bad-embedding"); + + Row row = new GCNResultDecoder().decode(5L, inferResult); + + assertNull(row.getField(1, new ArrayType(DoubleType.INSTANCE))); + } + + @Test + public void testResultDecoderRejectsNonMapInput() { + GCNResultDecoder decoder = new GCNResultDecoder(); + try { + decoder.decode(5L, "unexpected-result"); + } catch (IllegalArgumentException e) { + assertNotNull(e.getMessage()); + return; + } + throw new AssertionError("expected exception"); + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/graph/GCNTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/graph/GCNTest.java new file mode 100644 index 000000000..a3b153e39 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/graph/GCNTest.java @@ -0,0 +1,653 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertSame; +import static org.testng.Assert.assertThrows; +import static org.testng.Assert.assertTrue; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import org.apache.geaflow.common.config.Configuration; +import org.apache.geaflow.common.iterator.CloseableIterator; +import org.apache.geaflow.common.type.primitive.DoubleType; +import org.apache.geaflow.common.type.primitive.IntegerType; +import org.apache.geaflow.common.type.primitive.LongType; +import org.apache.geaflow.common.type.primitive.StringType; +import org.apache.geaflow.dsl.common.algo.AlgorithmModelRuntimeContext; +import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; +import org.apache.geaflow.dsl.common.data.Row; +import org.apache.geaflow.dsl.common.data.RowEdge; +import org.apache.geaflow.dsl.common.data.RowVertex; +import org.apache.geaflow.dsl.common.data.impl.ObjectRow; +import org.apache.geaflow.dsl.common.data.impl.types.ObjectEdge; +import org.apache.geaflow.dsl.common.data.impl.types.ObjectVertex; +import org.apache.geaflow.dsl.common.types.EdgeType; +import org.apache.geaflow.dsl.common.types.GraphSchema; +import org.apache.geaflow.dsl.common.types.ObjectType; +import org.apache.geaflow.dsl.common.types.TableField; +import org.apache.geaflow.dsl.common.types.VertexType; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNExpandMessage; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNFragmentMessage; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNState; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNEdgeRecord; +import org.apache.geaflow.dsl.udf.graph.gcn.MergedNeighborhoodCollector; +import org.apache.geaflow.model.graph.edge.EdgeDirection; +import org.apache.geaflow.state.pushdown.filter.IFilter; +import org.testng.annotations.Test; + +public class GCNTest { + + @Test(expectedExceptions = IllegalArgumentException.class, + expectedExceptionsMessageRegExp = "GCN requires model runtime context") + public void testInitRejectsNonModelContext() { + new GCN().init(new FakeRuntimeContext(new Configuration(), simpleGraphSchema()), new Object[0]); + } + + @Test(expectedExceptions = IllegalArgumentException.class, + expectedExceptionsMessageRegExp = "geaflow\\.dsl\\.gcn\\.hops must be >= 1") + public void testInitRejectsInvalidHops() { + Configuration config = baseConfig(); + config.put("geaflow.dsl.gcn.hops", "0"); + new GCN().init(new FakeModelRuntimeContext(config, simpleGraphSchema()), new Object[0]); + } + + @Test(expectedExceptions = IllegalArgumentException.class, + expectedExceptionsMessageRegExp = "geaflow\\.dsl\\.gcn\\.fanout must be -1 or > 0") + public void testInitRejectsInvalidFanout() { + Configuration config = baseConfig(); + config.put("geaflow.dsl.gcn.fanout", "0"); + new GCN().init(new FakeModelRuntimeContext(config, simpleGraphSchema()), new Object[0]); + } + + @Test(expectedExceptions = IllegalArgumentException.class, + expectedExceptionsMessageRegExp = "geaflow\\.dsl\\.gcn\\.edge\\.direction must be one of IN, OUT, BOTH") + public void testInitRejectsInvalidDirection() { + Configuration config = baseConfig(); + config.put("geaflow.dsl.gcn.edge.direction", "sideways"); + new GCN().init(new FakeModelRuntimeContext(config, simpleGraphSchema()), new Object[0]); + } + + @Test(expectedExceptions = IllegalArgumentException.class, + expectedExceptionsMessageRegExp = "geaflow\\.dsl\\.gcn\\.vertex\\.feature\\.fields must be configured") + public void testInitRejectsMissingFeatureFields() { + new GCN().init(new FakeModelRuntimeContext(new Configuration(), simpleGraphSchema()), + new Object[0]); + } + + @Test(expectedExceptions = IllegalArgumentException.class, + expectedExceptionsMessageRegExp = "GCN feature field must be numeric scalar: name") + public void testInitRejectsNonNumericFeatureField() { + Configuration config = new Configuration(); + config.put("geaflow.dsl.gcn.vertex.feature.fields", "name"); + new GCN().init(new FakeModelRuntimeContext(config, simpleGraphSchema()), new Object[0]); + } + + @Test(expectedExceptions = IllegalArgumentException.class, + expectedExceptionsMessageRegExp = "GCN feature fields are inconsistent across vertex value schemas") + public void testInitRejectsInconsistentFeatureSchema() { + Configuration config = baseConfig(); + GraphSchema schema = graphSchema( + "person", vertexType(new TableField("name", StringType.INSTANCE), + new TableField("age", IntegerType.INSTANCE)), + "software", vertexType(new TableField("name", StringType.INSTANCE), + new TableField("age", LongType.INSTANCE)), + "knows", edgeType(new TableField("weight", DoubleType.INSTANCE)) + ); + new GCN().init(new FakeModelRuntimeContext(config, schema), new Object[0]); + } + + @Test(expectedExceptions = IllegalArgumentException.class, + expectedExceptionsMessageRegExp = "GCN edge weight field 'missing' missing from edge label 'knows'") + public void testInitRejectsMissingEdgeWeightField() { + Configuration config = baseConfig(); + config.put("geaflow.dsl.gcn.edge.weight.field", "missing"); + new GCN().init(new FakeModelRuntimeContext(config, simpleGraphSchema()), new Object[0]); + } + + @Test + public void testMergedNeighborhoodUsesDynamicValue() { + MergedNeighborhoodCollector collector = new MergedNeighborhoodCollector(); + Row historyValue = ObjectRow.create("history", 18); + Row dynamicValue = ObjectRow.create("dynamic", 20); + ObjectVertex historyVertex = new ObjectVertex(1L); + historyVertex.setLabel("person"); + historyVertex.setValue(historyValue); + FakeModelRuntimeContext context = new FakeModelRuntimeContext(baseConfig(), simpleGraphSchema()); + context.dynamicVertexValue = dynamicValue; + + Row mergedValue = collector.resolveMergedVertexValue(1L, historyVertex, context); + + assertSame(mergedValue, dynamicValue); + } + + @Test + public void testMergedNeighborhoodSamplingDoesNotDependOnInputOrder() { + MergedNeighborhoodCollector collector = new MergedNeighborhoodCollector(); + RowEdge staticEdge1 = edge(1L, 2L, "knows", EdgeDirection.OUT, 1.0D); + RowEdge staticEdge2 = edge(1L, 3L, "knows", EdgeDirection.OUT, 2.0D); + RowEdge dynamicEdge = edge(1L, 4L, "knows", EdgeDirection.OUT, 3.0D); + FakeModelRuntimeContext context = new FakeModelRuntimeContext(baseConfig(), simpleGraphSchema()); + context.dynamicEdges = Collections.singletonList(dynamicEdge); + + List sampledEdgesA = collector.collectMergedEdges(1L, + Arrays.asList(staticEdge1, staticEdge2), context, EdgeDirection.OUT, 2); + List sampledEdgesB = collector.collectMergedEdges(1L, + Arrays.asList(staticEdge2, staticEdge1), context, EdgeDirection.OUT, 2); + + assertEquals(edgeSignatures(sampledEdgesA), edgeSignatures(sampledEdgesB)); + } + + @Test + public void testProcessRebuildsStateOnFirstIteration() { + FakeModelRuntimeContext context = new FakeModelRuntimeContext(baseConfig(), simpleGraphSchema()); + context.iterationId = 1L; + context.staticEdges = Collections.singletonList(edge(1L, 2L, "knows", EdgeDirection.OUT, 1.0D)); + + GCN algorithm = new GCN(); + algorithm.init(context, new Object[0]); + + algorithm.process(vertex(1L, "person", "alice", 18), + Optional.empty(), Collections.emptyList().iterator()); + + GCNState state = (GCNState) context.updatedValue.getField(0, ObjectType.INSTANCE); + assertEquals(state.getMinDepthByRoot().get(1L), Integer.valueOf(0)); + assertTrue(state.getAccumulator().getNodeFeatures().containsKey(1L)); + assertEquals(context.sentMessages.size(), 1); + assertEquals(context.sentMessages.get(0).vertexId, 2L); + GCNExpandMessage message = (GCNExpandMessage) context.sentMessages.get(0).message; + assertEquals(message.getRootId(), 1L); + assertEquals(message.getDepth(), 1); + } + + @Test + public void testProcessDeduplicatesRootExpansionByDepth() { + Configuration config = baseConfig(); + config.put("geaflow.dsl.gcn.hops", "3"); + FakeModelRuntimeContext context = new FakeModelRuntimeContext(config, simpleGraphSchema()); + context.iterationId = 3L; + context.staticEdges = Arrays.asList( + edge(4L, 2L, "knows", EdgeDirection.OUT, 1.0D), + edge(4L, 5L, "knows", EdgeDirection.OUT, 1.0D)); + + GCN algorithm = new GCN(); + algorithm.init(context, new Object[0]); + + GCNState state = new GCNState(); + GCNExpandMessage first = new GCNExpandMessage(1L, 2); + GCNExpandMessage second = new GCNExpandMessage(1L, 3); + algorithm.process(vertex(4L, "person", "diana", 40), Optional.of(ObjectRow.create(state)), + Arrays.asList(first, second).iterator()); + + int fragmentToRoot = 0; + int expandTo2 = 0; + int expandTo5 = 0; + for (SentMessage sentMessage : context.sentMessages) { + if (sentMessage.vertexId.equals(1L)) { + fragmentToRoot++; + } else if (sentMessage.vertexId.equals(2L)) { + expandTo2++; + } else if (sentMessage.vertexId.equals(5L)) { + expandTo5++; + } + } + assertEquals(fragmentToRoot, 1); + assertEquals(expandTo2, 1); + assertEquals(expandTo5, 1); + GCNState newState = (GCNState) context.updatedValue.getField(0, ObjectType.INSTANCE); + assertEquals(newState.getMinDepthByRoot().get(1L), Integer.valueOf(2)); + } + + @Test + public void testProcessWithInDirectionOnlyUsesIncomingNeighbors() { + Configuration config = baseConfig(); + config.put("geaflow.dsl.gcn.edge.direction", "IN"); + FakeModelRuntimeContext context = new FakeModelRuntimeContext(config, simpleGraphSchema()); + context.iterationId = 1L; + context.staticEdges = Arrays.asList( + edge(2L, 1L, "knows", EdgeDirection.IN, 1.0D), + edge(1L, 3L, "knows", EdgeDirection.OUT, 1.0D)); + + GCN algorithm = new GCN(); + algorithm.init(context, new Object[0]); + + algorithm.process(vertex(1L, "person", "alice", 18), + Optional.empty(), Collections.emptyList().iterator()); + + assertEquals(context.sentMessages.size(), 1); + assertEquals(context.sentMessages.get(0).vertexId, 2L); + GCNState state = (GCNState) context.updatedValue.getField(0, ObjectType.INSTANCE); + assertEquals(state.getAccumulator().getEdgeRecords().size(), 1); + GCNEdgeRecord edgeRecord = state.getAccumulator().getEdgeRecords().get(0); + assertEquals(edgeRecord.getSrcId(), 2L); + assertEquals(edgeRecord.getTargetId(), 1L); + } + + @Test + public void testProcessWithBothDirectionUsesIncomingAndOutgoingNeighbors() { + Configuration config = baseConfig(); + config.put("geaflow.dsl.gcn.edge.direction", "BOTH"); + FakeModelRuntimeContext context = new FakeModelRuntimeContext(config, simpleGraphSchema()); + context.iterationId = 1L; + context.staticEdges = Arrays.asList( + edge(2L, 1L, "knows", EdgeDirection.IN, 1.0D), + edge(1L, 3L, "knows", EdgeDirection.OUT, 1.0D)); + + GCN algorithm = new GCN(); + algorithm.init(context, new Object[0]); + + algorithm.process(vertex(1L, "person", "alice", 18), + Optional.empty(), Collections.emptyList().iterator()); + + assertEquals(context.sentMessages.size(), 2); + assertEquals(sentTargets(context.sentMessages), Arrays.asList(2L, 3L)); + } + + @Test + public void testFinishBatchWithFanoutOneKeepsDeterministicSubset() { + MergedNeighborhoodCollector collector = new MergedNeighborhoodCollector(); + FakeModelRuntimeContext context = new FakeModelRuntimeContext(baseConfig(), simpleGraphSchema()); + List edgesA = Arrays.asList( + edge(1L, 2L, "knows", EdgeDirection.OUT, 1.0D), + edge(1L, 3L, "knows", EdgeDirection.OUT, 1.0D), + edge(1L, 4L, "knows", EdgeDirection.OUT, 1.0D)); + List edgesB = Arrays.asList( + edge(1L, 4L, "knows", EdgeDirection.OUT, 1.0D), + edge(1L, 2L, "knows", EdgeDirection.OUT, 1.0D), + edge(1L, 3L, "knows", EdgeDirection.OUT, 1.0D)); + + List sampledA = collector.collectMergedEdges(1L, edgesA, context, + EdgeDirection.OUT, 1); + List sampledB = collector.collectMergedEdges(1L, edgesB, context, + EdgeDirection.OUT, 1); + + assertEquals(sampledA.size(), 1); + assertEquals(edgeSignatures(sampledA), edgeSignatures(sampledB)); + } + + @Test + public void testExtractFeaturesUsesZeroForNullFields() { + Configuration config = new Configuration(); + config.put("geaflow.dsl.gcn.vertex.feature.fields", "age,height"); + GraphSchema schema = graphSchema( + "person", vertexType( + new TableField("name", StringType.INSTANCE), + new TableField("age", IntegerType.INSTANCE), + new TableField("height", LongType.INSTANCE)), + "knows", edgeType(new TableField("weight", DoubleType.INSTANCE)) + ); + FakeModelRuntimeContext context = new FakeModelRuntimeContext(config, schema); + context.iterationId = 1L; + + GCN algorithm = new GCN(); + algorithm.init(context, new Object[0]); + + algorithm.process(vertexWithValue(1L, "person", ObjectRow.create("alice", null, 170L)), + Optional.empty(), Collections.emptyList().iterator()); + + GCNState state = (GCNState) context.updatedValue.getField(0, ObjectType.INSTANCE); + assertEquals(state.getAccumulator().getNodeFeatures().get(1L), Arrays.asList(0.0D, 170.0D)); + } + + @Test + public void testExtractEdgeWeightUsesZeroWhenWeightFieldNull() { + Configuration config = baseConfig(); + config.put("geaflow.dsl.gcn.edge.weight.field", "weight"); + FakeModelRuntimeContext context = new FakeModelRuntimeContext(config, simpleGraphSchema()); + context.iterationId = 1L; + context.staticEdges = Collections.singletonList( + edgeWithValue(1L, 2L, "knows", EdgeDirection.OUT, ObjectRow.create((Object) null))); + + GCN algorithm = new GCN(); + algorithm.init(context, new Object[0]); + + algorithm.process(vertex(1L, "person", "alice", 18), + Optional.empty(), Collections.emptyList().iterator()); + + GCNState state = (GCNState) context.updatedValue.getField(0, ObjectType.INSTANCE); + assertEquals(state.getAccumulator().getEdgeRecords().get(0).getWeight(), 0.0D); + } + + @Test + public void testFinishBatchRejectsInferResultSizeMismatch() { + Configuration config = baseConfig(); + config.put("geaflow.infer.env.enable", "true"); + FakeModelRuntimeContext context = new FakeModelRuntimeContext(config, simpleGraphSchema()) { + @Override + public List inferBatch(List> payloads) { + batchInferPayloads.add(new ArrayList>(payloads)); + return Collections.emptyList(); + } + }; + GCN algorithm = new GCN(); + algorithm.init(context, new Object[0]); + + assertThrows(IllegalArgumentException.class, () -> algorithm.finishBatch( + Collections.singletonList(vertex(1L, "person", "alice", 18)), + Collections.>singletonList(Optional.of(ObjectRow.create( + stateWithEdge(1L, 2L, 1.0D, 18.0D, 20.0D)))))); + } + + @Test + public void testFinishBatchUsesConfiguredBatchSizeAndPreservesDirectedWeights() { + Configuration config = baseConfig(); + config.put("geaflow.infer.env.enable", "true"); + config.put("geaflow.dsl.gcn.batch.size", "2"); + config.put("geaflow.dsl.gcn.edge.weight.field", "weight"); + FakeModelRuntimeContext context = new FakeModelRuntimeContext(config, simpleGraphSchema()); + GCN algorithm = new GCN(); + algorithm.init(context, new Object[0]); + + GCNState first = stateWithEdge(1L, 2L, 2.0D, 18.0D, 20.0D); + GCNState second = stateWithEdge(2L, 3L, 4.0D, 20.0D, 22.0D); + GCNState third = stateWithEdge(3L, 4L, 6.0D, 22.0D, 24.0D); + + algorithm.finishBatch(Arrays.asList( + vertex(1L, "person", "alice", 18), + vertex(2L, "person", "bob", 20), + vertex(3L, "person", "cathy", 22)), + Arrays.asList( + Optional.of(ObjectRow.create(first)), + Optional.of(ObjectRow.create(second)), + Optional.of(ObjectRow.create(third)))); + + assertEquals(context.batchInferPayloads.size(), 2); + assertEquals(context.batchInferPayloads.get(0).size(), 2); + assertEquals(context.batchInferPayloads.get(1).size(), 1); + Map firstPayload = context.batchInferPayloads.get(0).get(0); + List edgeIndex = (List) firstPayload.get("edge_index"); + assertEquals(edgeIndex.get(0), Arrays.asList(0, 0, 1)); + assertEquals(edgeIndex.get(1), Arrays.asList(1, 0, 1)); + List edgeWeight = (List) firstPayload.get("edge_weight"); + assertEquals(((Double) edgeWeight.get(0)).doubleValue(), 2.0D / Math.sqrt(3.0D), 1e-9); + assertEquals(context.takenRows.size(), 3); + assertEquals(context.takenRows.get(0).getField(0, ObjectType.INSTANCE), 1L); + assertEquals(context.takenRows.get(2).getField(0, ObjectType.INSTANCE), 3L); + } + + private static GCNState stateWithEdge(long srcId, long targetId, double weight, + double srcFeature, double targetFeature) { + GCNState state = new GCNState(); + state.initAccumulator(srcId); + state.getAccumulator().addFragment(new GCNFragmentMessage(srcId, srcId, + Collections.singletonList(srcFeature), + Collections.singletonList(new GCNEdgeRecord(srcId, targetId, "knows", + EdgeDirection.OUT, weight)), false)); + state.getAccumulator().addFragment(new GCNFragmentMessage(srcId, targetId, + Collections.singletonList(targetFeature), Collections.emptyList(), false)); + return state; + } + + private static Configuration baseConfig() { + Configuration config = new Configuration(); + config.put("geaflow.dsl.gcn.vertex.feature.fields", "age"); + return config; + } + + private static List edgeSignatures(List edges) { + List signatures = new ArrayList<>(); + for (RowEdge edge : edges) { + signatures.add(edge.getSrcId() + "->" + edge.getTargetId() + "#" + edge.getDirect() + + "@" + edge.getValue().getField(0, ObjectType.INSTANCE)); + } + return signatures; + } + + private static List sentTargets(List messages) { + List targets = new ArrayList(messages.size()); + for (SentMessage message : messages) { + targets.add((Long) message.vertexId); + } + Collections.sort(targets); + return targets; + } + + private static GraphSchema simpleGraphSchema() { + return graphSchema( + "person", vertexType( + new TableField("name", StringType.INSTANCE), + new TableField("age", IntegerType.INSTANCE)), + "knows", edgeType(new TableField("weight", DoubleType.INSTANCE)) + ); + } + + private static GraphSchema graphSchema(String vertexLabel, VertexType vertexType, String edgeLabel, + EdgeType edgeType) { + return new GraphSchema("g", Arrays.asList( + new TableField(vertexLabel, vertexType), + new TableField(edgeLabel, edgeType))); + } + + private static GraphSchema graphSchema(String vertexLabel1, VertexType vertexType1, + String vertexLabel2, VertexType vertexType2, + String edgeLabel, EdgeType edgeType) { + return new GraphSchema("g", Arrays.asList( + new TableField(vertexLabel1, vertexType1), + new TableField(vertexLabel2, vertexType2), + new TableField(edgeLabel, edgeType))); + } + + private static VertexType vertexType(TableField... valueFields) { + List fields = new ArrayList<>(); + fields.add(new TableField(VertexType.DEFAULT_ID_FIELD_NAME, LongType.INSTANCE, false)); + fields.add(new TableField(GraphSchema.LABEL_FIELD_NAME, StringType.INSTANCE, false)); + fields.addAll(Arrays.asList(valueFields)); + return new VertexType(fields); + } + + private static EdgeType edgeType(TableField... valueFields) { + List fields = new ArrayList<>(); + fields.add(new TableField(EdgeType.DEFAULT_SRC_ID_NAME, LongType.INSTANCE, false)); + fields.add(new TableField(EdgeType.DEFAULT_TARGET_ID_NAME, LongType.INSTANCE, false)); + fields.add(new TableField(GraphSchema.LABEL_FIELD_NAME, StringType.INSTANCE, false)); + fields.addAll(Arrays.asList(valueFields)); + return new EdgeType(fields, false); + } + + private static RowEdge edge(long srcId, long targetId, String label, EdgeDirection direction, + double weight) { + ObjectEdge edge = new ObjectEdge(srcId, targetId, ObjectRow.create(weight)); + edge.setLabel(label); + edge.setDirect(direction); + return edge; + } + + private static RowEdge edgeWithValue(long srcId, long targetId, String label, + EdgeDirection direction, Row value) { + ObjectEdge edge = new ObjectEdge(srcId, targetId, value); + edge.setLabel(label); + edge.setDirect(direction); + return edge; + } + + private static ObjectVertex vertex(long id, String label, String name, int age) { + return vertexWithValue(id, label, ObjectRow.create(name, age)); + } + + private static ObjectVertex vertexWithValue(long id, String label, Row value) { + ObjectVertex vertex = new ObjectVertex(id); + vertex.setLabel(label); + vertex.setValue(value); + return vertex; + } + + private static class FakeRuntimeContext implements AlgorithmRuntimeContext { + + private final Configuration config; + private final GraphSchema graphSchema; + long iterationId; + List staticEdges = Collections.emptyList(); + Row updatedValue; + final List sentMessages = new ArrayList<>(); + final List takenRows = new ArrayList<>(); + + FakeRuntimeContext(Configuration config, GraphSchema graphSchema) { + this.config = config; + this.graphSchema = graphSchema; + } + + @Override + public List loadEdges(EdgeDirection direction) { + return Collections.emptyList(); + } + + @Override + public CloseableIterator loadEdgesIterator(EdgeDirection direction) { + return null; + } + + @Override + public CloseableIterator loadEdgesIterator(IFilter filter) { + return null; + } + + @Override + public List loadStaticEdges(EdgeDirection direction) { + return filterEdges(staticEdges, direction); + } + + @Override + public CloseableIterator loadStaticEdgesIterator(EdgeDirection direction) { + return null; + } + + @Override + public CloseableIterator loadStaticEdgesIterator(IFilter filter) { + return null; + } + + @Override + public List loadDynamicEdges(EdgeDirection direction) { + return Collections.emptyList(); + } + + @Override + public CloseableIterator loadDynamicEdgesIterator(EdgeDirection direction) { + return null; + } + + @Override + public CloseableIterator loadDynamicEdgesIterator(IFilter filter) { + return null; + } + + @Override + public void sendMessage(Object vertexId, Object message) { + sentMessages.add(new SentMessage(vertexId, message)); + } + + @Override + public void updateVertexValue(Row value) { + updatedValue = value; + } + + @Override + public void take(Row value) { + takenRows.add(value); + } + + @Override + public long getCurrentIterationId() { + return iterationId; + } + + @Override + public GraphSchema getGraphSchema() { + return graphSchema; + } + + @Override + public Configuration getConfig() { + return config; + } + + @Override + public void voteToTerminate(String terminationReason, Object voteValue) { + } + + protected List filterEdges(List edges, EdgeDirection direction) { + if (direction == EdgeDirection.BOTH) { + return edges; + } + List filtered = new ArrayList(); + for (RowEdge edge : edges) { + if (edge.getDirect() == direction) { + filtered.add(edge); + } + } + return filtered; + } + } + + private static class SentMessage { + + private final Object vertexId; + private final Object message; + + private SentMessage(Object vertexId, Object message) { + this.vertexId = vertexId; + this.message = message; + } + } + + private static class FakeModelRuntimeContext extends FakeRuntimeContext + implements AlgorithmModelRuntimeContext { + + private Row dynamicVertexValue; + private List dynamicEdges = Collections.emptyList(); + protected final List>> batchInferPayloads = new ArrayList<>(); + + FakeModelRuntimeContext(Configuration config, GraphSchema graphSchema) { + super(config, graphSchema); + } + + @Override + public Object infer(Map payload) { + return inferBatch(Collections.singletonList(payload)).get(0); + } + + @Override + public List inferBatch(List> payloads) { + batchInferPayloads.add(new ArrayList<>(payloads)); + List results = new ArrayList<>(payloads.size()); + for (Map payload : payloads) { + results.add(Collections.singletonMap("embedding", + Collections.singletonList(payload.get("center_node_id")))); + } + return results; + } + + @Override + public Row loadDynamicVertexValue(Object vertexId) { + return dynamicVertexValue; + } + + @Override + public List loadDynamicEdges(Object vertexId, EdgeDirection direction) { + return filterEdges(dynamicEdges, direction); + } + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/pom.xml b/geaflow/geaflow-dsl/geaflow-dsl-runtime/pom.xml index e9863eb2e..e94147ffa 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/pom.xml +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/pom.xml @@ -153,6 +153,12 @@ org.apache.geaflow geaflow-view-meta + + + org.apache.geaflow + geaflow-infer + ${project.version} + @@ -217,4 +223,4 @@ - \ No newline at end of file + diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmAggTraversalFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmAggTraversalFunction.java index daca980db..8b21f6cd2 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmAggTraversalFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmAggTraversalFunction.java @@ -22,17 +22,21 @@ import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.SYSTEM_STATE_BACKEND_TYPE; import java.util.Collections; +import java.util.ArrayList; import java.util.HashSet; import java.util.Iterator; +import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.Set; import org.apache.geaflow.api.graph.function.vc.VertexCentricAggTraversalFunction; +import org.apache.geaflow.dsl.common.algo.BatchAlgorithmUserFunction; import org.apache.geaflow.dsl.common.algo.AlgorithmUserFunction; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.RowVertex; import org.apache.geaflow.dsl.common.types.GraphSchema; import org.apache.geaflow.dsl.runtime.traversal.message.ITraversalAgg; +import org.apache.geaflow.dsl.udf.graph.GCN; import org.apache.geaflow.model.traversal.ITraversalRequest; import org.apache.geaflow.state.KeyValueState; import org.apache.geaflow.state.StateFactory; @@ -73,7 +77,9 @@ public GeaFlowAlgorithmAggTraversalFunction(GraphSchema graphSchema, public void open( VertexCentricTraversalFuncContext vertexCentricFuncContext) { this.traversalContext = vertexCentricFuncContext; - this.algorithmCtx = new GeaFlowAlgorithmRuntimeContext(this, traversalContext, graphSchema); + this.algorithmCtx = userFunction instanceof GCN + ? new GeaFlowAlgorithmModelRuntimeContext(this, traversalContext, graphSchema) + : new GeaFlowAlgorithmRuntimeContext(this, traversalContext, graphSchema); this.userFunction.init(algorithmCtx, params); this.invokeVIds = new HashSet<>(); String stateName = traversalContext.getTraversalOpName() + "_" + STATE_SUFFIX; @@ -122,14 +128,32 @@ public void compute(Object vertexId, Iterator messages) { @Override public void finish() { - Iterator idIterator = getInvokeVIds(); - while (idIterator.hasNext()) { - Object id = idIterator.next(); - algorithmCtx.setVertexId(id); - RowVertex graphVertex = (RowVertex) traversalContext.vertex().withId(id).get(); - if (graphVertex != null) { - Row newValue = getVertexNewValue(graphVertex.getId()); - userFunction.finish(graphVertex, Optional.ofNullable(newValue)); + if (userFunction instanceof BatchAlgorithmUserFunction) { + List batchVertices = new ArrayList<>(); + List> batchValues = new ArrayList<>(); + Iterator idIterator = getInvokeVIds(); + while (idIterator.hasNext()) { + Object id = idIterator.next(); + algorithmCtx.setVertexId(id); + RowVertex graphVertex = (RowVertex) traversalContext.vertex().withId(id).get(); + if (graphVertex != null) { + Row newValue = getVertexNewValue(graphVertex.getId()); + batchVertices.add(graphVertex); + batchValues.add(Optional.ofNullable(newValue)); + } + } + ((BatchAlgorithmUserFunction) userFunction).finishBatch(batchVertices, + batchValues); + } else { + Iterator idIterator = getInvokeVIds(); + while (idIterator.hasNext()) { + Object id = idIterator.next(); + algorithmCtx.setVertexId(id); + RowVertex graphVertex = (RowVertex) traversalContext.vertex().withId(id).get(); + if (graphVertex != null) { + Row newValue = getVertexNewValue(graphVertex.getId()); + userFunction.finish(graphVertex, Optional.ofNullable(newValue)); + } } } algorithmCtx.finish(); diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicAggTraversalFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicAggTraversalFunction.java index 98c475b15..0025ff865 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicAggTraversalFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicAggTraversalFunction.java @@ -22,6 +22,7 @@ import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.SYSTEM_STATE_BACKEND_TYPE; import static org.apache.geaflow.operator.Constants.GRAPH_VERSION; +import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.Iterator; @@ -32,11 +33,13 @@ import org.apache.geaflow.api.function.iterator.RichIteratorFunction; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricAggTraversalFunction; import org.apache.geaflow.common.config.keys.FrameworkConfigKeys; +import org.apache.geaflow.dsl.common.algo.BatchAlgorithmUserFunction; import org.apache.geaflow.dsl.common.algo.AlgorithmUserFunction; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.RowVertex; import org.apache.geaflow.dsl.common.types.GraphSchema; import org.apache.geaflow.dsl.runtime.traversal.message.ITraversalAgg; +import org.apache.geaflow.dsl.udf.graph.GCN; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; import org.apache.geaflow.model.traversal.ITraversalRequest; @@ -76,6 +79,10 @@ public class GeaFlowAlgorithmDynamicAggTraversalFunction private boolean materializeInFinish; + private transient List batchFinishVertices; + + private transient List> batchFinishValues; + public GeaFlowAlgorithmDynamicAggTraversalFunction(GraphSchema graphSchema, AlgorithmUserFunction userFunction, Object[] params) { @@ -90,9 +97,12 @@ public void open( IncVertexCentricTraversalFuncContext vertexCentricFuncContext) { this.traversalContext = vertexCentricFuncContext; this.materializeInFinish = traversalContext.getRuntimeContext().getConfiguration().getBoolean(FrameworkConfigKeys.UDF_MATERIALIZE_GRAPH_IN_FINISH); - this.algorithmCtx = new GeaFlowAlgorithmDynamicRuntimeContext(this, traversalContext, - graphSchema); + this.algorithmCtx = userFunction instanceof GCN + ? new GeaFlowAlgorithmDynamicModelRuntimeContext(this, traversalContext, graphSchema) + : new GeaFlowAlgorithmDynamicRuntimeContext(this, traversalContext, graphSchema); this.initVertices = new HashSet<>(); + this.batchFinishVertices = new ArrayList<>(); + this.batchFinishValues = new ArrayList<>(); this.userFunction.init(algorithmCtx, params); this.mutableGraph = traversalContext.getMutableGraph(); @@ -129,6 +139,9 @@ public void init(ITraversalRequest traversalRequest) { // false when called after the first time to avoid redundant invocation. if (vertexId != null && needInit(vertexId)) { RowVertex vertex = (RowVertex) algorithmCtx.loadVertex(); + if (vertex == null) { + vertex = (RowVertex) algorithmCtx.getIncVCTraversalCtx().getTemporaryGraph().getVertex(); + } if (vertex != null) { algorithmCtx.setVertexId(vertex.getId()); Row newValue = getVertexNewValue(vertex.getId()); @@ -182,6 +195,9 @@ public void compute(Object vertexId, Iterator messages) { } } else { vertex = (RowVertex) algorithmCtx.loadVertex(); + if (vertex == null) { + vertex = (RowVertex) algorithmCtx.getIncVCTraversalCtx().getTemporaryGraph().getVertex(); + } } if (vertex != null) { Row newValue = getVertexNewValue(vertex.getId()); @@ -193,9 +209,17 @@ public void compute(Object vertexId, Iterator messages) { public void finish(Object vertexId, MutableGraph mutableGraph) { algorithmCtx.setVertexId(vertexId); RowVertex graphVertex = (RowVertex) algorithmCtx.loadVertex(); + if (graphVertex == null) { + graphVertex = (RowVertex) algorithmCtx.getIncVCTraversalCtx().getTemporaryGraph().getVertex(); + } if (graphVertex != null) { Row newValue = getVertexNewValue(graphVertex.getId()); - userFunction.finish(graphVertex, Optional.ofNullable(newValue)); + if (userFunction instanceof BatchAlgorithmUserFunction) { + batchFinishVertices.add(graphVertex); + batchFinishValues.add(Optional.ofNullable(newValue)); + } else { + userFunction.finish(graphVertex, Optional.ofNullable(newValue)); + } } if (materializeInFinish) { IVertex vertex = algorithmCtx.getIncVCTraversalCtx().getTemporaryGraph().getVertex(); @@ -216,6 +240,12 @@ public boolean needInit(Object v) { public void finish() { algorithmCtx.finish(); initVertices.clear(); + if (userFunction instanceof BatchAlgorithmUserFunction) { + ((BatchAlgorithmUserFunction) userFunction).finishBatch( + batchFinishVertices, batchFinishValues); + batchFinishVertices.clear(); + batchFinishValues.clear(); + } userFunction.finish(); long windowId = traversalContext.getRuntimeContext().getWindowId(); this.vertexUpdateValues.manage().operate().setCheckpointId(windowId); diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicModelRuntimeContext.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicModelRuntimeContext.java new file mode 100644 index 000000000..9bd7eee3c --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicModelRuntimeContext.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.runtime.engine; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import org.apache.geaflow.api.graph.function.vc.IncVertexCentricTraversalFunction.IncVertexCentricTraversalFuncContext; +import org.apache.geaflow.common.exception.GeaflowRuntimeException; +import org.apache.geaflow.dsl.common.algo.AlgorithmModelRuntimeContext; +import org.apache.geaflow.dsl.common.data.Row; +import org.apache.geaflow.dsl.common.data.RowEdge; +import org.apache.geaflow.dsl.common.types.GraphSchema; +import org.apache.geaflow.infer.InferContextLease; +import org.apache.geaflow.infer.InferContextPool; +import org.apache.geaflow.model.graph.edge.EdgeDirection; +import org.apache.geaflow.model.graph.vertex.IVertex; + +public class GeaFlowAlgorithmDynamicModelRuntimeContext extends GeaFlowAlgorithmDynamicRuntimeContext + implements AlgorithmModelRuntimeContext { + + public GeaFlowAlgorithmDynamicModelRuntimeContext( + GeaFlowAlgorithmDynamicAggTraversalFunction traversalFunction, + IncVertexCentricTraversalFuncContext traversalContext, + GraphSchema graphSchema) { + super(traversalFunction, traversalContext, graphSchema); + } + + @Override + public Object infer(Map payload) { + List results = inferBatch(Collections.singletonList(payload)); + return results.isEmpty() ? null : results.get(0); + } + + @Override + public List inferBatch(List> payloads) { + try { + try (InferContextLease inferContextLease = InferContextPool.borrow(getConfig())) { + List results = new ArrayList<>(payloads.size()); + for (Map payload : payloads) { + results.add(inferContextLease.infer(payload)); + } + return results; + } + } catch (Exception e) { + throw new GeaflowRuntimeException( + String.format("GCN model batch infer failed, vertexId=%s, batchSize=%s", + getVertexId(), payloads.size()), e); + } + } + + @Override + public Row loadDynamicVertexValue(Object vertexId) { + if (!Objects.equals(getVertexId(), vertexId)) { + return null; + } + IVertex vertex = getIncVCTraversalCtx().getTemporaryGraph().getVertex(); + return vertex == null ? null : vertex.getValue(); + } + + @Override + public List loadDynamicEdges(Object vertexId, EdgeDirection direction) { + if (!Objects.equals(getVertexId(), vertexId)) { + return Collections.emptyList(); + } + return super.loadDynamicEdges(direction); + } + + @Override + public void close() { + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicRuntimeContext.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicRuntimeContext.java index d929ae441..3dc3d3743 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicRuntimeContext.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicRuntimeContext.java @@ -80,6 +80,10 @@ public void setVertexId(Object vertexId) { this.edgeQuery.withId(vertexId); } + public Object getVertexId() { + return vertexId; + } + public IVertex loadVertex() { return vertexQuery.get(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmModelRuntimeContext.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmModelRuntimeContext.java new file mode 100644 index 000000000..d326ff4a0 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmModelRuntimeContext.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.runtime.engine; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.apache.geaflow.api.graph.function.vc.VertexCentricTraversalFunction.VertexCentricTraversalFuncContext; +import org.apache.geaflow.common.exception.GeaflowRuntimeException; +import org.apache.geaflow.dsl.common.algo.AlgorithmModelRuntimeContext; +import org.apache.geaflow.dsl.common.data.Row; +import org.apache.geaflow.dsl.common.data.RowEdge; +import org.apache.geaflow.dsl.common.types.GraphSchema; +import org.apache.geaflow.infer.InferContextLease; +import org.apache.geaflow.infer.InferContextPool; +import org.apache.geaflow.model.graph.edge.EdgeDirection; + +public class GeaFlowAlgorithmModelRuntimeContext extends GeaFlowAlgorithmRuntimeContext + implements AlgorithmModelRuntimeContext { + + public GeaFlowAlgorithmModelRuntimeContext( + GeaFlowAlgorithmAggTraversalFunction traversalFunction, + VertexCentricTraversalFuncContext traversalContext, + GraphSchema graphSchema) { + super(traversalFunction, traversalContext, graphSchema); + } + + @Override + public Object infer(Map payload) { + List results = inferBatch(Collections.singletonList(payload)); + return results.isEmpty() ? null : results.get(0); + } + + @Override + public List inferBatch(List> payloads) { + try { + try (InferContextLease inferContextLease = InferContextPool.borrow(getConfig())) { + List results = new ArrayList<>(payloads.size()); + for (Map payload : payloads) { + results.add(inferContextLease.infer(payload)); + } + return results; + } + } catch (Exception e) { + throw new GeaflowRuntimeException( + String.format("GCN model batch infer failed, vertexId=%s, batchSize=%s", + getVertexId(), payloads.size()), e); + } + } + + @Override + public Row loadDynamicVertexValue(Object vertexId) { + return null; + } + + @Override + public List loadDynamicEdges(Object vertexId, EdgeDirection direction) { + return Collections.emptyList(); + } + + @Override + public void close() { + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmRuntimeContext.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmRuntimeContext.java index 7696b4f10..d100ae2af 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmRuntimeContext.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmRuntimeContext.java @@ -71,6 +71,10 @@ public void setVertexId(Object vertexId) { this.edgeQuery.withId(vertexId); } + public Object getVertexId() { + return vertexId; + } + @SuppressWarnings("unchecked") @Override public List loadEdges(EdgeDirection direction) { diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLAlgorithmTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLAlgorithmTest.java index 8c75d6e78..59b227e03 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLAlgorithmTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLAlgorithmTest.java @@ -25,6 +25,8 @@ import org.apache.geaflow.common.config.keys.DSLConfigKeys; import org.apache.geaflow.common.config.keys.FrameworkConfigKeys; import org.apache.geaflow.file.FileConfigKeys; +import org.apache.geaflow.infer.InferContextPool; +import org.testng.annotations.AfterClass; import org.testng.annotations.Test; public class GQLAlgorithmTest { @@ -88,6 +90,33 @@ public void testAlgorithm_006() throws Exception { .checkSinkResult(); } + @Test + public void testAlgorithmGCN() throws Exception { + QueryTester + .build() + .withQueryPath("/query/gql_algorithm_gcn.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testAlgorithmGCNBatching() throws Exception { + QueryTester + .build() + .withQueryPath("/query/gql_algorithm_gcn_batch.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testAlgorithmGCNDirectedEdgeSemantics() throws Exception { + QueryTester + .build() + .withQueryPath("/query/gql_algorithm_gcn_directed.sql") + .execute() + .checkSinkResult(); + } + @Test public void testAlgorithm_008() throws Exception { QueryTester @@ -342,6 +371,11 @@ public void testEdgeIterator() throws Exception { .checkSinkResult(); } + @AfterClass(alwaysRun = true) + public void tearDownInferContextPool() { + InferContextPool.closeAll(); + } + private void clearGraph() throws IOException { File file = new File(TEST_GRAPH_PATH); if (file.exists()) { diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/QueryTester.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/QueryTester.java index 6ddcd691c..d6c5d8386 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/QueryTester.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/QueryTester.java @@ -23,6 +23,7 @@ import java.io.File; import java.io.IOException; import java.io.Serializable; +import java.lang.reflect.Field; import java.nio.charset.Charset; import java.util.Arrays; import java.util.HashMap; @@ -149,16 +150,71 @@ public QueryTester execute() throws Exception { graphDefinePath = this.graphDefinePath; } gqlPipeLine.setPipelineHook(new TestGQLPipelineHook(graphDefinePath, queryPath)); + Throwable executeFailure = null; try { gqlPipeLine.execute(); + } catch (Throwable t) { + executeFailure = t; + throw t; } finally { - environment.shutdown(); - ClusterMetaStore.close(); - ScheduledWorkerManagerFactory.clear(); + Throwable shutdownFailure = null; + try { + environment.shutdown(); + } catch (Throwable t) { + shutdownFailure = t; + } + try { + ClusterMetaStore.close(); + } catch (Throwable t) { + if (shutdownFailure == null) { + shutdownFailure = t; + } + } + try { + resetScheduledWorkerManagers(); + } catch (Throwable t) { + if (shutdownFailure == null) { + shutdownFailure = t; + } + } + if (executeFailure == null && shutdownFailure != null + && !isIgnorableShutdownFailure(shutdownFailure)) { + if (shutdownFailure instanceof Exception) { + throw (Exception) shutdownFailure; + } + throw new RuntimeException(shutdownFailure); + } } return this; } + private static boolean isIgnorableShutdownFailure(Throwable throwable) { + for (Throwable current = throwable; current != null; current = current.getCause()) { + String message = current.getMessage(); + if (message == null) { + continue; + } + if (message.contains("channel pool is closed") + || message.contains("executor not accepting a task")) { + return true; + } + } + return false; + } + + private static void resetScheduledWorkerManagers() throws IllegalAccessException, + NoSuchFieldException { + resetScheduledWorkerManager("redoWorkerManager"); + resetScheduledWorkerManager("checkpointWorkerManager"); + } + + private static void resetScheduledWorkerManager(String fieldName) + throws NoSuchFieldException, IllegalAccessException { + Field field = ScheduledWorkerManagerFactory.class.getDeclaredField(fieldName); + field.setAccessible(true); + field.set(null, null); + } + private void initResultDirectory() throws Exception { // delete target file path String targetPath = getTargetPath(queryPath); diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/TransFormFunctionUDF.py b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/TransFormFunctionUDF.py new file mode 100644 index 000000000..8537bc093 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/TransFormFunctionUDF.py @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + + +class GCNTestTransform(object): + input_size = 1 + + def load_model(self, model_path): + if not os.path.exists(model_path): + raise RuntimeError("missing model file: %s" % model_path) + # For test purpose, we don't load a real torch model. The built-in GCN UDF + # validates the Java->Python payload protocol end-to-end. + self._loaded = True + return self._infer_model + + def _infer_model(self, payload): + if not self._loaded: + raise RuntimeError("model not loaded") + + center_node_id = payload.get("center_node_id") + sampled_nodes = payload.get("sampled_nodes") or [] + node_features = payload.get("node_features") or [] + edge_index = payload.get("edge_index") or [[], []] + edge_weight = payload.get("edge_weight") or [] + + if center_node_id is None: + raise RuntimeError("missing center_node_id") + if not sampled_nodes: + raise RuntimeError("missing sampled_nodes") + if not node_features: + raise RuntimeError("missing node_features") + if center_node_id not in sampled_nodes: + raise RuntimeError("center_node_id not in sampled_nodes") + center_index = sampled_nodes.index(center_node_id) + embedding = node_features[center_index] if 0 <= center_index < len(node_features) else [] + + # Deterministic test output: + # - prediction: local index + 1 + # - confidence: number of edges in edge_weight + prediction = int(center_index) + 1 + confidence = float(len(edge_weight)) + + return { + "node_id": center_node_id, + "embedding": embedding, + "prediction": prediction, + "confidence": confidence, + "debug_edge_index_shape": [len(edge_index), len(edge_index[0]) if edge_index else 0], + } + + def transform_pre(self, payload): + if payload is None: + raise RuntimeError("missing payload") + return (payload,) + + def transform_post(self, res): + return res + + +class GCNBatchMarkerTransform(object): + input_size = 1 + + def load_model(self, model_path): + if not os.path.exists(model_path): + raise RuntimeError("missing model file: %s" % model_path) + return self._infer_model + + def _infer_model(self, payload): + center_node_id = payload.get("center_node_id") + sampled_nodes = payload.get("sampled_nodes") or [] + node_features = payload.get("node_features") or [] + if center_node_id not in sampled_nodes: + raise RuntimeError("center_node_id not in sampled_nodes") + center_index = sampled_nodes.index(center_node_id) + embedding = node_features[center_index] if 0 <= center_index < len(node_features) else [] + return { + "node_id": center_node_id, + "embedding": embedding, + "prediction": int(center_node_id), + "confidence": float(center_node_id), + } + + def transform_pre(self, payload): + return (payload,) + + def transform_post(self, res): + return res + + +class GCNDirectedEdgeCountTransform(object): + input_size = 1 + + def load_model(self, model_path): + if not os.path.exists(model_path): + raise RuntimeError("missing model file: %s" % model_path) + return self._infer_model + + def _infer_model(self, payload): + center_node_id = payload.get("center_node_id") + sampled_nodes = payload.get("sampled_nodes") or [] + node_features = payload.get("node_features") or [] + edge_index = payload.get("edge_index") or [[], []] + if center_node_id not in sampled_nodes: + raise RuntimeError("center_node_id not in sampled_nodes") + center_index = sampled_nodes.index(center_node_id) + outgoing_edges = 0 + row = edge_index[0] if len(edge_index) > 0 else [] + col = edge_index[1] if len(edge_index) > 1 else [] + for src, target in zip(row, col): + if src == center_index and target != center_index: + outgoing_edges += 1 + embedding = node_features[center_index] if 0 <= center_index < len(node_features) else [] + return { + "node_id": center_node_id, + "embedding": embedding, + "prediction": outgoing_edges, + "confidence": float(outgoing_edges), + } + + def transform_pre(self, payload): + return (payload,) + + def transform_post(self, res): + return res diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_gcn.txt b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_gcn.txt new file mode 100644 index 000000000..a2343d169 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_gcn.txt @@ -0,0 +1,3 @@ +1,1,4.0 +2,1,4.0 +3,1,4.0 diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_gcn_batch.txt b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_gcn_batch.txt new file mode 100644 index 000000000..6fd89e32d --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_gcn_batch.txt @@ -0,0 +1,5 @@ +1,1,1.0 +2,2,2.0 +3,3,3.0 +4,4,4.0 +5,5,5.0 diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_gcn_directed.txt b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_gcn_directed.txt new file mode 100644 index 000000000..945c09097 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_gcn_directed.txt @@ -0,0 +1,3 @@ +1,1,1.0 +2,2,2.0 +3,0,0.0 diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/model.pt b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/model.pt new file mode 100644 index 000000000..78560f68f --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/model.pt @@ -0,0 +1,18 @@ +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +GCN test placeholder model. Only file existence is required. diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_004.sql b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_004.sql index d5696b420..41a41d543 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_004.sql +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_004.sql @@ -98,4 +98,4 @@ USE GRAPH modern; INSERT INTO tbl_result CALL SSSP(1) YIELD (vid, distance) RETURN cast (vid as int), distance -; \ No newline at end of file +; diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_gcn.sql b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_gcn.sql new file mode 100644 index 000000000..1e31f1982 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_gcn.sql @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +set geaflow.dsl.window.size = -1; +set geaflow.dsl.gcn.vertex.feature.fields = age; +set geaflow.infer.env.enable = true; +set geaflow.infer.env.user.transform.classname = GCNTestTransform; +set geaflow.infer.env.conda.url = ''; + +CREATE GRAPH gcn_graph ( + Vertex person ( + id bigint ID, + name varchar, + age int + ), + Edge knows ( + srcId bigint SOURCE ID, + targetId bigint DESTINATION ID, + weight double + ) +) WITH ( + storeType='rocksdb', + shardNum = 1 +); + +CREATE TABLE person_source ( + name varchar, + age int, + id bigint +) WITH ( + type='file', + geaflow.dsl.file.path='resource:///data/modern_vertex_person_reorder.txt' +); + +CREATE TABLE tbl_result ( + node_id int, + prediction int, + confidence double +) WITH ( + type='file', + geaflow.dsl.file.path='${target}' +); + +-- Prepare a small connected subgraph to exercise neighbor expansion and payload assembly. +INSERT INTO gcn_graph.person VALUES (1, 'alice', 18); +INSERT INTO gcn_graph.person VALUES (2, 'bob', 20); +INSERT INTO gcn_graph.person VALUES (3, 'cathy', 22); +INSERT INTO gcn_graph.knows VALUES (1, 2, 1.0); +INSERT INTO gcn_graph.knows VALUES (2, 3, 1.0); + +USE GRAPH gcn_graph; + +INSERT INTO tbl_result +CALL gcn() YIELD (node_id, embedding, prediction, confidence) +RETURN cast(node_id as int), prediction, confidence +; diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_gcn_batch.sql b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_gcn_batch.sql new file mode 100644 index 000000000..2d35af138 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_gcn_batch.sql @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +set geaflow.dsl.window.size = -1; +set geaflow.dsl.gcn.vertex.feature.fields = age; +set geaflow.dsl.gcn.batch.size = 2; +set geaflow.infer.env.enable = true; +set geaflow.infer.env.user.transform.classname = GCNBatchMarkerTransform; +set geaflow.infer.env.conda.url = ''; + +CREATE GRAPH gcn_batch_graph ( + Vertex person ( + id bigint ID, + name varchar, + age int + ), + Edge knows ( + srcId bigint SOURCE ID, + targetId bigint DESTINATION ID, + weight double + ) +) WITH ( + storeType='rocksdb', + shardNum = 1 +); + +CREATE TABLE tbl_result ( + node_id int, + prediction int, + confidence double +) WITH ( + type='file', + geaflow.dsl.file.path='${target}' +); + +INSERT INTO gcn_batch_graph.person VALUES (1, 'alice', 18); +INSERT INTO gcn_batch_graph.person VALUES (2, 'bob', 20); +INSERT INTO gcn_batch_graph.person VALUES (3, 'cathy', 22); +INSERT INTO gcn_batch_graph.person VALUES (4, 'diana', 24); +INSERT INTO gcn_batch_graph.person VALUES (5, 'ella', 26); +INSERT INTO gcn_batch_graph.knows VALUES (1, 2, 1.0); +INSERT INTO gcn_batch_graph.knows VALUES (2, 3, 1.0); +INSERT INTO gcn_batch_graph.knows VALUES (3, 4, 1.0); +INSERT INTO gcn_batch_graph.knows VALUES (4, 5, 1.0); + +USE GRAPH gcn_batch_graph; + +INSERT INTO tbl_result +CALL gcn() YIELD (node_id, embedding, prediction, confidence) +RETURN cast(node_id as int), prediction, confidence +; diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_gcn_directed.sql b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_gcn_directed.sql new file mode 100644 index 000000000..5c2a3f2f2 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_gcn_directed.sql @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +set geaflow.dsl.window.size = -1; +set geaflow.dsl.gcn.vertex.feature.fields = age; +set geaflow.dsl.gcn.edge.weight.field = weight; +set geaflow.infer.env.enable = true; +set geaflow.infer.env.user.transform.classname = GCNDirectedEdgeCountTransform; +set geaflow.infer.env.conda.url = ''; + +CREATE GRAPH gcn_directed_graph ( + Vertex person ( + id bigint ID, + name varchar, + age int + ), + Edge knows ( + srcId bigint SOURCE ID, + targetId bigint DESTINATION ID, + weight double + ) +) WITH ( + storeType='rocksdb', + shardNum = 1 +); + +CREATE TABLE tbl_result ( + node_id int, + prediction int, + confidence double +) WITH ( + type='file', + geaflow.dsl.file.path='${target}' +); + +INSERT INTO gcn_directed_graph.person VALUES (1, 'alice', 18); +INSERT INTO gcn_directed_graph.person VALUES (2, 'bob', 20); +INSERT INTO gcn_directed_graph.person VALUES (3, 'cathy', 22); +INSERT INTO gcn_directed_graph.knows VALUES (1, 2, 1.0); +INSERT INTO gcn_directed_graph.knows VALUES (2, 1, 4.0); +INSERT INTO gcn_directed_graph.knows VALUES (2, 3, 2.0); + +USE GRAPH gcn_directed_graph; + +INSERT INTO tbl_result +CALL gcn() YIELD (node_id, embedding, prediction, confidence) +RETURN cast(node_id as int), prediction, confidence +; diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/requirements.txt b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/requirements.txt new file mode 100644 index 000000000..a67d5ea25 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/requirements.txt @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/torch.py b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/torch.py new file mode 100644 index 000000000..89e0d4610 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/torch.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +def set_num_threads(_thread_count): + return None diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferContext.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferContext.java index 0289c1985..76006cf5e 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferContext.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferContext.java @@ -23,6 +23,8 @@ import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.ReentrantLock; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.infer.exchange.DataExchangeContext; @@ -40,6 +42,9 @@ public class InferContext implements AutoCloseable { private final String receiveQueueKey; private InferTaskRunImpl inferTaskRunner; private InferDataBridgeImpl dataBridge; + private final ReentrantLock inferLock = new ReentrantLock(); + private final AtomicBoolean closed = new AtomicBoolean(false); + private volatile boolean broken; public InferContext(Configuration config) { this.shareMemoryContext = new DataExchangeContext(config); @@ -62,17 +67,22 @@ private void init() { } public OUT infer(Object... feature) throws Exception { + ensureAvailable(); + inferLock.lock(); try { + ensureAvailable(); dataBridge.write(feature); return dataBridge.read(); } catch (Exception e) { - inferTaskRunner.stop(); + broken = true; + stopInferTask(); LOGGER.error("model infer read result error, python process stopped", e); throw new GeaflowRuntimeException("receive infer result exception", e); + } finally { + inferLock.unlock(); } } - private InferEnvironmentContext getInferEnvironmentContext() { boolean initFinished = InferEnvironmentManager.checkInferEnvironmentStatus(); while (!initFinished) { @@ -93,11 +103,44 @@ private void runInferTask(InferEnvironmentContext inferEnvironmentContext) { inferTaskRunner.run(runCommands); } - @Override - public void close() { + public boolean isBroken() { + return broken; + } + + private void ensureAvailable() { + if (closed.get()) { + throw new GeaflowRuntimeException("infer context already closed"); + } + if (broken) { + throw new GeaflowRuntimeException("infer context is broken"); + } + } + + private void stopInferTask() { if (inferTaskRunner != null) { inferTaskRunner.stop(); + } + } + + @Override + public void close() { + if (!closed.compareAndSet(false, true)) { + return; + } + shareMemoryContext.markFinished(); + stopInferTask(); + inferLock.lock(); + try { + if (dataBridge != null) { + dataBridge.close(); + dataBridge = null; + } + shareMemoryContext.close(); LOGGER.info("infer task stop after close"); + } catch (Exception e) { + throw new GeaflowRuntimeException("close infer context failed", e); + } finally { + inferLock.unlock(); } } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferContextLease.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferContextLease.java new file mode 100644 index 000000000..6cfbf697b --- /dev/null +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferContextLease.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.infer; + +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.geaflow.common.exception.GeaflowRuntimeException; + +public class InferContextLease implements AutoCloseable { + + private final InferContextPool.PoolEntry poolEntry; + private final InferContext inferContext; + private final AtomicBoolean closed = new AtomicBoolean(false); + + InferContextLease(InferContextPool.PoolEntry poolEntry, InferContext inferContext) { + this.poolEntry = poolEntry; + this.inferContext = inferContext; + } + + public OUT infer(Object... feature) throws Exception { + if (closed.get()) { + throw new GeaflowRuntimeException("infer context lease already closed"); + } + return inferContext.infer(feature); + } + + @Override + public void close() { + if (closed.compareAndSet(false, true)) { + poolEntry.release(inferContext, inferContext.isBroken()); + } + } +} diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferContextPool.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferContextPool.java new file mode 100644 index 000000000..5bbc4bb94 --- /dev/null +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferContextPool.java @@ -0,0 +1,329 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.infer; + +import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.INFER_CONTEXT_POOL_BORROW_TIMEOUT_SEC; +import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.INFER_CONTEXT_POOL_MAX_SIZE; + +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; +import org.apache.geaflow.common.config.Configuration; +import org.apache.geaflow.common.exception.GeaflowRuntimeException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Shared pool for infer contexts keyed by stable configuration fingerprint. + */ +public class InferContextPool { + + private static final Logger LOGGER = LoggerFactory.getLogger(InferContextPool.class); + + private static final ConcurrentHashMap CONTEXT_POOL = + new ConcurrentHashMap<>(); + + private static final AtomicLong CREATED_CONTEXT_COUNT = new AtomicLong(); + private static final AtomicLong CLOSED_CONTEXT_COUNT = new AtomicLong(); + private static volatile InferContextFactory contextFactory = InferContext::new; + + private InferContextPool() { + } + + @SuppressWarnings("unchecked") + public static InferContextLease borrow(Configuration config) { + String key = generateConfigKey(config); + int maxSize = Math.max(1, config.getInteger(INFER_CONTEXT_POOL_MAX_SIZE)); + long timeoutMillis = Math.max(0L, + TimeUnit.SECONDS.toMillis(config.getInteger(INFER_CONTEXT_POOL_BORROW_TIMEOUT_SEC))); + while (true) { + PoolEntry entry = CONTEXT_POOL.computeIfAbsent(key, PoolEntry::new); + try { + return new InferContextLease<>(entry, + (InferContext) entry.borrow(config, maxSize, timeoutMillis)); + } catch (PoolEntryClosedException e) { + CONTEXT_POOL.remove(key, entry); + } + } + } + + public static void closeAll() { + List entries = new ArrayList<>(CONTEXT_POOL.values()); + for (PoolEntry entry : entries) { + entry.closeEntry(); + } + CONTEXT_POOL.clear(); + LOGGER.info("Closed all InferContext instances, closedPoolSize={}, createdCount={}, closedCount={}", + entries.size(), CREATED_CONTEXT_COUNT.get(), CLOSED_CONTEXT_COUNT.get()); + } + + public static String getStatus() { + return String.format("InferContextPool{size=%d, instances=%s, createdCount=%d, closedCount=%d}", + CONTEXT_POOL.size(), CONTEXT_POOL.keySet(), CREATED_CONTEXT_COUNT.get(), + CLOSED_CONTEXT_COUNT.get()); + } + + private static String generateConfigKey(Configuration config) { + StringBuilder builder = new StringBuilder(); + builder.append("masterId=").append(nullToEmpty(config.getMasterId())).append('\n'); + Map sortedConfig = new TreeMap<>(config.getConfigMap()); + for (Map.Entry entry : sortedConfig.entrySet()) { + builder.append(entry.getKey()) + .append('=') + .append(nullToEmpty(entry.getValue())) + .append('\n'); + } + return "infer_" + sha256Hex(builder.toString()); + } + + static void setContextFactoryForTest(InferContextFactory factory) { + contextFactory = factory; + } + + static void resetContextFactoryForTest() { + contextFactory = InferContext::new; + } + + private static String nullToEmpty(String value) { + return value == null ? "" : value; + } + + private static String sha256Hex(String value) { + try { + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + byte[] bytes = digest.digest(value.getBytes(StandardCharsets.UTF_8)); + StringBuilder builder = new StringBuilder(bytes.length * 2); + for (byte b : bytes) { + builder.append(String.format("%02x", b)); + } + return builder.toString(); + } catch (NoSuchAlgorithmException e) { + throw new GeaflowRuntimeException("SHA-256 is not available", e); + } + } + + interface InferContextFactory { + + InferContext create(Configuration config); + } + + static class PoolEntry { + + private final String key; + private final ReentrantLock lock = new ReentrantLock(); + private final Condition available = lock.newCondition(); + private final Deque> idleContexts = new ArrayDeque<>(); + private final Set> allContexts = new HashSet<>(); + private int creatingCount; + private int borrowedCount; + private int waitingBorrowers; + private boolean closed; + + PoolEntry(String key) { + this.key = key; + } + + InferContext borrow(Configuration config, int maxSize, long timeoutMillis) { + long remainingNanos = TimeUnit.MILLISECONDS.toNanos(timeoutMillis); + while (true) { + InferContext idleContext = tryBorrowIdle(); + if (idleContext != null) { + return idleContext; + } + if (reserveCreation(maxSize)) { + return createContext(config); + } + remainingNanos = awaitAvailable(remainingNanos, timeoutMillis); + } + } + + void release(InferContext inferContext, boolean broken) { + List> contextsToClose = new ArrayList<>(); + boolean removeEntry; + lock.lock(); + try { + if (!allContexts.contains(inferContext)) { + return; + } + borrowedCount--; + if (closed || broken) { + allContexts.remove(inferContext); + contextsToClose.add(inferContext); + } else { + idleContexts.addLast(inferContext); + } + available.signal(); + removeEntry = (closed || allContexts.isEmpty()) && borrowedCount == 0 + && waitingBorrowers == 0; + } finally { + lock.unlock(); + } + closeContexts(contextsToClose); + if (removeEntry) { + CONTEXT_POOL.remove(key, this); + } + } + + void closeEntry() { + List> contextsToClose; + lock.lock(); + try { + if (closed && allContexts.isEmpty() && idleContexts.isEmpty()) { + return; + } + closed = true; + contextsToClose = new ArrayList<>(allContexts); + idleContexts.clear(); + allContexts.clear(); + borrowedCount = 0; + available.signalAll(); + } finally { + lock.unlock(); + } + closeContexts(contextsToClose); + } + + private InferContext tryBorrowIdle() { + lock.lock(); + try { + if (closed) { + throw new PoolEntryClosedException(); + } + InferContext inferContext = idleContexts.pollFirst(); + if (inferContext != null) { + borrowedCount++; + } + return inferContext; + } finally { + lock.unlock(); + } + } + + private boolean reserveCreation(int maxSize) { + lock.lock(); + try { + if (closed) { + throw new PoolEntryClosedException(); + } + if (allContexts.size() + creatingCount >= maxSize) { + return false; + } + creatingCount++; + return true; + } finally { + lock.unlock(); + } + } + + private InferContext createContext(Configuration config) { + InferContext inferContext; + try { + inferContext = contextFactory.create(config); + } catch (RuntimeException e) { + lock.lock(); + try { + creatingCount--; + available.signal(); + } finally { + lock.unlock(); + } + throw e; + } + + lock.lock(); + try { + creatingCount--; + if (closed) { + available.signal(); + closeContexts(java.util.Collections.singletonList(inferContext)); + throw new PoolEntryClosedException(); + } + allContexts.add(inferContext); + borrowedCount++; + long createdCount = CREATED_CONTEXT_COUNT.incrementAndGet(); + LOGGER.info("Created InferContext for key {}, poolSize={}, createdCount={}", + key, CONTEXT_POOL.size(), createdCount); + return inferContext; + } finally { + lock.unlock(); + } + } + + private long awaitAvailable(long remainingNanos, long timeoutMillis) { + lock.lock(); + try { + if (closed) { + throw new PoolEntryClosedException(); + } + if (timeoutMillis == 0L) { + throw borrowTimeoutException(key, timeoutMillis); + } + waitingBorrowers++; + try { + if (remainingNanos <= 0L) { + throw borrowTimeoutException(key, timeoutMillis); + } + return available.awaitNanos(remainingNanos); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new GeaflowRuntimeException("infer context borrow interrupted", e); + } finally { + waitingBorrowers--; + } + } finally { + lock.unlock(); + } + } + + private void closeContexts(List> contexts) { + for (InferContext inferContext : contexts) { + try { + inferContext.close(); + CLOSED_CONTEXT_COUNT.incrementAndGet(); + } catch (Exception e) { + LOGGER.warn("Failed to close InferContext for key {}", key, e); + } + } + } + } + + private static GeaflowRuntimeException borrowTimeoutException(String key, long timeoutMillis) { + return new GeaflowRuntimeException(String.format( + "borrow infer context timeout, key=%s, timeoutMillis=%d", key, timeoutMillis)); + } + + static final class PoolEntryClosedException extends RuntimeException { + + private static final long serialVersionUID = 1L; + } +} diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeContext.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeContext.java index 417e72703..2cdfd8ede 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeContext.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeContext.java @@ -24,6 +24,7 @@ import java.io.Closeable; import java.io.File; import java.io.IOException; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -48,6 +49,9 @@ public class DataExchangeContext implements Closeable { private final File receiveQueueFile; private final File sendQueueFile; + private final Thread queueEndpointCleanupHook; + private final AtomicBoolean closed = new AtomicBoolean(false); + private final AtomicBoolean queueEndpointFreed = new AtomicBoolean(false); private String receivePath; private String sendPath; @@ -62,7 +66,8 @@ public DataExchangeContext(Configuration config) { int queueCapacity = config.getInteger(INFER_ENV_SHARE_MEMORY_QUEUE_SIZE); this.receiveQueue = new DataExchangeQueue(receivePath, queueCapacity, true); this.sendQueue = new DataExchangeQueue(sendPath, queueCapacity, true); - Runtime.getRuntime().addShutdownHook(new Thread(() -> UnSafeUtils.UNSAFE.freeMemory(queueEndpoint))); + this.queueEndpointCleanupHook = new Thread(this::freeQueueEndpoint); + Runtime.getRuntime().addShutdownHook(queueEndpointCleanupHook); } public String getReceiveQueueKey() { @@ -75,6 +80,10 @@ public String getSendQueueKey() { @Override public synchronized void close() throws IOException { + if (!closed.compareAndSet(false, true)) { + return; + } + markFinished(); if (receiveQueue != null) { receiveQueue.close(); } @@ -87,7 +96,12 @@ public synchronized void close() throws IOException { if (sendQueueFile != null) { sendQueueFile.delete(); } - UnSafeUtils.UNSAFE.freeMemory(this.queueEndpoint); + try { + Runtime.getRuntime().removeShutdownHook(queueEndpointCleanupHook); + } catch (IllegalStateException ignored) { + // Ignore shutdown-in-progress failures and rely on idempotent free. + } + freeQueueEndpoint(); FileUtils.deleteQuietly(localDirectory); } @@ -99,6 +113,15 @@ public DataExchangeQueue getSendQueue() { return sendQueue; } + public void markFinished() { + if (receiveQueue != null) { + receiveQueue.markFinished(); + } + if (sendQueue != null) { + sendQueue.markFinished(); + } + } + private File createTempFile(String prefix, String suffix) { try { if (!localDirectory.exists()) { @@ -109,4 +132,10 @@ private File createTempFile(String prefix, String suffix) { throw new GeaflowRuntimeException("create temp file on infer directory failed ", e); } } + + private void freeQueueEndpoint() { + if (queueEndpointFreed.compareAndSet(false, true)) { + UnSafeUtils.UNSAFE.freeMemory(queueEndpoint); + } + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeQueue.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeQueue.java index 29057f60e..c02ecf2c4 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeQueue.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeQueue.java @@ -26,7 +26,7 @@ public final class DataExchangeQueue implements Closeable { - private static final AtomicBoolean CLOSED = new AtomicBoolean(false); + private final AtomicBoolean closed = new AtomicBoolean(false); private final long outputNextAddress; private final long capacityAddress; private final long outputAddress; @@ -66,11 +66,15 @@ public DataExchangeQueue(String mapKey, int capacity, boolean reset) { @Override public synchronized void close() { - CLOSED.set(true); + if (!closed.compareAndSet(false, true)) { + return; + } if (memoryMapper != null) { memoryMapper.close(); } - UnSafeUtils.UNSAFE.freeMemory(mapAddress); + if (mapAddress != 0) { + UnSafeUtils.UNSAFE.freeMemory(mapAddress); + } } public long getMemoryMapSize() { @@ -133,7 +137,7 @@ public boolean enableFinished() { } public synchronized void markFinished() { - if (!CLOSED.get()) { + if (!closed.get()) { UnSafeUtils.UNSAFE.putOrderedLong(null, endPointAddress, -1); } } @@ -165,4 +169,4 @@ public static long getNextPointIndex(long v, int capacity) { } return Pow2.align(v, capacity); } -} \ No newline at end of file +} diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/util/InferFileUtils.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/util/InferFileUtils.java index a7a570cc2..d4d3b52db 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/util/InferFileUtils.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/util/InferFileUtils.java @@ -38,6 +38,7 @@ import java.nio.file.FileSystems; import java.nio.file.Files; import java.nio.file.Path; +import java.nio.file.Paths; import java.nio.file.StandardOpenOption; import java.util.Arrays; import java.util.Collections; @@ -69,6 +70,11 @@ public class InferFileUtils { public static final String JAR_FILE_EXTENSION = ".jar"; + public static final String MODEL_FILE_NAME = "model.pt"; + + private static final List NATIVE_LIB_EXTENSIONS = Arrays.asList( + ".so", ".pyd", ".dll", ".dylib"); + private static final int DEFAULT_BUFFER_SIZE = 1024; public static final String REQUIREMENTS_TXT = "requirements.txt"; @@ -239,42 +245,87 @@ public static List getPathsFromResourceJAR(String folder) throws URISyntax public static void prepareInferFilesFromJars(String targetDirectory) { File userJobJarFile = getUserJobJarFile(); - Preconditions.checkNotNull(userJobJarFile); - try { - JarFile jarFile = new JarFile(userJobJarFile); + if (userJobJarFile == null) { + LOGGER.info("cannot find user job jar in classpath root, fallback to classpath files"); + prepareInferFilesFromClasspath(targetDirectory); + return; + } + try (JarFile jarFile = new JarFile(userJobJarFile)) { Enumeration entries = jarFile.entries(); while (entries.hasMoreElements()) { JarEntry entry = entries.nextElement(); String entryName = entry.getName(); - if (!entry.isDirectory()) { + if (!entry.isDirectory() && shouldExtractInferResource(entryName)) { String inferFile = extractFile(targetDirectory, entryName, entry, jarFile); - LOGGER.info("cp infer file {} to {} from jar file {}", entryName, inferFile, userJobJarFile.getName()); - } else { - File entryDestination = new File(targetDirectory, entry.getName()); - if (!entryDestination.exists()) { - entryDestination.mkdirs(); - } - LOGGER.info("create infer directory is {}", entryDestination); + LOGGER.info("cp infer file {} to {} from jar file {}", entryName, inferFile, + userJobJarFile.getName()); } } - jarFile.close(); } catch (IOException e) { LOGGER.error("open jar file {} failed", userJobJarFile.getName()); } } + private static void prepareInferFilesFromClasspath(String targetDirectory) { + List resourceFiles = getPythonFilesByCondition(file -> + file.isFile() && isInferResourceFile(file.getName())); + for (File resourceFile : resourceFiles) { + String inferFile = copyPythonFile(targetDirectory, resourceFile); + LOGGER.info("cp infer file {} to {} from classpath resource", resourceFile.getName(), + inferFile); + } + } + + private static boolean isInferResourceFile(String fileName) { + return fileName.endsWith(PY_FILE_EXTENSION) + || REQUIREMENTS_TXT.equals(fileName) + || MODEL_FILE_NAME.equals(fileName) + || isNativeLibFile(fileName); + } + + private static boolean shouldExtractInferResource(String entryName) { + if (entryName == null || entryName.contains("..")) { + return false; + } + String normalized = entryName.replace('\\', '/'); + String fileName = normalized.substring(normalized.lastIndexOf('/') + 1); + return isInferResourceFile(fileName); + } + + private static boolean isNativeLibFile(String fileName) { + for (String extension : NATIVE_LIB_EXTENSIONS) { + if (fileName.endsWith(extension)) { + return true; + } + } + return false; + } private static String extractFile(String targetDirectory, String fileName, JarEntry entry, JarFile jarFile) throws IOException { - String targetFilePath = targetDirectory + File.separator + fileName; + File targetFile = buildSafeTargetFile(targetDirectory, fileName); + File parent = targetFile.getParentFile(); + if (parent != null && !parent.exists()) { + forceMkdir(parent); + } try (InputStream inputStream = jarFile.getInputStream(entry); - FileOutputStream outputStream = new FileOutputStream(targetFilePath)) { + FileOutputStream outputStream = new FileOutputStream(targetFile)) { byte[] buffer = new byte[DEFAULT_BUFFER_SIZE]; int bytesRead; while ((bytesRead = inputStream.read(buffer)) != -1) { outputStream.write(buffer, 0, bytesRead); } } - return targetFilePath; + return targetFile.getAbsolutePath(); + } + + private static File buildSafeTargetFile(String targetDirectory, String fileName) + throws IOException { + Path targetRoot = Paths.get(targetDirectory).toAbsolutePath().normalize(); + Path targetPath = targetRoot.resolve(fileName).normalize(); + if (!targetPath.startsWith(targetRoot)) { + throw new IOException("illegal infer resource path: " + fileName); + } + return targetPath.toFile(); } } diff --git a/geaflow/geaflow-infer/src/main/resources/infer/env/install-infer-env.sh b/geaflow/geaflow-infer/src/main/resources/infer/env/install-infer-env.sh index 28259fb6f..1865d90f1 100644 --- a/geaflow/geaflow-infer/src/main/resources/infer/env/install-infer-env.sh +++ b/geaflow/geaflow-infer/src/main/resources/infer/env/install-infer-env.sh @@ -28,10 +28,9 @@ echo "execute shell at path ${CURRENT_DIR}" echo "install requirements path ${REQUIREMENTS_PATH}" MINICONDA_INSTALL=$CURRENT_DIR/miniconda.sh -[ ! -e $MINICONDA_INSTALL ] && touch $MINICONDA_INSTALL function install_miniconda() { - if [ ! -f "$CONDA_INSTALL" ]; then + if [ ! -s "$MINICONDA_INSTALL" ]; then print_function "STEP" "download miniconda oss ${MINICOMDA_OSS_URL}..." download $MINICOMDA_OSS_URL $MINICONDA_INSTALL chmod +x $MINICONDA_INSTALL diff --git a/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/inferSession.py b/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/inferSession.py index 63ef72ccc..e40862d71 100644 --- a/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/inferSession.py +++ b/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/inferSession.py @@ -19,22 +19,27 @@ import torch torch.set_num_threads(1) -# class TorchInferSession(object): -# def __init__(self, transform_class) -> None: -# self._transform = transform_class -# self._model_path = os.getcwd() + "/model.pt" -# self._model = transform_class.load_model(self._model_path) -# -# def run(self, *inputs): -# feature = self._transform.transform_pre(*inputs) -# res = self._model(*feature) -# return self._transform.transform_post(res) - class TorchInferSession(object): def __init__(self, transform_class) -> None: self._transform = transform_class + self._model_path = os.getcwd() + "/model.pt" + if not hasattr(transform_class, "load_model"): + raise RuntimeError("transform class must define load_model(model_path)") + self._model = transform_class.load_model(self._model_path) + self._legacy_mode = self._model is None + if not self._legacy_mode and not callable(self._model): + raise RuntimeError("load_model(model_path) must return a callable model") def run(self, *inputs): - a,b = self._transform.transform_pre(*inputs) - return self._transform.transform_post(a) + if self._legacy_mode: + return self._transform.transform_post(self._transform.transform_pre(*inputs)) + feature = self._transform.transform_pre(*inputs) + if isinstance(feature, tuple): + model_args = feature + elif isinstance(feature, list): + model_args = tuple(feature) + else: + model_args = (feature,) + res = self._model(*model_args) + return self._transform.transform_post(res) diff --git a/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/setup.py b/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/setup.py index cc9d1c61b..6492b7b78 100644 --- a/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/setup.py +++ b/geaflow/geaflow-infer/src/main/resources/infer/inferRuntime/setup.py @@ -20,6 +20,13 @@ # cython:language_level=3 from distutils.core import setup -from Cython.Build import cythonize +from distutils.extension import Extension -setup(ext_modules=cythonize("mmap_ipc.pyx")) +try: + from Cython.Build import cythonize +except ImportError: + ext_modules = [Extension("mmap_ipc", sources=["mmap_ipc.cpp"], language="c++")] +else: + ext_modules = cythonize("mmap_ipc.pyx") + +setup(ext_modules=ext_modules) diff --git a/geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/InferContextCloseTest.java b/geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/InferContextCloseTest.java new file mode 100644 index 000000000..352b727b1 --- /dev/null +++ b/geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/InferContextCloseTest.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.infer; + +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +import java.lang.reflect.Field; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; +import org.apache.geaflow.infer.exchange.DataExchangeContext; +import org.apache.geaflow.infer.exchange.UnSafeUtils; +import org.apache.geaflow.infer.exchange.impl.InferDataBridgeImpl; +import org.testng.annotations.Test; + +public class InferContextCloseTest { + + @Test + public void testCloseDoesNotWaitForPendingInferRead() throws Exception { + InferContext context = newBareContext(); + CountDownLatch readEntered = new CountDownLatch(1); + CountDownLatch releaseRead = new CountDownLatch(1); + AtomicReference inferResult = new AtomicReference<>(); + AtomicReference inferError = new AtomicReference<>(); + + InferDataBridgeImpl dataBridge = mock(InferDataBridgeImpl.class); + doNothing().when(dataBridge).close(); + doAnswer(invocation -> true).when(dataBridge).write(org.mockito.Matchers.anyVararg()); + doAnswer(invocation -> { + readEntered.countDown(); + if (!releaseRead.await(2, TimeUnit.SECONDS)) { + throw new IllegalStateException("timed out waiting for close signal"); + } + return "ok"; + }).when(dataBridge).read(); + + DataExchangeContext shareMemoryContext = mock(DataExchangeContext.class); + doAnswer(invocation -> { + releaseRead.countDown(); + return null; + }).when(shareMemoryContext).markFinished(); + doNothing().when(shareMemoryContext).close(); + + setField(context, "inferLock", new ReentrantLock()); + setField(context, "closed", new AtomicBoolean(false)); + setField(context, "broken", false); + setField(context, "dataBridge", dataBridge); + setField(context, "shareMemoryContext", shareMemoryContext); + + Thread inferThread = new Thread(() -> { + try { + inferResult.set(context.infer("payload")); + } catch (Throwable e) { + inferError.set(e); + } + }); + inferThread.start(); + + assertTrue(readEntered.await(1, TimeUnit.SECONDS), "infer() did not enter read()"); + + Thread closeThread = new Thread(context::close); + closeThread.start(); + Thread.sleep(200L); + + assertFalse(closeThread.isAlive(), "close() should not wait for infer() to finish"); + + inferThread.join(1000L); + closeThread.join(1000L); + + assertFalse(inferThread.isAlive(), "infer() did not complete after close()"); + assertFalse(closeThread.isAlive(), "close() did not complete"); + assertEquals(inferResult.get(), "ok"); + assertEquals(inferError.get(), null); + } + + private InferContext newBareContext() throws InstantiationException { + @SuppressWarnings("unchecked") + InferContext context = (InferContext) UnSafeUtils.UNSAFE + .allocateInstance(InferContext.class); + return context; + } + + private static void setField(Object target, String fieldName, Object value) throws Exception { + Field field = InferContext.class.getDeclaredField(fieldName); + long offset = UnSafeUtils.UNSAFE.objectFieldOffset(field); + if (field.getType() == boolean.class) { + UnSafeUtils.UNSAFE.putBoolean(target, offset, (Boolean) value); + } else { + UnSafeUtils.UNSAFE.putObject(target, offset, value); + } + } +} diff --git a/geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/InferContextPoolTest.java b/geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/InferContextPoolTest.java new file mode 100644 index 000000000..0fd77fb52 --- /dev/null +++ b/geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/InferContextPoolTest.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.infer; + +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotSame; +import static org.testng.Assert.assertThrows; +import static org.testng.Assert.assertTrue; + +import java.util.LinkedHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.geaflow.common.config.Configuration; +import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; +import org.apache.geaflow.common.config.keys.FrameworkConfigKeys; +import org.apache.geaflow.file.FileConfigKeys; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.Test; + +public class InferContextPoolTest { + + @AfterMethod + public void tearDown() { + InferContextPool.closeAll(); + InferContextPool.resetContextFactoryForTest(); + } + + @Test + public void testBorrowCreatesDifferentContextsForConcurrentLeases() { + AtomicInteger created = new AtomicInteger(); + InferContextPool.setContextFactoryForTest(config -> mockContext(created.incrementAndGet())); + Configuration config = buildPoolConfig(); + + try (InferContextLease first = InferContextPool.borrow(config); + InferContextLease second = InferContextPool.borrow(config)) { + assertNotSame(first, second); + assertEquals(created.get(), 2); + } + } + + @Test + public void testBorrowReusesStableConfigFingerprint() throws Exception { + AtomicInteger created = new AtomicInteger(); + CopyOnWriteArrayList> contexts = new CopyOnWriteArrayList<>(); + InferContextPool.setContextFactoryForTest(config -> { + InferContext context = mockContext(created.incrementAndGet()); + contexts.add(context); + return context; + }); + + Configuration firstConfig = buildPoolConfig(); + Configuration secondConfig = new Configuration(new LinkedHashMap() {{ + put("k2", "v2"); + put("k1", "v1"); + put(FrameworkConfigKeys.INFER_CONTEXT_POOL_MAX_SIZE.getKey(), "1"); + put(FrameworkConfigKeys.INFER_CONTEXT_POOL_BORROW_TIMEOUT_SEC.getKey(), "1"); + put(FrameworkConfigKeys.INFER_ENV_SHARE_MEMORY_QUEUE_SIZE.getKey(), "1024"); + put(ExecutionConfigKeys.JOB_WORK_PATH.getKey(), System.getProperty("java.io.tmpdir")); + put(FileConfigKeys.USER_NAME.getKey(), "tester"); + put(ExecutionConfigKeys.JOB_UNIQUE_ID.getKey(), "job"); + }}); + secondConfig.setMasterId(firstConfig.getMasterId()); + firstConfig.put(FrameworkConfigKeys.INFER_CONTEXT_POOL_MAX_SIZE, "1"); + + CountDownLatch borrowStarted = new CountDownLatch(1); + CountDownLatch borrowFinished = new CountDownLatch(1); + AtomicInteger borrowedCount = new AtomicInteger(); + + try (InferContextLease first = InferContextPool.borrow(firstConfig)) { + Thread waiter = new Thread(() -> { + borrowStarted.countDown(); + try (InferContextLease ignored = InferContextPool.borrow(secondConfig)) { + borrowedCount.incrementAndGet(); + } finally { + borrowFinished.countDown(); + } + }); + waiter.start(); + assertTrue(borrowStarted.await(1, TimeUnit.SECONDS)); + Thread.sleep(100L); + assertEquals(created.get(), 1); + assertEquals(borrowedCount.get(), 0); + } + + assertTrue(borrowFinished.await(1, TimeUnit.SECONDS)); + assertEquals(borrowedCount.get(), 1); + assertEquals(created.get(), 1); + assertEquals(contexts.size(), 1); + } + + @Test + public void testBorrowReusesContextAcrossSequentialLeases() throws Exception { + AtomicInteger created = new AtomicInteger(); + InferContextPool.setContextFactoryForTest(config -> mockContext(created.incrementAndGet())); + Configuration config = buildPoolConfig(); + config.put(FrameworkConfigKeys.INFER_CONTEXT_POOL_MAX_SIZE, "1"); + + try (InferContextLease first = InferContextPool.borrow(config)) { + assertEquals(first.infer("value"), Integer.valueOf(1)); + } + + try (InferContextLease second = InferContextPool.borrow(config)) { + assertEquals(second.infer("value"), Integer.valueOf(1)); + } + + assertEquals(created.get(), 1); + } + + @Test + public void testBorrowBlocksUntilLeaseReturned() throws Exception { + AtomicInteger created = new AtomicInteger(); + InferContextPool.setContextFactoryForTest(config -> mockContext(created.incrementAndGet())); + Configuration config = buildPoolConfig(); + config.put(FrameworkConfigKeys.INFER_CONTEXT_POOL_MAX_SIZE, "1"); + config.put(FrameworkConfigKeys.INFER_CONTEXT_POOL_BORROW_TIMEOUT_SEC, "1"); + + CountDownLatch borrowStarted = new CountDownLatch(1); + CountDownLatch borrowFinished = new CountDownLatch(1); + AtomicInteger borrowedCount = new AtomicInteger(); + + try (InferContextLease first = InferContextPool.borrow(config)) { + Thread waiter = new Thread(() -> { + borrowStarted.countDown(); + try (InferContextLease ignored = InferContextPool.borrow(config)) { + borrowedCount.incrementAndGet(); + } finally { + borrowFinished.countDown(); + } + }); + waiter.start(); + assertTrue(borrowStarted.await(1, TimeUnit.SECONDS)); + Thread.sleep(100L); + assertEquals(borrowedCount.get(), 0); + } + + assertTrue(borrowFinished.await(1, TimeUnit.SECONDS)); + assertEquals(borrowedCount.get(), 1); + assertEquals(created.get(), 1); + } + + @Test + public void testBorrowTimeoutWhenPoolExhausted() { + AtomicInteger created = new AtomicInteger(); + InferContextPool.setContextFactoryForTest(config -> mockContext(created.incrementAndGet())); + Configuration config = buildPoolConfig(); + config.put(FrameworkConfigKeys.INFER_CONTEXT_POOL_MAX_SIZE, "1"); + config.put(FrameworkConfigKeys.INFER_CONTEXT_POOL_BORROW_TIMEOUT_SEC, "0"); + + try (InferContextLease ignored = InferContextPool.borrow(config)) { + assertThrows(RuntimeException.class, () -> InferContextPool.borrow(config)); + } + assertEquals(created.get(), 1); + } + + @Test + public void testBrokenContextIsClosedInsteadOfReused() throws Exception { + AtomicInteger created = new AtomicInteger(); + InferContext brokenContext = mock(InferContext.class); + doAnswer(invocation -> { + throw new RuntimeException("broken"); + }).when(brokenContext).infer(org.mockito.Matchers.anyVararg()); + doAnswer(invocation -> true).when(brokenContext).isBroken(); + InferContext healthyContext = mockContext(2); + InferContextPool.setContextFactoryForTest(config -> + created.getAndIncrement() == 0 ? brokenContext : healthyContext); + + Configuration config = buildPoolConfig(); + InferContextLease lease = InferContextPool.borrow(config); + assertThrows(RuntimeException.class, () -> lease.infer("value")); + lease.close(); + + try (InferContextLease next = InferContextPool.borrow(config)) { + assertEquals(next.infer("value"), Integer.valueOf(2)); + } + + verify(brokenContext, times(1)).close(); + } + + @Test + public void testContextReturnsToPoolAfterTransientInferFailure() throws Exception { + AtomicInteger created = new AtomicInteger(); + AtomicInteger invocations = new AtomicInteger(); + InferContext inferContext = mock(InferContext.class); + doAnswer(invocation -> { + if (invocations.getAndIncrement() == 0) { + throw new RuntimeException("temporary failure"); + } + return 7; + }).when(inferContext).infer(org.mockito.Matchers.anyVararg()); + doAnswer(invocation -> false).when(inferContext).isBroken(); + InferContextPool.setContextFactoryForTest(config -> { + created.incrementAndGet(); + return inferContext; + }); + + Configuration config = buildPoolConfig(); + try (InferContextLease lease = InferContextPool.borrow(config)) { + assertThrows(RuntimeException.class, () -> lease.infer("value")); + } + + try (InferContextLease retry = InferContextPool.borrow(config)) { + assertEquals(retry.infer("value"), Integer.valueOf(7)); + } + + assertEquals(created.get(), 1); + verify(inferContext, times(0)).close(); + } + + @Test + public void testCloseAllClosesBorrowedAndIdleContexts() throws Exception { + AtomicInteger created = new AtomicInteger(); + InferContext firstContext = mockContext(1); + InferContext secondContext = mockContext(2); + InferContextPool.setContextFactoryForTest(config -> + created.getAndIncrement() == 0 ? firstContext : secondContext); + + Configuration config = buildPoolConfig(); + InferContextLease borrowedLease = InferContextPool.borrow(config); + try (InferContextLease idleLease = InferContextPool.borrow(config)) { + assertEquals(idleLease.infer("value"), Integer.valueOf(2)); + } + + InferContextPool.closeAll(); + borrowedLease.close(); + + verify(firstContext, times(1)).close(); + verify(secondContext, times(1)).close(); + } + + private InferContext mockContext(Object response) { + InferContext inferContext = mock(InferContext.class); + try { + doAnswer(invocation -> response).when(inferContext).infer(org.mockito.Matchers.anyVararg()); + } catch (Exception e) { + throw new RuntimeException(e); + } + doAnswer(invocation -> false).when(inferContext).isBroken(); + return inferContext; + } + + private Configuration buildPoolConfig() { + Configuration config = new Configuration(); + config.setMasterId("master"); + config.put("k1", "v1"); + config.put("k2", "v2"); + config.put(FrameworkConfigKeys.INFER_CONTEXT_POOL_MAX_SIZE, "2"); + config.put(FrameworkConfigKeys.INFER_CONTEXT_POOL_BORROW_TIMEOUT_SEC, "1"); + config.put(FrameworkConfigKeys.INFER_ENV_SHARE_MEMORY_QUEUE_SIZE, "1024"); + config.put(ExecutionConfigKeys.JOB_WORK_PATH, System.getProperty("java.io.tmpdir")); + config.put(FileConfigKeys.USER_NAME, "tester"); + config.put(ExecutionConfigKeys.JOB_UNIQUE_ID, "job"); + return config; + } +} diff --git a/geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/InferEnvironmentInstallScriptTest.java b/geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/InferEnvironmentInstallScriptTest.java new file mode 100644 index 000000000..f4a574c27 --- /dev/null +++ b/geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/InferEnvironmentInstallScriptTest.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.infer; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.attribute.PosixFilePermission; +import java.util.EnumSet; +import java.util.Set; +import java.util.stream.Stream; +import org.testng.annotations.Test; + +public class InferEnvironmentInstallScriptTest { + + @Test + public void testFreshInstallScriptDownloadsInstallerIntoNonEmptyFile() throws Exception { + Path rootDir = Files.createTempDirectory("infer-env-script"); + try { + Path runtimeDir = Files.createDirectories(rootDir.resolve("runtime")); + Files.createDirectories(rootDir.resolve("inferFiles")); + Path fakeBinDir = Files.createDirectories(rootDir.resolve("bin")); + Path fakeInstallerSource = rootDir.resolve("fake-miniconda.sh"); + Files.write(fakeInstallerSource, buildFakeInstaller().getBytes(UTF_8)); + makeExecutable(fakeInstallerSource); + + Path fakeWget = fakeBinDir.resolve("wget"); + Files.write(fakeWget, buildFakeWget().getBytes(UTF_8)); + makeExecutable(fakeWget); + + Path scriptPath = runtimeDir.resolve("install-infer-env.sh"); + try (InputStream inputStream = InferEnvironmentInstallScriptTest.class + .getResourceAsStream("/infer/env/install-infer-env.sh")) { + if (inputStream == null) { + throw new IOException("cannot load install-infer-env.sh from classpath"); + } + Files.copy(inputStream, scriptPath); + } + makeExecutable(scriptPath); + + Path missingRequirements = rootDir.resolve("missing-requirements.txt"); + + ProcessBuilder processBuilder = new ProcessBuilder( + "bash", + scriptPath.toAbsolutePath().toString(), + runtimeDir.toAbsolutePath().toString(), + missingRequirements.toAbsolutePath().toString(), + fakeInstallerSource.toUri().toString()); + processBuilder.environment().put("PATH", + fakeBinDir.toAbsolutePath() + System.getProperty("path.separator") + + processBuilder.environment().get("PATH")); + processBuilder.redirectErrorStream(true); + + Process process = processBuilder.start(); + String output; + try (InputStream inputStream = process.getInputStream()) { + output = new String(readAllBytes(inputStream), UTF_8); + } + int exitCode = process.waitFor(); + + assertEquals(exitCode, 0, output); + assertTrue(Files.isRegularFile(runtimeDir.resolve("miniconda.sh")), output); + assertTrue(Files.size(runtimeDir.resolve("miniconda.sh")) > 0, output); + assertTrue(Files.isRegularFile(runtimeDir.resolve("conda/bin/python3")), output); + } finally { + deleteRecursively(rootDir); + } + } + + private String buildFakeInstaller() { + return "#!/bin/sh\n" + + "prefix=\"\"\n" + + "while [ \"$#\" -gt 0 ]; do\n" + + " if [ \"$1\" = \"-p\" ]; then\n" + + " shift\n" + + " prefix=\"$1\"\n" + + " fi\n" + + " shift\n" + + "done\n" + + "mkdir -p \"$prefix/bin\"\n" + + "cat > \"$prefix/bin/python3\" <<'EOF'\n" + + "#!/bin/sh\n" + + "exit 0\n" + + "EOF\n" + + "chmod +x \"$prefix/bin/python3\"\n"; + } + + private String buildFakeWget() { + return "#!/bin/sh\n" + + "src=\"$1\"\n" + + "out=\"\"\n" + + "while [ \"$#\" -gt 0 ]; do\n" + + " if [ \"$1\" = \"-O\" ]; then\n" + + " shift\n" + + " out=\"$1\"\n" + + " fi\n" + + " shift\n" + + "done\n" + + "src=\"${src#file://}\"\n" + + "cp \"$src\" \"$out\"\n"; + } + + private void makeExecutable(Path file) throws IOException { + Set permissions = EnumSet.of( + PosixFilePermission.OWNER_READ, + PosixFilePermission.OWNER_WRITE, + PosixFilePermission.OWNER_EXECUTE, + PosixFilePermission.GROUP_READ, + PosixFilePermission.GROUP_EXECUTE, + PosixFilePermission.OTHERS_READ, + PosixFilePermission.OTHERS_EXECUTE); + try { + Files.setPosixFilePermissions(file, permissions); + } catch (UnsupportedOperationException e) { + file.toFile().setExecutable(true, false); + } + } + + private void deleteRecursively(Path root) throws IOException { + if (root == null || !Files.exists(root)) { + return; + } + try (Stream paths = Files.walk(root)) { + paths.sorted((left, right) -> right.compareTo(left)) + .forEach(path -> { + try { + Files.deleteIfExists(path); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + } + + private byte[] readAllBytes(InputStream inputStream) throws IOException { + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + byte[] buffer = new byte[1024]; + int length; + while ((length = inputStream.read(buffer)) >= 0) { + outputStream.write(buffer, 0, length); + } + return outputStream.toByteArray(); + } +} diff --git a/geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/InferSessionCompatibilityTest.java b/geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/InferSessionCompatibilityTest.java new file mode 100644 index 000000000..26d987d05 --- /dev/null +++ b/geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/InferSessionCompatibilityTest.java @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.infer; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.testng.Assert.assertEquals; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Stream; +import org.testng.SkipException; +import org.testng.annotations.Test; + +public class InferSessionCompatibilityTest { + + @Test + public void testTransformContractRunsThroughInferSession() throws Exception { + String output = runInferSessionScript( + "import json\n" + + "from inferSession import TorchInferSession\n" + + "\n" + + "class Transform(object):\n" + + " input_size = 1\n" + + "\n" + + " def __init__(self):\n" + + " self.loaded = False\n" + + "\n" + + " def load_model(self, model_path):\n" + + " self.loaded = True\n" + + " def model(value):\n" + + " return {'loaded': self.loaded, 'modeled': value * 2}\n" + + " return model\n" + + "\n" + + " def transform_pre(self, value):\n" + + " return (value + 1,)\n" + + "\n" + + " def transform_post(self, res):\n" + + " return {'loaded': res['loaded'], 'value': res['modeled'], 'post_processed': True}\n" + + "\n" + + "print(json.dumps(TorchInferSession(Transform()).run(41), sort_keys=True))\n"); + + assertEquals(output, "{\"loaded\": true, \"post_processed\": true, \"value\": 84}"); + } + + @Test + public void testTransformContractSupportsLegacyModeWhenLoadModelReturnsNone() throws Exception { + String output = runInferSessionScript( + "import json\n" + + "from inferSession import TorchInferSession\n" + + "\n" + + "class Transform(object):\n" + + " input_size = 1\n" + + "\n" + + " def load_model(self, model_path):\n" + + " return None\n" + + "\n" + + " def transform_pre(self, value):\n" + + " return {'legacy': value + 1}\n" + + "\n" + + " def transform_post(self, res):\n" + + " return {'mode': 'legacy', 'value': res['legacy']}\n" + + "\n" + + "print(json.dumps(TorchInferSession(Transform()).run(41), sort_keys=True))\n"); + + assertEquals(output, "{\"mode\": \"legacy\", \"value\": 42}"); + } + + @Test + public void testTransformContractRejectsNonCallableNonLegacyModel() throws Exception { + String output = runInferSessionScript( + "from inferSession import TorchInferSession\n" + + "\n" + + "class Transform(object):\n" + + " input_size = 1\n" + + "\n" + + " def load_model(self, model_path):\n" + + " return object()\n" + + "\n" + + " def transform_pre(self, value):\n" + + " return (value,)\n" + + "\n" + + " def transform_post(self, res):\n" + + " return res\n" + + "\n" + + "try:\n" + + " TorchInferSession(Transform())\n" + + "except RuntimeError as e:\n" + + " print(str(e))\n"); + + assertEquals(output, "load_model(model_path) must return a callable model"); + } + + @Test + public void testTransformContractExpandsTupleArgsToModel() throws Exception { + String output = runInferSessionScript( + "import json\n" + + "from inferSession import TorchInferSession\n" + + "\n" + + "class Transform(object):\n" + + " input_size = 1\n" + + "\n" + + " def load_model(self, model_path):\n" + + " def model(left, right):\n" + + " return {'sum': left + right}\n" + + " return model\n" + + "\n" + + " def transform_pre(self, value):\n" + + " return (value, value + 2)\n" + + "\n" + + " def transform_post(self, res):\n" + + " return {'value': res['sum']}\n" + + "\n" + + "print(json.dumps(TorchInferSession(Transform()).run(40), sort_keys=True))\n"); + + assertEquals(output, "{\"value\": 82}"); + } + + @Test + public void testTransformContractSupportsListModelInputAndListPostPayload() throws Exception { + String output = runInferSessionScript( + "import json\n" + + "from inferSession import TorchInferSession\n" + + "\n" + + "class Transform(object):\n" + + " input_size = 1\n" + + "\n" + + " def load_model(self, model_path):\n" + + " def model(left, right):\n" + + " return [left, right, left + right]\n" + + " return model\n" + + "\n" + + " def transform_pre(self, value):\n" + + " return [value, value + 1]\n" + + "\n" + + " def transform_post(self, res):\n" + + " return res\n" + + "\n" + + "print(json.dumps(TorchInferSession(Transform()).run(5)))\n"); + + assertEquals(output, "[5, 6, 11]"); + } + + @Test + public void testTransformContractSurfacesLoadModelExceptionMessage() throws Exception { + String output = runInferSessionScript( + "from inferSession import TorchInferSession\n" + + "\n" + + "class Transform(object):\n" + + " input_size = 1\n" + + "\n" + + " def load_model(self, model_path):\n" + + " raise RuntimeError('boom from load_model')\n" + + "\n" + + " def transform_pre(self, value):\n" + + " return (value,)\n" + + "\n" + + " def transform_post(self, res):\n" + + " return res\n" + + "\n" + + "try:\n" + + " TorchInferSession(Transform())\n" + + "except RuntimeError as e:\n" + + " print(str(e))\n"); + + assertEquals(output, "boom from load_model"); + } + + private String runInferSessionScript(String scriptBody) throws Exception { + ensurePythonAvailable(); + Path tempDir = Files.createTempDirectory("infer-session-test"); + try { + copyInferSessionResource(tempDir); + Files.write(tempDir.resolve("torch.py"), + Arrays.asList("def set_num_threads(_):", " return None"), UTF_8); + Files.write(tempDir.resolve("runner.py"), scriptBody.getBytes(UTF_8)); + + Process process = new ProcessBuilder("python3", "runner.py") + .directory(tempDir.toFile()) + .redirectErrorStream(true) + .start(); + byte[] outputBytes; + try (InputStream inputStream = process.getInputStream()) { + outputBytes = readAllBytes(inputStream); + } + int exitCode = process.waitFor(); + String output = new String(outputBytes, UTF_8).trim(); + assertEquals(exitCode, 0, output); + return output; + } finally { + deleteRecursively(tempDir); + } + } + + private void ensurePythonAvailable() throws Exception { + try { + Process process = new ProcessBuilder("python3", "--version") + .redirectErrorStream(true) + .start(); + try (InputStream inputStream = process.getInputStream()) { + readAllBytes(inputStream); + } + if (process.waitFor() != 0) { + throw new SkipException("python3 is unavailable for infer session compatibility test"); + } + } catch (IOException e) { + throw new SkipException("python3 is unavailable for infer session compatibility test", e); + } + } + + private void copyInferSessionResource(Path tempDir) throws IOException { + try (InputStream inputStream = InferSessionCompatibilityTest.class + .getResourceAsStream("/infer/inferRuntime/inferSession.py")) { + if (inputStream == null) { + throw new IOException("cannot load inferSession.py from classpath"); + } + Files.copy(inputStream, tempDir.resolve("inferSession.py")); + } + } + + private void deleteRecursively(Path root) throws IOException { + if (root == null || !Files.exists(root)) { + return; + } + List paths = new ArrayList(); + try (Stream stream = Files.walk(root)) { + for (Path path : (Iterable) stream::iterator) { + paths.add(path); + } + } + Collections.sort(paths, Collections.reverseOrder()); + for (Path path : paths) { + Files.deleteIfExists(path); + } + } + + private byte[] readAllBytes(InputStream inputStream) throws IOException { + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + byte[] buffer = new byte[1024]; + int length; + while ((length = inputStream.read(buffer)) >= 0) { + outputStream.write(buffer, 0, length); + } + return outputStream.toByteArray(); + } +} diff --git a/geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/exchange/DataExchangeQueueLifecycleTest.java b/geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/exchange/DataExchangeQueueLifecycleTest.java new file mode 100644 index 000000000..c5e139906 --- /dev/null +++ b/geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/exchange/DataExchangeQueueLifecycleTest.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.infer.exchange; + +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +import java.lang.reflect.Field; +import java.util.concurrent.atomic.AtomicBoolean; +import org.testng.annotations.Test; + +public class DataExchangeQueueLifecycleTest { + + @Test + public void testNewQueueCanStillMarkFinishedAfterPreviousQueueWasClosed() throws Exception { + DataExchangeQueue previousQueue = newQueue(); + DataExchangeQueue nextQueue = newQueue(); + try { + previousQueue.close(); + nextQueue.markFinished(); + + assertTrue(nextQueue.enableFinished(), + "markFinished() should be local to the queue instance, not globally disabled"); + } finally { + freeQueueMemory(previousQueue); + freeQueueMemory(nextQueue); + } + } + + @Test + public void testMarkFinishedIsNoOpAfterQueueClosed() throws Exception { + DataExchangeQueue queue = newQueue(); + try { + queue.close(); + queue.markFinished(); + + assertFalse(queue.enableFinished(), + "markFinished() should be ignored after the queue instance is closed"); + } finally { + freeQueueMemory(queue); + } + } + + @Test + public void testCloseIsIdempotentPerInstance() throws Exception { + DataExchangeQueue queue = newQueue(); + try { + queue.close(); + queue.close(); + assertTrue(((AtomicBoolean) getObjectField(queue, "closed")).get(), + "close() should mark the queue instance as closed"); + } finally { + freeQueueMemory(queue); + } + } + + private DataExchangeQueue newQueue() throws Exception { + DataExchangeQueue queue = (DataExchangeQueue) UnSafeUtils.UNSAFE + .allocateInstance(DataExchangeQueue.class); + long endPointAddress = UnSafeUtils.UNSAFE.allocateMemory(Long.BYTES); + UnSafeUtils.UNSAFE.putLong(endPointAddress, 0L); + setLongField(queue, "endPointAddress", endPointAddress); + setLongField(queue, "mapAddress", 0L); + setObjectField(queue, "closed", new AtomicBoolean(false)); + setObjectField(queue, "memoryMapper", null); + return queue; + } + + private void freeQueueMemory(DataExchangeQueue queue) throws Exception { + long endPointAddress = (Long) getObjectField(queue, "endPointAddress"); + if (endPointAddress != 0L) { + UnSafeUtils.UNSAFE.freeMemory(endPointAddress); + setLongField(queue, "endPointAddress", 0L); + } + long mapAddress = (Long) getObjectField(queue, "mapAddress"); + if (mapAddress != 0L) { + UnSafeUtils.UNSAFE.freeMemory(mapAddress); + setLongField(queue, "mapAddress", 0L); + } + } + + private void setObjectField(Object target, String fieldName, Object value) throws Exception { + Field field = DataExchangeQueue.class.getDeclaredField(fieldName); + long fieldOffset = UnSafeUtils.UNSAFE.objectFieldOffset(field); + UnSafeUtils.UNSAFE.putObject(target, fieldOffset, value); + } + + private void setLongField(Object target, String fieldName, long value) throws Exception { + Field field = DataExchangeQueue.class.getDeclaredField(fieldName); + long fieldOffset = UnSafeUtils.UNSAFE.objectFieldOffset(field); + UnSafeUtils.UNSAFE.putLong(target, fieldOffset, value); + } + + private Object getObjectField(Object target, String fieldName) throws Exception { + Field field = DataExchangeQueue.class.getDeclaredField(fieldName); + long fieldOffset = UnSafeUtils.UNSAFE.objectFieldOffset(field); + if (field.getType() == long.class) { + return UnSafeUtils.UNSAFE.getLong(target, fieldOffset); + } + return UnSafeUtils.UNSAFE.getObject(target, fieldOffset); + } +} diff --git a/geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/util/InferFileUtilsTest.java b/geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/util/InferFileUtilsTest.java new file mode 100644 index 000000000..66ea9847c --- /dev/null +++ b/geaflow/geaflow-infer/src/test/java/org/apache/geaflow/infer/util/InferFileUtilsTest.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.infer.util; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +import java.io.File; +import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Stream; +import org.testng.annotations.Test; + +public class InferFileUtilsTest { + + @Test + public void testShouldExtractInferResourceRejectsTraversalPath() throws Exception { + assertTrue(invokeShouldExtractInferResource("pkg/infer.py")); + assertTrue(invokeShouldExtractInferResource("nested/model.pt")); + assertTrue(invokeShouldExtractInferResource("pkg/sub/model.pt")); + assertFalse(invokeShouldExtractInferResource("../evil.py")); + assertFalse(invokeShouldExtractInferResource("..\\evil.py")); + assertFalse(invokeShouldExtractInferResource("nested/../../evil.py")); + assertFalse(invokeShouldExtractInferResource("nested/data.json")); + } + + @Test + public void testBuildSafeTargetFileKeepsNestedPathUnderTargetDirectory() throws Exception { + Path targetDirectory = Files.createTempDirectory("infer-file-utils"); + try { + File targetFile = invokeBuildSafeTargetFile(targetDirectory.toString(), "pkg/model.pt"); + assertEquals(targetFile.toPath(), + targetDirectory.resolve("pkg/model.pt").toAbsolutePath().normalize()); + } finally { + deleteRecursively(targetDirectory); + } + } + + @Test(expectedExceptions = IOException.class, + expectedExceptionsMessageRegExp = "illegal infer resource path: .*") + public void testBuildSafeTargetFileRejectsPathTraversal() throws Throwable { + Path targetDirectory = Files.createTempDirectory("infer-file-utils"); + try { + invokeBuildSafeTargetFile(targetDirectory.toString(), "../evil.py"); + } catch (InvocationTargetException e) { + throw e.getCause(); + } finally { + deleteRecursively(targetDirectory); + } + } + + @Test(expectedExceptions = IOException.class, + expectedExceptionsMessageRegExp = "illegal infer resource path: .*") + public void testBuildSafeTargetFileRejectsAbsolutePath() throws Throwable { + Path targetDirectory = Files.createTempDirectory("infer-file-utils"); + try { + invokeBuildSafeTargetFile(targetDirectory.toString(), "/tmp/evil.py"); + } catch (InvocationTargetException e) { + throw e.getCause(); + } finally { + deleteRecursively(targetDirectory); + } + } + + @Test + public void testBuildSafeTargetFileAllowsDeepNestedPath() throws Exception { + Path targetDirectory = Files.createTempDirectory("infer-file-utils"); + try { + File targetFile = invokeBuildSafeTargetFile(targetDirectory.toString(), + "pkg/sub/model.pt"); + assertEquals(targetFile.toPath(), + targetDirectory.resolve("pkg/sub/model.pt").toAbsolutePath().normalize()); + } finally { + deleteRecursively(targetDirectory); + } + } + + private static boolean invokeShouldExtractInferResource(String fileName) throws Exception { + Method method = InferFileUtils.class.getDeclaredMethod("shouldExtractInferResource", + String.class); + method.setAccessible(true); + return (boolean) method.invoke(null, fileName); + } + + private static File invokeBuildSafeTargetFile(String targetDirectory, String fileName) + throws Exception { + Method method = InferFileUtils.class.getDeclaredMethod("buildSafeTargetFile", + String.class, String.class); + method.setAccessible(true); + return (File) method.invoke(null, targetDirectory, fileName); + } + + private static void deleteRecursively(Path root) throws IOException { + if (root == null || !Files.exists(root)) { + return; + } + List paths = new ArrayList(); + try (Stream stream = Files.walk(root)) { + for (Path path : (Iterable) stream::iterator) { + paths.add(path); + } + } + Collections.sort(paths, Collections.reverseOrder()); + for (Path path : paths) { + Files.deleteIfExists(path); + } + } +}