In [None]:
# Fraud Analysis Notebook
# Objective: Investigate cross-bank fraud connections involving a compromised device and user account
# Author: [Your Name]
# Date: April 28, 2025

import pandas as pd
import geoip2.database
from user_agents import parse
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import logging
import numpy as np

# Set up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# Set up paths
DATA_DIR = Path("data")
NOTEBOOK_DIR = Path("notebooks")
OUTPUT_CSV = DATA_DIR / "enriched_data.csv"
GEOLITE_DB = DATA_DIR / "GeoLite2-City.mmdb"

# Ensure directories exist
DATA_DIR.mkdir(exist_ok=True)
NOTEBOOK_DIR.mkdir(exist_ok=True)

def load_data(file_path: str) -> pd.DataFrame:
    """Load Excel data into a DataFrame."""
    try:
        df = pd.read_excel(file_path)
        logger.info(f"Loaded data from {file_path} with {len(df)} records")
        return df
    except FileNotFoundError:
        logger.error(f"File not found: {file_path}")
        return pd.DataFrame()
    except Exception as e:
        logger.error(f"Error loading data: {e}")
        return pd.DataFrame()

def enrich_ip_data(df: pd.DataFrame, ip_column: str) -> pd.DataFrame:
    """Enrich IP addresses with country, city, subnet, and timezone, handling multiple IPs."""
    if not Path(GEOLITE_DB).exists():
        logger.error(f"GeoLite2 database not found at {GEOLITE_DB}")
        return df
    
    reader = geoip2.database.Reader(GEOLITE_DB)
    
    def get_ip_info(ip_str):
        if pd.isna(ip_str):
            return pd.Series({
                "country": None,
                "city": None,
                "subnet": None,
                "timezone": None
            })
        # Handle multiple IPs
        ip_list = str(ip_str).split(",")
        # Use the first valid IP for enrichment
        for ip in ip_list:
            ip = ip.strip()
            try:
                response = reader.city(ip)
                subnet = str(response.traits.network) if response.traits.network else None
                return pd.Series({
                    "country": response.country.name,
                    "city": response.city.name,
                    "subnet": subnet,
                    "timezone": response.location.time_zone
                })
            except Exception as e:
                logger.warning(f"Failed to enrich IP {ip}: {e}")
        # Return None if all IPs fail
        return pd.Series({
            "country": None,
            "city": None,
            "subnet": None,
            "timezone": None
        })
    
    logger.info("Enriching IP data...")
    ip_info = df[ip_column].apply(get_ip_info)
    enriched_df = pd.concat([df, ip_info], axis=1)
    logger.info("IP enrichment completed")
    return enriched_df

def enrich_user_agent(df: pd.DataFrame, ua_column: str) -> pd.DataFrame:
    """Parse User-Agent strings for OS, browser, and device type."""
    def parse_ua(ua):
        if pd.isna(ua):
            return pd.Series({
                "os": None,
                "browser": None,
                "is_mobile": None
            })
        try:
            ua_obj = parse(str(ua))
            return pd.Series({
                "os": ua_obj.os.family,
                "browser": ua_obj.browser.family,
                "is_mobile": ua_obj.is_mobile
            })
        except Exception as e:
            logger.warning(f"Failed to parse User-Agent {ua}: {e}")
            return pd.Series({
                "os": None,
                "browser": None,
                "is_mobile": None
            })
    
    logger.info("Enriching User-Agent data...")
    ua_info = df[ua_column].apply(parse_ua)
    enriched_df = pd.concat([df, ua_info], axis=1)
    logger.info("User-Agent enrichment completed")
    return enriched_df

def find_related_accounts(df: pd.DataFrame, compromised_device: str, compromised_user: str) -> dict:
    """Find accounts linked to compromised device/user via device parameters or location."""
    related = {
        "by_device_id": [],
        "by_device_fingerprint": [],
        "by_ip": [],
        "by_location": []
    }
    
    # Accounts sharing the compromised device
    device_accounts = df[df["device_id"] == compromised_device]["identity"].unique()
    related["by_device_id"] = [acc for acc in device_accounts if acc != "-"]
    logger.info(f"Found {len(related['by_device_id'])} accounts linked by device ID")
    
    # Accounts with same device fingerprint
    try:
        device_fingerprint = df[df["device_id"] == compromised_device]["device_fingerprint"].iloc[0]
        fingerprint_accounts = df[df["device_fingerprint"] == device_fingerprint]["identity"].unique()
        related["by_device_fingerprint"] = [acc for acc in fingerprint_accounts if acc != "-"]
        logger.info(f"Found {len(related['by_device_fingerprint'])} accounts linked by device fingerprint")
    except IndexError:
        logger.warning("No device fingerprint found for compromised device")
    
    # Accounts sharing IPs
    compromised_ips = df[df["identity"] == compromised_user]["ips"].str.split(",", expand=True).stack().str.strip().unique()
    ip_accounts = df[df["ips"].str.contains("|".join(compromised_ips), na=False)]["identity"].unique()
    related["by_ip"] = [acc for acc in ip_accounts if acc != "-"]
    logger.info(f"Found {len(related['by_ip'])} accounts linked by IP")
    
    # Accounts in same city
    try:
        compromised_city = df[df["identity"] == compromised_user]["city"].iloc[0]
        if pd.notna(compromised_city):
            location_accounts = df[df["city"] == compromised_city]["identity"].unique()
            related["by_location"] = [acc for acc in location_accounts if acc != "-"]
            logger.info(f"Found {len(related['by_location'])} accounts linked by city")
        else:
            logger.warning("Compromised user's city is missing")
    except IndexError:
        logger.warning("No city data found for compromised user")
    
    return related

def create_network_visualization(df: pd.DataFrame, related_accounts: dict) -> None:
    """Create a network graph of accounts and devices."""
    G = nx.Graph()
    
    # Add nodes (filter out invalid identities)
    for device in df["device_id"].unique():
        if pd.notna(device):
            G.add_node(device, type="device")
    for account in df["identity"].unique():
        if pd.notna(account) and account != "-":
            G.add_node(account, type="account")
    
    # Add edges
    for _, row in df.iterrows():
        if pd.notna(row["device_id"]) and pd.notna(row["identity"]) and row["identity"] != "-":
            G.add_edge(row["device_id"], row["identity"])
    
    # Highlight compromised nodes
    compromised_nodes = [COMPROMISED_DEVICE, COMPROMISED_USER]
    node_colors = ["red" if node in compromised_nodes else "blue" for node in G.nodes]
    node_sizes = [1000 if node in compromised_nodes else 500 for node in G.nodes]
    
    plt.figure(figsize=(12, 8))
    pos = nx.spring_layout(G, seed=42)
    nx.draw(
        G,
        pos,
        with_labels=True,
        node_color=node_colors,
        node_size=node_sizes,
        font_size=8,
        edge_color="gray"
    )
    plt.title("Network of Devices and User Accounts (Red: Compromised)")
    plt.savefig(NOTEBOOK_DIR / "network_graph.png", dpi=300, bbox_inches="tight")
    plt.show()
    logger.info("Network visualization saved")

def create_eda_visualizations(df: pd.DataFrame) -> None:
    """Generate exploratory data analysis visualizations with counts above bars."""
    # Bank distribution
    plt.figure(figsize=(10, 6))
    ax = sns.countplot(data=df, x="bank")
    plt.title("Distribution of Records by Bank")
    plt.xticks(rotation=45)
    # Add counts above bars
    for p in ax.patches:
        ax.text(
            p.get_x() + p.get_width() / 2.,
            p.get_height() + 0.5,
            f"{int(p.get_height())}",
            ha="center",
            va="bottom"
        )
    plt.savefig(NOTEBOOK_DIR / "bank_distribution.png", dpi=300, bbox_inches="tight")
    plt.show()
    
    # Country distribution
    plt.figure(figsize=(10, 6))
    ax = sns.countplot(data=df, x="country")
    plt.title("Distribution of Records by Country")
    plt.xticks(rotation=45)
    # Add counts above bars
    for p in ax.patches:
        ax.text(
            p.get_x() + p.get_width() / 2.,
            p.get_height() + 0.5,
            f"{int(p.get_height())}",
            ha="center",
            va="bottom"
        )
    plt.savefig(NOTEBOOK_DIR / "country_distribution.png", dpi=300, bbox_inches="tight")
    plt.show()
    
    # Device type distribution
    plt.figure(figsize=(10, 6))
    ax = sns.countplot(data=df, x="is_mobile")
    plt.title("Distribution of Device Types")
    plt.xticks([0, 1], ["Desktop", "Mobile"])
    # Add counts above bars
    for p in ax.patches:
        ax.text(
            p.get_x() + p.get_width() / 2.,
            p.get_height() + 0.5,
            f"{int(p.get_height())}",
            ha="center",
            va="bottom"
        )
    plt.savefig(NOTEBOOK_DIR / "device_type_distribution.png", dpi=300, bbox_inches="tight")
    plt.show()
    
    # Missing data analysis
    missing_data = df.isnull().sum()
    logger.info(f"Missing data:\n{missing_data}")
    
    # City data check
    city_counts = df["city"].value_counts(dropna=False)
    logger.info(f"City distribution:\n{city_counts}")

# Main analysis
COMPROMISED_DEVICE = "91b12379-8098-457f-a2ad-a94d767797c2"
COMPROMISED_USER = "0007f265568f1abc1da791e852877df2047b3af9"

# Step 1: Load data
df = load_data(DATA_DIR / "test.xlsx")
if df.empty:
    logger.error("Failed to load data. Exiting.")
    raise SystemExit("Failed to load data.")

# Step 2: Enrich data
df = enrich_ip_data(df, "ips")
df = enrich_user_agent(df, "browser")

# Step 3: Find related accounts
related_accounts = find_related_accounts(df, COMPROMISED_DEVICE, COMPROMISED_USER)
print("Related Accounts:")
for key, accounts in related_accounts.items():
    print(f"{key}: {accounts}")

# Step 4: Visualizations
create_network_visualization(df, related_accounts)
create_eda_visualizations(df)

# Step 5: Save enriched data
df.to_csv(OUTPUT_CSV, index=False)
logger.info(f"Enriched data saved to {OUTPUT_CSV}")