In [None]:
# Mount Google Colab drive
from google.colab import drive
drive.mount('/content/drive')

# Imports
import os
import numpy as np
import pandas as pd
import joblib
import re

from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from time import time
import networkx as nx
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn import tree
from sklearn.tree import _tree

# Install torch-geometric
!pip install torch-geometric
from torch_geometric.data import HeteroData

from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.ensemble import RandomForestClassifier

# Change directory to location
loc = "/content/drive/MyDrive/KE_GNN/"
os.chdir(loc)
os.getcwd()


# how much of a size reduction to the total transaction set
size_reduction = 0.3
df = pd.read_csv("{}clean_processed_transactions.csv".format(loc)).drop('Unnamed: 0',axis =1)



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Collecting torch-geometric
  Downloading torch_geometric-2.5.3-py3-none-any.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
Collecting aiohttp (from torch-geometric)
  Downloading aiohttp-3.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m24.1 MB/s[0m eta [36m0:00:00[0m
Collecting aiosignal>=1.1.2 (from aiohttp->torch-geometric)
  Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)
Collecting frozenlist>=1.1.1 (from aiohttp->torch-geometric)
  Downloading frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (239 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m239.5/239.5 kB[0m [31m25.2

In [None]:
# features for the transactions nodes
features = ['transaction_pk',
                 'DoW', 'minute', 'number_active_cards', 'number_active_accounts', 'Amount','age at time',  # random varaiables
                 'User error_counter',  'User-Merchant error_counter',  'User-Card error_counter',  'User-MCC error_counter', 'Merchant error_counter', # user error counts = a rolling count of pprevious errors
                 'User - previous_error', 'User-Merchant - previous_error', 'User-Card - previous_error', 'User-MCC - previous_error', 'Merchant - previous_error',  # indicator if the previous payments were errors
                 'User CS', 'User CC', 'User CM', 'User CSTD', 'User CM3', 'User CSTD3', 'User CSTD7', 'User CM7', # user relationship
                 'User-Merchant CS', 'User-Merchant CC', 'User-Merchant CM', 'User-Merchant CSTD',  'User-Merchant CM3', 'User-Merchant CSTD3', 'User-Merchant CSTD7', 'User-Merchant CM7', # user- merchant relationship
                 'User-Card CS', 'User-Card CC', 'User-Card CM', 'User-Card CSTD', 'User-Card CM3', 'User-Card CSTD3', 'User-Card CSTD7', 'User-Card CM7', # user-card relationship
                 'User-MCC CS', 'User-MCC CC', 'User-MCC CM', 'User-MCC CSTD','User-MCC CM3', 'User-MCC CSTD3','User-MCC CSTD7', 'User-MCC CM7', # user - mcc relationship
                 'Merchant CS', 'Merchant CC', 'Merchant CM', 'Merchant CSTD', 'Merchant CM3', 'Merchant CSTD3', 'Merchant CSTD7', 'Merchant CM7', # merchant relationship
                 'User_lag_tester_payment_1', 'User_lag_tester_payment_5', 'User_lag_tester_payment_10',  'User_lag_tester_payment_20', # user test lag payments
                 'User-Merchant_lag_tester_payment_1', 'User-Merchant_lag_tester_payment_5', 'User-Merchant_lag_tester_payment_10', 'User-Merchant_lag_tester_payment_20', # user merchant relationship tester payments
                 'User-Card_lag_tester_payment_1', 'User-Card_lag_tester_payment_5', 'User-Card_lag_tester_payment_10', 'User-Card_lag_tester_payment_20', # user cards tester payment
                 'User-MCC_lag_tester_payment_1', 'User-MCC_lag_tester_payment_5', 'User-MCC_lag_tester_payment_10', 'User-MCC_lag_tester_payment_20', # user MCC tester payments
                 'Merchant_lag_tester_payment_1', 'Merchant_lag_tester_payment_5', 'Merchant_lag_tester_payment_10', 'Merchant_lag_tester_payment_20', # merchant tester paymetns
                 'User occurance 1 mins', 'User occurance 10 mins', # user occ in mins
                 'User-Merchant occurance 1 mins', 'User-Merchant occurance 10 mins', # user merchant occurance in mins
                 'User-Card occurance 1 mins', 'User-Card occurance 10 mins', # user card occurance in mins
                 'User-MCC occurance 1 mins', 'User-MCC occurance 10 mins', # user MCC occurance in mins
                 'Merchant occurance 1 mins', 'Merchant occurance 10 mins',  # merchants in mins
                 'Error - Bad input', 'Error - Insuf bal', 'Error - Tech Glitch' ,
                 'OH1: Chip Transaction', 'OH1: Online Transaction', 'OH1: Swipe Transaction',
                 'FR: Merchant City', 'FR: Merchant State', 'FR: Zip', 'OH4: Ohio', 'OH4: Online', 'OH4: US',
                 'OH4: high_risk', 'OH4: world_non_us', 'OH5: Ohio', 'OH5: US',
                 'Per Capita Income - Zipcode','FR: Zipcode', 'Gender','FICO Score', 'Total Debt',
                 'FR: MCC', 'OH2: Agricultural Services', 'OH2: Contracted Services', 'OH2: Transportation Services',
                 'OH2: Utility Services', 'OH2: Retail Outlet Services', 'OH2: Clothing Stores',
                 'OH2: Miscellaneous Stores', 'OH2: Business Services', 'OH2: Professional Services and Membership Organizations', 'OH2: Government Services']

# features for users nodes
u_features = ['User_index','Per Capita Income - Zipcode','FR: Zipcode', 'Gender','FICO Score', 'Total Debt']

# features for the merchant nodes
m_features = ['Merchant_index', 'Merchant in Counry',  'FR: MCC',
              'OH2: Agricultural Services', 'OH2: Contracted Services', 'OH2: Transportation Services',
                 'OH2: Utility Services', 'OH2: Retail Outlet Services', 'OH2: Clothing Stores',
                 'OH2: Miscellaneous Stores', 'OH2: Business Services', 'OH2: Professional Services and Membership Organizations',
                 'OH2: Government Services' ]
# location node features
l_features = ['merch_city_index','state_value']

# features for the card node
c_features = ['user_card_index', 'OH2: Amex', 'OH2: Discover', 'OH2: Mastercard', 'OH2: Visa',
              'OH3: Credit', 'OH3: Debit', 'OH3: Debit (Prepaid)','Credit Limit']

# total collection of features
total_features = features + u_features + m_features + l_features + c_features

# prints features not currently being used in the graph structures
print('Features NOT being used: ')
[x for x in df.columns.tolist() if x not in total_features]


Features NOT being used: 


['User',
 'Card',
 'Year',
 'Month',
 'Day',
 'Time',
 'Merchant Name',
 'Merchant City',
 'Merchant State',
 'Zip',
 'MCC',
 'Is Fraud?',
 'PK',
 'date_time',
 'DOB',
 'Zipcode',
 'previous_error',
 'user_card',
 'State',
 'Merchant State2',
 'User State2',
 'Fraud2']

In [None]:
# creating an index sorting each user helps with creating graph structure
def index_creator(df):

  df = df.sort_values(['date_time']).reset_index(drop=True)
  df['transaction_pk'] = df.index

  index_map = {}
  assigned_index = 0

  # Assign the incremental index starting from 0 for each primary key
  for key in df['User'].unique():
      index_map[key] = assigned_index
      assigned_index += 1

  # Map the primary keys to their corresponding starting index
  df['User_index'] = df['User'].map(index_map)

  index_map = {}
  assigned_index = 0
  #assigned_index = max(df['User_index']) + 1

  # Assign the incremental index starting from 0 for each primary key
  for key in df['Merchant Name'].unique():
      index_map[key] = assigned_index
      assigned_index += 1

  # Map the primary keys to their corresponding starting index
  df['Merchant_index'] = df['Merchant Name'].map(index_map)



  index_map = {}
  assigned_index = 0

  index_map = {}
  assigned_index = 0
  #assigned_index = max(df['User_index']) + 1

  # Assign the incremental index starting from 0 for each primary key
  for key in df['user_card'].unique():
      index_map[key] = assigned_index
      assigned_index += 1

  # Map the primary keys to their corresponding starting index
  df['user_card_index'] = df['user_card'].map(index_map)
  return df

def graph_maker(df,test_train):
  '''
  input is the df that is to be made into a graph
  test_train is either train or test
  '''
  data = HeteroData() # Full graph



  #user <-> card
  user_edge_index = df[['User_index', 'user_card_index']].drop_duplicates().to_numpy().T
  rev_user_edge_index = df[['user_card_index', 'User_index']].drop_duplicates().to_numpy().T

  data['user', 'owns', 'card'].edge_index = torch.from_numpy(user_edge_index)
  data['card', 'rev_own', 'user'].edge_index = torch.from_numpy(rev_user_edge_index)

  #card <-> transaction
  card_edge_index = df[['user_card_index', 'transaction_pk']].drop_duplicates().to_numpy().T
  rev_card_edge_index = df[['transaction_pk', 'user_card_index']].drop_duplicates().to_numpy().T

  data['card', 'transfer', 'transaction'].edge_index = torch.from_numpy(card_edge_index)
  data['transaction', 'rev_transfer', 'card'].edge_index = torch.from_numpy(rev_card_edge_index)




  # creating the location information I think
  df = df.astype({
              'Merchant State': 'category',
              'State': 'category',
            })

  codes = dict(zip(df['Merchant State'], df['Merchant State'].cat.codes))
  df['Merchant State1'] = df['Merchant State'].map(codes)
  df['State1'] = df['State'].map(codes)
  df = df.astype({
              'Merchant State1': 'int64',
              'State1': 'int64',
            })


  #person <-> location
  location_edge_index1 = df[['User_index', 'State1']].drop_duplicates().to_numpy().T
  rev_location_edge_index1 = df[['State1', 'User_index']].drop_duplicates().to_numpy().T



  data['user', 'happened_at', 'location'].edge_index = torch.from_numpy(location_edge_index1)
  data['location', 'rev_happend_at', 'user'].edge_index = torch.from_numpy(rev_location_edge_index1)


  #merchant <-> transaction
  merchant_edge_index = df[['Merchant_index', 'transaction_pk']].drop_duplicates().to_numpy().T
  rev_merchant_edge_index = df[['transaction_pk', 'Merchant_index']].drop_duplicates().to_numpy().T

  data['merchant', 'transfer', 'transaction'].edge_index = torch.from_numpy(merchant_edge_index)
  data['transaction', 'rev_transfer', 'merchant'].edge_index = torch.from_numpy(rev_merchant_edge_index)


  # user <-> transaction
  user_transaction_edge_index = df[['User_index', 'transaction_pk']].drop_duplicates().to_numpy().T
  rev_user_transaction_edge_index = df[['transaction_pk', 'User_index']].drop_duplicates().to_numpy().T

  data['user', 'bought', 'transaction'].edge_index = torch.from_numpy(user_transaction_edge_index)
  data['transaction', 'rev_bought', 'user'].edge_index = torch.from_numpy(rev_user_transaction_edge_index)
  # user <-> merchant
  user_merchant_edge_index = df[['User_index', 'Merchant_index']].drop_duplicates().to_numpy().T
  rev_user_merchant_edge_index = df[['Merchant_index', 'User_index']].drop_duplicates().to_numpy().T

  data['user', 'bought_from', 'merchant'].edge_index = torch.from_numpy(user_merchant_edge_index)
  data['merchant', 'rev_bought_from', 'user'].edge_index = torch.from_numpy(rev_user_merchant_edge_index)

  # transaction <-> loc
  transaction_loc_edge_index = df[['transaction_pk', 'Merchant State1']].drop_duplicates().to_numpy().T
  rev_transaction_loc_edge_index = df[['Merchant State1', 'transaction_pk']].drop_duplicates().to_numpy().T

  data['transaction', 'bought_in', 'location'].edge_index = torch.from_numpy(transaction_loc_edge_index)
  data['location', 'rev_bought_in', 'transaction'].edge_index = torch.from_numpy(rev_transaction_loc_edge_index)

  # card <-> merchant

  card_merchant_edge_index = df[['user_card_index', 'Merchant_index']].drop_duplicates().to_numpy().T
  rev_card_merchant_edge_index = df[['Merchant_index', 'user_card_index']].drop_duplicates().to_numpy().T

  data['card', 'bought_with', 'merchant'].edge_index = torch.from_numpy(card_merchant_edge_index)
  data['merchant', 'rev_bought_with', 'card'].edge_index = torch.from_numpy(rev_card_merchant_edge_index)

  # location <-> location

  locs_edge_index = df[['State1', 'Merchant State1']].drop_duplicates().to_numpy().T
  rev_locs_edge_index = df[['Merchant State1', 'State1']].drop_duplicates().to_numpy().T

  data['location', 'at', 'location'].edge_index = torch.from_numpy(locs_edge_index)
  data['location', 'rev_at', 'location'].edge_index = torch.from_numpy(rev_locs_edge_index)

  # transaction features
  x = df[features].drop('transaction_pk',axis=1).to_numpy(dtype='float32')
  y = df['Is Fraud?'].to_numpy(dtype='float32').reshape(-1,1)
  data['transaction'].x = torch.from_numpy(x)
  data['transaction'].y = torch.from_numpy(y)

  #merchant features
  merchant_data = df[m_features].sort_values('Merchant_index').drop_duplicates(subset=['Merchant_index']).to_numpy(dtype='float32')[:,1:]
  data['merchant'].x = torch.from_numpy(merchant_data)
  #features
  user_data = df[u_features].sort_values('User_index').drop_duplicates(subset=['User_index']).to_numpy(dtype='float32')[:,1:]
  data['user'].x = torch.from_numpy(user_data)

  #card
  card_data = df[c_features].sort_values('user_card_index').drop_duplicates(subset=['user_card_index']).to_numpy(dtype='float32')[:,1:]
  data['card'].x = torch.from_numpy(card_data)
  #location

  data['location'].x = torch.from_numpy(np.ones((len(df['Merchant State'].unique()), 1), dtype='float32'))
  # masks
  if test_train == 'train':
    a = int(len(df))
    b = 0
    train_mask = np.concatenate((np.ones(a, dtype=bool), np.zeros(b, dtype=bool)), axis=0)
    data['transaction'].train_mask = torch.from_numpy(train_mask)
  elif test_train == 'test':
    a = 0
    b = int(len(df))
    test_mask = np.concatenate((np.ones(a, dtype=bool), np.zeros(b, dtype=bool)), axis=0)
    data['transaction'].test_mask = torch.from_numpy(test_mask)
  print(data.validate())
  return data, df[features].drop('transaction_pk',axis=1).columns.tolist()

df = df.sort_values(['date_time']).reset_index(drop=True)
a = int(len(df) * (1- size_reduction))
df_reduced = df[:a]

df_train_t = df_reduced[0: int(len(df_reduced)*.70)]
df_train_v = df_reduced[int(len(df_reduced)*.70): int(len(df_reduced)*.85)]
df_test = df_reduced[int(len(df_reduced)*.85):]
print('total size: {} , train len: {}, valid len: {}, test len: {}'.format(len(df_reduced), len(df_train_t),
                                                                           len(df_train_v), len(df_test)))

df_train_t = index_creator(df_train_t)
df_train_v = index_creator(df_train_v)

train_data, transaction_features = graph_maker(df_train_t,'train')
valid_data, transaction_features = graph_maker(df_train_v,'train')

df_test = index_creator(df_test)
test_data, transaction_features = graph_maker(df_test,'test')

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)

torch.save(test_data, '{}Graph storage/test_graph.pt'.format(loc))
torch.save(train_data, '{}Graph storage/train_graph.pt'.format(loc))
torch.save(valid_data, '{}Graph storage/valid_graph.pt'.format(loc))

#save df test, train, valid for XGBOOST model
df_train_t.to_csv('{}Graph storage/train_df.csv'.format(loc),index=False)
df_train_v.to_csv('{}Graph storage/valid_df.csv'.format(loc),index=False)
df_test.to_csv('{}Graph storage/test_df.csv'.format(loc), index=False)



total size: 17341766 , train len: 12139236, valid len: 2601265, test len: 2601265


  df['Merchant State1'] = df['Merchant State'].map(codes)
  df['State1'] = df['State'].map(codes)


True


  df['Merchant State1'] = df['Merchant State'].map(codes)
  df['State1'] = df['State'].map(codes)


True


  df['Merchant State1'] = df['Merchant State'].map(codes)
  df['State1'] = df['State'].map(codes)


True
cpu


Creating the random forest to be used for the clauses.




In [None]:

dec_model = RandomForestClassifier(n_estimators=100, max_depth=10,
                                    min_samples_split = 10,
                                   random_state=420, n_jobs=-1)
dec_model.fit(df_train_t[transaction_features],
              df_train_t[['Is Fraud?']])

# ensuring that model was trained correctly
y_pred = dec_model.predict(df_train_v[transaction_features])

print(classification_report(df_train_v[['Is Fraud?']], y_pred))
# save
joblib.dump(dec_model, '{}clause_random_forest.joblib'.format(loc))

  dec_model.fit(df_train_t[transaction_features],


              precision    recall  f1-score   support

           0       1.00      1.00      1.00   2598056
           1       0.99      0.42      0.59      3209

    accuracy                           1.00   2601265
   macro avg       1.00      0.71      0.80   2601265
weighted avg       1.00      1.00      1.00   2601265



['/content/drive/MyDrive/KE_GNN/clause_random_forest.joblib']

In [None]:
# Function to extract decision rules from a tree model
def get_rules(tree, feature_names, class_names):
    rules = []  # List to store the extracted rules
    ruleD = {}  # Dictionary to store the details of the rules

    # Loop through each estimator in the tree ensemble
    for tree_idx, est in enumerate(tree.estimators_):
        tree_ = est.tree_  # Get the tree structure
        feature_name = [
            feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature
        ]  # Get the feature names, handling undefined features

        paths = []  # List to store all paths (rules) in the tree
        path = []  # List to store the current path

        # Recursive function to traverse the tree and extract paths
        def recurse(node, path, paths):
            if tree_.feature[node] != _tree.TREE_UNDEFINED:  # If the node is not a leaf
                name = feature_name[node]
                threshold = tree_.threshold[node]
                p1, p2 = list(path), list(path)  # Create copies of the current path
                p1 += [f"({name} <= {np.round(threshold, 3)})"]  # Append condition for left child
                recurse(tree_.children_left[node], p1, paths)  # Recurse to the left child
                p2 += [f"({name} > {np.round(threshold, 3)})"]  # Append condition for right child
                recurse(tree_.children_right[node], p2, paths)  # Recurse to the right child
            else:
                path += [(tree_.value[node], tree_.n_node_samples[node])]  # Append leaf node value and sample count
                paths += [path]  # Add the complete path to paths

        recurse(0, path, paths)  # Start recursion from the root node

        # Sort paths by the number of samples in descending order
        samples_count = [p[-1][1] for p in paths]
        ii = list(np.argsort(samples_count))
        paths = [paths[i] for i in reversed(ii)]

        # Construct rules from paths
        for path in paths:
            rule = "if "
            for p in path[:-1]:
                if rule != "if ":
                    rule += " and "
                rule += str(p)
            rule += " then "
            if class_names is None:
                rule += "response: " + str(np.round(path[-1][0][0][0], 3))  # Add response value if class names are not provided
            else:
                classes = path[-1][0][0]
                l = np.argmax(classes)
                rule += f"class: {class_names[l]} (proba: {np.round(100.0 * classes[l] / np.sum(classes), 2)}%)"  # Add class and probability
            rule += f" | based on {path[-1][1]:,} samples"  # Add sample count
            if class_names[l] == '1Fraud':
                rules += [rule]  # Add the rule to the list if the class is '1Fraud'
                ruleD[rule] = [class_names[l], np.round(100.0 * classes[l] / np.sum(classes), 2), path[-1][1]]  # Add rule details to the dictionary

    return rules, ruleD  # Return the extracted rules and their details

rules, ruleDic = get_rules(dec_model, transaction_features, ['Non-Fraud', '1Fraud'])


def extract_conditions(string):
  '''
  cleans the string and extracts the conditions
  '''
  conditions = []
  pattern = r'\((.*?)\)'
  matches = re.findall(pattern, string)[:-1]
  for match in matches:
      parts = match.split()
      variable_name = match.replace(parts[-2], '').replace(parts[-1], '').strip()
      conditions.append([variable_name, parts[-2], float(parts[-1])])
  return conditions

# Extract conditions from the string column
rule_df = pd.DataFrame.from_dict(ruleDic, orient='index').reset_index()
rule_df.columns = ['rule','fraud','Perc','sample']

def rule_extraction(df,num, sample=500, perc = 90):
  '''
  extracts the rules that meet the criteria
  '''
  df = df[df['sample'] >= sample]
  df = df[df['Perc'] >= perc]
  df = df.sort_values('Perc', ascending = False)[:num].reset_index()
  print(df)
  rule_dic = {}
  for index,row in df.iterrows():
    rule_number = 1 + index
    rule_dic['RULE{}'.format(rule_number)] = extract_conditions(row['rule'])
  return rule_dic

first_dict = rule_extraction(df = rule_df, num = 10,  sample = 750, perc = 90)
first_dict
def add_feature_location(features, rule_dict):
  '''adds the location (index) of each feature on the transaction node'''
  final_dict = {}
  for k, v in rule_dict.items():
    conditions = list()
    for x in v:
      conditions2 = list()
      conditions2.append((x[0], features.index(x[0])))
      conditions2.append(x[1])
      conditions2.append(x[2])
      conditions.append(conditions2)
      final_dict[k] = conditions
  return final_dict
final_dict = add_feature_location(transaction_features, first_dict)
final_dict

   index                                               rule   fraud   Perc  \
0   1714  if (OH4: high_risk <= 0.5) and (FR: Merchant S...  1Fraud  100.0   
1   1885  if (FR: Merchant State <= 0.546) and (FR: Merc...  1Fraud  100.0   
2   1984  if (User-Merchant CSTD > 0.0) and (OH1: Swipe ...  1Fraud  100.0   
3   2880  if (OH4: high_risk <= 0.5) and (Merchant CSTD ...  1Fraud  100.0   
4   5483  if (OH1: Swipe Transaction <= 0.5) and (FR: MC...  1Fraud  100.0   
5   5717  if (OH1: Online Transaction > 0.5) and (OH2: C...  1Fraud  100.0   
6   6291  if (Merchant CSTD7 > 0.018) and (FR: Merchant ...  1Fraud  100.0   
7   6966  if (User-Merchant CM > 0.013) and (OH1: Online...  1Fraud  100.0   
8   7762  if (OH4: US <= 0.5) and (Merchant CSTD <= 0.08...  1Fraud  100.0   
9   8459  if (FR: Zip <= 0.379) and (FR: MCC <= 0.05) an...  1Fraud  100.0   

   sample  
0    1040  
1     756  
2     837  
3     750  
4     806  
5    1051  
6     966  
7     996  
8     819  
9     832  


{'RULE1': [[('OH4: high_risk', 98), '<=', 0.5],
  [('FR: Merchant State', 93), '>', 0.004],
  [('OH2: Contracted Services', 109), '>', 0.5],
  [('FR: Merchant City', 92), '>', 0.003],
  [('FR: Zipcode', 103), '>', 0.001],
  [('OH4: world_non_us', 99), '<=', 0.5]],
 'RULE2': [[('FR: Merchant State', 93), '<=', 0.546],
  [('FR: Merchant City', 92), '>', 0.003],
  [('User-Merchant CSTD3', 29), '>', 0.0],
  [('FR: MCC', 107), '>', 0.001],
  [('Merchant CC', 49), '>', 0.092],
  [('OH4: Online', 96), '>', 0.5]],
 'RULE3': [[('User-Merchant CSTD', 27), '>', 0.0],
  [('OH1: Swipe Transaction', 91), '<=', 0.5],
  [('OH2: Contracted Services', 109), '>', 0.5]],
 'RULE4': [[('OH4: high_risk', 98), '<=', 0.5],
  [('Merchant CSTD', 51), '<=', 0.193],
  [('FR: Merchant City', 92), '>', 0.003],
  [('OH2: Contracted Services', 109), '>', 0.5],
  [('FR: Zip', 94), '>', 0.009],
  [('FR: Merchant State', 93), '>', 0.006],
  [('User CM3', 20), '<=', 0.041],
  [('FR: Zipcode', 103), '>', 0.0],
  [('OH4: wo

In [None]:
import pickle

with open('{}Clause Storage/Knowledge_enhancements_large.pkl'.format(loc), 'wb') as fp:
    pickle.dump(final_dict, fp)
    print('dictionary saved successfully to file')



def filter_transactions(x_dict, conditions):
  '''
  This finds the location of the transactions in the graph that meet the conditions
  '''
  filtered_indices_list = []
  for condition in conditions:
      (column_name, column_index), operator, value = condition
      column_values = x_dict[:, column_index]

      if operator == '>':
          condition_met = column_values > float(value)
      elif operator == '<':
          condition_met = column_values < float(value)
      elif operator == '>=':
          condition_met = column_values >= float(value)
      elif operator == '<=':
          condition_met = column_values <= float(value)
      elif operator == '==':
          condition_met = column_values == float(value)
      else:
          raise ValueError(f"Invalid operator: {operator}")

      filtered_indices_list.append(condition_met)

  filtered_indices = torch.stack(filtered_indices_list, dim=1)

  # Check if all conditions are met for each row
  all_conditions_met = torch.all(filtered_indices, dim=1)
  return torch.tensor(all_conditions_met, dtype=torch.float32).view(-1, 1)

condition_train = {}
condition_valid = {}
condition_test = {}
for k, v in final_dict.items():
  condition_train[k] = filter_transactions(train_data.x_dict['transaction'], v)
  condition_valid[k] = filter_transactions(valid_data.x_dict['transaction'], v)
  condition_test[k] = filter_transactions(test_data.x_dict['transaction'], v)



with open('{}Clause Storage/train_KE_location_large.pkl'.format(loc), 'wb') as fp:
    pickle.dump(condition_train, fp)
    print('dictionary saved successfully to file')


with open('{}Clause Storage/valid_KE_location_large.pkl'.format(loc), 'wb') as fp:
    pickle.dump(condition_valid, fp)
    print('dictionary saved successfully to file')



with open('{}Clause Storage/test_KE_location_large.pkl'.format(loc), 'wb') as fp:
    pickle.dump(condition_test, fp)
    print('dictionary saved successfully to file')




dictionary saved successfully to file


  return torch.tensor(all_conditions_met, dtype=torch.float32).view(-1, 1)


dictionary saved successfully to file
dictionary saved successfully to file
dictionary saved successfully to file
