In [0]:
# COMMAND ----------
# ✅ 1. Define SalesAgent (handles sales_transactions)
class SalesAgent:
    def __init__(self, table_name="sales_transactions"):
        # both names for compatibility
        self.table_name = table_name
        self.table = table_name

    def get_revenue_by_region(self):
        query = f"""
        SELECT
            region,
            SUM(revenue) AS total_revenue,
            COUNT(*) AS total_orders,
            ROUND(AVG(revenue),2) AS avg_order_value
        FROM
            {self.table}
        GROUP BY
            region
        ORDER BY
            total_revenue DESC
        """
        return spark.sql(query)

    def get_sales_by_category(self):
        q = f"""
        SELECT
            category AS product_category,
            SUM(revenue) AS total_revenue,
            COUNT(*) AS total_orders,
            ROUND(AVG(revenue),2) AS avg_order_value
        FROM {self.table}
        GROUP BY category
        ORDER BY total_revenue DESC
        """
        return spark.sql(q)

    def top_products(self, top_n=5):
        q = f"""
        SELECT
            product,
            category,
            SUM(revenue) AS total_revenue,
            COUNT(*) AS total_orders
        FROM {self.table}
        GROUP BY product, category
        ORDER BY total_revenue DESC
        LIMIT {int(top_n)}
        """
        return spark.sql(q)

# COMMAND ----------
# ✅ 3. Define CustomerAgent (handles customer_behavior)
class CustomerAgent:
    def __init__(self, table_name="customer_behavior"):
        self.table_name = table_name
        self.table = table_name

    def churn_by_region(self):
        sql = f"""
            SELECT region,
                   churn_risk,
                   COUNT(*) AS customer_count
            FROM {self.table}
            GROUP BY region, churn_risk
            ORDER BY region
        """
        return spark.sql(sql)

    def avg_lifetime_by_segment(self):
        sql = f"""
            SELECT segment,
                   ROUND(AVG(lifetime_value),2) AS avg_lifetime_value,
                   COUNT(*) AS num_customers
            FROM {self.table}
            GROUP BY segment
            ORDER BY avg_lifetime_value DESC
        """
        return spark.sql(sql)

    def churn_summary(self):
        sql = f"""
            SELECT churn_risk,
                   COUNT(*) AS customer_count,
                   ROUND(100.0 * COUNT(*) / SUM(COUNT(*)) OVER (),2) AS percentage
            FROM {self.table}
            GROUP BY churn_risk
            ORDER BY customer_count DESC
        """
        return spark.sql(sql)

# COMMAND ----------
# ✅ 4. Define CoordinatorAgent (combines both agents)
class CoordinatorAgent:
    def __init__(self, sales_agent, customer_agent):
        self.sales = sales_agent
        self.cust = customer_agent

    def regions_high_revenue_high_churn(self):
        df_rev = self.sales.get_revenue_by_region()
        df_rev.createOrReplaceTempView("rev")

        df_churn = spark.sql(f"""
            SELECT region,
                   SUM(CASE WHEN churn_risk='High' THEN 1 ELSE 0 END) AS churn_count,
                   COUNT(*) AS total_customers,
                   ROUND(100.0 * SUM(CASE WHEN churn_risk='High' THEN 1 ELSE 0 END) / COUNT(*),2) AS churn_pct
            FROM {self.cust.table}
            GROUP BY region
        """)
        df_churn.createOrReplaceTempView("churn")

        joined = spark.sql("""
            SELECT r.region,
                   r.total_revenue,
                   COALESCE(c.churn_pct,0) AS churn_pct
            FROM rev r
            LEFT JOIN churn c ON lower(r.region)=lower(c.region)
            ORDER BY r.total_revenue DESC
        """)

        return joined

    def compare_sales_with_segments(self):
        df_rev = self.sales.get_revenue_by_region()
        df_seg = self.cust.avg_lifetime_by_segment()

        df_rev.createOrReplaceTempView("rev")
        df_seg.createOrReplaceTempView("seg")

        merged = spark.sql("""
            SELECT r.region, s.segment, s.avg_lifetime_value, r.total_revenue
            FROM rev r
            CROSS JOIN seg s
        """)
        return merged

    def products_for_premium(self, premium_segment_name="Premium"):
        df_premium = spark.sql(f"""
            SELECT region
            FROM {self.cust.table}
            WHERE lower(segment)=lower('{premium_segment_name}')
        """)
        regions = [r.region for r in df_premium.collect()]

        if not regions:
            return spark.sql(f"SELECT * FROM {self.sales.table} LIMIT 0")  # empty DF

        region_list = ", ".join([f"'{r}'" for r in regions])
        df = spark.sql(f"""
            SELECT category AS product_category,
                   SUM(revenue) AS total_revenue,
                   COUNT(*) AS total_orders
            FROM {self.sales.table}
            WHERE region IN ({region_list})
            GROUP BY category
            ORDER BY total_revenue DESC
        """)
        return df

# COMMAND ----------
# ✅ 5. Instantiate all agents
agent = SalesAgent(table_name="sales_transactions")
agent_cust = CustomerAgent(table_name="customer_behavior")
coord = CoordinatorAgent(agent, agent_cust)

print("✅ Agents initialized successfully.")

# COMMAND ----------
# ✅ 6. Run Task 2.2 Queries (Case Study Questions)

print("🔹 1. Regions with high revenue but high customer churn:")
display(coord.regions_high_revenue_high_churn())

print("🔹 2. Compare sales performance with customer segments:")
display(coord.compare_sales_with_segments())

print("🔹 3. Product categories that appeal to Premium customers:")
display(coord.products_for_premium(premium_segment_name="Premium"))


Environment ready.
✅ Agents initialized successfully.
🔹 1. Regions with high revenue but high customer churn:


region,total_revenue,churn_pct
North,2525.0,0.0
South,750.0,100.0


🔹 2. Compare sales performance with customer segments:


region,segment,avg_lifetime_value,total_revenue
North,Standard,3000.0,2525.0
South,Premium,13500.0,750.0
North,Premium,13500.0,2525.0
South,Standard,3000.0,750.0


🔹 3. Product categories that appeal to Premium customers:


product_category,total_revenue,total_orders
Electronics,2525.0,2


In [0]:
query = """
SELECT s.region,
       SUM(s.revenue) AS total_revenue,
       COUNT(*) AS total_orders
FROM sales_transactions s
JOIN customer_behavior c
  ON lower(s.region) = lower(c.region)
WHERE lower(c.segment) = 'premium'
GROUP BY s.region
ORDER BY total_revenue DESC
"""
display(spark.sql(query))


region,total_revenue,total_orders
North,5050.0,4
