In [15]:
import pandas as pd
import networkx as nx
from datetime import datetime
from typing import List, Dict, Any, Tuple
import numpy as np

In [16]:
entities_df = pd.read_parquet('../data/jp_morgan/sorted/entities_final_fraud.parquet')
transactions_df =  pd.read_parquet('../data/jp_morgan/sorted/nodes_final_fraud.parquet')

In [17]:
import networkx as nx
import pandas as pd
from typing import Optional

class TransactionGraph:
    def __init__(self):
        self.G = nx.MultiDiGraph()
        # Create CASH node
        self.G.add_node('CASH', node_type='CASH')
        
    def add_entity_node(self, entity_id: str, country: Optional[str] = None):
        """Add an entity node with optional country info"""
        node_type = 'ENTITY' if country else 'STANDALONE'
        attrs = {'node_type': node_type}
        if country:
            attrs['country'] = country
        self.G.add_node(entity_id, **attrs)
        
    def add_account_node(self, entity_id: str, account: str):
        """Add an account node and link it to its parent entity"""
        account_id = f"{entity_id}_{account}"
        self.G.add_node(account_id, 
                       node_type='ACCOUNT',
                       entity_id=entity_id,
                       account_number=account)
        self.G.add_edge(entity_id, account_id, relationship='HAS_ACCOUNT')
        
    def batch_create_nodes(self, df: pd.DataFrame):
        """Batch create nodes from a dataframe with Id, Account, Country columns"""
        # First create all entity nodes
        entities = df[['Id', 'Country']].drop_duplicates()
        for _, row in entities.iterrows():
            self.add_entity_node(str(row['Id']), row['Country'])
            
        # Then create account nodes where they exist
        accounts = df[df['Account'].notna()][['Id', 'Account']]
        for _, row in accounts.iterrows():
            self.add_account_node(str(row['Id']), str(row['Account']))


    def add_transaction_edge(self, party_id: str, counterparty_id: str,
                           transaction_id: str, amount: float,
                           txn_time_hr: int, txn_time_mm: int, txn_age_days: int,
                           is_credit: bool, std_txn_type: str):
        """Add a transaction edge between nodes with relative time and direction based on credit flag
        
        Args:
            party_id: ID of the party node
            counterparty_id: ID of the counterparty node 
            transaction_id: Unique transaction identifier
            amount: Transaction amount
            txn_time_hr: Hour of transaction (0-23)
            txn_time_mm: Minute of transaction (0-59) 
            txn_age_days: Days since transaction
            is_credit: True if money flows into party account, False if flows out
            std_txn_type: Standardized transaction type
        """
        # For deposits/withdrawals, use CASH node
        if std_txn_type in ['DEPOSIT', 'WITHDRAWAL']:
            if std_txn_type == 'DEPOSIT':
                from_id = 'CASH'
                to_id = party_id
            else:  # WITHDRAWAL
                from_id = party_id
                to_id = 'CASH'
        else:
            # For other transactions, direction based on is_credit
            if is_credit:
                from_id = counterparty_id
                to_id = party_id
            else:
                from_id = party_id
                to_id = counterparty_id
                
        self.G.add_edge(from_id, to_id,
                       transaction_id=transaction_id,
                       amount=amount,
                       txn_time_hr=txn_time_hr,
                       txn_time_mm=txn_time_mm, 
                       txn_age_days=txn_age_days,
                       relationship='TRANSACTION',
                       std_txn_type=std_txn_type)

    def batch_add_transaction_edges(self, df: pd.DataFrame):
        """Batch add transaction edges from a dataframe
        
        Args:
            df: DataFrame with columns:
                - party_id: ID of the party node
                - counterparty_id: ID of the counterparty node
                - transaction_id: Unique transaction identifier 
                - amount: Transaction amount
                - txn_time_hr: Hour of transaction (0-23)
                - txn_time_mm: Minute of transaction (0-59)
                - txn_age_days: Age of entity on transaction
                - is_credit: True if money flows into party account
                - std_txn_type: Standardised transaction type
        """
        # Vectorized processing of from_id and to_id based on transaction type
        df['from_id'] = df.apply(lambda x: 
            'CASH' if x['std_txn_type'] == 'DEPOSIT' else
            x['party_id'] if x['std_txn_type'] == 'WITHDRAWAL' else
            x['counterparty_id'] if x['is_credit'] else x['party_id'], axis=1)
            
        df['to_id'] = df.apply(lambda x:
            x['party_id'] if x['std_txn_type'] == 'DEPOSIT' else
            'CASH' if x['std_txn_type'] == 'WITHDRAWAL' else
            x['party_id'] if x['is_credit'] else x['counterparty_id'], axis=1)

        # Batch add edges
        edges = [(row['from_id'], row['to_id'], {
            'transaction_id': row['transaction_id'],
            'amount': row['amount'],
            'txn_time_hr': row['txn_time_hr'],
            'txn_time_mm': row['txn_time_mm'],
            'txn_age_days': row['txn_age_days'],
            'relationship': 'TRANSACTION',
            'std_txn_type': row['std_txn_type']
        }) for _, row in df.iterrows()]
        
        self.G.add_edges_from(edges)
        
    
    # def add_transaction_edge(self, from_id: str, to_id: str, 
    #                        transaction_id: str, amount: float, 
    #                        timestamp: str):
    #     """Add a transaction edge between nodes"""
    #     self.G.add_edge(from_id, to_id,
    #                    transaction_id=transaction_id,
    #                    amount=amount,
    #                    timestamp=timestamp,
    #                    relationship='TRANSACTION')

# Example usage:
# graph = TransactionGraph()
# df = pd.DataFrame({
#     'Id': ['E1', 'E2', 'E3'],
#     'Account': ['A1', 'A2', None],
#     'Country': ['US', 'UK', None]
# })
# graph.batch_create_nodes(df)
# 
# # Add a transaction
# graph.add_transaction_edge('E1_A1', 'E2_A2', 'T1', 1000.0, '2023-01-01')


In [18]:
graph = TransactionGraph()

In [19]:
graph.batch_create_nodes(entities_df)

In [20]:
graph.batch_add_transaction_edges(transactions_df)

In [None]:
graph.G.size()

In [22]:
import pickle

with open('../data/jp_morgan/pickled/graph_fraud_final.pickle', 'wb') as f:
    pickle.dump(graph.G, f)

In [None]:
nx.draw(graph.G)

In [17]:
from pyvis.network import Network

nx_graph = graph.G
nt = Network('1000px', '1000px')
nt.from_nx(nx_graph)
nt.save_graph('nx.html', )

In [None]:
graph.G.edges.data()

In [None]:
graph.G.is_multigraph()