In [2]:
import json
import os
import logging
import numpy as np
import pandas as pd
from datetime import datetime
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score, roc_curve
import torch
from torch_geometric.data import Data
import torch.nn as nn
import torch.nn.functional as F
import skfuzzy as fuzz
from skfuzzy import control as ctrl

# Logging setup
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def setup_fuzzy_system():
    cross_border = ctrl.Antecedent(np.arange(0, 1.1, 0.1), 'cross_border')
    country_risk = ctrl.Antecedent(np.arange(0, 1.1, 0.1), 'country_risk')
    pep_involvement = ctrl.Antecedent(np.arange(0, 1.1, 0.1), 'pep_involvement')
    transaction_type = ctrl.Antecedent(np.arange(0, 3, 1), 'transaction_type')
    
    # Consequent
    risk = ctrl.Consequent(np.arange(0, 101, 1), 'risk')

    # Membership Functions
    pep_involvement['no'] = fuzz.trapmf(pep_involvement.universe, [0, 0, 0.3, 0.5])
    pep_involvement['yes'] = fuzz.trapmf(pep_involvement.universe, [0.5, 0.7, 1, 1])
    
    cross_border['low'] = fuzz.trapmf(cross_border.universe, [0, 0, 0.3, 0.5])
    cross_border['high'] = fuzz.trapmf(cross_border.universe, [0.5, 0.7, 1, 1])
    
    country_risk['low'] = fuzz.trapmf(country_risk.universe, [0, 0, 0.3, 0.5])
    country_risk['high'] = fuzz.trapmf(country_risk.universe, [0.5, 0.7, 1, 1])

    transaction_type['crypto_transfer'] = fuzz.trimf(transaction_type.universe, [0, 0, 1])
    transaction_type['payment'] = fuzz.trimf(transaction_type.universe, [1, 1, 2])
    transaction_type['other'] = fuzz.trimf(transaction_type.universe, [2, 2, 2])

    risk['low'] = fuzz.trimf(risk.universe, [0, 0, 50])
    risk['medium'] = fuzz.trimf(risk.universe, [20, 50, 80])
    risk['high'] = fuzz.trimf(risk.universe, [60, 100, 100])
    
    
#     cross_border = ctrl.Antecedent(np.arange(0, 2, 1), 'cross_border')
#     country_risk = ctrl.Antecedent(np.arange(0, 2, 1), 'country_risk')
#     pep_involvement = ctrl.Antecedent(np.arange(0, 2, 1), 'pep_involvement')
#     transaction_type = ctrl.Antecedent(np.arange(0, 3, 1), 'transaction_type')

#     pep_involvement['no'] = fuzz.trimf(pep_involvement.universe, [0, 0, 0.5])
#     pep_involvement['yes'] = fuzz.trimf(pep_involvement.universe, [0.5, 1, 1])
#     cross_border['low'] = fuzz.trimf(cross_border.universe, [0, 0, 1])
#     cross_border['high'] = fuzz.trimf(cross_border.universe, [0, 1, 1])
#     country_risk['low'] = fuzz.trimf(country_risk.universe, [0, 0, 1])
#     country_risk['high'] = fuzz.trimf(country_risk.universe, [0, 1, 1])
#     transaction_type['crypto_transfer'] = fuzz.trimf(transaction_type.universe, [0, 0, 1])
#     transaction_type['payment'] = fuzz.trimf(transaction_type.universe, [1, 1, 2])
#     transaction_type['other'] = fuzz.trimf(transaction_type.universe, [2, 2, 2])

    risk = ctrl.Consequent(np.arange(0, 101, 1), 'risk')
    risk['low'] = fuzz.trimf(risk.universe, [0, 0, 40])
    risk['medium'] = fuzz.trimf(risk.universe, [30, 50, 70])
    risk['high'] = fuzz.trimf(risk.universe, [60, 100, 100])

    rule1 = ctrl.Rule(transaction_type['crypto_transfer'] | transaction_type['payment'], risk['high'])
    rule2 = ctrl.Rule(pep_involvement['yes'] | country_risk['high'], risk['high'])
    rule3 = ctrl.Rule(cross_border['high'], risk['medium'])
    rule4 = ctrl.Rule(cross_border['low'] & transaction_type['other'], risk['low'])

    aml_control = ctrl.ControlSystem([rule1, rule2, rule3, rule4])
    aml_sim = ctrl.ControlSystemSimulation(aml_control)
    return aml_sim

def evaluate_transaction(row, aml_sim):
    transaction_type_map = {'CRYPTO-TRANSFER': 1, 'PAYMENT': 1, 'OTHER': 2}
    transaction_type_value = transaction_type_map.get(row['Transaction_Type'], 2)
    pep_involvement_value = 1 if row['Bene_Is_Pep'] or row['Sender_Is_Pep'] else 0
    cross_border_value = 1 if row['Sender_Country'] != row['Bene_Country'] else 0
    high_risk_countries = ['Iran', 'Syria', 'North-Korea']
    country_risk_value = 1 if row['Bene_Country'] in high_risk_countries else 0

    aml_sim.input['transaction_type'] = transaction_type_value
    aml_sim.input['pep_involvement'] = pep_involvement_value
    aml_sim.input['cross_border'] = cross_border_value
    aml_sim.input['country_risk'] = country_risk_value

    aml_sim.compute()
    risk_score = aml_sim.output['risk']

    reasons = []
    if transaction_type_value == 1:
        reasons.append('Transaction type: High-risk transaction (crypto transfer/payment)')
    if pep_involvement_value == 1:
        reasons.append('PEP involvement')
    if cross_border_value == 1:
        reasons.append('Cross-border transaction')
    if country_risk_value == 1:
        reasons.append('High-risk country involvement')

    return risk_score, reasons

class EdgeGCN_LSTM(nn.Module):
    def __init__(self, hidden_channels, lstm_hidden_channels, out_channels, dropout_rate):
        super(EdgeGCN_LSTM, self).__init__()
        self.conv1 = GCNConv(1, hidden_channels)
        self.lstm = nn.LSTM(input_size=hidden_channels * 2 + 3, hidden_size=lstm_hidden_channels, batch_first=True)
        self.lin = nn.Linear(lstm_hidden_channels, out_channels)
        self.dropout_rate = dropout_rate

    def forward(self, x, edge_index, edge_attr):
        x = F.dropout(F.relu(self.conv1(x, edge_index)), p=self.dropout_rate, training=self.training)
        sender_features = x[edge_index[0]]
        receiver_features = x[edge_index[1]]
        edge_features = torch.cat([sender_features, receiver_features, edge_attr], dim=1)
        edge_features = edge_features.unsqueeze(0)
        lstm_out, _ = self.lstm(edge_features)
        lstm_out = lstm_out.squeeze(0)
        out = self.lin(lstm_out)
        return out.view(-1)

class GraphDataProcessor:
    def __init__(self, df):
        self.df = df
        self.aml_sim = setup_fuzzy_system()

    def evaluate_risk_scores(self):
        results = self.df.apply(lambda row: evaluate_transaction(row, self.aml_sim), axis=1)
        self.df['Risk_Score'], self.df['Risk_Reasons'] = zip(*results)

    def prepare_graph_data(self):
        self.evaluate_risk_scores()
        self.df['Time_step'] = pd.to_datetime(self.df['Time_step'])
        self.df = self.df.sort_values(by=['Sender_Customer_Id', 'Time_step'])
        self.df['Label'] = pd.to_numeric(self.df['Label'], errors='coerce').fillna(0).astype(int)
        
        all_ids = pd.concat([self.df['Sender_Customer_Id'], self.df['Bene_Customer_Id']]).unique()
        id_map = {id: idx for idx, id in enumerate(all_ids)}

        sender_customer_mapped = self.df['Sender_Customer_Id'].map(id_map).values
        bene_customer_mapped = self.df['Bene_Customer_Id'].map(id_map).values

        edge_index = torch.tensor(
            np.array([sender_customer_mapped, bene_customer_mapped]), 
            dtype=torch.long
        )

        node_features = torch.zeros((len(all_ids), 1))

        transaction_type_encoded = torch.tensor(
            pd.to_numeric(LabelEncoder().fit_transform(self.df['Transaction_Type']), errors='coerce').astype(float),
            dtype=torch.float
        ).view(-1, 1)

        self.df['USD_Amount'] = pd.to_numeric(self.df['USD_Amount'], errors='coerce')
        usd_amount_np = StandardScaler().fit_transform(self.df[['USD_Amount']].astype(float))
        usd_amount = torch.tensor(usd_amount_np, dtype=torch.float).view(-1, 1)
        
        risk_score = torch.tensor(
            pd.to_numeric(self.df['Risk_Score'].apply(lambda x: x[0] if isinstance(x, (list, tuple)) else x),
            errors='coerce').astype(float),
            dtype=torch.float
        ).view(-1, 1)

        edge_attr = torch.cat([transaction_type_encoded, usd_amount, risk_score], dim=1)

        edge_labels = torch.tensor(
            pd.to_numeric(self.df['Label'], errors='coerce').astype(int).values,
            dtype=torch.long
        )

        return Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr, y=edge_labels)

def process_message(data, model, device, output_directory):
    try:
        logging.info("Preparing DataFrame from input data")
        df = pd.DataFrame(data)
        logging.info(f"DataFrame created with {len(df)} rows")

        logging.info("Preparing graph data")
        processor = GraphDataProcessor(df)
        graph_data = processor.prepare_graph_data().to(device)
        logging.info(f"Graph data prepared with {graph_data.num_nodes} nodes and {graph_data.num_edges} edges")

        logging.info("Evaluating model")
        model.eval()
        with torch.no_grad():
            output = model(graph_data.x, graph_data.edge_index, graph_data.edge_attr)
            predictions = torch.sigmoid(output).cpu().numpy()
            predictions = predictions.flatten()
            logging.info(f"Predictions obtained with length {len(predictions)}")

        df['Predictions'] = predictions
        df['Label_Prediction'] = (predictions >= 0.5).astype(int)

        # Include reasons for risk scores
        df['Risk_Reasons'] = df['Risk_Reasons'].apply(lambda reasons: ' | '.join(reasons))

        output_file = os.path.join(output_directory, f'processed_{datetime.now().strftime("%Y%m%d_%H%M%S")}.csv')
        df.to_csv(output_file, index=False)
        logging.info(f"Results written to {output_file}")
        
    except Exception as e:
        logging.error(f"Error processing message: {str(e)}")
