In [2]:
import pandas as pd
import numpy as np
import gc
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
from torch.optim import Adam
import networkx as nx

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### GraphSAGE load

In [2]:
#load data
classes = pd.read_csv('Elliptic++Dataset/txs_classes.csv')
features = pd.read_csv('Elliptic++Dataset/txs_features.csv')
edges = pd.read_csv('Elliptic++Dataset/txs_edgelist.csv')

#Convert classes to numerical format
classes["class"] -= 1
features["class"] = classes["class"]
# class_map = {"unknown": 2, "licit": 0, "illicit": 1}
# classes["class"] = classes["class"].map(class_map)

# #extract IDs, features and labels
features = features.dropna()
x = torch.as_tensor(features.iloc[:, 1:].drop("class", axis=1).values.astype(np.float32))
y = torch.as_tensor(features["class"].values)

#edges to tensor
edge_index = torch.as_tensor(edges.values.T)

#create data object
data = Data(x=x, edge_index=edge_index, y=y).to(device)
print(data)

Data(x=[202804, 183], edge_index=[2, 234355], y=[202804])


In [6]:
class GraphSAGE(torch.nn.Module):
  def __init__(self, in_channels, hidden_channels, out_channels):
    super(GraphSAGE, self).__init__()
    self.conv1 = SAGEConv(in_channels, hidden_channels)
    self.conv2 = SAGEConv(hidden_channels, out_channels)

  def forward(self, x, edge_index):
    x = self.conv1(x, edge_index)
    x = F.relu(x)
    x = self.conv2(x, edge_index)
    return F.log_softmax(x, dim=1)

### Setup GraphSAGE for Elliptic

In [4]:
from data_processing import *

data, OUT_DIM = load_create_ellipticpp(timestep=(1,5))

In [12]:
from tqdm import tqdm

#dataLoader for neighborhood sampling
# train_mask = torch.zeros(data.num_nodes,dtype=torch.bool)
# train_mask[:int(0.7*data.num_nodes)]= True
# data.train_mask = train_mask


model = GraphSAGE(in_channels=data.num_features, hidden_channels = 32, out_channels = OUT_DIM).to(device)
train_loader = NeighborLoader(data=data, num_neighbors = [10,5], batch_size = 1000)
# #define optimizer and loss function
optimizer = Adam(model.parameters(), lr= 0.005, weight_decay = 5e-4)
loss_fn = torch.nn.CrossEntropyLoss()

def train(iterator):
  model.train()
  total_loss = 0
    
  iterator = tqdm(train_loader, )
  for idx, batch in enumerate(train_loader):
    optimizer.zero_grad()
    out = model(batch.x, batch.edge_index)
    # print(f"Output Shape: {out.shape}, Label Shape: {batch.y.shape}")

    # if torch.isnan(out).any():
    #     print("warning:NaN detected in output")
    #     return total_loss
        
    mask = batch.train_mask
    loss = loss_fn(out[mask], batch.y[mask])
      
    loss.backward()
    optimizer.step()
    total_loss += loss.item()
    iterator.set_description(f"Batch {idx+1}{data.size()[0]//1000}")
  return total_loss 

iterator = tqdm(range(20), desc="")
for idx, epoch in enumerate(iterator):
  loss = train(iterator)
  print(f"Epoch {idx+1}/{20} - Loss: {loss:.4f}")

Batch 190189:   0%|          | 0/190 [00:00<?, ?it/s]
  5%|▌         | 1/20 [00:00<00:15,  1.21it/s]

Epoch 1/20 - Loss: 101540.5549


Batch 190189:   0%|          | 0/190 [00:00<?, ?it/s]
 10%|█         | 2/20 [00:01<00:15,  1.19it/s]

Epoch 2/20 - Loss: 48146.0492


Batch 190189:   0%|          | 0/190 [00:00<?, ?it/s]
 15%|█▌        | 3/20 [00:02<00:14,  1.21it/s]

Epoch 3/20 - Loss: 17732.3772


Batch 190189:   0%|          | 0/190 [00:00<?, ?it/s]
 20%|██        | 4/20 [00:03<00:13,  1.21it/s]

Epoch 4/20 - Loss: 18988.9323


Batch 190189:   0%|          | 0/190 [00:00<?, ?it/s]
 25%|██▌       | 5/20 [00:04<00:12,  1.21it/s]

Epoch 5/20 - Loss: 15092.2048


Batch 190189:   0%|          | 0/190 [00:00<?, ?it/s]
 30%|███       | 6/20 [00:04<00:11,  1.20it/s]

Epoch 6/20 - Loss: 2513.6064


Batch 190189:   0%|          | 0/190 [00:00<?, ?it/s]
 35%|███▌      | 7/20 [00:05<00:10,  1.20it/s]

Epoch 7/20 - Loss: 1773.0598


Batch 190189:   0%|          | 0/190 [00:00<?, ?it/s]
 40%|████      | 8/20 [00:06<00:09,  1.21it/s]

Epoch 8/20 - Loss: 2488.5583


Batch 190189:   0%|          | 0/190 [00:00<?, ?it/s]
 45%|████▌     | 9/20 [00:07<00:09,  1.16it/s]

Epoch 9/20 - Loss: 630.1238


Batch 190189:   0%|          | 0/190 [00:00<?, ?it/s]
 50%|█████     | 10/20 [00:08<00:08,  1.14it/s]

Epoch 10/20 - Loss: 289.7612


Batch 190189:   0%|          | 0/190 [00:00<?, ?it/s]
 55%|█████▌    | 11/20 [00:09<00:08,  1.10it/s]

Epoch 11/20 - Loss: 256.5485


Batch 190189:   0%|          | 0/190 [00:00<?, ?it/s]
 60%|██████    | 12/20 [00:10<00:07,  1.10it/s]

Epoch 12/20 - Loss: 25.2098


Batch 190189:   0%|          | 0/190 [00:00<?, ?it/s]
 65%|██████▌   | 13/20 [00:11<00:06,  1.11it/s]

Epoch 13/20 - Loss: 15.7714


Batch 190189:   0%|          | 0/190 [00:00<?, ?it/s]
 70%|███████   | 14/20 [00:12<00:05,  1.13it/s]

Epoch 14/20 - Loss: 16.9325


Batch 190189:   0%|          | 0/190 [00:00<?, ?it/s]
 75%|███████▌  | 15/20 [00:12<00:04,  1.13it/s]

Epoch 15/20 - Loss: 17.7039


Batch 190189:   0%|          | 0/190 [00:00<?, ?it/s]
 80%|████████  | 16/20 [00:13<00:03,  1.15it/s]

Epoch 16/20 - Loss: 15.6924


Batch 190189:   0%|          | 0/190 [00:00<?, ?it/s]
 85%|████████▌ | 17/20 [00:14<00:02,  1.13it/s]

Epoch 17/20 - Loss: 15.6735


Batch 190189:   0%|          | 0/190 [00:00<?, ?it/s]
 90%|█████████ | 18/20 [00:15<00:01,  1.13it/s]

Epoch 18/20 - Loss: 13.2690


Batch 190189:   0%|          | 0/190 [00:00<?, ?it/s]
 95%|█████████▌| 19/20 [00:16<00:00,  1.15it/s]

Epoch 19/20 - Loss: 11.8873


Batch 190189:   0%|          | 0/190 [00:00<?, ?it/s]
100%|██████████| 20/20 [00:17<00:00,  1.15it/s]


Epoch 20/20 - Loss: 11.8602
