In [1]:
import os
import random
import gc

In [2]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

In [3]:
import torch
from torch_geometric.data import DataLoader
from torch_geometric.nn import HeteroConv, GATConv, Linear, SAGEConv, to_hetero
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Dropout
from torch.nn.functional import relu, tanh, softmax
import torch_geometric.transforms as T

In [4]:
from IPython.core.debugger import set_trace

In [2]:
from EnhancedGATModel import EnhancedGATModel

ImportError: cannot import name 'EnhancedGATModel' from 'EnhancedGATModel' (/AiDev/GNNNew/EnhancedGATModel.py)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
def LoadGraph(file_name):
    graph = torch.load(file_name)
    graph['pending_transaction'].x = graph['pending_transaction'].x.unsqueeze(0)
    return graph

In [8]:
class EnhancedGATModel(torch.nn.Module):
    @property
    def edge_types(self):
        return [
            ('user', 'owns', 'card'),
            ('card','belongs_to','user'),
            ('user', 'has', 'user_history'),
            ('user_history', 'belongs_to', 'user'),
            ('user_history_transaction', 'part_of', 'user_history'),
            ('user_history_transaction', 'paid_with', 'card'),
            ('user_history_transaction', 'made_at', 'merchant'),
            ('card','paid_for','user_history_transaction'),
            ('merchant', 'made', 'user_history_transaction'),
            ('user_history', 'reflects_on', 'pending_transaction'),            
            ('merchant', 'made', 'user_history_transaction'),
            ('merchant', 'selling', 'pending_transaction'),
            ('user', 'purchasing', 'pending_transaction')
        ]

    def __init__(self, hidden_channels, out_channels, num_layers=2, dropout_rate=0.5):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        self.dropouts = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                edge_type: GATConv((-1, -1), hidden_channels, add_self_loops=False)
                for edge_type in self.edge_types
            }, aggr='mean')
            self.convs.append(conv)
            self.dropouts.append(Dropout(dropout_rate)) 
        
        self.lin1 = Linear(3*hidden_channels, hidden_channels * 6)
        self.lin2 = Linear(hidden_channels * 6, out_channels)
        #self.skip_lin = Linear(105, 3*hidden_channels)
        #self.bn_skip=nn.LayerNorm(105)
        self.dropout_1 = Dropout(dropout_rate)
        self.dropout_2 = Dropout(dropout_rate)

    def forward(self, x_dict, edge_index_dict):
        #x_pending_transaction = x_dict['pending_transaction']
        #x_user = x_dict['user']
        #x_transaction_history=x_dict['user_history']
        #x_skip = torch.cat([x_pending_transaction, x_user, x_transaction_history], dim=1)

        for conv, dropout in zip(self.convs, self.dropouts):
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: tanh(dropout(x)) for key, x in x_dict.items()}
        
        x_pending_transaction = x_dict['pending_transaction']
        x_user = x_dict['user']
        x_transaction_history = x_dict['user_history']
        #print((x_pending_transaction.shape, x_user.shape, x_transaction_history.shape, x_skip.shape))
        
        combined = torch.cat([x_pending_transaction, x_user, x_transaction_history], dim=1)        
        combined = self.dropout_1(combined)
        combined = tanh(self.lin1(combined))        
        
        out = self.lin2(combined)
        return out

In [9]:
fraud = torch.tensor([[1., 0.]], dtype=torch.float)
legit = torch.tensor([[0., 1.]], dtype=torch.float)
fraud=fraud.to(device)
legit=legit.to(device)

def train(model, file_list):
    model.train()
    total_loss = 0
    count=0
    total_true=0
    total_false=0    
    
    #for graph in dataloader:
    for file in file_list:
        optimizer.zero_grad()
        count+=1
        try:
            graph=LoadGraph(file)            
            #graph = T.ToUndirected()(graph)
            graph = graph.to(device)
            graph = T.NormalizeFeatures()(graph)            
            out = model(graph.x_dict, graph.edge_index_dict)
            #details.append([file, graph.y, out])
        except Exception as e:
            print(f'\nFile:{file}, error {e}')
            raise
            continue
            #set_trace()            
        else:
            target = fraud if graph.y==1 else legit            
            loss = F.binary_cross_entropy_with_logits(out, target)
            if train_show:
                print((out,target,loss))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        if count%10 == 0:
            print(".", end="")

        del graph
    gc.collect()
            
    return total_loss / len(file_list)

In [10]:
def validate(model, valid_data):
    model.eval()
    correct = 0
    total = 0
    total_loss = 0.0
    avg_total_loss = 0.0
    true_positive, true_negative, false_positive, false_negative = 0,0,0,0
    
    with torch.no_grad():
        for file_path in valid_data:
            graph = LoadGraph(file_path)
            graph = graph.to(device)

            try:                
                out = model(graph.x_dict, graph.edge_index_dict)
                target = fraud if graph.y==1 else legit
                loss = F.binary_cross_entropy_with_logits(out, target)
                #print((out,target,loss))
                
            except Exception as e:
                print(f'\nFile:{file_path}, error {e}')
                continue;
                set_trace()

            else:
                total_loss += loss.item()
                total += target.size(0)                
                
                predicted = torch.argmax(out, dim=1)
                actual = torch.argmax(target, dim=1)
                #print(f'out={out}, target={target}, predicted={predicted}, actual={actual}')
                true_positive += ((predicted == 0) & (actual == 0)).sum().item()
                true_negative += ((predicted == 1) & (actual == 1)).sum().item()
                false_positive += ((predicted == 0) & (actual == 1)).sum().item()
                false_negative += ((predicted == 1) & (actual == 0)).sum().item()

                correct += (predicted == actual).sum().item()
                
            if total%10 == 0:
                print(".", end="")
            del graph
        gc.collect()

    results_dict = {
        "True Positives": true_positive,
        "True Negatives": true_negative,
        "False Positives": false_positive,
        "False Negatives": false_negative
    }

    accuracy = correct / total
    average_loss = total_loss / total
    return accuracy, average_loss, results_dict

In [11]:
graph_root_directory = '/AiDev/data/Set4'

In [12]:
graphs = pd.read_csv('answer_key_Set4.csv')

In [13]:
fraud_set = graphs[graphs['is_fraud']==1.0]
legit_set = graphs[graphs['is_fraud']==0.0]
print(f'graphs size: {len(graphs)}, frauds: {len(fraud_set)}, legit_set: {len(legit_set)}')

graphs size: 82128, frauds: 27376, legit_set: 54752


In [14]:
tiny_fraud = fraud_set.sample(frac=.05)
tiny_legit = legit_set.sample(n=len(tiny_fraud))
tiny_set = pd.concat([tiny_fraud, tiny_legit])
#tiny_set = tiny_fraud
#print(len(tiny_fraud), len(tiny_set))
print(len(tiny_fraud), len(tiny_legit), len(tiny_set))

1369 1369 2738


In [15]:
final_set = tiny_set['file_path']

In [16]:
train_set, valid_set = train_test_split(final_set, test_size=0.20, random_state=85)

In [17]:
model = EnhancedGATModel(hidden_channels=24, out_channels=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001, weight_decay=.001)
model = model.to(device)

In [18]:
epochs=1
train_show = False

In [19]:
def train_loop(model, train_files, valid_files, epochs=epochs):
    for epoch in range(epochs):
        print("\nTrain")    
        train_loss = train(model, train_files)    
        print("\nValidate")
        valid_acc, valid_loss, valid_details = validate(model, valid_files)
        print(f'Epoch: {epoch+1}, Train Loss: {train_loss:.4f}, Validation Accuracy: {valid_acc:.4f}, Validation Loss: {valid_loss:.4f}')
        print(valid_details)
    
        #torch.save(model.state_dict(), f'models/{run_id}_{epoch}_{train_loss}_{valid_acc}_{valid_loss}_enhanced_gat_model_state.pth')
        #np.save(f'models/{run_id}_{epoch}_confusion.npy', valid_details)
        #del train_details
        del valid_details
        gc.collect()

In [20]:
train_loop(model, train_set, valid_set)


Train
...........................................................................................................................................................................................................................
Validate
......................................................Epoch: 1, Train Loss: 0.6909, Validation Accuracy: 0.5693, Validation Loss: 0.6882
{'True Positives': 218, 'True Negatives': 94, 'False Positives': 171, 'False Negatives': 65}


In [21]:
train_loop(model, train_set, valid_set)


Train
...........................................................................................................................................................................................................................
Validate
......................................................Epoch: 1, Train Loss: 0.6894, Validation Accuracy: 0.6077, Validation Loss: 0.6833
{'True Positives': 244, 'True Negatives': 89, 'False Positives': 176, 'False Negatives': 39}


In [22]:
train_loop(model, train_set, valid_set,3)


Train
...........................................................................................................................................................................................................................
Validate
......................................................Epoch: 1, Train Loss: 0.6884, Validation Accuracy: 0.6715, Validation Loss: 0.6779
{'True Positives': 247, 'True Negatives': 121, 'False Positives': 144, 'False Negatives': 36}

Train
...........................................................................................................................................................................................................................
Validate
......................................................Epoch: 2, Train Loss: 0.6869, Validation Accuracy: 0.6788, Validation Loss: 0.6726
{'True Positives': 247, 'True Negatives': 125, 'False Positives': 140, 'False Negatives': 36}

Train
...........................................................

In [23]:
train_loop(model, train_set, valid_set,3)


Train
...........................................................................................................................................................................................................................
Validate
......................................................Epoch: 1, Train Loss: 0.6818, Validation Accuracy: 0.6971, Validation Loss: 0.6579
{'True Positives': 257, 'True Negatives': 125, 'False Positives': 140, 'False Negatives': 26}

Train
...........................................................................................................................................................................................................................
Validate
......................................................Epoch: 2, Train Loss: 0.6780, Validation Accuracy: 0.7007, Validation Loss: 0.6499
{'True Positives': 257, 'True Negatives': 127, 'False Positives': 138, 'False Negatives': 26}

Train
...........................................................

In [24]:
torch.save(model, 'model_EnhancedGAT_7007')

In [25]:
train_loop(model, train_set, valid_set)


Train
...........................................................................................................................................................................................................................
Validate
......................................................Epoch: 1, Train Loss: 0.6729, Validation Accuracy: 0.7007, Validation Loss: 0.6328
{'True Positives': 257, 'True Negatives': 127, 'False Positives': 138, 'False Negatives': 26}


In [26]:
train_loop(model, train_set, valid_set)


Train
...........................................................................................................................................................................................................................
Validate
......................................................Epoch: 1, Train Loss: 0.6700, Validation Accuracy: 0.7007, Validation Loss: 0.6194
{'True Positives': 257, 'True Negatives': 127, 'False Positives': 138, 'False Negatives': 26}


In [27]:
train_loop(model, train_set, valid_set)


Train
...........................................................................................................................................................................................................................
Validate
......................................................Epoch: 1, Train Loss: 0.6665, Validation Accuracy: 0.7007, Validation Loss: 0.6125
{'True Positives': 257, 'True Negatives': 127, 'False Positives': 138, 'False Negatives': 26}


In [28]:
train_loop(model, train_set, valid_set)


Train
...........................................................................................................................................................................................................................
Validate
......................................................Epoch: 1, Train Loss: 0.6617, Validation Accuracy: 0.7007, Validation Loss: 0.6060
{'True Positives': 257, 'True Negatives': 127, 'False Positives': 138, 'False Negatives': 26}


In [29]:
train_loop(model, train_set, valid_set)


Train
...........................................................................................................................................................................................................................
Validate
......................................................Epoch: 1, Train Loss: 0.6578, Validation Accuracy: 0.7007, Validation Loss: 0.6003
{'True Positives': 257, 'True Negatives': 127, 'False Positives': 138, 'False Negatives': 26}


In [30]:
train_loop(model, train_set, valid_set)


Train
...........................................................................................................................................................................................................................
Validate
......................................................Epoch: 1, Train Loss: 0.6527, Validation Accuracy: 0.7007, Validation Loss: 0.5966
{'True Positives': 257, 'True Negatives': 127, 'False Positives': 138, 'False Negatives': 26}


In [31]:
train_loop(model, train_set, valid_set)


Train
...........................................................................................................................................................................................................................
Validate
......................................................Epoch: 1, Train Loss: 0.6479, Validation Accuracy: 0.7007, Validation Loss: 0.5944
{'True Positives': 257, 'True Negatives': 127, 'False Positives': 138, 'False Negatives': 26}


In [32]:
train_loop(model, train_set, valid_set)


Train
...........................................................................................................................................................................................................................
Validate
......................................................Epoch: 1, Train Loss: 0.6447, Validation Accuracy: 0.7007, Validation Loss: 0.5945
{'True Positives': 257, 'True Negatives': 127, 'False Positives': 138, 'False Negatives': 26}


In [33]:
train_loop(model, train_set, valid_set)


Train
...........................................................................................................................................................................................................................
Validate
......................................................Epoch: 1, Train Loss: 0.6416, Validation Accuracy: 0.7007, Validation Loss: 0.5962
{'True Positives': 257, 'True Negatives': 127, 'False Positives': 138, 'False Negatives': 26}
