# Classification using FGW

In [1]:
import numpy as np
import os,sys
sys.path.append(os.path.realpath('../lib'))
from data_loader import load_local_data
from custom_svc import Graph_FGW_SVC_Classifier
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

Simple training example using FGW on the mutag dataset

In [2]:
dataset_n='mutag'
path='../data/'
X,y=load_local_data(path,dataset_n,wl=2)

We load the Mutag dataset, using the "wl" option that computes the Weisfeler-Lehman features for each nodes as shown is the notebook wl_labeling.ipynb

We create a SVM-like classifier on the precomputed matrix $K=e^{-\gamma*FGW}$.
To compute FGW we use the shortest_path distance for the structure matrices of each graph, and the so-called 'hamming'distance between their features. It is defined as 

$$d(a_{i},b_{j})=\sum_{k=0}^{wl} \delta(\tau(a_{i}^{k}),\tau(b_{j}^{k}))$$

where $\delta(x,y)=1$ if $x\neq y$ else $\delta(x,y)=0$ and $\tau(a_{i}^{k})$ denotes the concatenated label at iteration $k$ in the Weisfeler-Lehman process.

In [3]:
graph_svc=Graph_FGW_SVC_Classifier(C=1,gamma=1,alpha=0.5,method='shortest_path',features_metric='hamming_dist',wl=2)
X_train, X_test, y_train, y_test=train_test_split(X,y,test_size=0.33, random_state=42)

In [4]:
print(X_train)

[<graph.Graph object at 0x00000224567A80A0>
 <graph.Graph object at 0x00000224567B1390>
 <graph.Graph object at 0x0000022456BD0910>
 <graph.Graph object at 0x0000022435BA18A0>
 <graph.Graph object at 0x00000224567A6B60>
 <graph.Graph object at 0x00000224567C0250>
 <graph.Graph object at 0x00000224567A9EA0>
 <graph.Graph object at 0x00000224567C1420>
 <graph.Graph object at 0x00000224567AA1D0>
 <graph.Graph object at 0x0000022456BD2440>
 <graph.Graph object at 0x00000224567C1F60>
 <graph.Graph object at 0x00000224567C1BA0>
 <graph.Graph object at 0x00000224567A9F90>
 <graph.Graph object at 0x00000224567A99F0>
 <graph.Graph object at 0x00000224567C2050>
 <graph.Graph object at 0x00000224567C0BB0>
 <graph.Graph object at 0x00000224567AA7A0>
 <graph.Graph object at 0x00000224567A7D60>
 <graph.Graph object at 0x00000224567A7400>
 <graph.Graph object at 0x00000224567AB910>
 <graph.Graph object at 0x00000224567AAFB0>
 <graph.Graph object at 0x0000022456BD11E0>
 <graph.Graph object at 0x000002

In [5]:
%%time
graph_svc.fit(X_train,y_train)


Wall time: 15.5 s


In [6]:
%%time
preds=graph_svc.predict(X_test)

Wall time: 16.7 s


In [8]:
print(preds)
np.sum(preds==y_test)/len(y_test)

[ 1  1 -1  1  1  1  1 -1 -1 -1  1 -1  1 -1  1  1 -1  1  1 -1 -1  1  1  1
  1  1  1 -1  1  1  1  1  1  1  1  1  1  1 -1  1  1 -1  1  1  1  1 -1  1
  1  1 -1  1  1 -1  1  1 -1 -1 -1 -1  1  1 -1]


0.9365079365079365