# GraphSAGE and the GDS model catalog

In [1]:
# update PYTHONPATH to be able to load utils module
import sys
import os
sys.path.append(os.path.join(os.curdir, ".."))
# sys.path

In [2]:
from graphdatascience import GraphDataScience
gds = GraphDataScience("bolt://localhost:7687", auth=("neo4j", "admin123"))

In [3]:
from utils import create_projected_graph
projected_graph_object = create_projected_graph(
    gds,
    graph_name="pGraphAll",
    node_spec="Train",
    relationship_spec={
        "RELATED_TO": {"orientation": "UNDIRECTED"}
    },
    nodeProperties=["x0"],
)
projected_graph_object

Graph({'graphName': 'pGraphAll', 'nodeCount': 657, 'relationshipCount': 28694, 'database': 'neo4j', 'configuration': {}, 'schema': {'graphProperties': {}, 'relationships': {'RELATED_TO': {}}, 'nodes': {'Train': {'x0': 'Float (DefaultValue(NaN), PERSISTENT)'}}}, 'memoryUsage': '347 KiB'})

In [4]:
projected_graph_object.node_labels()

['Train']

In [5]:
projected_graph_object.node_count()

657

In [6]:
projected_graph_object.node_properties()

Train    [x0]
dtype: object

In [7]:
r = gds.graph.nodeProperties.stream(projected_graph_object, ["x0"])
r.propertyValue.isnull().sum()

0

In [8]:
model_name = "myGS"
try:
    m = gds.model.get(model_name)
    gds.beta.model.drop(m)
except ValueError:
    pass

In [9]:
model_object, res = gds.beta.graphSage.train(
    projected_graph_object,
    modelName=model_name,
    featureProperties=["x0"],
    # nodeLabels=["Train"],
    # relationshipTypes=["RELATED_TO_TRAIN"],
    learningRate=0.0001
)
res

modelInfo        {'modelName': 'myGS', 'modelType': 'graphSage'...
configuration    {'maxIterations': 10, 'negativeSampleWeight': ...
trainMillis                                                     46
Name: 0, dtype: object

In [10]:
gds.beta.model.list()

Unnamed: 0,modelInfo,trainConfig,graphSchema,loaded,stored,creationTime,shared
0,"{'modelName': 'myGS', 'modelType': 'graphSage'...","{'maxIterations': 10, 'negativeSampleWeight': ...","{'graphProperties': {}, 'relationships': {'REL...",True,False,2023-01-04T21:05:20.544807241+01:00,False


In [11]:
res["modelInfo"]

{'modelName': 'myGS',
 'modelType': 'graphSage',
 'metrics': {'ranIterationsPerEpoch': [2],
  'iterationLossesPerEpoch': [[26.578488557733984, 26.578488141015168]],
  'didConverge': True,
  'ranEpochs': 1,
  'epochLosses': [26.578488141015168]}}

In [12]:
res["configuration"]

{'maxIterations': 10,
 'negativeSampleWeight': 20,
 'searchDepth': 5,
 'aggregator': 'MEAN',
 'activationFunction': 'SIGMOID',
 'penaltyL2': 0.0,
 'learningRate': 0.0001,
 'concurrency': 4,
 'jobId': '859e0db6-0562-4892-bfef-373be9448c40',
 'modelName': 'myGS',
 'embeddingDimension': 64,
 'nodeLabels': ['*'],
 'sudo': False,
 'featureProperties': ['x0'],
 'sampleSizes': [25, 10],
 'relationshipTypes': ['*'],
 'batchSize': 100,
 'epochs': 1,
 'tolerance': 0.0001}

In [13]:
graph_object = create_projected_graph(
    gds,
    "pgraphTest",
    node_spec="Test",
    relationship_spec={
        "RELATED_TO": {"orientation": "UNDIRECTED"}
    },
    nodeProperties=["x0"]
)
graph_object

Graph({'graphName': 'pgraphTest', 'nodeCount': 143, 'relationshipCount': 1148, 'database': 'neo4j', 'configuration': {}, 'schema': {'graphProperties': {}, 'relationships': {'RELATED_TO': {}}, 'nodes': {'Test': {'x0': 'Float (DefaultValue(NaN), PERSISTENT)'}}}, 'memoryUsage': '337 KiB'})

In [14]:
embeddings = gds.beta.graphSage.stream(
    graph_object,
    modelName=model_name,
)
embeddings.head()

Unnamed: 0,nodeId,embedding
0,0,"[0.12822204244043253, 0.14201885703254316, 0.1..."
1,1,"[0.12822966528605215, 0.14197938554132972, 0.1..."
2,4,"[0.12818660338318555, 0.14220192182019373, 0.1..."
3,6,"[0.12827174843249659, 0.14176119902706652, 0.1..."
4,11,"[0.1283371496022736, 0.14142064097600568, 0.14..."


In [15]:
embeddings.loc[0]["embedding"]

[0.12822204244043253,
 0.14201885703254316,
 0.14337842580713483,
 0.10906009836638815,
 0.11329364502908318,
 0.09775632106885747,
 0.1358488616673177,
 0.14166659481229685,
 0.10273297060169308,
 0.12046625581163048,
 0.11136253475926539,
 0.11215539171667141,
 0.1374422012228958,
 0.13201772169432527,
 0.14472751600143516,
 0.11861066676266395,
 0.12206160435195654,
 0.1650379071212499,
 0.13096667516510924,
 0.09250742508903725,
 0.10666508855894363,
 0.1415336648663157,
 0.10894266689338013,
 0.11317287233183546,
 0.1261122546802996,
 0.12183432749609334,
 0.11375221876827496,
 0.10364208754779175,
 0.11375195913094087,
 0.1528732130841031,
 0.12512127359261058,
 0.12266410571365242,
 0.10405246008789001,
 0.13438465693732068,
 0.11459772956569617,
 0.11531091947213526,
 0.1358971218288696,
 0.14749963285394108,
 0.14553271505099458,
 0.13367663366531965,
 0.10787394515056332,
 0.11556337567707373,
 0.12026832462762913,
 0.12062729597897467,
 0.10255079547558363,
 0.14163528337820