# Setup

In [None]:
import pickle
import matplotlib.pyplot as plt
import sys
import os
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import dataset
import torch
from pprint import pprint
import pandas as pd
import seaborn as sns
import sklearn
from  sklearn.manifold import TSNE
from  sklearn.decomposition import PCA
from  sklearn.preprocessing import KBinsDiscretizer
from  sklearn.preprocessing import MinMaxScaler
from  sklearn.preprocessing import RobustScaler
from  sklearn.preprocessing import StandardScaler
from  sklearn.preprocessing import QuantileTransformer
import numpy as np
import math

In [None]:
pd.option_context('display.max_rows', None, 'display.max_columns', None,'display.max_colwidth', -1)

In [3]:
data_type = "kdd"
name_suffix = ""

In [4]:
if data_type == "kdd":
    with open(f"./kdd/KDDDataset_ft{name_suffix}.pkl", "rb") as f:
        dataset = pickle.load(f)

# Full data

In [None]:
# FYI: Accounts order when the dataset samples were preprared
accounts = []
for group in dataset.data.groupby("account_id"):
    account_name = group[0]
    accounts.append(account_name)
print(accounts)

In [None]:
# FYI: raw initial data
raw_data = pd.read_csv("./" + dataset.data_root[8:])
raw_data["date_raw"] = "19" + raw_data["Year"].astype(str) + "-" + raw_data["Month"].astype(str) + "-" + raw_data["Day"].astype(str)
raw_data["date_raw"] = pd.to_datetime(raw_data["date_raw"], format="%Y-%m-%d")
raw_data['weekday'] = raw_data['date_raw'].dt.dayofweek
raw_data = raw_data.sort_values(by=["date_raw"])
raw_data

In [None]:
raw_data.columns

In [None]:
dataset.data

In [None]:
preproc_data = dataset.data # Do NOT sort again (keep original timestamp sort)
preprocessed_data = []
for group, data in preproc_data.groupby("account_id"):
    preprocessed_data.append(data)
preprocessed_data = pd.concat(preprocessed_data)
preprocessed_data

In [None]:
samples = dataset.samples
print(len(samples))

targets = dataset.targets
print(len(targets))

print(preprocessed_data["account_id"].drop_duplicates().shape[0])

In [None]:
vocab_keys = list(dataset.vocab.token2id.keys())
for k in vocab_keys:
    print(f"\n--{k}--")
    # pprint(dataset.vocab.token2id[k])
    print(len(dataset.vocab.token2id[k]))

In [None]:
dataset.vocab.token2id["SPECIAL"]

# Sample

In [None]:
raw_data["account_id"].drop_duplicates().sort_values()

In [15]:
sample_id = 1
stride = 150
final_id = sample_id

# If pre-training dataset
raw_sample = raw_data[raw_data["account_id"] == final_id + 1][:stride]
preprocessed_sample = preprocessed_data[preprocessed_data["account_id"] == final_id + 1][:stride]

# If fine-tuning dataset, check the offset depending on the account_ids included in the dataset (cf. cell above)
offset = 18
raw_sample = raw_data[raw_data["account_id"] == final_id + offset][:stride]
preprocessed_sample = preprocessed_data[preprocessed_data["account_id"] == final_id + offset][:stride]

pytorch_sample = torch.tensor(samples[sample_id]).reshape(-1, dataset.ncols) # Not always the same seq_len
pytorch_target = torch.tensor(targets[sample_id])

In [None]:
raw_sample 

In [None]:
preprocessed_sample

In [None]:
pytorch_sample

In [None]:
pytorch_target

In [20]:
# Note: Need padding
# dataset.samples[0].shape
# dataset.samples[758].shape

# Other note:
# The notebook is very slow. Is it because the pickled dataset is heavy?

# Explore (fine tuning dataset)

In [None]:
# if finetuning dataset: check class imbalance
print(raw_data[["account_id", "status"]].drop_duplicates().groupby("status").count())
print(preprocessed_data[["account_id", "status"]].drop_duplicates().groupby("status").count())
print(torch.unique(torch.tensor(dataset.targets), return_counts=True))
print(76/(606+76))

In [None]:
# Number of transactions and Evolution of fraud over the years

raw_data_ym = raw_data.copy(deep=True)
raw_data_ym["year_month"] = raw_data_ym["Year"].astype(str) + "-" + raw_data_ym["Month"].astype(str)
fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(12, 6))
sns.countplot(
    ax=ax1,
    x='year_month', 
    data=raw_data_ym,
)
sns.lineplot(
    ax=ax2,
    x='year_month', 
    y='status', 
    data=raw_data_ym,
    estimator="mean", # Aggregate the frauds by avging them
)

ax1.tick_params(axis='x', labelrotation=70)
ax2.tick_params(axis='x', labelrotation=70)
plt.tight_layout()

# Take aways: 
# 2x more transactions in 1996 and 1997
# Always more transactions in January
# Fraud slowly decreases over time
# No obvious seasonality

In [None]:
# Evolution of fraud by day of the month

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(12, 3))
sns.lineplot(
    ax=ax,
    x='Day', 
    y='status',
    estimator="mean", 
    data=raw_data,
)
ax.legend(loc='upper left', ncol=6)
ax.tick_params(axis='x', labelrotation=70)
plt.tight_layout()

# Take aways: 
# - day of the month matters
# - more fraud in the second half of the month

In [None]:
# Evolution of fraud by weekday

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(12, 3))
sns.barplot(
    ax=ax,
    x='status', 
    hue='weekday',
    data=raw_data,
    estimator="mean",
)
ax.legend(loc='upper left', ncol=7)
ax.tick_params(axis='x', labelrotation=0)
plt.tight_layout()

# Take aways: 
# - day of the week matters a little bit

In [None]:
# Numerical variables
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(12, 6))

sns.scatterplot(
    ax=ax1,
    data=raw_data[raw_data["status"] == 0],
    y="amount_trans", 
    x='balance',
    hue="status",
    alpha=0.25,
    size=1,
)
sns.scatterplot(
    ax=ax2,
    data=raw_data[raw_data["status"] == 1],
    y="amount_trans", 
    x='balance',
    alpha=0.25,
    hue="status",
    size=1,
)

# Take away:
# - Numerical features alone do not seem indicative of Fraud
# - transaction amount correlates positively with account balance (not surprising)

In [None]:
# categorical variable data exploration
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, figsize=(12, 4))
sns.countplot(
    ax=ax1,
    x='k_symbol', 
    hue='status', 
    data=raw_data,
)
sns.countplot(
    ax=ax2,
    x='operation', 
    hue='status', 
    data=raw_data,
)
sns.countplot(
    ax=ax3,
    x='type_trans', 
    hue='status', 
    data=raw_data,
)

# Take-away:
# More type_trans == 2 for Frauds