In [1]:
import numpy as np
import pandas as pd 
%config InlineBackend.figure_format = 'retina'
%matplotlib widget

ds = pd.read_csv("../data/raw/dairy_dataset.csv")
ds.head()

Unnamed: 0,Location,Total Land Area (acres),Number of Cows,Farm Size,Date,Product ID,Product Name,Brand,Quantity (liters/kg),Price per Unit,...,Production Date,Expiration Date,Quantity Sold (liters/kg),Price per Unit (sold),Approx. Total Revenue(INR),Customer Location,Sales Channel,Quantity in Stock (liters/kg),Minimum Stock Threshold (liters/kg),Reorder Quantity (liters/kg)
0,Telangana,310.84,96,Medium,2022-02-17,5,Ice Cream,Dodla Dairy,222.4,85.72,...,2021-12-27,2022-01-21,7,82.24,575.68,Madhya Pradesh,Wholesale,215,19.55,64.03
1,Uttar Pradesh,19.19,44,Large,2021-12-01,1,Milk,Amul,687.48,42.61,...,2021-10-03,2021-10-25,558,39.24,21895.92,Kerala,Wholesale,129,43.17,181.1
2,Tamil Nadu,581.69,24,Medium,2022-02-28,4,Yogurt,Dodla Dairy,503.48,36.5,...,2022-01-14,2022-02-13,256,33.81,8655.36,Madhya Pradesh,Online,247,15.1,140.83
3,Telangana,908.0,89,Small,2019-06-09,3,Cheese,Britannia Industries,823.36,26.52,...,2019-05-15,2019-07-26,601,28.92,17380.92,Rajasthan,Online,222,74.5,57.68
4,Maharashtra,861.95,21,Medium,2020-12-14,8,Buttermilk,Mother Dairy,147.77,83.85,...,2020-10-17,2020-10-28,145,83.07,12045.15,Jharkhand,Retail,2,76.02,33.4


In [2]:
# find unique values in the "Product Name" column
unique_products = np.append(["All"], ds["Product Name"].unique())
unique_brands = np.append(["All"], ds["Brand"].unique())

In [3]:
import ipywidgets as widgets
from ipywidgets import interact
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

# Create a function to plot the heatmap based on selected product
def plot_product_distribution(product_name="All", brand_name="All", metric="Count", normalize=False):
    """
    Plot a heatmap showing the distribution of the selected product from farm locations to customer locations.
    
    Parameters:
    -----------
    product_name : str
        The name of the product to visualize
    brand_name: str
        The name of the brand to visualize
    metric : str
        The metric to use for the heatmap values (default: Quantity)
    normalize : bool
        Whether to normalize the values by row (farm location)
    """
    # Filter the dataset for the selected product
    ds_subset = ds.copy()
    if product_name != "All":
        ds_subset = ds_subset[ds_subset['Product Name'] == product_name]
    
    if brand_name != "All":
        ds_subset = ds_subset[ds_subset['Brand'] == brand_name]
    
    
    if ds_subset.empty:
        print(f"No data available for 'product_name = {product_name}' and 'brand = {brand_name}'")
        return
    
    # Create a pivot table with Location as rows, Customer Location as columns, and sum of the selected metric as values
    if metric == "Count":
        pivot_matrix = pd.crosstab(ds_subset["Location"], ds_subset["Customer Location"])
    else:
        pivot_matrix = pd.pivot_table(
            ds_subset, 
            values=metric, 
            index='Location', 
            columns='Customer Location', 
            aggfunc='sum',
            fill_value=0  # Replace NaN with 0
        )
    
    # Normalize by row if requested
    if normalize:
        row_sums = pivot_matrix.sum(axis=1)
        pivot_matrix = pivot_matrix.div(row_sums, axis=0) * 100
        cbar_label = f'{metric} (% of total)'
    else:
        cbar_label = metric
    
    # Create a heatmap using seaborn
    fig = plt.figure(figsize=(10, 8.33))
        
    if normalize:
        fmt = '.1f'  # One decimal place for percentages
    elif metric == "Count":
        fmt = "d"
    else:
        fmt = '.1f' if pivot_matrix.max().max() < 1000 else '.0f'
    
    
    # Create the heatmap
    sns.heatmap(
        pivot_matrix, 
        cmap="YlGnBu", 
        annot=True, 
        fmt=fmt, 
        cbar_kws={'label': cbar_label},
        annot_kws={'size': 8}
    )
    
    title = f'Distribution Matrix of {brand_name}\'s {product_name}'
    plt.title(title)
    plt.xlabel('Customer Location')
    plt.ylabel('Farm Location')
    plt.tight_layout()
    plt.show()
    
    # Print some statistics about the product
    if metric != "Count":
        print(f"\nStatistics for {brand_name}\'s {product_name}:")
        print(f"Total quantity: {ds_subset[metric].sum():.1f}")
        print(f"Average quantity per transaction: {ds_subset[metric].mean():.1f}")
        print(f"Number of transactions: {len(ds_subset)}")
        
        # Top farm locations for this product
        top_farms = ds_subset.groupby('Location')[metric].sum().sort_values(ascending=False).head(3)
        print(f"\nTop 3 farm locations by {metric}:")
        for loc, val in top_farms.items():
            print(f"- {loc}: {val:.1f}")
        
        # Top customer locations for this product
        top_customers = ds_subset.groupby('Customer Location')[metric].sum().sort_values(ascending=False).head(3)
        print(f"\nTop 3 customer locations by {metric}:")
        for loc, val in top_customers.items():
            print(f"- {loc}: {val:.1f}")
    
    # Return the figure and title for saving later
    return fig, title
            
# Function to save the figure
def save_figure(fig, title):
    if fig:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        # Create a filename based on the title
        filename = f"../data/media/{title.replace(' ', '_').replace(':', '-')}_{timestamp}.png"
        fig.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Figure saved as {filename}")
    else:
        print("No figure to save. Please generate a plot first.")

# Create interactive widgets
def interactive_heatmap():
    # Create dropdown for product selection
    product_dropdown = widgets.Dropdown(
        options=sorted(unique_products),
        description='Product:',
        value=unique_products[0],
        style={'description_width': 'initial'}
    )
    
    brand_dropdown = widgets.Dropdown(
        options=sorted(unique_brands),
        description='Brand:',
        value=unique_brands[0],
        style={'description_width': 'initial'}
    )
    
    # Create dropdown for metric selection
    metric_dropdown = widgets.Dropdown(
        options=['Count', 'Quantity (liters/kg)', 'Quantity Sold (liters/kg)', 'Approx. Total Revenue(INR)'],
        description='Metric:',
        value='Quantity (liters/kg)',
        style={'description_width': 'initial'}
    )
    
    # Create checkbox for normalization
    normalize_checkbox = widgets.Checkbox(
        value=False,
        description='Normalize by farm location (%)',
        style={'description_width': 'initial'}
    )
    
    # Create a save button
    save_button = widgets.Button(
        description='Save Figure',
        button_style='success',
        icon='save'
    )
    
    # Variable to store the current figure
    current_fig, current_title = None, None
    
    # Define the function that will run when the plot is updated
    def update_plot(product_name, brand_name, metric, normalize):
        nonlocal current_fig, current_title
        current_fig, current_title = plot_product_distribution(product_name, brand_name, metric, normalize)
    
    # Define the function that will run when the save button is clicked
    def on_save_button_clicked(b):
        save_figure(current_fig, current_title)
    
    # Attach the click handler to the save button
    save_button.on_click(on_save_button_clicked)
    
    # Use interact to create the interactive visualization
    out = widgets.interactive_output(
        update_plot, 
        {
            'product_name': product_dropdown,
            'brand_name': brand_dropdown,
            'metric': metric_dropdown,
            'normalize': normalize_checkbox
        }
    )
    
    # Display the widgets and output
    display(widgets.VBox([
        widgets.HBox([product_dropdown, brand_dropdown]),
        widgets.HBox([metric_dropdown, normalize_checkbox]),
        save_button,
        out
    ]))

# Run the interactive visualization
print("Dairy Product Distribution Analysis")
print("==================================")
print("Select a product from the dropdown to visualize its distribution from farm to customer locations.")
interactive_heatmap()

Dairy Product Distribution Analysis
Select a product from the dropdown to visualize its distribution from farm to customer locations.


VBox(children=(HBox(children=(Dropdown(description='Product:', options=('All', 'Butter', 'Buttermilk', 'Cheese…