# Introduction 

In this hands-on Python tutorial, we'll delve into an intriguing dataset that captures the dynamics of transactions within a blockchain network. Our goal will be to prepare this data for training machine learning models using the powerful PyTorch and PyTorch Lightning frameworks.

# The Dataset

Let's dissect the dataset's description:

* Temporal Structure: The dataset is divided into 49 distinct time steps, each spaced roughly two weeks apart. Within every time step, we find a connected group of transactions occurring within a three-hour window.
* Transaction Features: Each transaction is characterized by 94 'local' features. These include its timestamp, input/output counts, fees, volume, and interesting aggregations (e.g., average BTC involved in inputs/outputs).
* Neighborhood Features: An additional 72 'aggregated' features illuminate each transaction's context. We get statistics like the maximum, minimum, standard deviation, and correlation coefficients derived from transactions one hop away.

# Exploratory Data Analysis (EDA)

We'll begin our journey with exploratory data analysis (EDA). Key things to explore:

* Distributions: Examine the distributions of transaction features (fees, volumes, etc.) to spot patterns and potential outliers.
* Correlations: Investigate relationships between transaction features. Which features correlate, and can this insight inform our model design?
* Temporal Trends: Analyze how features change across the 49 time steps. Are there seasonal effects or evolving network behaviors?

# PyTorch Datasets

Our EDA findings will guide how we structure our PyTorch Datasets. Here's where things get exciting:

* Custom Dataset Class: We'll create a custom PyTorch Dataset class to load and preprocess the raw data dynamically during model training.
* Data Transformations: We might apply scaling, normalization, or other essential transformations to make the data more suitable for machine learning.

# PyTorch Lightning Integration

Finally, we'll leverage PyTorch Lightning to streamline our training process.

* DataModule: A PyTorch Lightning DataModule will encapsulate our Datasets, manage data loading, and handle batching for efficient model training.

# What You'll Build

By the end of this tutorial, you'll have a solid foundation for training machine learning models on this dataset. This foundation sets the stage for exciting applications such as:

* Fraud detection
* Transaction pattern analysis
* Blockchain network behavior prediction

Let's get started!


# Dataset background
The dataset description originates from [Kaggle Elliptic Data Set](https://www.kaggle.com/datasets/ellipticco/elliptic-data-set) and is restated here for convenience. 

## Dataset description
This anonymized data set is a transaction graph collected from the Bitcoin blockchain. A node in the graph represents a transaction, an edge can be viewed as a flow of Bitcoins between one transaction and the other. Each node has 166 features and has been labeled as being created by a "licit", "illicit" or "unknown" entity.

### Nodes and edges

The graph is made of 203,769 nodes and 234,355 edges. Two percent (4,545) of the nodes are labelled class1 (illicit). Twenty-one percent (42,019) are labelled class2 (licit). The remaining transactions are not labelled with regard to licit versus illicit.

### Features

There are 166 features associated with each node. Due to intellectual property issues, we cannot provide an exact description of all the features in the dataset. There is a time step associated to each node, representing a measure of the time when a transaction was broadcasted to the Bitcoin network. The time steps, running from 1 to 49, are evenly spaced with an interval of about two weeks. Each time step contains a single connected component of transactions that appeared on the blockchain within less than three hours between each other; there are no edges connecting the different time steps.

The first 94 features represent local information about the transaction – including the time step described above, number of inputs/outputs, transaction fee, output volume and aggregated figures such as average BTC received (spent) by the inputs/outputs and average number of incoming (outgoing) transactions associated with the inputs/outputs. The remaining 72 features are aggregated features, obtained using transaction information one-hop backward/forward from the center node - giving the maximum, minimum, standard deviation and correlation coefficients of the neighbour transactions for the same information data (number of inputs/outputs, transaction fee, etc.).

### Dataset files

The dataset consists of three files:
* **elliptic_txs_classes.csv:** Each node is labelled as a "licit" (0), "illicit" (1), or "unkonwn" (2) entity in the class column, the txId column is a unique identifier to the node.  
* **elliptic_txs_edgelist.csv:** A list of nodes who are connected. The file has two columns txID1 and txId2. 
* **elliptic_txs_features.csv:** A file with 171 columns with the first column the transaction id, and the other columns node features. 

For detailed statistics, please visit the Kaggle Data Explorer of the [Elliptic Data Set](https://www.kaggle.com/datasets/ellipticco/elliptic-data-set). 

# Loading the data
We use the FinTorch.datasets library to load the [Elliptic Data Set](https://www.kaggle.com/datasets/ellipticco/elliptic-data-set). The following code downloads the dataset:

In [None]:
# from fintorch.datasets import elliptic
from fintorch.datasets import elliptic

# Load the elliptic dataset
elliptic_dataset = elliptic.EllipticDataset('~/.fintorch_data', force_reload=True)

Let's discuss the code line by line:

1. **Importing:** We import the elliptic module from the fintorch.datasets package. This module provides convenient access to the Elliptic Bitcoin Dataset.

2. **Loading the Dataset:** We create an instance of the elliptic.EllipticDataset class and store it in the dataset variable. This loads the dataset from Kaggle and places it in the .fintorch_data/ directory. The fintorch framework uses with the [Kaggle API](https://github.com/Kaggle/kaggle-api) to download datasets. Make sure you've followed the instructions in the fintorch documentation to set up your Kaggle API credentials for seamless data access.



With the dataset ready, let's examine its structure. 

# Exploration


We convert the PyTorch DataSet into a Polars DataSet and perform basic exploratory data analysis:

In [None]:
type(elliptic_dataset)

We have a single graph thus we access element 0 in the data list:

In [None]:
elliptic_dataset[0]

We have the following elements in the dataset:
* **x:** 203.769 nodes with 167 feature values
* **edge_index:** 234.355 pairs of nodes representing the edges between nodes. Note that we transformed the node names into indices. The mapping is stored in *elliptic_dataset.map_id*
* **train_mask:** a mask to indicate which nodes are used to train the model
* **val_mask:** a mask to indicate which nodes are used as validation set
* **test_mask:** a mask to indicate which nodes are used as a test set

In addition, we can query some properties of the dataset:

In [None]:
print(f'Number of node features: {elliptic_dataset.num_features}')
print(f'Number of edge features: {elliptic_dataset.num_edge_features}')
print(f'Number of classes: {elliptic_dataset.num_classes}')
print(f'Feature input matrix shape:{elliptic_dataset.x.shape}')
print(f'Edge index feature matrix shape:{elliptic_dataset.edge_index.shape}')
print(f'Label feature matrix shape:{elliptic_dataset.y.shape}')

In [None]:
import polars as pol

# Convert elliptic_dataset.y to a numpy array and then to a polars Series
y_series = pol.Series(elliptic_dataset.y.numpy())

# Calculate the fraction of each value in the distribution
fraction = y_series.value_counts() 
# Normalize the count column in fraction
fraction = fraction.with_columns(count_normalized = fraction['count'] / y_series.shape[0])


# Print the fraction of the value distribution
print(fraction)


Here we show the distribution of the output class and we observe that roughly 80% has label unknown. 

# Simple model
A model to train with

In [None]:
import torch.nn as nn
import torch_geometric.nn as geom_nn
import torch_geometric.data as geom_data

import torch.optim as optim

import pytorch_lightning as pl

gnn_layer_by_name = {
    "GCN": geom_nn.GCNConv,
    "GAT": geom_nn.GATConv,
    "GraphConv": geom_nn.GraphConv
}

class GNNModel(nn.Module):

    def __init__(self, c_in, c_hidden, c_out, num_layers=2, layer_name="GCN", dp_rate=0.1, **kwargs):
        """
        Inputs:
            c_in - Dimension of input features
            c_hidden - Dimension of hidden features
            c_out - Dimension of the output features. Usually number of classes in classification
            num_layers - Number of "hidden" graph layers
            layer_name - String of the graph layer to use
            dp_rate - Dropout rate to apply throughout the network
            kwargs - Additional arguments for the graph layer (e.g. number of heads for GAT)
        """
        super().__init__()
        gnn_layer = gnn_layer_by_name[layer_name]

        layers = []
        in_channels, out_channels = c_in, c_hidden
        for l_idx in range(num_layers-1):
            layers += [
                gnn_layer(in_channels=in_channels,
                          out_channels=out_channels,
                          **kwargs),
                nn.ReLU(inplace=True),
                nn.Dropout(dp_rate)
            ]
            in_channels = c_hidden
        layers += [gnn_layer(in_channels=in_channels,
                             out_channels=c_out,
                             **kwargs)]
        self.layers = nn.ModuleList(layers)

    def forward(self, x, edge_index):
        """
        Inputs:
            x - Input features per node
            edge_index - List of vertex index pairs representing the edges in the graph (PyTorch geometric notation)
        """
        for l in self.layers:
            # For graph layers, we need to add the "edge_index" tensor as additional input
            # All PyTorch Geometric graph layer inherit the class "MessagePassing", hence
            # we can simply check the class type.
            if isinstance(l, geom_nn.MessagePassing):
                x = l(x, edge_index)
            else:
                x = l(x)

        return x

In [None]:
class MLPModel(nn.Module):

    def __init__(self, c_in, c_hidden, c_out, num_layers=2, dp_rate=0.1):
        """
        Inputs:
            c_in - Dimension of input features
            c_hidden - Dimension of hidden features
            c_out - Dimension of the output features. Usually number of classes in classification
            num_layers - Number of hidden layers
            dp_rate - Dropout rate to apply throughout the network
        """
        super().__init__()
        layers = []
        in_channels, out_channels = c_in, c_hidden
        for l_idx in range(num_layers-1):
            layers += [
                nn.Linear(in_channels, out_channels),
                nn.ReLU(inplace=True),
                nn.Dropout(dp_rate)
            ]
            in_channels = c_hidden
        layers += [nn.Linear(in_channels, c_out)]
        self.layers = nn.Sequential(*layers)

    def forward(self, x, *args, **kwargs):
        """
        Inputs:
            x - Input features per node
        """
        return self.layers(x)

In [None]:
class NodeLevelGNN(pl.LightningModule):

    def __init__(self, model_name, **model_kwargs):
        super().__init__()
        # Saving hyperparameters
        self.save_hyperparameters()

        if model_name == "MLP":
            self.model = MLPModel(**model_kwargs)
        else:
            self.model = GNNModel(**model_kwargs)
        self.loss_module = nn.CrossEntropyLoss()

    def forward(self, data, mode="train"):
        x, edge_index = data.x, data.edge_index
        x = self.model(x, edge_index)

        # Only calculate the loss on the nodes corresponding to the mask
        if mode == "train":
            mask = data.train_mask
        elif mode == "val":
            mask = data.val_mask
        elif mode == "test":
            mask = data.test_mask
        else:
            assert False, f"Unknown forward mode: {mode}"


        loss = self.loss_module(x[mask], data.y[mask].long())
        acc = (x[mask].argmax(dim=-1) == data.y[mask]).sum().float() / mask.sum()
        return loss, acc

    def configure_optimizers(self):
        # We use SGD here, but Adam works as well
        optimizer = optim.SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=2e-3)
        return optimizer

    def training_step(self, batch, batch_idx):
        loss, acc = self.forward(batch, mode="train")
        self.log('train_loss', loss)
        self.log('train_acc', acc)
        return loss

    def validation_step(self, batch, batch_idx):
        _, acc = self.forward(batch, mode="val")
        self.log('val_acc', acc)

    def test_step(self, batch, batch_idx):
        _, acc = self.forward(batch, mode="test")
        self.log('test_acc', acc)


In [None]:
from torch_geometric.loader import DataLoader
import os

CHECKPOINT_PATH = "./logs"

def train_node_classifier(model_name, dataset, **model_kwargs):
    # pl.seed_everything(42)
    node_data_loader = DataLoader(dataset, batch_size = 1)

    # Create a PyTorch Lightning trainer with the generation callback
    root_dir = os.path.join(CHECKPOINT_PATH, "NodeLevel" + model_name)
    os.makedirs(root_dir, exist_ok=True)
    trainer = pl.Trainer(default_root_dir=root_dir,
                         accelerator="gpu",
                         devices=1,
                         max_epochs=00,
                         enable_progress_bar=False) # False because epoch size is 1
    
    # pl.seed_everything()
    model = NodeLevelGNN(model_name=model_name, c_in=167, c_out=3, **model_kwargs)
    trainer.fit(model, node_data_loader, node_data_loader)
    model = NodeLevelGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

    # Test best model on the test set
    test_result = trainer.test(model, node_data_loader, verbose=False)
    batch = next(iter(node_data_loader))
    batch = batch.to(model.device)
    _, train_acc = model.forward(batch, mode="train")
    _, val_acc = model.forward(batch, mode="val")
    result = {"train": train_acc,
              "val": val_acc,
              "test": test_result[0]['test_acc']}
    return model, result

In [None]:
# Small function for printing the test scores
def print_results(result_dict):
    if "train" in result_dict:
        print(f"Train accuracy: {(100.0*result_dict['train']):4.2f}%")
    if "val" in result_dict:
        print(f"Val accuracy:   {(100.0*result_dict['val']):4.2f}%")
    print(f"Test accuracy:  {(100.0*result_dict['test']):4.2f}%")

In [None]:
node_mlp_model, node_mlp_result = train_node_classifier(model_name="MLP",
                                                        dataset=elliptic_dataset,
                                                        c_hidden=16,
                                                        num_layers=4,
                                                        dp_rate=0.1)

print_results(node_mlp_result)

In [None]:
import torch
import polars as pol

output = node_mlp_model.model(elliptic_dataset.x, elliptic_dataset.edge_index)
# Assuming your tensor is named 'tensor'
argmax_tensor = torch.argmax(output, dim=1)

# Convert elliptic_dataset.y to a numpy array and then to a polars Series
y_series = pol.Series(argmax_tensor.numpy())

# Calculate the fraction of each value in the distribution
fraction = y_series.value_counts() 
# Normalize the count column in fraction
fraction = fraction.with_columns(count_normalized = fraction['count'] / y_series.shape[0])

# Print the fraction of the value distribution
print(fraction)

In [None]:
node_gnn_model, node_gnn_result = train_node_classifier(model_name="GNN",
                                                        layer_name="GCN",
                                                        dataset=elliptic_dataset,
                                                        c_hidden=256,
                                                        num_layers=5,
                                                        dp_rate=0.1)
print_results(node_gnn_result)