-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrdflib_example.py
46 lines (35 loc) · 1.92 KB
/
rdflib_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from graphflex import GraphFlex
from graphflex.connectors.sparql import RDFLibConnector
from graphflex.functions.edgenode import NumericalEdgeNode
from graphflex.functions.postprocessing.filter import NonUniqueFeatureFilter
from sklearn.ensemble import ExtraTreesClassifier
connector = RDFLibConnector('../data/animals/animals.owl', 'xml')
pos = ["http://dl-learner.org/benchmark/dataset/animals#dog01",
"http://dl-learner.org/benchmark/dataset/animals#dolphin01",
"http://dl-learner.org/benchmark/dataset/animals#platypus01",
"http://dl-learner.org/benchmark/dataset/animals#bat01"]
neg = ["http://dl-learner.org/benchmark/dataset/animals#trout01",
"http://dl-learner.org/benchmark/dataset/animals#herring01",
"http://dl-learner.org/benchmark/dataset/animals#shark01",
"http://dl-learner.org/benchmark/dataset/animals#lizard01",
"http://dl-learner.org/benchmark/dataset/animals#croco01",
"http://dl-learner.org/benchmark/dataset/animals#trex01",
"http://dl-learner.org/benchmark/dataset/animals#turtle01",
"http://dl-learner.org/benchmark/dataset/animals#eagle01",
"http://dl-learner.org/benchmark/dataset/animals#ostrich01",
"http://dl-learner.org/benchmark/dataset/animals#penguin01"]
nodes = pos + neg
labels = [1 for _ in range(len(pos))] + [0 for _ in range(len(neg))]
gflex = GraphFlex(connector,
max_depth=10,
edge_node_feature=NumericalEdgeNode(),
post_processor=NonUniqueFeatureFilter(),
n_jobs=4,
verbose=True)
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(nodes,labels,test_size=0.5, stratify=labels,random_state=42)
train_matrix = gflex.fit_transform(X_train)
print(train_matrix.shape)
clf = ExtraTreesClassifier(n_estimators=100)
clf.fit(train_matrix, y_train)
print(clf.score(gflex.transform(X_test), y_test))