In [None]:
# %% [markdown]
# # Customer Segmentation Project
# ## Notebook 01: Data Cleaning
#
# This notebook loads and cleans the raw transaction data for customer segmentation analysis.

In [None]:
# %% [markdown]
# ### 1. Import Libraries

In [None]:
# %%
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

In [None]:
# Import project modules
import sys
sys.path.append('../src')
from utils import load_data, clean_data, plot_distribution, save_plot

In [None]:
# Set display options
pd.set_option('display.max_columns', None)
pd.set_option('display.width', 1000)
plt.style.use('seaborn-v0_8-darkgrid')

In [None]:
# %% [markdown]
# ### 2. Load Raw Data

In [None]:
# %%
df_raw = load_data('../raw_data.csv')

In [None]:
# Display basic information
print("\n=== Dataset Information ===")
print(f"Shape: {df_raw.shape}")
print(f"\nColumns: {list(df_raw.columns)}")
print(f"\nData Types:\n{df_raw.dtypes}")
print(f"\nFirst 5 rows:")
display(df_raw.head())

In [None]:
# %% [markdown]
# ### 3. Explore Data Structure

In [None]:
# %%
print("=== Basic Statistics ===")
print(f"\nNumber of unique customers: {df_raw['customer_id'].nunique()}")
print(f"Date range: {df_raw['transaction_date'].min()} to {df_raw['transaction_date'].max()}")
print(f"Number of unique product categories: {df_raw['product_category'].nunique()}")
print(f"Product categories: {df_raw['product_category'].unique().tolist()}")
print(f"Sales channels: {df_raw['channel'].unique().tolist()}")

In [None]:
# %% [markdown]
# ### 4. Check Data Quality

In [None]:
# %%
print("=== Data Quality Check ===")

In [None]:
# Missing values
missing = df_raw.isnull().sum()
print(f"\nMissing values per column:")
print(missing[missing > 0] if missing.sum() > 0 else "No missing values found")

In [None]:
# Duplicate rows
duplicates = df_raw.duplicated().sum()
print(f"\nDuplicate rows: {duplicates}")

In [None]:
# Invalid ranges
print(f"\nTransaction value range: ${df_raw['order_value'].min():.2f} to ${df_raw['order_value'].max():.2f}")
print(f"Quantity range: {df_raw['quantity'].min()} to {df_raw['quantity'].max()}")

In [None]:
# Negative or zero values
negative_values = df_raw[df_raw['order_value'] <= 0]
print(f"\nRows with non-positive order value: {len(negative_values)}")

In [None]:
# %% [markdown]
# ### 5. Data Cleaning

In [None]:
# %%
df_clean = clean_data(df_raw)

In [None]:
# Verify data types
print("\n=== Cleaned Data Types ===")
print(df_clean.dtypes)

In [None]:
# %% [markdown]
# ### 6. Exploratory Data Analysis

In [None]:
# %%
fig1 = plot_distribution(
    df_clean,
    'order_value',
    'Distribution of Order Values',
    bins=30,
    figsize=(10, 6)
)
plt.show()

In [None]:
fig2 = plot_distribution(
    df_clean,
    'quantity',
    'Distribution of Quantities',
    bins=20,
    figsize=(10, 6)
)
plt.show()

In [None]:
plt.figure(figsize=(12, 6))
df_clean['transaction_date'].dt.to_period('M').value_counts().sort_index().plot(kind='bar')
plt.title('Transactions Over Time (Monthly)', fontsize=14, fontweight='bold')
plt.xlabel('Month')
plt.ylabel('Number of Transactions')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(10, 6))

category_stats = df_clean.groupby('product_category').agg({
    'order_value': ['count', 'sum', 'mean']
}).round(2)

category_stats.columns = ['Transaction Count', 'Total Revenue', 'Average Order Value']
category_stats = category_stats.sort_values('Total Revenue', ascending=False)

ax = category_stats['Total Revenue'].plot(kind='bar', color='skyblue', edgecolor='black')

plt.title('Revenue by Product Category', fontsize=14, fontweight='bold')
plt.xlabel('Product Category')
plt.ylabel('Total Revenue ($)')
plt.xticks(rotation=45)
plt.tight_layout()

for i, v in enumerate(category_stats['Total Revenue']):
    ax.text(i, v + 100, f'${v:,.0f}', ha='center', va='bottom', fontsize=9)

plt.show()

In [None]:
plt.figure(figsize=(8, 6))

channel_stats = df_clean.groupby('channel').agg({
    'order_value': ['count', 'sum', 'mean']
}).round(2)

channel_stats.columns = ['Transaction Count', 'Total Revenue', 'Average Order Value']
channel_stats = channel_stats.sort_values('Total Revenue', ascending=False)

ax = channel_stats['Transaction Count'].plot(kind='bar', color='lightgreen', edgecolor='black')

plt.title('Transactions by Channel', fontsize=14, fontweight='bold')
plt.xlabel('Channel')
plt.ylabel('Number of Transactions')
plt.xticks(rotation=45)
plt.tight_layout()

for i, v in enumerate(channel_stats['Transaction Count']):
    ax.text(i, v + 0.5, str(v), ha='center', va='bottom', fontsize=10)

plt.show()