In [None]:
# 05 – Visualizations & Exploratory Analysis
#
# Steps:
# - Visualise stock trends & correlations
# - Plot fraud class imbalance
# - Visualise fraud correlations
# - Show Random Forest feature importances
# - Analyse time-of-day fraud risk

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import joblib

from src.fraud_utils import get_feature_importances

# -------------------------------------------------------------------
# Config
# -------------------------------------------------------------------
DATA_DIR = os.path.join("..", "data")
RESULTS_DIR = os.path.join("..", "results")
processed_dir = os.path.join(DATA_DIR, "processed")

sns.set(style="whitegrid")

# -------------------------------------------------------------------
# Load Cleaned Datasets
# -------------------------------------------------------------------
stocks_clean = pd.read_csv(
    os.path.join(DATA_DIR, "cleaned_stocks.csv"),
    parse_dates=["date"],
)

fraud_clean = pd.read_csv(os.path.join(DATA_DIR, "cleaned_fraud_dataset.csv"))

print("Stocks shape:", stocks_clean.shape)
print("Fraud shape:", fraud_clean.shape)
display(stocks_clean.head())
display(fraud_clean.head())

# -------------------------------------------------------------------
# Stock Price Trend Chart
# -------------------------------------------------------------------
symbol = stocks_clean["symbol"].unique()[0]
symbol_df = stocks_clean[stocks_clean["symbol"] == symbol].copy()
symbol_df = symbol_df.sort_values("date")

plt.figure(figsize=(10, 4))
plt.plot(symbol_df["date"], symbol_df["close"])
plt.title(f"{symbol} – Closing Price Over Time")
plt.xlabel("Date")
plt.ylabel("Price")
plt.tight_layout()
plt.show()

# -------------------------------------------------------------------
# Stock Return Correlation Heatmap
# -------------------------------------------------------------------
pivot_close = stocks_clean.pivot_table(
    index="date",
    columns="symbol",
    values="close",
)

returns = pivot_close.pct_change().dropna()
corr_matrix = returns.corr()

plt.figure(figsize=(8, 6))
sns.heatmap(corr_matrix, annot=True, fmt=".2f", cmap="coolwarm", square=True)
plt.title("Correlation of Daily Returns Across Symbols")
plt.tight_layout()
plt.show()

# -------------------------------------------------------------------
# Fraud Class Distribution
# -------------------------------------------------------------------
plt.figure(figsize=(5, 4))
fraud_counts = fraud_clean["is_fraud"].value_counts().sort_index()
sns.barplot(x=fraud_counts.index, y=fraud_counts.values)
plt.xticks([0, 1], ["Not Fraud", "Fraud"])
plt.ylabel("Count")
plt.title("Fraud vs Non-Fraud Transaction Counts")
plt.tight_layout()
plt.show()

print("\nFraud class counts:")
print(fraud_counts)

# -------------------------------------------------------------------
# Fraud Numerical Correlation Heatmap
# -------------------------------------------------------------------
numeric_cols = fraud_clean.select_dtypes(include=[np.number]).columns.tolist()

plt.figure(figsize=(10, 8))
corr = fraud_clean[numeric_cols].corr()
sns.heatmap(corr, cmap="coolwarm", center=0)
plt.title("Correlation Heatmap – Fraud Dataset (Numeric Features)")
plt.tight_layout()
plt.show()

# -------------------------------------------------------------------
# Feature Importances – Random Forest
# -------------------------------------------------------------------
rf_model_path = os.path.join(processed_dir, "rf_model.joblib")

if os.path.exists(rf_model_path):
    rf_model = joblib.load(rf_model_path)
    n_features = rf_model.n_features_in_
    feature_names = [f"feature_{i}" for i in range(n_features)]

    importance_df = get_feature_importances(rf_model, feature_names).head(20)

    plt.figure(figsize=(8, 6))
    sns.barplot(
        data=importance_df,
        x="importance",
        y="feature",
    )
    plt.title("Top 20 Feature Importances – Random Forest Fraud Model")
    plt.tight_layout()
    plt.show()

    print("\nTop feature importances:")
    display(importance_df)
else:
    print("\nRandom Forest model file not found at:", rf_model_path)
    print("Run 04_fraud_detection_model first to generate and save it.")

# -------------------------------------------------------------------
# Time-of-Day Fraud Risk
# -------------------------------------------------------------------
df = fraud_clean.copy()

if "transaction_datetime" in df.columns:
    df["transaction_datetime"] = pd.to_datetime(df["transaction_datetime"])
else:
    if "transaction_date" in df.columns:
        df["transaction_datetime"] = pd.to_datetime(df["transaction_date"])
    else:
        df["transaction_datetime"] = pd.date_range(
            start="2020-01-01", periods=len(df), freq="T"
        )

df["hour"] = df["transaction_datetime"].dt.hour

hourly = df.groupby("hour")["is_fraud"].mean().reset_index()
hourly["fraud_rate"] = hourly["is_fraud"]

plt.figure(figsize=(8, 4))
sns.lineplot(data=hourly, x="hour", y="fraud_rate", marker="o")
plt.title("Fraud Rate by Hour of Day")
plt.xlabel("Hour (0–23)")
plt.ylabel("Fraud Rate")
plt.tight_layout()
plt.show()

print("\nHourly fraud rate sample:")
display(hourly.head())

print("\n=== 05_visualizations completed ===")
