# Can Abnormality be Detected by Graph Neural Networks?
Ziwei Chai , Siqi You , Yang Yang , Shiliang Pu , Jiarong Xu , Haoyang Cai and Weihao Jiang

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import networkx as nx
import plotly.graph_objs as go

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch.nn import Parameter
from torch_geometric.utils import remove_self_loops, get_laplacian
from numpy import polynomial
import math

In [None]:
df_classes = pd.read_csv("./dataset/elliptic_csv/elliptic_txs_classes.csv")
df_edges = pd.read_csv("./dataset/elliptic_csv/elliptic_txs_edgelist.csv")
df_features = pd.read_csv("./dataset/elliptic_csv/elliptic_txs_features.csv", header=None)

In [None]:
df_classes

In [None]:
df_edges

In [None]:
print('Shape of classes', df_classes.shape)
print('Shape of edges', df_edges.shape)
print('Shape of features', df_features.shape)

In [None]:
tx_features = ["tx_feat_"+str(i) for i in range(2,95)]
agg_features = ["agg_feat_"+str(i) for i in range(1,73)]
df_features.columns = ["txId","Time_step"] + tx_features + agg_features
df_features = pd.merge(df_features,df_classes,left_on="txId",right_on="txId",how='left')
df_features['class'] = df_features['class'].apply(lambda x: '0' if x == "unknown" else x)

In [None]:
df_features

In [None]:
group_class = df_classes.groupby('class').count()
plt.barh(['Ilicit', 'Licit', 'Unknown'], group_class['txId'].values, color=['r', 'g', 'orange'] )

In [None]:
risk_counts = df_features['class'].value_counts()
plt.pie(risk_counts.values, labels=risk_counts.index, autopct='%1.1f%%')
plt.show()

In [None]:
group_feature = df_features.groupby('Time_step').count()
group_feature['txId'].plot()
plt.title('Number of transactions by Time step')

In [None]:
df_class_feature = pd.merge(df_classes, df_features )

In [None]:
group_class_feature = df_features.groupby(['Time_step', 'class']).count()
group_class_feature = group_class_feature['txId'].reset_index().rename(columns={'txId': 'count'})

In [None]:
group_class_feature.head()

In [None]:
sns.lineplot(x='Time_step', y='count', hue='class', data = group_class_feature, palette=['orange', 'r', 'g'] )

In [None]:
class1 = group_class_feature[group_class_feature['class'] == '1']
class2 = group_class_feature[group_class_feature['class'] == '2']
class3 = group_class_feature[group_class_feature['class'] == '0' ]

p1 = plt.bar(class3['Time_step'], class3['count'], color = 'orange')

p2 = plt.bar(class2['Time_step'], class2['count'], color='g', bottom=class3['count'])

p3 = plt.bar(class1['Time_step'], class1['count'], color='r', bottom=np.array(class3['count'])+np.array(class2['count']))

plt.xlabel('Time_step')

In [None]:
bad_ids = df_features[(df_features['Time_step'] == 32) & ((df_features['class'] == '1'))]['txId']
short_edges = df_edges[df_edges['txId1'].isin(bad_ids)]
graph = nx.from_pandas_edgelist(short_edges, source = 'txId1', target = 'txId2', create_using = nx.DiGraph())
pos = nx.spring_layout(graph)

edge_x = []
edge_y = []
for edge in graph.edges():
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    edge_x.append(x0)
    edge_x.append(x1)
    edge_x.append(None)
    edge_y.append(y0)
    edge_y.append(y1)
    edge_y.append(None)

edge_trace = go.Scatter(
    x=edge_x, y=edge_y,
    line=dict(width=0.5, color='blue'),
    hoverinfo='none',
    mode='lines')

node_x = []
node_y = []
node_text=[]
for node in graph.nodes():
    x, y = pos[node]
    node_x.append(x)
    node_y.append(y)
    node_text.append(node)

node_trace = go.Scatter(
    x=node_x, y=node_y,
    mode='markers',
    hoverinfo='text',
    marker=dict(
        color=[],
        size=10,
        colorbar=dict(
            thickness=15,
            title='Transaction Type',
            xanchor='left',
            titleside='right',
            tickmode='array',
            tickvals=[0,1,2],
            ticktext=['Unknown','Illicit','Licit']
        ),
        line_width=2))
node_trace.text=node_text
node_trace.marker.color = pd.to_numeric(df_features[df_features['txId'].isin(list(graph.nodes()))]['class'])

fig = go.Figure(data=[edge_trace, node_trace],
             layout=go.Layout(
                title="Illicit Transactions",
                titlefont_size=16,
                showlegend=False,
                hovermode='closest',
                margin=dict(b=20,l=5,r=5,t=40),
                annotations=[ dict(
                    showarrow=True,
                    xref="paper", yref="paper",
                    x=0.005, y=-0.002 ) ],
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                )
fig.show()

In [None]:
good_ids = df_features[(df_features['Time_step'] == 32) & ((df_features['class'] == '2'))]['txId']
short_edges = df_edges[df_edges['txId1'].isin(good_ids)]
graph = nx.from_pandas_edgelist(short_edges, source = 'txId1', target = 'txId2', create_using = nx.DiGraph())
pos = nx.spring_layout(graph)

edge_x = []
edge_y = []
for edge in graph.edges():
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    edge_x.append(x0)
    edge_x.append(x1)
    edge_x.append(None)
    edge_y.append(y0)
    edge_y.append(y1)
    edge_y.append(None)

edge_trace = go.Scatter(
    x=edge_x, y=edge_y,
    line=dict(width=0.5, color='blue'),
    hoverinfo='none',
    mode='lines')

node_x = []
node_y = []
node_text=[]
for node in graph.nodes():
    x, y = pos[node]
    node_x.append(x)
    node_y.append(y)
    node_text.append(node)

node_trace = go.Scatter(
    x=node_x, y=node_y,
    mode='markers',
    hoverinfo='text',
    marker=dict(
        color=[],
        size=10,
        colorbar=dict(
            thickness=15,
            title='Transaction Type',
            xanchor='left',
            titleside='right',
            tickmode='array',
            tickvals=[0,1,2],
            ticktext=['Unknown','Illicit','Licit']
        ),
        line_width=2))
node_trace.text=node_text
node_trace.marker.color = pd.to_numeric(df_features[df_features['txId'].isin(graph.nodes())]['class'])

fig = go.Figure(data=[edge_trace, node_trace],
                layout=go.Layout(
                title="Licit Transactions",
                titlefont_size=16,
                showlegend=False,
                hovermode='closest',
                margin=dict(b=20,l=5,r=5,t=40),
                annotations=[ dict(
                    showarrow=True,
                    xref="paper", yref="paper",
                    x=0.005, y=-0.002 ) ],
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                )
fig.show()

In [None]:
def Bernstein(de, i):
    coefficients = [0, ] * i + [math.comb(de, i)]
    first_term = polynomial.polynomial.Polynomial(coefficients)
    second_term = polynomial.polynomial.Polynomial([1, -1]) ** (de - i)
    return first_term * second_term

In [None]:
class Conv(MessagePassing):
    def __init__(self, hid_channels, K):
        super().__init__()
        self.K = K
        self.in_features = hid_channels
        self.out_features = hid_channels
        self.weight = Parameter(torch.Tensor(K + 1, 1))
        self.bias = Parameter(torch.Tensor(hid_channels))
        self.reset_parameters()

    def reset_parameters(self):
        self.bias.data.fill_(0)
        torch.nn.init.zeros_(self.weight)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def __norm__(self, edge_index, num_nodes, e_weight, l_max, dtype=None):
        edge_index, e_weight = remove_self_loops(edge_index, e_weight)
        edge_index, e_weight = get_laplacian(edge_index, e_weight, 'sym', dtype, num_nodes)

        e_weight = e_weight / l_max
        e_weight.masked_fill_(e_weight == float('inf'), 0)
        return edge_index, e_weight

    def forward(self, x, edge_index, e_weight=None, l_max=None):
        if l_max is None:
            l_max = torch.tensor(2.0, dtype=x.dtype, device=x.device)
        if not isinstance(l_max, torch.Tensor):
            l_max = torch.tensor(l_max, dtype=x.dtype, device=x.device)

        edge_index, norm = self.__norm__(edge_index, x.size(self.node_dim), e_weight, l_max, dtype=x.dtype)

        bx = [x]
        b_next = x

        for _ in range(self.K):
            b_next = self.propagate(edge_index, x=b_next, norm=norm, size=None)
            bx.append(b_next)

        b_coeffs = []
        for i in range(self.K + 1):
            b_coeffs.append(Bernstein(self.K, i).coef)
        
        weight = torch.sigmoid(self.weight)

        result = torch.zeros_like(x)
        for k in range(0, self.K + 1):
            coeff = b_coeffs[k]
            basis = bx[0] * coeff[0]
            for i in range(1, self.K + 1):
                basis += bx[i] * coeff[i]
            result += basis * weight[k]
        return result

In [None]:
class AMNet(nn.Module):
    def __init__(self, in_features, out_features, num_classes, K, num_filters):
        super().__init__()
        self.linear_transform_in = nn.Sequential(nn.Linear(in_features, out_features), nn.ReLU(), nn.Linear(out_features, out_features))
        self.K = K
        self.filters = nn.ModuleList([Conv(out_features, K) for _ in range(num_filters)])
        self.num_filters = num_filters
        self.W_f = nn.Sequential(nn.Linear(out_features, out_features), nn.Tanh())
        self.W_x = nn.Sequential(nn.Linear(out_features, out_features), nn.Tanh())
        self.linear_cls_out = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(out_features, num_classes)
        )

        self.attn = list(self.W_x.parameters()) + (list(self.W_f.parameters()))
        self.lin = list(self.linear_transform_in.parameters()) + list(self.linear_cls_out.parameters())
        self.reset_parameters()

    def forward(self, x, edge_index, label=None):
        x = self.linear_transform_in(x)
        h_list = []
        for i, _filter in enumerate(self.filters):
            h = _filter(x, edge_index)
            h_list.append(h)

        _filters = torch.stack(h_list, dim=1)
        x_proj = self.W_x(x).unsqueeze(-1)

        score = F.softmax(torch.bmm(self.W_f(_filters), x_proj), dim=1)

        out = _filters[:, 0, :] * score[:, 0]
        for i in range(1, self.num_filters):
            out += _filters[:, i, :] * score[:, i]

        y_pred = self.linear_cls_out(out)
        margin_loss = 0.

        if self.training:
            abnormal_train, normal_train = label
            normal_bias = torch.mean(torch.clamp(score[normal_train][:, 1] - score[normal_train][:, 0], -0.))
            abnormal_bias = torch.mean(torch.clamp(score[abnormal_train][:, 0] - score[abnormal_train][:, 1], -0.))
            margin_loss = abnormal_bias + normal_bias

        if self.training:
            return y_pred, margin_loss
        else:
            return y_pred
    
    def reset_parameters(self):
        pass