# Quantity Discount (QD) Handler Module

Handles calculation, creation, activation, and deactivation of Quantity Discounts.

## Usage
This module is called from `module_3_periodic_actions.ipynb` with a DataFrame.

```python
%run qd_handler.ipynb

# Process QD with DataFrame from Module 3
result = process_qd(df_qd, dry_run=True)
```

## Input Requirements (DataFrame columns from Module 3)
**Identifiers:**
- `product_id`, `warehouse_id`, `cohort_id`, `sku`, `brand`, `cat`

**Pricing Data:**
- `wac_p` - Weighted Average Cost (per basic unit)
- `current_price` - Current price (per basic unit)
- `target_margin`, `min_boundary`

**Cart Rules:**
- `current_cart_rule`, `new_cart_rule`

**Market Margins (will be converted to prices):**
- `below_market`, `market_min`, `market_25`, `market_50`, `market_75`, `market_max`, `above_market`

**Margin Tiers (will be converted to prices):**
- `margin_tier_1` through `margin_tier_above_2`

**QD Configuration:**
- `keep_qd_tiers` - List of tiers to keep, e.g., `['T1', 'T2']`

## Workflow
1. **Deactivate ALL existing Quantity Discounts (FIRST!)**
2. Get top-selling packing unit per product per warehouse (last 90 days)
3. Get warehouse ticket statistics for wholesale calculations
4. Calculate tier quantities (T1, T2) using historical order data
5. Calculate T1 & T2 prices using market margins and margin tiers
6. Calculate T3 (wholesale) prices based on delivery cost savings
7. Validate T3 constraints (T3 qty > T2 qty, T3 price < T2 price)
7.5. Adjust T2 qty if ratio is too low (check_qty < 1.3 and elasticity_ratio > 3)
8. Apply keep_qd_tiers filter from Module 3 & calculate tier flags
9. Select top 400 **tier entries** per warehouse (sorted by mtd_qty × effective_price)
10. Build QD configurations & upload format (Group 1: T1+WS, Group 2: T2)
11. Upload QD file to API
12. Update cart rules (cart_rule >= max tier quantity)

## Important Notes
- Uses `new_price` from Module 3 if available, otherwise falls back to `current_price`
- The 400 limit is per warehouse for **tier entries**, not SKUs (API limitation)
- If a SKU has 3 tiers, it counts as 3 towards the 400 limit
- Upload format: Group 1 = T1 + Wholesale (max 200), Group 2 = T2 + overflow
- Start time = now + 10 mins, End time = now + 12 hours (Cairo time)
- Cart rules are updated to be >= max(T1 qty, T2 qty, WS qty)

## Wholesale (T3) Logic
- Car cost savings passed to retailer for bulk orders
- Multiplier range: 3x to orders_per_car_by_weight
- Constraints:
  - Minimum margin: max(40% of current margin, 1.5%)
  - Max ticket size: 35,000 EGP
  - T3 price must be < T2 price

## Output
Returns dict with processing results, configs, and working DataFrame and QD configurations


In [None]:
# =============================================================================
# IMPORTS & CONFIGURATION
# =============================================================================
import pandas as pd
import requests
from datetime import datetime, timedelta
import pytz
import os
import ast
import json
import time
import base64
import boto3
from botocore.exceptions import ClientError
import snowflake.connector
import sys

%run queries_module.ipynb
# Add parent directory for imports
sys.path.insert(0, '..')
import setup_environment_2

# Initialize environment variables (loads Snowflake credentials)
setup_environment_2.initialize_env()

# Cairo Timezone
CAIRO_TZ = pytz.timezone('Africa/Cairo')
CAIRO_NOW = datetime.now(CAIRO_TZ)
TODAY = CAIRO_NOW.date()

# =============================================================================
# SNOWFLAKE CONNECTION
# =============================================================================
def query_snowflake(query):
    """Execute a query on Snowflake and return results as DataFrame."""
    con = snowflake.connector.connect(
        user=os.environ["SNOWFLAKE_USERNAME"],
        account=os.environ["SNOWFLAKE_ACCOUNT"],
        password=os.environ["SNOWFLAKE_PASSWORD"],
        database=os.environ["SNOWFLAKE_DATABASE"]
    )
    try:
        cur = con.cursor()
        cur.execute("USE WAREHOUSE COMPUTE_WH")
        cur.execute(query)
        data = cur.fetchall()
        columns = [desc[0].lower() for desc in cur.description]
        return pd.DataFrame(data, columns=columns)
    finally:
        con.close()

def get_snowflake_timezone():
    result = query_snowflake("SHOW PARAMETERS LIKE 'TIMEZONE'")
    return result['value'].iloc[0] if len(result) > 0 else "UTC"

TIMEZONE = get_snowflake_timezone()

# =============================================================================
# AWS & API FUNCTIONS
# =============================================================================
def get_secret(secret_name: str) -> str:
    """Retrieve a secret from AWS Secrets Manager."""
    region_name = "us-east-1"
    session = boto3.session.Session()
    client = session.client(service_name='secretsmanager', region_name=region_name)

    try:
        response = client.get_secret_value(SecretId=secret_name)
    except ClientError as e:
        print(f"AWS Error: {e}")
        raise e
    
    if 'SecretString' in response:
        return response['SecretString']
    return base64.b64decode(response['SecretBinary'])


def get_access_token(url: str, client_id: str, client_secret: str) -> str:
    """Get OAuth2 access token for MaxAB API authentication."""
    response = requests.post(
        url,
        data={
            'grant_type': 'password',
            'client_id': client_id,
            'client_secret': client_secret,
            'username': API_USERNAME,
            'password': API_PASSWORD
        }
    )
    return response.json()['access_token']


def _get_api_token() -> str:
    """Get a fresh API token for MaxAB API requests."""
    return get_access_token(
        'https://sso.maxab.info/auth/realms/maxab/protocol/openid-connect/token',
        'main-system-externals',
        API_SECRET
    )

# =============================================================================
# API CREDENTIALS INITIALIZATION
# =============================================================================
pricing_api_secret = json.loads(get_secret("prod/pricing/api/"))
API_USERNAME = pricing_api_secret["egypt_username"]
API_PASSWORD = pricing_api_secret["egypt_password"]
API_SECRET = pricing_api_secret["egypt_secret"]

# =============================================================================
# API CONFIGURATION
# =============================================================================
QD_API_URL = 'https://api.maxab.app/commerce/api/admins/v1/quantity-discounts/'

# Default QD settings
DEFAULT_QD_DURATION_HOURS = 12  # QD valid until next run

print("✓ QD Handler initialized")
print(f"  Timezone: {TIMEZONE}")


In [None]:
# =============================================================================
# QD CALCULATION CONFIGURATION
# =============================================================================
import numpy as np

# Discount bounds (same as original)
MAX_DISCOUNT_PCT = 5.0    # Maximum allowed discount from current price
MIN_DISCOUNT_PCT = 0.35   # Minimum required discount from current price

# Ratio constraints: discount_2/discount_1 should be between MIN_RATIO and MAX_RATIO times qty_2/qty_1
MIN_RATIO = 1.1  # Minimum elasticity ratio
MAX_RATIO = 3.0  # Maximum elasticity ratio

# Minimum gap between tier prices
MIN_GAP_PCT = 0.25

# =============================================================================
# WHOLESALE (TIER 3) CONFIGURATION
# =============================================================================
WS_CAR_COST = 1400           # Cost per delivery (EGP)
WS_CAR_CAPACITY_TONS = 1.8   # Max car capacity in tons
WS_MAX_TICKET_SIZE = 35000   # Maximum ticket size (EGP)
WS_MIN_MARGIN = 0.015        # Minimum margin (1.5%) above WAC

# Top SKU selection
TOP_SKUS_PER_WAREHOUSE = 400  # Number of top tier entries to select per warehouse

# =============================================================================
# UPLOAD FORMAT CONFIGURATION
# =============================================================================
MAX_GROUP_SIZE = 200         # Max items per discount group in API
MAX_DISCOUNT_CAP_T1 = 4.0    # Maximum discount for Tier 1
MAX_DISCOUNT_CAP_T2 = 5.0    # Maximum discount for Tier 2
MAX_DISCOUNT_CAP_WS = 6.0    # Maximum discount for Wholesale

# QD Duration
QD_DURATION_HOURS = 12       # QD valid for 12 hours

# =============================================================================
# WAREHOUSE TO TAG ID MAPPING
# =============================================================================
WAREHOUSE_TAG_MAPPING = {
    501: {'name': 'Assiut FC', 'tag_id': 3301},
    401: {'name': 'Bani sweif', 'tag_id': 3302},
    236: {'name': 'Barageel', 'tag_id': 3303},
    337: {'name': 'El-Mahala', 'tag_id': 3304},
    797: {'name': 'Khorshed Alex', 'tag_id': 3305},
    339: {'name': 'Mansoura FC', 'tag_id': 3306},
    703: {'name': 'Menya Samalot', 'tag_id': 3307},
    1: {'name': 'Mostorod', 'tag_id': 3308},
    962: {'name': 'Sakkarah', 'tag_id': 3309},
    170: {'name': 'Sharqya', 'tag_id': 3310},
    632: {'name': 'Sohag', 'tag_id': 3311},
    8: {'name': 'Tanta', 'tag_id': 3312},
    38: {'name': 'El-Marg', 'tag_id': 3313},  # Added if missing
}

print("✓ QD calculation parameters:")
print(f"  MAX_DISCOUNT_PCT: {MAX_DISCOUNT_PCT}%")
print(f"  MIN_DISCOUNT_PCT: {MIN_DISCOUNT_PCT}%")
print(f"  RATIO RANGE: [{MIN_RATIO}, {MAX_RATIO}]")
print(f"\n✓ Wholesale (T3) parameters:")
print(f"  WS_CAR_COST: {WS_CAR_COST} EGP")
print(f"  WS_MAX_TICKET_SIZE: {WS_MAX_TICKET_SIZE} EGP")
print(f"  WS_MIN_MARGIN: {WS_MIN_MARGIN*100}%")
print(f"  TOP_SKUS_PER_WAREHOUSE: {TOP_SKUS_PER_WAREHOUSE}")
print(f"\n✓ Upload parameters:")
print(f"  MAX_GROUP_SIZE: {MAX_GROUP_SIZE}")
print(f"  QD_DURATION_HOURS: {QD_DURATION_HOURS}")


In [None]:
# =============================================================================
# DATA FETCHING: PACKING UNITS & TIER QUANTITIES
# =============================================================================

def get_top_selling_packing_units(product_warehouse_list: list) -> pd.DataFrame:
    """
    Get the top-selling packing unit per product per warehouse (last 90 days).
    
    Args:
        product_warehouse_list: List of (product_id, warehouse_id) tuples
        
    Returns:
        DataFrame with product_id, warehouse_id, packing_unit_id, basic_unit_count
    """
    if not product_warehouse_list:
        return pd.DataFrame(columns=['product_id', 'warehouse_id', 'packing_unit_id', 'basic_unit_count'])
    
    # Build tuples string for SQL
    tuples_str = ','.join([f"({int(p)}, {int(w)})" for p, w in product_warehouse_list])
    
    query = f'''
    WITH input_products AS (
        SELECT product_id, warehouse_id
        FROM (VALUES {tuples_str}) AS x(product_id, warehouse_id)
    ),
    
    sales_by_pu AS (
        SELECT 
            pso.product_id,
            so.warehouse_id,
            pso.packing_unit_id,
            SUM(pso.total_price) as nmv
        FROM product_sales_order pso
        JOIN sales_orders so ON so.id = pso.sales_order_id
        JOIN input_products ip ON ip.product_id = pso.product_id AND ip.warehouse_id = so.warehouse_id
        WHERE so.created_at >= CURRENT_DATE - 90
            AND so.sales_order_status_id NOT IN (7, 12)
            AND so.channel IN ('telesales', 'retailer')
            AND pso.purchased_item_count <> 0
        GROUP BY 1, 2, 3
    ),
    
    ranked_pu AS (
        SELECT 
            s.product_id, 
            s.warehouse_id, 
            s.packing_unit_id,
            pup.basic_unit_count,
            s.nmv,
            ROW_NUMBER() OVER (PARTITION BY s.product_id, s.warehouse_id ORDER BY s.nmv DESC) as rnk
        FROM sales_by_pu s
        JOIN packing_unit_products pup 
            ON pup.product_id = s.product_id 
            AND pup.packing_unit_id = s.packing_unit_id
        WHERE pup.deleted_at IS NULL
    )
    
    SELECT product_id, warehouse_id, packing_unit_id, basic_unit_count
    FROM ranked_pu
    WHERE rnk = 1
    '''
    
    print("  Fetching top-selling packing units (last 90 days)...")
    df = query_snowflake(query)
    
    # Convert to numeric
    for col in df.columns:
        df[col] = pd.to_numeric(df[col], errors='ignore')
    
    print(f"    Found packing units for {len(df)} product-warehouse combinations")
    return df


def get_tier_quantities(product_warehouse_pu_list: list) -> pd.DataFrame:
    """
    Calculate tier quantities based on historical order data.
    
    Args:
        product_warehouse_pu_list: List of (warehouse_id, product_id, packing_unit_id) tuples
        
    Returns:
        DataFrame with tier_1_qty, tier_2_qty per product-warehouse
    """
    if not product_warehouse_pu_list:
        return pd.DataFrame(columns=['warehouse_id', 'product_id', 'packing_unit_id', 'tier_1_qty', 'tier_2_qty'])
    
    # Build tuples string for SQL
    tuples_str = ','.join([f"({int(w)}, {int(p)}, {int(pu)})" for w, p, pu in product_warehouse_pu_list])
    
    query = f'''
    WITH selected_products AS (
        SELECT warehouse_id, product_id, packing_unit_id
        FROM (VALUES {tuples_str}) AS x(warehouse_id, product_id, packing_unit_id)
    ),
    
    -- Retailers in QD cohorts
    base AS (
        SELECT *, ROW_NUMBER() OVER (PARTITION BY retailer_id ORDER BY priority) as rnk 
        FROM (
            SELECT x.*, TAGGABLE_ID as retailer_id 
            FROM (
                SELECT id as cohort_id, name as cohort_name, priority, dynamic_tag_id 
                FROM cohorts 
                WHERE is_active = 'true'
                    AND id IN (700,701,702,703,704,1123,1124,1125,1126)
            ) x 
            JOIN DYNAMIC_TAGgables dt ON x.dynamic_tag_id = dt.dynamic_tag_id
            WHERE dt.taggable_id not IN (
                SELECT taggable_id FROM DYNAMIC_TAGgables 
                WHERE dynamic_tag_id IN (2807, 2808, 2809, 2810, 2811, 2812)
            )
        )
        QUALIFY rnk = 1 
    ),
    
    -- Warehouse mapping
    warehouse_mapping AS (
        SELECT * FROM (VALUES
            ('Cairo', 'Mostorod', 1),
            ('Giza', 'Barageel', 236),
            ('Giza', 'Sakkarah', 962),
            ('Delta West', 'El-Mahala', 337),
            ('Delta West', 'Tanta', 8),
            ('Delta East', 'Mansoura FC', 339),
            ('Delta East', 'Sharqya', 170),
            ('Upper Egypt', 'Assiut FC', 501),
            ('Upper Egypt', 'Bani sweif', 401),
            ('Upper Egypt', 'Menya Samalot', 703),
            ('Upper Egypt', 'Sohag', 632),
            ('Alexandria', 'Khorshed Alex', 797)
        ) x(region_name, wh, warehouse_id)
    ),
    
    raw_order_quantities AS (
        SELECT 
            whs.warehouse_id,
            pso.product_id,
            pso.packing_unit_id,
            so.parent_sales_order_id,
            so.retailer_id,
            so.created_at::date as order_date,
            SUM(pso.purchased_item_count) as order_qty,
            EXP(-0.02 * DATEDIFF('day', so.created_at::date, CURRENT_DATE)) as recency_weight
            
        FROM product_sales_order pso
        JOIN sales_orders so ON so.id = pso.sales_order_id
        JOIN base ON base.retailer_id = so.retailer_id
        JOIN materialized_views.retailer_polygon ON materialized_views.retailer_polygon.retailer_id = so.retailer_id
        JOIN districts ON districts.id = materialized_views.retailer_polygon.district_id
        JOIN cities ON cities.id = districts.city_id
        JOIN states ON states.id = cities.state_id
        JOIN regions ON regions.id = states.region_id
        JOIN warehouse_mapping whs ON whs.region_name = CASE WHEN regions.id = 2 THEN states.name_en ELSE regions.name_en END
        JOIN selected_products sp ON sp.warehouse_id = whs.warehouse_id 
            AND sp.product_id = pso.product_id
            AND sp.packing_unit_id = pso.packing_unit_id
        
        WHERE TRUE
            AND so.created_at::date BETWEEN DATE_TRUNC('month', CURRENT_DATE - INTERVAL '4 months') AND CURRENT_DATE - 1
            AND so.sales_order_status_id NOT IN (7, 12)
            AND so.channel IN ('telesales', 'retailer')
            AND pso.purchased_item_count <> 0
        
        GROUP BY 1, 2, 3, 4, 5, 6
    ),
    
    retailer_frequency AS (
        SELECT 
            warehouse_id, product_id, packing_unit_id, retailer_id,
            COUNT(DISTINCT parent_sales_order_id) as order_count,
            COUNT(DISTINCT DATE_TRUNC('week', order_date)) as weeks_ordered
        FROM raw_order_quantities
        GROUP BY 1, 2, 3, 4
    ),
    
    frequent_buyers AS (
        SELECT warehouse_id, product_id, packing_unit_id, retailer_id
        FROM retailer_frequency
        WHERE order_count >= 2 OR weeks_ordered >= 2
    ),
    
    filtered_orders AS (
        SELECT roq.*
        FROM raw_order_quantities roq
        JOIN frequent_buyers fb ON fb.warehouse_id = roq.warehouse_id
            AND fb.product_id = roq.product_id
            AND fb.packing_unit_id = roq.packing_unit_id
            AND fb.retailer_id = roq.retailer_id
    ),
    
    initial_stats AS (
        SELECT 
            warehouse_id, product_id, packing_unit_id,
            PERCENTILE_CONT(0.25) WITHIN GROUP (ORDER BY order_qty) as q1,
            PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY order_qty) as q3,
            MEDIAN(order_qty) as median_qty,
            STDDEV_POP(order_qty) as stddev_qty,
            AVG(order_qty) as avg_qty
        FROM filtered_orders
        GROUP BY 1, 2, 3
    ),
    
    cleaned_orders AS (
        SELECT fo.*
        FROM filtered_orders fo
        JOIN initial_stats ist ON ist.warehouse_id = fo.warehouse_id
            AND ist.product_id = fo.product_id
            AND ist.packing_unit_id = fo.packing_unit_id
        WHERE fo.order_qty >= ist.q1 - 1.5 * (ist.q3 - ist.q1)
            AND fo.order_qty <= ist.q3 + 1.5 * (ist.q3 - ist.q1)
            AND (ist.stddev_qty = 0 OR ABS(fo.order_qty - ist.avg_qty) <= 3 * ist.stddev_qty)
    ),
    
    recent_trends AS (
        SELECT 
            warehouse_id, product_id, packing_unit_id,
            SUM(order_qty * recency_weight) / NULLIF(SUM(recency_weight), 0) as weighted_avg_qty,
            AVG(CASE WHEN order_date >= CURRENT_DATE - 15 THEN order_qty END) as last_15d_avg,
            MEDIAN(CASE WHEN order_date >= CURRENT_DATE - 15 THEN order_qty END) as last_15d_median,
            MAX(CASE WHEN order_date >= CURRENT_DATE - 15 THEN order_qty END) as last_15d_max,
            COUNT(CASE WHEN order_date >= CURRENT_DATE - 15 THEN 1 END) as last_15d_orders
        FROM cleaned_orders
        GROUP BY 1, 2, 3
    ),
    
    quantity_stats AS (
        SELECT 
            warehouse_id, product_id, packing_unit_id,
            COUNT(DISTINCT parent_sales_order_id) as total_orders,
            MEDIAN(order_qty) as median_qty,
            STDDEV_POP(order_qty) as stddev_qty,
            PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY order_qty) as q3_qty,
            PERCENTILE_CONT(0.85) WITHIN GROUP (ORDER BY order_qty) as p85_qty,
            PERCENTILE_CONT(0.90) WITHIN GROUP (ORDER BY order_qty) as p90_qty,
            PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY order_qty) as p95_qty
        FROM cleaned_orders
        GROUP BY 1, 2, 3
    ),
    
    most_frequent_qty AS (
        SELECT warehouse_id, product_id, packing_unit_id, order_qty as mode_qty
        FROM (
            SELECT warehouse_id, product_id, packing_unit_id, order_qty,
                   COUNT(*) as freq,
                   ROW_NUMBER() OVER (PARTITION BY warehouse_id, product_id, packing_unit_id ORDER BY COUNT(*) DESC, order_qty DESC) as rn
            FROM cleaned_orders
            GROUP BY 1, 2, 3, 4
        )
        WHERE rn = 1
    ),
    
    tier_calculations AS (
        SELECT 
            qs.warehouse_id, qs.product_id, qs.packing_unit_id,
            qs.median_qty, qs.stddev_qty, qs.q3_qty, qs.p85_qty, qs.p90_qty, qs.p95_qty,
            COALESCE(mf.mode_qty, qs.median_qty) as mode_qty,
            rt.weighted_avg_qty, rt.last_15d_median, rt.last_15d_max, rt.last_15d_orders,
            
            -- Tier 1 calculation
            CEIL(GREATEST(
                (0.7 * qs.median_qty + 0.3 * COALESCE(rt.weighted_avg_qty, qs.median_qty)) + 1.0 * COALESCE(qs.stddev_qty, 1),
                qs.q3_qty,
                COALESCE(mf.mode_qty, qs.median_qty) + GREATEST(3, qs.median_qty * 0.3),
                CASE 
                    WHEN rt.last_15d_orders >= 2 AND rt.last_15d_median > qs.median_qty 
                    THEN rt.last_15d_median * 1.2
                    ELSE qs.median_qty * 1.3
                END,
                qs.median_qty + 2
            )) as tier_1_qty,
            
            -- Tier 2 base calculation
            CEIL(GREATEST(
                qs.q3_qty + 1.5 * COALESCE(qs.stddev_qty, 1),
                qs.p85_qty + 1.0 * COALESCE(qs.stddev_qty, 1),
                qs.p90_qty + 0.5 * COALESCE(qs.stddev_qty, 1),
                qs.p95_qty,
                (0.6 * qs.median_qty + 0.4 * COALESCE(rt.weighted_avg_qty, qs.median_qty)) * 2.0,
                CASE 
                    WHEN rt.last_15d_orders >= 2 AND rt.last_15d_max > qs.p90_qty 
                    THEN rt.last_15d_max * 1.1
                    ELSE qs.median_qty * 1.6
                END
            )) as tier_2_qty_base
            
        FROM quantity_stats qs
        LEFT JOIN most_frequent_qty mf ON mf.warehouse_id = qs.warehouse_id 
            AND mf.product_id = qs.product_id AND mf.packing_unit_id = qs.packing_unit_id
        LEFT JOIN recent_trends rt ON rt.warehouse_id = qs.warehouse_id
            AND rt.product_id = qs.product_id AND rt.packing_unit_id = qs.packing_unit_id
    )
    
    SELECT 
        warehouse_id, product_id, packing_unit_id,
        tier_1_qty,
        LEAST(
            CEIL(GREATEST(tier_2_qty_base, tier_1_qty * 1.6)),
            GREATEST(tier_1_qty * 3.5, tier_1_qty + 20)
        ) as tier_2_qty,
        median_qty, stddev_qty
    FROM tier_calculations
    '''
    
    print("  Calculating tier quantities from order history...")
    df = query_snowflake(query)
    
    # Convert to numeric
    for col in df.columns:
        df[col] = pd.to_numeric(df[col], errors='ignore')
    
    print(f"    Calculated tiers for {len(df)} product-warehouse combinations")
    return df


def get_warehouse_ticket_stats() -> pd.DataFrame:
    """
    Get warehouse-level ticket size statistics for wholesale calculations.
    
    Returns:
        DataFrame with warehouse_id, avg_ticket_size, orders_per_car_by_weight
    """
    query = f'''
    WITH base AS (
        SELECT *, ROW_NUMBER() OVER (PARTITION BY retailer_id ORDER BY priority) as rnk 
        FROM (
            SELECT x.*, TAGGABLE_ID as retailer_id 
            FROM (
                SELECT id as cohort_id, name as cohort_name, priority, dynamic_tag_id 
                FROM cohorts 
                WHERE is_active = 'true'
                    AND id IN (700,701,702,703,704,1123,1124,1125,1126)
            ) x 
            JOIN DYNAMIC_TAGgables dt ON x.dynamic_tag_id = dt.dynamic_tag_id
            WHERE dt.taggable_id not IN (
                SELECT taggable_id FROM DYNAMIC_TAGgables 
                WHERE dynamic_tag_id IN (2807, 2808, 2809, 2810, 2811, 2812)
            )
        )
        QUALIFY rnk = 1 
    ),

    -- Map regions to warehouses
    whs AS (
        SELECT * FROM (VALUES
            ('Cairo', 'El-Marg', 38),
            ('Cairo', 'Mostorod', 1),
            ('Giza', 'Barageel', 236),
            ('Giza', 'Sakkarah', 962),
            ('Delta West', 'El-Mahala', 337),
            ('Delta West', 'Tanta', 8),
            ('Delta East', 'Mansoura FC', 339),
            ('Delta East', 'Sharqya', 170),
            ('Upper Egypt', 'Assiut FC', 501),
            ('Upper Egypt', 'Bani sweif', 401),
            ('Upper Egypt', 'Menya Samalot', 703),
            ('Upper Egypt', 'Sohag', 632),
            ('Alexandria', 'Khorshed Alex', 797)
        ) x(region_name, wh, warehouse_id)
    ),

    -- Get ticket sizes (order values) for last 4 months
    ticket_sizes AS (
        SELECT 
            whs.warehouse_id,
            whs.wh as warehouse_name,
            so.parent_sales_order_id,
            so.retailer_id,
            SUM(pso.total_price) as ticket_size,
            SUM(pso.purchased_item_count * pup.weight / 1000) as order_weight_kg
        FROM product_sales_order pso
        JOIN sales_orders so ON so.id = pso.sales_order_id
        JOIN base ON base.retailer_id = so.retailer_id
        JOIN packing_unit_products pup ON pup.product_id = pso.product_id 
            AND pup.packing_unit_id = pso.packing_unit_id
        JOIN materialized_views.retailer_polygon rp ON rp.retailer_id = so.retailer_id
        JOIN districts ON districts.id = rp.district_id
        JOIN cities ON cities.id = districts.city_id
        JOIN states ON states.id = cities.state_id
        JOIN regions ON regions.id = states.region_id
        JOIN whs ON whs.region_name = CASE WHEN regions.id = 2 THEN states.name_en ELSE regions.name_en END
        WHERE so.created_at::date BETWEEN DATE_TRUNC('month', CURRENT_DATE - INTERVAL '4 months') AND CURRENT_DATE - 1
            AND so.sales_order_status_id NOT IN (7, 12)
            AND so.channel IN ('telesales', 'retailer')
            AND pso.purchased_item_count > 0
        GROUP BY whs.warehouse_id, whs.wh, so.parent_sales_order_id, so.retailer_id
    ),

    -- Calculate warehouse-level statistics
    warehouse_stats AS (
        SELECT 
            warehouse_id,
            warehouse_name,
            COUNT(DISTINCT parent_sales_order_id) as total_orders,
            COUNT(DISTINCT retailer_id) as total_retailers,
            AVG(ticket_size) as avg_ticket_size,
            MEDIAN(ticket_size) as median_ticket_size,
            AVG(order_weight_kg) as avg_order_weight_kg
        FROM ticket_sizes
        WHERE ticket_size > 0
        GROUP BY warehouse_id, warehouse_name
    )

    SELECT 
        warehouse_id,
        warehouse_name,
        ROUND(avg_ticket_size, 2) as avg_ticket_size,
        ROUND(median_ticket_size, 2) as median_ticket_size,
        ROUND(avg_order_weight_kg, 2) as avg_order_weight_kg,
        -- Calculate how many orders fit in one car based on weight
        ROUND({WS_CAR_CAPACITY_TONS * 1000} / NULLIF(avg_order_weight_kg, 0), 1) as orders_per_car_by_weight
    FROM warehouse_stats
    ORDER BY warehouse_id
    '''
    
    print("  Fetching warehouse ticket statistics...")
    df = query_snowflake(query)
    
    # Convert to numeric
    for col in df.columns:
        df[col] = pd.to_numeric(df[col], errors='ignore')
    
    print(f"    Got stats for {len(df)} warehouses")
    return df

print("✓ Data fetching functions defined")


In [None]:
# =============================================================================
# TIER PRICE CALCULATION
# =============================================================================

def calculate_tier_prices(row):
    """
    Calculate tier 1 and tier 2 prices for a single row.
    
    Uses market margins and margin tiers (converted to prices) to find
    the best two prices within discount bounds.
    
    Constraints:
    - Ensure: WAC < Tier 2 < Tier 1 < Current Price
    - Ensure: BOTH tiers must be valid or BOTH are None
    - Enforce ratio: discount_2/discount_1 / (qty_2/qty_1) between MIN_RATIO and MAX_RATIO
    
    Args:
        row: DataFrame row with wac_p, current_price, market margins, margin tiers, tier quantities
        
    Returns:
        Series with tier_1_price, tier_2_price, price_source
    """
    current_price = row.get('packing_unit_price')  # Price for packing unit
    wac = row.get('wac_pu')  # WAC for packing unit
    tier_1_qty = row.get('tier_1_qty')
    tier_2_qty = row.get('tier_2_qty')
    
    # Validation
    if pd.isna(current_price) or current_price <= 0:
        return pd.Series({'tier_1_price': np.nan, 'tier_2_price': np.nan, 'price_source': 'invalid_current_price'})
    
    if pd.isna(wac) or wac <= 0:
        return pd.Series({'tier_1_price': np.nan, 'tier_2_price': np.nan, 'price_source': 'invalid_wac'})
    
    if current_price <= wac:
        return pd.Series({'tier_1_price': np.nan, 'tier_2_price': np.nan, 'price_source': 'current_price_below_wac'})
    
    # Calculate discount bounds
    max_discount_price = current_price * (1 - MAX_DISCOUNT_PCT / 100)  # Minimum allowed price (max discount)
    min_discount_price = current_price * (1 - MIN_DISCOUNT_PCT / 100)  # Maximum allowed price (min discount)
    
    # Collect candidate prices from market margins (convert margin to price)
    candidate_prices = []
    
    # Market margin columns (these are margins, convert to prices)
    market_margin_cols = ['below_market', 'market_min', 'market_25', 'market_50', 
                          'market_75', 'market_max', 'above_market']
    
    for col in market_margin_cols:
        margin = row.get(col)
        if pd.notna(margin) and 0 < margin < 1:
            price = wac / (1 - margin)
            if max_discount_price <= price <= min_discount_price and price > wac:
                candidate_prices.append(('market', col, price))
    
    # Margin tier columns (these are margins, convert to prices)
    margin_tier_cols = ['margin_tier_1', 'margin_tier_2', 'margin_tier_3', 'margin_tier_4',
                        'margin_tier_5', 'margin_tier_above_1', 'margin_tier_above_2']
    
    for col in margin_tier_cols:
        margin = row.get(col)
        if pd.notna(margin) and 0 < margin < 1:
            price = wac / (1 - margin)
            if max_discount_price <= price <= min_discount_price and price > wac:
                candidate_prices.append(('margin_tier', col, price))
    
    # Remove duplicates and sort by price descending
    unique_prices = {}
    for source_type, source_col, price in candidate_prices:
        price_rounded = round(price, 2)
        if price_rounded not in unique_prices:
            unique_prices[price_rounded] = (source_type, source_col)
    
    valid_prices = sorted(unique_prices.keys(), reverse=True)
    
    # Need at least 2 prices
    if len(valid_prices) < 2:
        return pd.Series({'tier_1_price': np.nan, 'tier_2_price': np.nan, 'price_source': 'insufficient_valid_prices'})
    
    tier_1 = None
    tier_2 = None
    source = ''
    
    # Strategy: Find two prices with minimum gap
    for i, p1 in enumerate(valid_prices):
        for p2 in valid_prices[i+1:]:
            # Ensure minimum gap between tiers
            if p2 < p1 * (1 - MIN_GAP_PCT / 100):
                tier_1 = p1
                tier_2 = p2
                source = f"{unique_prices[p1][0]}_{unique_prices[p2][0]}"
                break
        if tier_1 is not None:
            break
    
    # If no pair with minimum gap, take top two
    if tier_1 is None and len(valid_prices) >= 2:
        tier_1 = valid_prices[0]
        tier_2 = valid_prices[1]
        source = 'top_two_prices'
    
    # Final validation
    if tier_1 is None or tier_2 is None:
        return pd.Series({'tier_1_price': np.nan, 'tier_2_price': np.nan, 'price_source': 'no_valid_pair'})
    
    # Ensure correct ordering
    if tier_2 >= tier_1:
        tier_1, tier_2 = max(tier_1, tier_2), min(tier_1, tier_2)
    
    # Check basic constraints
    if not (wac < tier_2 < tier_1 < current_price):
        return pd.Series({'tier_1_price': np.nan, 'tier_2_price': np.nan, 'price_source': 'invalid_ordering'})
    
    # ==========================================================================
    # RATIO ADJUSTMENT
    # Ensure: discount_2/discount_1 / (qty_2/qty_1) is between MIN_RATIO and MAX_RATIO
    # ==========================================================================
    if pd.notna(tier_1_qty) and pd.notna(tier_2_qty) and tier_1_qty > 0:
        tier_1_discount = current_price - tier_1
        tier_2_discount = current_price - tier_2
        
        if tier_1_discount > 0:
            qty_ratio = tier_2_qty / tier_1_qty
            discount_ratio = tier_2_discount / tier_1_discount
            
            if qty_ratio > 0:
                elasticity_ratio = discount_ratio / qty_ratio
                
                # If ratio too high, reduce T2 discount (increase T2 price)
                if elasticity_ratio > MAX_RATIO:
                    target_discount_ratio = MAX_RATIO * qty_ratio
                    target_tier_2_discount = target_discount_ratio * tier_1_discount
                    adjusted_tier_2 = current_price - target_tier_2_discount
                    
                    if adjusted_tier_2 > wac and adjusted_tier_2 < tier_1:
                        tier_2 = round(adjusted_tier_2, 2)
                        source += '_ratio_down'
                    else:
                        return pd.Series({'tier_1_price': np.nan, 'tier_2_price': np.nan, 
                                         'price_source': f'cannot_adjust_ratio_{elasticity_ratio:.2f}_max'})
                
                # If ratio too low, increase T2 discount (decrease T2 price)
                elif elasticity_ratio < MIN_RATIO:
                    target_discount_ratio = MIN_RATIO * qty_ratio
                    target_tier_2_discount = target_discount_ratio * tier_1_discount
                    adjusted_tier_2 = current_price - target_tier_2_discount
                    
                    if adjusted_tier_2 > wac and adjusted_tier_2 < tier_1:
                        tier_2 = round(adjusted_tier_2, 2)
                        source += '_ratio_up'
                    else:
                        return pd.Series({'tier_1_price': np.nan, 'tier_2_price': np.nan, 
                                         'price_source': f'cannot_adjust_ratio_{elasticity_ratio:.2f}_min'})
    
    # Final rounding
    tier_1 = round(tier_1, 2)
    tier_2 = round(tier_2, 2)
    
    return pd.Series({
        'tier_1_price': tier_1,
        'tier_2_price': tier_2,
        'price_source': source
    })


def calculate_wholesale_tier(row):
    """
    Calculate wholesale (Tier 3) pricing based on delivery cost savings.
    
    Logic:
    - Car cost per order = WS_CAR_COST / orders_per_car
    - If retailer consolidates orders, they save delivery costs
    - Savings = deliveries_saved * car_cost_per_order
    - Calculate scenarios from 3x to orders_per_car multiplier
    
    Constraints:
    - new_price >= max(wac/(1 - 0.4*current_margin), wac/(1 - WS_MIN_MARGIN))
    - order_value <= WS_MAX_TICKET_SIZE
    - new_price < tier_2_price (T3 price must be lower than T2)
    
    Args:
        row: DataFrame row with packing_unit_price, wac_pu, tier_2_price, 
             avg_ticket_size, orders_per_car_by_weight
             
    Returns:
        Series with ws_qty, ws_price, ws_discount_pct, ws_margin, etc.
    """
    current_price = row.get('packing_unit_price')
    wac = row.get('wac_pu')
    avg_ts = row.get('avg_ticket_size', 4000)
    tier_2_price = row.get('tier_2_price')
    
    # Get orders per car (how many orders fit in one car trip based on weight)
    orders_per_car = row.get('orders_per_car_by_weight', 15)
    if pd.isna(orders_per_car) or orders_per_car <= 0:
        orders_per_car = 15
    
    # Calculate car cost per order
    car_cost_per_order = WS_CAR_COST / orders_per_car
    
    if pd.isna(avg_ts) or avg_ts <= 0:
        avg_ts = 4000
    
    # Validation
    if pd.isna(current_price) or pd.isna(wac) or current_price <= 0 or wac <= 0:
        return pd.Series({
            'ws_qty': np.nan, 'ws_price': np.nan, 'ws_discount_pct': np.nan,
            'ws_margin': np.nan, 'ws_multiplier': np.nan, 'ws_savings_pct': np.nan
        })
    
    if pd.isna(tier_2_price) or tier_2_price <= 0:
        return pd.Series({
            'ws_qty': np.nan, 'ws_price': np.nan, 'ws_discount_pct': np.nan,
            'ws_margin': np.nan, 'ws_multiplier': np.nan, 'ws_savings_pct': np.nan
        })
    
    # ==========================================================================
    # MARGIN-BASED MINIMUM PRICE FOR WHOLESALE
    # Wholesale: minimum margin = 40% of current margin
    # ==========================================================================
    current_margin = (current_price - wac) / current_price
    min_ws_margin = 0.4 * current_margin
    min_ws_price_margin_based = wac / (1 - min_ws_margin) if min_ws_margin < 1 else current_price
    
    # Also keep legacy minimum (WAC + WS_MIN_MARGIN)
    min_ws_price_legacy = wac / (1 - WS_MIN_MARGIN)
    
    # Use the HIGHER of the two constraints
    min_acceptable_price = max(min_ws_price_margin_based, min_ws_price_legacy)
    
    best_scenario = None
    best_savings_pct = 0
    
    # Test scenarios from 3x to orders_per_car
    for multiplier in range(3, int(orders_per_car) + 1):
        # Order value at this multiplier
        order_value = avg_ts * multiplier
        
        # Deliveries saved = multiplier - 1 (consolidating multiple orders into one)
        deliveries_saved = multiplier - 1
        
        # Total savings = deliveries_saved * car_cost_per_order
        total_savings = deliveries_saved * car_cost_per_order
        
        # How many units of this SKU fit in this order value?
        qty_at_current_price = order_value / current_price
        
        if qty_at_current_price <= 0:
            continue
        
        # Discount per unit from car cost savings
        discount_per_unit = total_savings / qty_at_current_price
        
        # New price after passing car cost savings
        new_price = current_price - discount_per_unit
        
        # Check all constraints:
        # 1. Price above minimum margin
        # 2. Order value within max ticket size
        # 3. Price below tier_2_price (T3 must be cheaper than T2)
        if new_price >= min_acceptable_price and order_value <= WS_MAX_TICKET_SIZE and new_price < tier_2_price:
            # Calculate margin at new price
            margin = (new_price - wac) / new_price
            
            # Savings percentage for retailer
            savings_pct = (discount_per_unit / current_price) * 100
            
            # Keep track of best scenario (highest savings while valid)
            if savings_pct > best_savings_pct:
                best_savings_pct = savings_pct
                best_scenario = {
                    'ws_qty': round(qty_at_current_price, 0),
                    'ws_price': round(new_price, 2),
                    'ws_discount_pct': round((current_price - new_price) / current_price * 100, 2),
                    'ws_margin': round(margin, 4),
                    'ws_multiplier': multiplier,
                    'ws_savings_pct': round(savings_pct, 2)
                }
    
    if best_scenario:
        return pd.Series(best_scenario)
    else:
        return pd.Series({
            'ws_qty': np.nan, 'ws_price': np.nan, 'ws_discount_pct': np.nan,
            'ws_margin': np.nan, 'ws_multiplier': np.nan, 'ws_savings_pct': np.nan
        })

print("✓ Tier price calculation function defined")
print("✓ Wholesale tier calculation function defined")


In [None]:
# =============================================================================
# MAIN FUNCTION: process_qd
# =============================================================================
def process_qd(df_qd: pd.DataFrame, dry_run: bool = True) -> dict:
    """
    Main function to process Quantity Discounts.
    Called from module_3_periodic_actions.ipynb with a filtered DataFrame.
    
    This function:
    1. Deactivates ALL currently active Quantity Discounts (FIRST!)
    2. Gets packing units for each product-warehouse
    3. Gets warehouse ticket statistics for wholesale calculations
    4. Calculates tier quantities from order history
    5. Calculates T1 & T2 prices using market margins and margin tiers
    6. Calculates T3 (wholesale) prices based on delivery cost savings
    7. Selects top N SKUs per warehouse based on mtd_qty * current_price
    8. Filters tiers based on keep_qd_tiers from Module 3
    9. Creates new QDs with calculated tiers
    
    Args:
        df_qd: DataFrame with columns from Module 3 (see documentation)
        dry_run: If True, only log what would be done (default: True)
        
    Returns:
        dict with processing results
    """
    print("\n" + "="*70)
    print("QD HANDLER: PROCESSING QUANTITY DISCOUNTS")
    print("="*70)
    print(f"Mode: {'DRY RUN (testing)' if dry_run else 'LIVE'}")
    print(f"Timestamp: {CAIRO_NOW.strftime('%Y-%m-%d %H:%M')} Cairo Time")
    print(f"Input SKUs: {len(df_qd)}")
    
    if len(df_qd) == 0:
        print("\nNo SKUs to process. Exiting.")
        return {
            'mode': 'testing' if dry_run else 'live',
            'total_input': 0,
            'processed': 0,
            'failed': 0,
            'deactivate_result': {'total_active': 0, 'deactivated': [], 'failed': []},
            'create_result': {'created_count': 0, 'failed_count': 0, 'errors': []}
        }
    
    # Preview input
    print(f"\nUnique warehouses: {df_qd['warehouse_id'].nunique()}")
    
    # =========================================================================
    # STEP 1: DEACTIVATE ALL EXISTING QUANTITY DISCOUNTS (FIRST!)
    # =========================================================================
    print("\n" + "-"*60)
    print("STEP 1: Deactivating existing Quantity Discounts...")
    print("-"*60)
    
    deactivate_result = deactivate_active_qd(dry_run=dry_run)
    
    # =========================================================================
    # STEP 2: GET PACKING UNITS
    # =========================================================================
    print("\n" + "-"*60)
    print("STEP 2: Getting top-selling packing units...")
    print("-"*60)
    
    # Create list of (product_id, warehouse_id) tuples
    product_warehouse_list = df_qd[['product_id', 'warehouse_id']].drop_duplicates().values.tolist()
    
    df_packing_units = get_top_selling_packing_units(product_warehouse_list)
    
    if len(df_packing_units) == 0:
        print("  ⚠ No packing units found!")
        return {
            'mode': 'testing' if dry_run else 'live',
            'total_input': len(df_qd),
            'processed': 0,
            'failed': len(df_qd),
            'deactivate_result': deactivate_result,
            'create_result': {'created_count': 0, 'failed_count': 0, 'errors': [{'error': 'No packing units found'}]}
        }
    
    # Merge packing units with input data
    df_work = df_qd.merge(df_packing_units, on=['product_id', 'warehouse_id'], how='inner')
    print(f"  Matched {len(df_work)} SKUs with packing units")
    
    # Use new_price if available, otherwise fallback to current_price
    df_work['effective_price'] = df_work['new_price'].fillna(df_work['current_price'])
    print(f"  Using new_price: {df_work['new_price'].notna().sum()} SKUs")
    print(f"  Using current_price (fallback): {df_work['new_price'].isna().sum()} SKUs")
    
    # Calculate packing unit prices (multiply by basic_unit_count)
    df_work['wac_pu'] = df_work['wac_p'] * df_work['basic_unit_count']
    df_work['packing_unit_price'] = df_work['effective_price'] * df_work['basic_unit_count']
    
    # =========================================================================
    # STEP 3: GET WAREHOUSE TICKET STATISTICS (for wholesale)
    # =========================================================================
    print("\n" + "-"*60)
    print("STEP 3: Getting warehouse ticket statistics...")
    print("-"*60)
    
    df_warehouse_stats = get_warehouse_ticket_stats()
    
    if len(df_warehouse_stats) > 0:
        df_work = df_work.merge(
            df_warehouse_stats[['warehouse_id', 'avg_ticket_size', 'orders_per_car_by_weight']],
            on='warehouse_id',
            how='left'
        )
        print(f"  Merged ticket stats for {df_work['avg_ticket_size'].notna().sum()} SKUs")
    else:
        print("  ⚠ No warehouse stats found, using defaults for wholesale")
        df_work['avg_ticket_size'] = 4000
        df_work['orders_per_car_by_weight'] = 15
    
    # =========================================================================
    # STEP 4: CALCULATE TIER QUANTITIES
    # =========================================================================
    print("\n" + "-"*60)
    print("STEP 4: Calculating tier quantities...")
    print("-"*60)
    
    # Create list of (warehouse_id, product_id, packing_unit_id) tuples
    product_warehouse_pu_list = df_work[['warehouse_id', 'product_id', 'packing_unit_id']].drop_duplicates().values.tolist()
    
    df_tier_qty = get_tier_quantities(product_warehouse_pu_list)
    
    if len(df_tier_qty) == 0:
        print("  ⚠ No tier quantities calculated!")
    else:
        # Merge tier quantities
        df_work = df_work.merge(
            df_tier_qty[['warehouse_id', 'product_id', 'packing_unit_id', 'tier_1_qty', 'tier_2_qty']],
            on=['warehouse_id', 'product_id', 'packing_unit_id'],
            how='left'
        )
        print(f"  {df_work['tier_1_qty'].notna().sum()} SKUs have tier quantities")
    
    # =========================================================================
    # STEP 5: CALCULATE T1 & T2 PRICES
    # =========================================================================
    print("\n" + "-"*60)
    print("STEP 5: Calculating T1 & T2 prices...")
    print("-"*60)
    
    # Apply price calculation to each row
    price_results = df_work.apply(calculate_tier_prices, axis=1)
    df_work = pd.concat([df_work, price_results], axis=1)
    
    valid_t1_t2 = df_work['tier_1_price'].notna() & df_work['tier_2_price'].notna()
    print(f"  Valid T1 & T2 prices: {valid_t1_t2.sum()} / {len(df_work)}")
    
    # Show price source distribution
    if 'price_source' in df_work.columns:
        print("\n  Price source distribution:")
        for source, count in df_work['price_source'].value_counts().head(5).items():
            print(f"    {source}: {count}")
    
    # =========================================================================
    # STEP 6: CALCULATE T3 (WHOLESALE) PRICES
    # =========================================================================
    print("\n" + "-"*60)
    print("STEP 6: Calculating T3 (wholesale) prices...")
    print("-"*60)
    
    # Calculate wholesale tier for rows with valid T2 prices
    ws_results = df_work.apply(calculate_wholesale_tier, axis=1)
    df_work = pd.concat([df_work, ws_results], axis=1)
    
    valid_t3 = df_work['ws_price'].notna()
    print(f"  Valid T3 prices: {valid_t3.sum()} / {len(df_work)}")
    
    if valid_t3.sum() > 0:
        print(f"\n  T3 Statistics:")
        print(f"    Average multiplier: {df_work.loc[valid_t3, 'ws_multiplier'].mean():.1f}x")
        print(f"    Average discount: {df_work.loc[valid_t3, 'ws_discount_pct'].mean():.2f}%")
        print(f"    Average margin: {df_work.loc[valid_t3, 'ws_margin'].mean()*100:.2f}%")
    
    # =========================================================================
    # STEP 7: VALIDATE T3 CONSTRAINTS
    # =========================================================================
    print("\n" + "-"*60)
    print("STEP 7: Validating T3 constraints...")
    print("-"*60)
    
    # Constraint 1: T3 qty must be > T2 qty
    invalid_t3_qty = (df_work['ws_qty'].notna() & 
                      df_work['tier_2_qty'].notna() & 
                      (df_work['ws_qty'] <= df_work['tier_2_qty']))
    if invalid_t3_qty.sum() > 0:
        # Fix: Set T3 qty = T2 qty * 10
        df_work.loc[invalid_t3_qty, 'ws_qty'] = (
            df_work.loc[invalid_t3_qty, 'tier_2_qty'] * 10
        ).astype(int)
        print(f"  Fixed {invalid_t3_qty.sum()} SKUs where T3 qty <= T2 qty")
    
    # Constraint 2: T3 price must be < T2 price (T3 discount > T2 discount)
    # This is already enforced in calculate_wholesale_tier, but double-check
    invalid_t3_price = (df_work['ws_price'].notna() & 
                        df_work['tier_2_price'].notna() & 
                        (df_work['ws_price'] >= df_work['tier_2_price']))
    if invalid_t3_price.sum() > 0:
        # Invalidate T3 for these rows
        df_work.loc[invalid_t3_price, 'ws_qty'] = np.nan
        df_work.loc[invalid_t3_price, 'ws_price'] = np.nan
        df_work.loc[invalid_t3_price, 'ws_discount_pct'] = np.nan
        print(f"  Invalidated {invalid_t3_price.sum()} SKUs where T3 price >= T2 price")
    
    print(f"  Final valid T3 count: {df_work['ws_price'].notna().sum()}")
    
    # =========================================================================
    # STEP 7.5: ADJUST TIER 2 QUANTITY IF RATIO IS TOO LOW
    # =========================================================================
    # If tier_2_qty/tier_1_qty < 1.3 and elasticity_ratio > 3, adjust tier_2_qty
    print("\n  Checking tier quantity ratios...")
    
    # Calculate discount ratios
    df_work['discount_1_pct'] = df_work['tier_1_disc_pct'] if 'tier_1_disc_pct' in df_work.columns else \
        ((df_work['packing_unit_price'] - df_work['tier_1_price']) / df_work['packing_unit_price'] * 100)
    df_work['discount_2_pct'] = df_work['tier_2_disc_pct'] if 'tier_2_disc_pct' in df_work.columns else \
        ((df_work['packing_unit_price'] - df_work['tier_2_price']) / df_work['packing_unit_price'] * 100)
    
    # Calculate ratios
    df_work['check_qty'] = df_work['tier_2_qty'] / df_work['tier_1_qty']
    df_work['discount_ratio'] = df_work['discount_2_pct'] / df_work['discount_1_pct'].replace(0, np.nan)
    df_work['elasticity_ratio'] = df_work['discount_ratio'] / df_work['check_qty'].replace(0, np.nan)
    df_work['target_qty_ratio'] = df_work['discount_ratio'] / 2
    df_work['target_tier_2_q'] = np.round(df_work['target_qty_ratio'] * df_work['tier_1_qty'])
    
    # Adjust tier_2_qty where check_qty < 1.3 and elasticity_ratio > 3
    adjustment_mask = (df_work['check_qty'] < 1.3) & (df_work['elasticity_ratio'] > 3)
    if adjustment_mask.sum() > 0:
        df_work.loc[adjustment_mask, 'tier_2_qty'] = df_work.loc[adjustment_mask, 'target_tier_2_q']
        print(f"  Adjusted tier_2_qty for {adjustment_mask.sum()} SKUs (low qty ratio + high elasticity)")
    
    # =========================================================================
    # STEP 8: APPLY KEEP_QD_TIERS FILTER & CALCULATE TIER FLAGS
    # =========================================================================
    print("\n" + "-"*60)
    print("STEP 8: Applying keep_qd_tiers filter and calculating tier flags...")
    print("-"*60)
    
    # Calculate discount percentages for T1 & T2
    df_work['tier_1_disc_pct'] = ((df_work['packing_unit_price'] - df_work['tier_1_price']) / df_work['packing_unit_price'] * 100).round(2)
    df_work['tier_2_disc_pct'] = ((df_work['packing_unit_price'] - df_work['tier_2_price']) / df_work['packing_unit_price'] * 100).round(2)
    
    # Apply keep_qd_tiers filter and calculate tier flags
    def apply_tier_filter(row):
        """Apply keep_qd_tiers filter and return tier flags."""
        keep_tiers = parse_keep_qd_tiers(row.get('keep_qd_tiers'))
        
        # If no tiers specified, default to all valid tiers
        if not keep_tiers:
            keep_tiers = ['T1', 'T2', 'T3']
        
        # Determine which tiers are valid after filtering
        t1_valid = ('T1' in keep_tiers and 
                    pd.notna(row.get('tier_1_qty')) and 
                    pd.notna(row.get('tier_1_disc_pct')) and 
                    row.get('tier_1_disc_pct', 0) > 0)
        
        t2_valid = ('T2' in keep_tiers and 
                    pd.notna(row.get('tier_2_qty')) and 
                    pd.notna(row.get('tier_2_disc_pct')) and 
                    row.get('tier_2_disc_pct', 0) > 0)
        
        t3_valid = ('T3' in keep_tiers and 
                    pd.notna(row.get('ws_qty')) and 
                    pd.notna(row.get('ws_discount_pct')) and 
                    row.get('ws_discount_pct', 0) > 0)
        
        return pd.Series({
            't1_f': int(t1_valid),
            't2_f': int(t2_valid),
            't3_f': int(t3_valid)
        })
    
    tier_flags = df_work.apply(apply_tier_filter, axis=1)
    df_work = pd.concat([df_work, tier_flags], axis=1)
    
    # Set invalid tier values to null
    # T1: if t1_f == 0, set tier_1_qty, tier_1_price, tier_1_disc_pct to null
    df_work.loc[df_work['t1_f'] == 0, 'tier_1_qty'] = np.nan
    df_work.loc[df_work['t1_f'] == 0, 'tier_1_price'] = np.nan
    df_work.loc[df_work['t1_f'] == 0, 'tier_1_disc_pct'] = np.nan
    
    # T2: if t2_f == 0, set tier_2_qty, tier_2_price, tier_2_disc_pct to null
    df_work.loc[df_work['t2_f'] == 0, 'tier_2_qty'] = np.nan
    df_work.loc[df_work['t2_f'] == 0, 'tier_2_price'] = np.nan
    df_work.loc[df_work['t2_f'] == 0, 'tier_2_disc_pct'] = np.nan
    
    # T3: if t3_f == 0, set ws_qty, ws_price, ws_discount_pct to null
    df_work.loc[df_work['t3_f'] == 0, 'ws_qty'] = np.nan
    df_work.loc[df_work['t3_f'] == 0, 'ws_price'] = np.nan
    df_work.loc[df_work['t3_f'] == 0, 'ws_discount_pct'] = np.nan
    
    # Calculate total tiers per SKU
    df_work['all_f'] = df_work['t1_f'] + df_work['t2_f'] + df_work['t3_f']
    
    # Only keep SKUs with at least 2 valid tiers
    df_work = df_work[df_work['all_f'] >= 2].copy()
    
    print(f"  SKUs with valid tiers after filtering: {len(df_work)}")
    print(f"  Total tier entries: {df_work['all_f'].sum()}")
    print(f"    T1 valid: {df_work['t1_f'].sum()}")
    print(f"    T2 valid: {df_work['t2_f'].sum()}")
    print(f"    T3 valid: {df_work['t3_f'].sum()}")
    
    # =========================================================================
    # STEP 9: SELECT TOP TIERS PER WAREHOUSE (MAX 400 TIER ENTRIES)
    # =========================================================================
    print("\n" + "-"*60)
    print(f"STEP 9: Selecting top {TOP_SKUS_PER_WAREHOUSE} tier entries per warehouse...")
    print("-"*60)
    
    # Calculate ranking score: mtd_qty * effective_price (higher is better)
    df_work['mtd_qty'] = df_work['mtd_qty'].fillna(0)
    df_work['ranking_score'] = df_work['mtd_qty'] * df_work['effective_price']
    
    # Sort by warehouse and ranking score (descending)
    df_work = df_work.sort_values(['warehouse_id', 'ranking_score'], ascending=[True, False])
    
    # Calculate cumulative sum of tier entries per warehouse
    df_work['cumsum'] = df_work.groupby('warehouse_id')['all_f'].cumsum()
    
    # Filter to keep cumsum <= 400 (max 400 tier entries per warehouse)
    df_top = df_work[df_work['cumsum'] <= TOP_SKUS_PER_WAREHOUSE].copy()
    
    print(f"  Before filtering: {len(df_work)} SKUs ({df_work['all_f'].sum()} tier entries)")
    print(f"  After top {TOP_SKUS_PER_WAREHOUSE} limit: {len(df_top)} SKUs ({df_top['all_f'].sum()} tier entries)")
    print(f"\n  Tier entries per warehouse:")
    for wh in df_top['warehouse_id'].unique():
        wh_data = df_top[df_top['warehouse_id'] == wh]
        print(f"    Warehouse {wh}: {len(wh_data)} SKUs, {wh_data['all_f'].sum()} tiers")
    
    # =========================================================================
    # STEP 10: BUILD QD CONFIGURATIONS
    # =========================================================================
    print("\n" + "-"*60)
    print("STEP 10: Building QD configurations...")
    print("-"*60)
    
    qd_configs = []
    
    for _, row in df_top.iterrows():
        # Build tiers based on tier flags
        tiers = []
        
        # Tier 1
        if row['t1_f'] == 1:
            tiers.append({
                "tier": 1,
                "quantity": int(row['tier_1_qty']),
                "discount_pct": float(row['tier_1_disc_pct'])
            })
        
        # Tier 2
        if row['t2_f'] == 1:
            tiers.append({
                "tier": 2,
                "quantity": int(row['tier_2_qty']),
                "discount_pct": float(row['tier_2_disc_pct'])
            })
        
        # Tier 3 (Wholesale)
        if row['t3_f'] == 1:
            tiers.append({
                "tier": 3,
                "quantity": int(row['ws_qty']),
                "discount_pct": float(row['ws_discount_pct'])
            })
        
        qd_configs.append({
            'product_id': int(row['product_id']),
            'warehouse_id': int(row['warehouse_id']),
            'cohort_id': int(row.get('cohort_id', 0)),
            'packing_unit_id': int(row['packing_unit_id']),
            'tiers': tiers,
            'sku': row.get('sku', 'N/A'),
            'packing_unit_price': row['packing_unit_price'],
            'tier_1_price': row.get('tier_1_price'),
            'tier_2_price': row.get('tier_2_price'),
            'ws_price': row.get('ws_price'),
            'ranking_score': row.get('ranking_score', 0)
        })
    
    print(f"  Valid QD configs: {len(qd_configs)}")
    
    # Count tiers distribution
    tier_counts = {1: 0, 2: 0, 3: 0}
    for config in qd_configs:
        for t in config['tiers']:
            tier_counts[t['tier']] += 1
    print(f"\n  Tier distribution in configs:")
    print(f"    T1: {tier_counts[1]} configs")
    print(f"    T2: {tier_counts[2]} configs")
    print(f"    T3 (wholesale): {tier_counts[3]} configs")
    print(f"    Total tier entries: {sum(tier_counts.values())}")
    
    # Preview configs
    if qd_configs:
        print("\n  Sample QD configs:")
        for config in qd_configs[:5]:
            tier_str = ", ".join([f"T{t['tier']}:qty={t['quantity']},disc={t['discount_pct']:.2f}%" for t in config['tiers']])
            print(f"    {config['sku'][:30]}: [{tier_str}]")
    
    # =========================================================================
    # STEP 10.5: SAVE DATA FOR REVIEW BEFORE PUSH
    # =========================================================================
    print("\n" + "-"*60)
    print("STEP 10.5: Saving data for review...")
    print("-"*60)
    
    # Save detailed data to Excel for review
    review_filename = f'QD_detailed_review_{CAIRO_NOW.strftime("%Y%m%d_%H%M")}.xlsx'
    
    # Select important columns for review
    review_columns = [
        'warehouse_id', 'product_id', 'packing_unit_id', 'sku', 'brand', 'cat',
        'effective_price', 'packing_unit_price', 'wac_p', 'wac_pu', 'basic_unit_count',
        'tier_1_qty', 'tier_1_price', 'tier_1_disc_pct', 't1_f',
        'tier_2_qty', 'tier_2_price', 'tier_2_disc_pct', 't2_f',
        'ws_qty', 'ws_price', 'ws_discount_pct', 't3_f',
        'all_f', 'ranking_score', 'mtd_qty', 'keep_qd_tiers'
    ]
    # Filter to columns that exist
    review_columns = [c for c in review_columns if c in df_top.columns]
    
    df_review = df_top[review_columns].copy()
    df_review.to_excel(review_filename, index=False)
    print(f"  ✓ Saved review file: {review_filename}")
    print(f"    Total SKUs: {len(df_review)}")
    print(f"    Columns: {len(review_columns)}")
    
    # =========================================================================
    # STEP 11: CREATE NEW QUANTITY DISCOUNTS
    # =========================================================================
    print("\n" + "-"*60)
    print("STEP 11: Creating new Quantity Discounts...")
    print("-"*60)
    
    if len(qd_configs) == 0:
        print("  No Quantity Discounts to create.")
        create_result = {"success": True, "created_count": 0, "failed_count": 0, "errors": []}
    else:
        print(f"  Creating {len(qd_configs)} Quantity Discounts...")
        create_result = bulk_create_qd(qd_configs, df_top, dry_run=dry_run)
        
        print(f"\n  Creation Result:")
        print(f"    Created: {create_result['created_count']}")
        print(f"    Failed: {create_result['failed_count']}")
    
    # =========================================================================
    # STEP 12: UPDATE CART RULES
    # =========================================================================
    print("\n" + "-"*60)
    print("STEP 12: Updating cart rules...")
    print("-"*60)
    
    # Prepare cart rules update - cart rule should be >= max tier quantity
    cart_rules_update = prepare_cart_rules_update(df_top, df_qd)
    
    if len(cart_rules_update) == 0:
        print("  No cart rules need updating.")
        cart_rules_result = {'success': [], 'failed': []}
    else:
        print(f"  Uploading cart rules...")
        cart_rules_result = upload_cart_rules(cart_rules_update, dry_run=dry_run)
        
        print(f"\n  Cart Rules Result:")
        print(f"    Cohorts updated: {len(cart_rules_result['success'])}")
        print(f"    Cohorts failed: {len(cart_rules_result['failed'])}")
    
    # =========================================================================
    # SUMMARY
    # =========================================================================
    total_tiers = sum(tier_counts.values())
    
    print("\n" + "="*70)
    print("QD HANDLER - SUMMARY")
    print("="*70)
    print(f"Mode: {'DRY RUN (testing)' if dry_run else 'LIVE'}")
    print(f"Total SKUs in input: {len(df_qd)}")
    print(f"SKUs with valid T1 & T2 prices: {valid_t1_t2.sum()}")
    print(f"SKUs with valid T3 prices: {valid_t3.sum()}")
    print(f"SKUs after keep_qd_tiers & {TOP_SKUS_PER_WAREHOUSE} tier limit: {len(df_top)}")
    print(f"Total tier entries: {total_tiers}")
    print(f"Valid QD configs: {len(qd_configs)}")
    print(f"QD found active: {deactivate_result['total_active']}")
    print(f"QD deactivated: {len(deactivate_result['deactivated'])}")
    print(f"QD created: {create_result['created_count']}")
    print(f"QD creation failed: {create_result['failed_count']}")
    print(f"Cart rules updated: {len(cart_rules_update)} products")
    print("="*70)
    
    return {
        'mode': 'testing' if dry_run else 'live',
        'total_input': len(df_qd),
        'processed': create_result['created_count'],
        'failed': create_result['failed_count'],
        'total_tiers': total_tiers,
        'deactivate_result': deactivate_result,
        'create_result': create_result,
        'cart_rules_result': cart_rules_result,
        'cart_rules_update': cart_rules_update,
        'qd_configs': qd_configs,  # Return configs for inspection
        'df_work': df_top,  # Return working DataFrame for debugging
        'review_file': review_filename  # File saved for review before push
    }

print("✓ process_qd() function defined")


In [None]:
# =============================================================================
# HELPER FUNCTIONS
# =============================================================================

def parse_keep_qd_tiers(value):
    """Parse keep_qd_tiers from string or list."""
    # Handle list first (before pd.isna which fails on lists)
    if isinstance(value, list):
        return value
    # Now safe to check for None/NaN
    if value is None:
        return []
    try:
        if pd.isna(value):
            return []
    except (ValueError, TypeError):
        pass  # pd.isna fails on some types, continue
    if isinstance(value, str):
        try:
            return ast.literal_eval(value)
        except:
            return []
    return []


def build_tiers_from_row(row, keep_tiers: list) -> list:
    """
    Build tier configuration from row data.
    
    Args:
        row: DataFrame row with qd_tier_X_qty and qd_tier_X_disc_pct columns
        keep_tiers: List of tiers to include, e.g., ['T1', 'T2']
        
    Returns:
        List of tier configs for create_qd()
    """
    tiers = []
    
    tier_map = {
        'T1': (1, 'qd_tier_1_qty', 'qd_tier_1_disc_pct'),
        'T2': (2, 'qd_tier_2_qty', 'qd_tier_2_disc_pct'),
        'T3': (3, 'qd_tier_3_qty', 'qd_tier_3_disc_pct')
    }
    
    for tier_name in keep_tiers:
        if tier_name in tier_map:
            tier_num, qty_col, disc_col = tier_map[tier_name]
            qty = row.get(qty_col, 0)
            disc = row.get(disc_col, 0)
            
            # Only include if both qty and discount are valid
            if qty and qty > 0 and disc and disc > 0:
                tiers.append({
                    "tier": tier_num,
                    "quantity": int(qty),
                    "discount_pct": float(disc)
                })
    
    return tiers

print("Helper functions defined ✓")


In [None]:
# =============================================================================
# API FUNCTIONS
# =============================================================================

def deactivate_active_qd(dry_run: bool = True) -> dict:
    """
    Deactivate ALL active Quantity Discounts.
    
    This function:
    1. Queries Snowflake to get all currently active QD IDs
    2. Calls the API to deactivate each one
    
    Args:
        dry_run: If True, only log what would be done without making API calls
        
    Returns:
        dict with 'success', 'deactivated', 'failed', 'total_active'
    """
    print("\n" + "="*60)
    print("DEACTIVATING ACTIVE QUANTITY DISCOUNTS")
    print("="*60)
    print(f"Mode: {'DRY RUN' if dry_run else 'LIVE'}")
    
    # Step 1: Query Snowflake to get all active QD IDs
    
    print("\nStep 1: Querying active Quantity Discounts from Snowflake...")
    df_active =get_active_qd_now()
    
    if len(df_active) == 0:
        print("  No active Quantity Discounts found.")
        return {
            'success': True,
            'deactivated': [],
            'failed': [],
            'total_active': 0
        }
    
    discount_ids = df_active['discount_id'].tolist()
    print(f"  Found {len(discount_ids)} active Quantity Discounts")
    
    # Step 2: Deactivate each QD via API
    print(f"\nStep 2: Deactivating {len(discount_ids)} discounts...")
    
    results = {'deactivated': [], 'failed': []}
    
    # Get fresh API token
    if not dry_run:
        auth_token = _get_api_token()
        headers = {
            'Authorization': f'Bearer {auth_token}',
            'Content-Type': 'application/json'
        }
    
    for idx, discount_id in enumerate(discount_ids):
        if dry_run:
            print(f"  [{idx+1}/{len(discount_ids)}] [DRY RUN] Would deactivate: {discount_id}")
            results['deactivated'].append(discount_id)
            continue
        
        url = f"{QD_API_URL}{discount_id}/deactivate"
        
        try:
            response = requests.patch(url, headers=headers, json={'active': False})
            
            if response.status_code in [200, 204]:
                print(f"  [{idx+1}/{len(discount_ids)}] [OK] Deactivated: {discount_id}")
                results['deactivated'].append(discount_id)
            else:
                print(f"  [{idx+1}/{len(discount_ids)}] [ERROR] {discount_id}: {response.status_code} - {response.text[:100]}")
                results['failed'].append({
                    'id': discount_id, 
                    'error': f"{response.status_code}: {response.text[:200]}"
                })
        except Exception as e:
            print(f"  [{idx+1}/{len(discount_ids)}] [EXCEPTION] {discount_id}: {e}")
            results['failed'].append({'id': discount_id, 'error': str(e)})
        
        # Rate limiting - 0.5 second delay between requests
        time.sleep(0.5)
    
    # Summary
    print(f"\n{'='*60}")
    print("DEACTIVATION SUMMARY")
    print(f"{'='*60}")
    print(f"Total active found: {len(discount_ids)}")
    print(f"Successfully deactivated: {len(results['deactivated'])}")
    print(f"Failed: {len(results['failed'])}")
    
    if results['failed']:
        print("\nFailed IDs:")
        for item in results['failed'][:10]:  # Show first 10
            print(f"  - {item['id']}: {item['error']}")
        if len(results['failed']) > 10:
            print(f"  ... and {len(results['failed']) - 10} more")
    
    return {
        'success': len(results['failed']) == 0,
        'deactivated': results['deactivated'],
        'failed': results['failed'],
        'total_active': len(discount_ids)
    }


def create_upload_format(df_configs: pd.DataFrame) -> pd.DataFrame:
    """
    Create upload format DataFrame from QD configurations.
    
    Format: ONE row per warehouse_id with:
    - Discounts Group 1: List of [tier 1 items + wholesale items] (max 200, overflow goes to Group 2)
    - Discounts Group 2: List of [tier 2 items + overflow from Group 1]
    - Each item format: [product_id, packing_unit_id, quantity, discount_pct]
    
    Args:
        df_configs: DataFrame with columns: warehouse_id, product_id, packing_unit_id,
                   tier_1_qty, tier_1_disc_pct, tier_2_qty, tier_2_disc_pct,
                   ws_qty, ws_discount_pct, packing_unit_price, t1_f, t2_f, t3_f
                   
    Returns:
        DataFrame with upload format
    """
    final_quantity_discount = pd.DataFrame(columns=['warehouse_id', 'Discounts Group 1', 'Discounts Group 2', 'Description'])
    
    for wh_id in df_configs['warehouse_id'].unique():
        warehouse_data = df_configs[df_configs['warehouse_id'] == wh_id]
        warehouse_id = int(wh_id)
        
        # Collect all tier 1 items
        tier_1_items = []
        # Collect all tier 2 items
        tier_2_items = []
        # Collect all wholesale items
        ws_items = []
        
        for _, r in warehouse_data.iterrows():
            product_id = int(r['product_id'])
            packing_unit_id = int(r['packing_unit_id'])
            current_price = r['packing_unit_price']
            
            # Tier 1 (cap discount at MAX_DISCOUNT_CAP_T1)
            if r.get('t1_f', 0) == 1 and pd.notna(r.get('tier_1_qty')) and pd.notna(r.get('tier_1_disc_pct')):
                q_1 = int(r['tier_1_qty'])
                d_1 = min(round(r['tier_1_disc_pct'], 2), MAX_DISCOUNT_CAP_T1)
                tier_1_items.append([product_id, packing_unit_id, q_1, d_1])
            
            # Tier 2 (cap discount at MAX_DISCOUNT_CAP_T2)
            if r.get('t2_f', 0) == 1 and pd.notna(r.get('tier_2_qty')) and pd.notna(r.get('tier_2_disc_pct')):
                q_2 = int(r['tier_2_qty'])
                d_2 = min(round(r['tier_2_disc_pct'], 2), MAX_DISCOUNT_CAP_T2)
                tier_2_items.append([product_id, packing_unit_id, q_2, d_2])
            
            # Wholesale (cap discount at MAX_DISCOUNT_CAP_WS)
            if r.get('t3_f', 0) == 1 and pd.notna(r.get('ws_qty')) and pd.notna(r.get('ws_discount_pct')):
                q_ws = int(r['ws_qty'])
                d_ws = min(round(r['ws_discount_pct'], 2), MAX_DISCOUNT_CAP_WS)
                ws_items.append([product_id, packing_unit_id, q_ws, d_ws])
        
        # Group 1: Tier 1 + Wholesale (max 200)
        group_1_items = tier_1_items + ws_items
        
        # Group 2: Tier 2 + overflow from Group 1
        if len(group_1_items) > MAX_GROUP_SIZE:
            # Overflow goes to Group 2
            overflow = group_1_items[MAX_GROUP_SIZE:]
            group_1_items = group_1_items[:MAX_GROUP_SIZE]
            group_2_items = tier_2_items + overflow
        else:
            group_2_items = tier_2_items
        
        new_row = {
            'warehouse_id': warehouse_id,
            'Discounts Group 1': group_1_items,
            'Discounts Group 2': group_2_items,
            'Description': f'{warehouse_id}QD'
        }
        final_quantity_discount = pd.concat([final_quantity_discount, pd.DataFrame([new_row])], ignore_index=True)
    
    return final_quantity_discount


def prepare_upload_file(df_upload: pd.DataFrame, dry_run: bool = True) -> tuple:
    """
    Prepare the final upload file with tag IDs and date/time.
    
    Args:
        df_upload: DataFrame from create_upload_format()
        dry_run: If True, only prepare but don't save
        
    Returns:
        tuple: (prepared_df, filename)
    """
    # Merge with warehouse mapping
    df_mapping = pd.DataFrame([
        {'warehouse_id': wh_id, 'warehouse_name': info['name'], 'tag_id': info['tag_id']}
        for wh_id, info in WAREHOUSE_TAG_MAPPING.items()
    ])
    
    to_upload = df_upload.merge(df_mapping, on='warehouse_id', how='left')
    
    # Set description
    to_upload['Description'] = (
        to_upload['warehouse_name'].astype(str)
        .str.replace(' ', '')
        .str.replace('-', '')
        + "QD"
    )
    
    # Set start and end dates
    cairo_now = datetime.now(CAIRO_TZ)
    start_date = cairo_now + timedelta(minutes=10)
    end_date = cairo_now + timedelta(hours=QD_DURATION_HOURS)
    
    start_date_str = start_date.strftime('%d/%m/%Y %H:%M')
    end_date_str = end_date.strftime('%d/%m/%Y %H:%M')
    
    to_upload['Start Date/Time'] = start_date_str
    to_upload['End Date/Time'] = end_date_str
    to_upload = to_upload.rename(columns={'tag_id': 'Tag ID'})
    
    # Select final columns
    to_upload = to_upload[['Tag ID', 'Description', 'Start Date/Time', 'End Date/Time', 'Discounts Group 1', 'Discounts Group 2']]
    
    # Remove rows without Tag ID
    to_upload = to_upload[to_upload['Tag ID'].notna()]
    
    filename = 'QD_upload.xlsx'
    
    if not dry_run:
        to_upload.to_excel(filename, index=False)
        print(f"  ✓ Saved upload file: {filename} ({len(to_upload)} warehouses)")
    
    return to_upload, filename


def post_QD(filename: str) -> requests.Response:
    """
    Upload QD file to API.
    
    Args:
        filename: Path to the Excel file to upload
        
    Returns:
        Response object from the API
    """
    auth_token = _get_api_token()
    
    url = 'https://api.maxab.app/commerce/api/admins/v1/quantity-discounts/upload'
    
    headers = {
        'Authorization': f'Bearer {auth_token}'
    }
    
    with open(filename, 'rb') as f:
        files = {'file': (filename, f, 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet')}
        response = requests.post(url, headers=headers, files=files)
    
    return response


def bulk_create_qd(qd_configs: list, df_work: pd.DataFrame, dry_run: bool = True) -> dict:
    """
    Bulk create Quantity Discounts using file upload method.
    
    Args:
        qd_configs: List of QD configuration dicts (for logging)
        df_work: Working DataFrame with all tier data
        dry_run: If True, only log what would be done
            
    Returns:
        dict with 'success', 'created_count', 'failed_count', 'errors'
    """
    print("\n  Creating upload format...")
    
    # Create upload format
    df_upload = create_upload_format(df_work)
    
    print(f"  Upload format created: {len(df_upload)} warehouse rows")
    print(f"\n  Per warehouse breakdown:")
    for _, row in df_upload.iterrows():
        wh = row['warehouse_id']
        g1_count = len(row['Discounts Group 1'])
        g2_count = len(row['Discounts Group 2'])
        print(f"    WH {wh}: Group 1 = {g1_count} items, Group 2 = {g2_count} items")
    
    # Prepare upload file
    print("\n  Preparing upload file...")
    to_upload, filename = prepare_upload_file(df_upload, dry_run=dry_run)
    
    if dry_run:
        print(f"\n  [DRY RUN] Would upload {len(to_upload)} warehouses")
        return {
            "success": True,
            "created_count": len(qd_configs),
            "failed_count": 0,
            "errors": [],
            "upload_df": to_upload
        }
    
    # Upload to API
    print(f"\n  Uploading QD file to API...")
    response = post_QD(filename)
    
    if response.ok:
        print(f"  ✓ Upload succeeded (status: {response.status_code})")
        return {
            "success": True,
            "created_count": len(qd_configs),
            "failed_count": 0,
            "errors": [],
            "upload_df": to_upload
        }
    else:
        print(f"  ❌ Upload failed (status: {response.status_code})")
        print(f"  Response: {response.content[:500]}")
        return {
            "success": False,
            "created_count": 0,
            "failed_count": len(qd_configs),
            "errors": [{"error": f"API upload failed: {response.status_code}"}],
            "upload_df": to_upload
        }


def prepare_cart_rules_update(df_work: pd.DataFrame, df_qd_input: pd.DataFrame) -> pd.DataFrame:
    """
    Prepare cart rules update based on QD tier quantities.
    
    Cart rule should be >= max(tier_1_qty, tier_2_qty, ws_qty) for each SKU.
    
    Args:
        df_work: Working DataFrame with tier quantities
        df_qd_input: Original input DataFrame with current_cart_rule and new_cart_rule
        
    Returns:
        DataFrame with cart rules to update
    """
    # Merge cart rules from original input to working df using product_id and warehouse_id
    cart_cols = ['product_id', 'warehouse_id']
    if 'current_cart_rule' in df_qd_input.columns:
        cart_cols.append('current_cart_rule')
    if 'new_cart_rule' in df_qd_input.columns:
        cart_cols.append('new_cart_rule')
    
    df_cart_merge = df_qd_input[cart_cols].drop_duplicates()
    df_work_cart = df_work.merge(df_cart_merge, on=['product_id', 'warehouse_id'], how='left')
    
    # Use new_cart_rule if available, otherwise current_cart_rule
    if 'new_cart_rule' in df_work_cart.columns:
        df_work_cart['effective_cart_rule'] = df_work_cart['new_cart_rule'].fillna(
            df_work_cart.get('current_cart_rule', 0)
        )
    else:
        df_work_cart['effective_cart_rule'] = df_work_cart.get('current_cart_rule', 0)
    
    df_work_cart['effective_cart_rule'] = df_work_cart['effective_cart_rule'].fillna(0)
    
    # Calculate max tier quantity for each SKU
    tier_cols = ['tier_1_qty', 'tier_2_qty', 'ws_qty']
    tier_cols = [c for c in tier_cols if c in df_work_cart.columns]
    df_work_cart['max_tier_qty'] = df_work_cart[tier_cols].max(axis=1, skipna=True)
    
    # Only update cart rules that need to increase
    needs_update = df_work_cart['max_tier_qty'] > df_work_cart['effective_cart_rule']
    cart_rules_update = df_work_cart[needs_update][['cohort_id', 'product_id', 'packing_unit_id', 'max_tier_qty']].copy()
    cart_rules_update = cart_rules_update.rename(columns={'max_tier_qty': 'new_cart_rule'})
    
    # Round cart rules and ensure they're integers
    cart_rules_update['new_cart_rule'] = cart_rules_update['new_cart_rule'].round().astype(int)
    
    # Deduplicate by taking max per cohort/product/packing_unit
    cart_rules_update = cart_rules_update.groupby(['cohort_id', 'product_id', 'packing_unit_id'])['new_cart_rule'].max().reset_index()
    cart_rules_update = cart_rules_update.drop_duplicates()
    
    return cart_rules_update


def post_cart_rules(cohort_id: int, filename: str) -> requests.Response:
    """
    Upload cart rules file to API for a specific cohort.
    
    Args:
        cohort_id: The cohort ID to update
        filename: Path to the Excel file to upload
        
    Returns:
        Response object from the API
    """
    auth_token = _get_api_token()
    
    url = f'https://api.maxab.app/commerce/api/admins/v1/cohorts/{cohort_id}/cart-rules/upload'
    
    headers = {
        'Authorization': f'Bearer {auth_token}'
    }
    
    with open(filename, 'rb') as f:
        files = {'file': (filename, f, 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet')}
        response = requests.post(url, headers=headers, files=files)
    
    return response


def upload_cart_rules(cart_rules_update: pd.DataFrame, dry_run: bool = True) -> dict:
    """
    Upload cart rules updates by cohort.
    
    Args:
        cart_rules_update: DataFrame with cohort_id, product_id, packing_unit_id, new_cart_rule
        dry_run: If True, only log what would be done
        
    Returns:
        dict with upload results
    """
    results = {'success': [], 'failed': []}
    
    print(f"\n  Cart rules to update: {len(cart_rules_update)} products across {cart_rules_update['cohort_id'].nunique()} cohorts")
    
    for cohort in cart_rules_update['cohort_id'].unique():
        req_data = cart_rules_update[cart_rules_update['cohort_id'] == cohort].copy()
        
        if len(req_data) > 0:
            # Prepare data for upload
            req_data = req_data[['product_id', 'packing_unit_id', 'new_cart_rule']]
            req_data.columns = ['Product ID', 'Packing Unit ID', 'Cart Rules']
            
            filename = f'CartRules_{cohort}.xlsx'
            
            if dry_run:
                print(f"    [DRY RUN] Cohort {cohort}: Would upload {len(req_data)} rules")
                results['success'].append(cohort)
                continue
            
            # Save and upload
            req_data.to_excel(filename, index=False, engine='openpyxl')
            
            time.sleep(2)  # Rate limiting
            response = post_cart_rules(cohort, filename)
            
            if response.ok:
                print(f"    ✓ Cohort {cohort}: {len(req_data)} rules uploaded")
                results['success'].append(cohort)
            else:
                print(f"    ❌ Cohort {cohort}: Upload failed ({response.status_code})")
                results['failed'].append({'cohort_id': cohort, 'error': response.content[:200]})
    
    return results

print("✓ API functions defined")


In [None]:
# =============================================================================
# STANDALONE FUNCTIONS (can be called independently)
# =============================================================================
# Use deactivate_active_qd() directly if you only need to deactivate QDs
# without creating new ones.
#
# Example:
#   result = deactivate_active_qd(dry_run=True)
#   print(f"Deactivated: {len(result['deactivated'])} QDs")

print("✓ QD Handler ready to use")
print("\nAvailable functions:")
print("  - process_qd(df_qd, dry_run=True)      : Main function to process QDs from Module 3")
print("  - deactivate_active_qd(dry_run=True)   : Deactivate all active QDs")
print("  - create_upload_format(df_configs)     : Create upload format DataFrame")
print("  - prepare_upload_file(df_upload, ...)  : Prepare final upload file with tag IDs")
print("  - post_QD(filename)                    : Upload QD file to API")
print("  - prepare_cart_rules_update(df_work, df_qd) : Prepare cart rules update")
print("  - upload_cart_rules(cart_rules, ...)   : Upload cart rules by cohort")


In [None]:
# =============================================================================
# USAGE EXAMPLE (for testing)
# =============================================================================
# Uncomment below to test with sample data:
#
# sample_df = pd.DataFrame({
#     'product_id': [12345, 67890],
#     'warehouse_id': [9, 625],
#     'cohort_id': [3304, 3305],
#     'sku': ['Test Product 1', 'Test Product 2'],
#     'keep_qd_tiers': [['T1', 'T2'], ['T1']]
# })
# 
# result = process_qd(sample_df, dry_run=True)
# print(result)


In [None]:
# Cell kept for potential future use


In [None]:
# Cell kept for potential future use
