In [None]:
# 156761 Python Scripts for Extraction of Data to Fit the NBD-Dirichlet Model
# (c) Malcolm Wright
# Use at your own risk

# IF YOU USE A DIFFERENT FILE, BE SURE TO UPDATE THE VARIABLE NAMES IN THE SCRIPTS

In [None]:
# Import packages and set styles (includes plotting in case you want to play around)

import pandas as pd
from IPython.display import display
pd.set_option('display.max_rows', 50) 

In [None]:
# Import selected file - change the filename as needed

input_file_name = 'dh_pasta.csv'
df = pd.read_csv(input_file_name)

In [None]:
# Inspect the dataframe

df

In [None]:
# Explore the data

df.info()
duplicates = df.duplicated().sum()  #计算并返回DataFrame中重复行的总数
missing = df.isnull().sum()         #计算并返回每一列中缺失值（空值）的总数

print('\nDuplicates',duplicates)
print('\nMissing\n\n',missing)

In [None]:
def get_available_quarters(df): #这个函数的作用是从传入的DataFrame中提取出 Quarter 列的所有唯一值，并将其返回，以便查看数据集中有哪些季度信息
    """Display available quarters in the dataset."""
    return df['Quarter'].unique()

#这个函数的作用是从数据集中获取可用的季度列表，提示用户选择一个季度，并检查用户的选择是否有效。如果选择有效，则返回该季度；如果无效，则抛出错误
def select_quarter(df):
    """Prompt the user to select a quarter from the available quarters."""
    available_quarters = get_available_quarters(df)
    print(f"Available Quarters: {available_quarters}")
    selected_quarter = int(input("Please select a Quarter: "))
    
    if selected_quarter not in available_quarters:
        raise ValueError(f"Selected Quarter {selected_quarter} is not available.")
        
    return selected_quarter

#这个函数的作用是提示用户输入一个家庭购买者数量的阈值，然后返回该值。这个阈值用于将购买者数量少于该值的品牌归类为“其他”
def select_household_threshold():
    """Prompt the user to input the threshold number of household buyers for the 'Other' category."""
    threshold = int(input("Please enter the household threshold below which brands will be aggregated into 'Other': "))
    return threshold

#这个函数的主要作用是根据用户选择的季度和家庭购买者阈值，计算每个品牌的各种统计数据，包括客户数量、篮子数量、库存代码数量、总数量和总收入。
#它还将小品牌归为“其他”类别，并计算所有品牌的总体统计数据。最终返回包含所有这些统计结果的DataFrame
def calculate_brand_statistics(df, selected_quarter, household_threshold):
    """Calculate statistics for each brand, including 'All' and 'Other' categories."""
    # Filter the DataFrame for the selected quarter
    df_quarter = df[df['Quarter'] == selected_quarter]

    # Determine which brands are small based on the household threshold
    brand_household_count = df_quarter.groupby('brand')['household'].nunique()
    small_brands = brand_household_count[brand_household_count < household_threshold].index.tolist()

    # Create an empty DataFrame to store results
    results_df = pd.DataFrame()

    # Process each brand, omitting the small brands
    brand = df_quarter['brand'].unique()

    for brands in brand: 
        if brands in small_brands:
            continue  # Skip processing for small brands
        
        print(f"Processing brand: {brands}")
        
        # Filter the data for the current brand
        df_brand = df_quarter[df_quarter['brand'] == brands]
        
        # Calculate the required statistics for the brand
        quarterly_totals = df_brand.groupby('Quarter').agg(
            Total_Num_Customers=('household', 'nunique'),
            Total_Num_Baskets=('basket', 'nunique'),
            Total_Num_Stockcodes=('upc', 'nunique'),
            Total_Quantity=('units', 'sum'),
            Total_Revenue=('dollar_sales', 'sum')
        ).reset_index()
        
        # Add a column for the brand name
        quarterly_totals['brand'] = brands
        
        # Append the results to the results_df DataFrame
        results_df = pd.concat([results_df, quarterly_totals], ignore_index=True)

    # Calculate overall category statistics (All brands combined)
    overall_totals = df_quarter.groupby('Quarter').agg(
        Total_Num_Customers=('household', 'nunique'),
        Total_Num_Baskets=('basket', 'nunique'),
        Total_Num_Stockcodes=('upc', 'nunique'),
        Total_Quantity=('units', 'sum'),
        Total_Revenue=('dollar_sales', 'sum')
    ).reset_index()

    # Add a column for the brand, labeling it as 'All'
    overall_totals['brand'] = 'All'

    # Append the overall category results to the results_df DataFrame
    results_df = pd.concat([results_df, overall_totals], ignore_index=True)

    # Aggregate the small brands into an 'Other' category
    df_other = df_quarter[df_quarter['brand'].isin(small_brands)]
    other_totals = df_other.groupby('Quarter').agg(
        Total_Num_Customers=('household', 'nunique'),
        Total_Num_Baskets=('basket', 'nunique'),
        Total_Num_Stockcodes=('upc', 'nunique'),
        Total_Quantity=('units', 'sum'),
        Total_Revenue=('dollar_sales', 'sum')
    ).reset_index()

    # Label the 'Other' category
    other_totals['brand'] = 'Other'

    # Append the 'Other' category to the results_df DataFrame
    results_df = pd.concat([results_df, other_totals], ignore_index=True)

    return results_df

#这个脚本从用户那里获取输入（选择季度和家庭购买者阈值），计算品牌的统计数据，将结果保存为CSV文件，并显示相关信息以供用户确认
def save_to_csv(df, filename):
    """Save the DataFrame to a CSV file."""
    df.to_csv(filename, index=False)
    print(f"Results saved to {filename}")

# Example of how to use these functions
if __name__ == "__main__":

    # User input for quarter selection
    selected_quarter = select_quarter(df)

    # User input for household threshold
    household_threshold = select_household_threshold()

    # Calculate brand statistics
    results_df = calculate_brand_statistics(df, selected_quarter, household_threshold)

    # Save the results to a CSV file
    save_to_csv(results_df, 'brand_statistics_quarter_selected_with_overall_and_other.csv')
    
    # Establish the number of category buyers / customers

    all_customers = (df['household']).nunique()
    print("Total number of unique households buying: ", all_customers)

    # Display the first few rows of the results for confirmation
    display(results_df)