<a href="https://colab.research.google.com/github/anjli01/PySpark-Notes/blob/main/09_Aggregations_%26_GroupBy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. Aggregations & GroupBy

Aggregations are fundamental operations to summarize data, typically performed on groups of rows. The `groupBy()` function is used to define these groups.

### Core Concepts & Functions:

*   **`groupBy(*cols)`**: Groups a DataFrame by one or more specified columns. It returns a `GroupedData` object, on which aggregation functions can be applied. If no `groupBy()` is used, aggregations apply to the entire DataFrame.
*   **`agg(*exprs)`**: The most flexible way to apply one or more aggregation functions to grouped data (or the entire DataFrame). You can pass standard aggregate functions (e.g., `count()`, `sum()`) or define custom ones.
    *   **Renaming Output Columns**: Always use `.alias("new_column_name")` with aggregation expressions within `agg()` for clear and meaningful output column names.
*   **Direct Aggregation Functions**: These are shorthand methods directly available on a `GroupedData` object for common aggregations:
    *   `count()`: Returns the number of items in each group.
    *   `sum(col)`: Calculates the sum of values in a numeric column for each group.
    *   `avg(col)`: Computes the average of values in a numeric column for each group.
    *   `min(col)`: Finds the minimum value in a column for each group.
    *   `max(col)`: Finds the maximum value in a column for each group.
*   **Source**: All aggregate functions (like `count`, `sum`, `avg`, `min`, `max`) are available in `pyspark.sql.functions`.

### Why it's important for Data Engineers:

*   **Data Summarization**: Crucial for generating reports, dashboards, and key performance indicators (KPIs).
*   **Feature Engineering**: Aggregations are often used to create new features for machine learning models (e.g., total sales per customer, average transaction value).
*   **Data Validation**: Quickly verify data integrity by checking counts, sums, or averages.

### Example (Python):

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count, sum, avg, min, max

# 1. Initialize Spark Session
spark = SparkSession.builder.appName("Aggregations").getOrCreate()

# 2. Create Sample Data
data = [("A", "Sales", 1000),
        ("B", "Sales", 1500),
        ("A", "HR", 800),
        ("C", "IT", 2000),
        ("B", "HR", 1200),
        ("C", "Sales", 1800)]
columns = ["Employee", "Department", "Salary"]
df = spark.createDataFrame(data, columns)
df.show()

# --- Single Aggregations ---

print("\n--- Count of employees per Department ---")
df.groupBy("Department").count().show()

print("\n--- Sum of salary per Department ---")
df.groupBy("Department").sum("Salary").show()

print("\n--- Average salary per Department ---")
df.groupBy("Department").avg("Salary").show()

print("\n--- Min and Max salary per Department ---")
# Note: min and max can be chained, but it's clearer with .agg() for multiple
df.groupBy("Department").min("Salary").show() # Shows only min
df.groupBy("Department").max("Salary").show() # Shows only max


print("\n--- Using agg() for single aggregation (equivalent to direct agg function) ---")
df.groupBy("Department").agg(sum("Salary").alias("TotalSalary")).show()

# --- Aggregating over multiple columns ---

data_multi = [("A", "Sales", "NY", 1000),
              ("B", "Sales", "LD", 1500),
              ("A", "HR", "NY", 800),
              ("C", "IT", "SF", 2000),
              ("B", "HR", "LD", 1200),
              ("C", "Sales", "NY", 1800),
              ("A", "Sales", "NY", 900)]
columns_multi = ["Employee", "Department", "City", "Salary"]
df_multi = spark.createDataFrame(data_multi, columns_multi)
df_multi.show()

print("\n--- Count of employees per Department and City ---")
df_multi.groupBy("Department", "City").count().show()

print("\n--- Sum and average of salary per Department and City ---")
df_multi.groupBy("Department", "City") \
    .agg(sum("Salary").alias("TotalSalary"),
         avg("Salary").alias("AverageSalary")) \
    .show()

# --- Aggregating on entire DataFrame (no groupBy) ---

print("\n--- Aggregations on entire DataFrame ---")
df.agg(count("Employee").alias("TotalEmployees"),
       sum("Salary").alias("GrandTotalSalary"),
       avg("Salary").alias("OverallAverageSalary")).show()

spark.stop()

+--------+----------+------+
|Employee|Department|Salary|
+--------+----------+------+
|       A|     Sales|  1000|
|       B|     Sales|  1500|
|       A|        HR|   800|
|       C|        IT|  2000|
|       B|        HR|  1200|
|       C|     Sales|  1800|
+--------+----------+------+


--- Count of employees per Department ---
+----------+-----+
|Department|count|
+----------+-----+
|     Sales|    3|
|        HR|    2|
|        IT|    1|
+----------+-----+


--- Sum of salary per Department ---
+----------+-----------+
|Department|sum(Salary)|
+----------+-----------+
|     Sales|       4300|
|        HR|       2000|
|        IT|       2000|
+----------+-----------+


--- Average salary per Department ---
+----------+------------------+
|Department|       avg(Salary)|
+----------+------------------+
|     Sales|1433.3333333333333|
|        HR|            1000.0|
|        IT|            2000.0|
+----------+------------------+


--- Min and Max salary per Department ---
+----------

---

## 2. Pivot Operations

Pivot operations (also known as cross-tabulations) transform data by turning unique values from one column into new columns. This is useful for reshaping data from a "long" format to a "wide" format.

### Steps for Pivoting:

1.  **`groupBy(*cols)`**: Group the DataFrame by the column(s) that will remain as rows in the pivoted output.
2.  **`pivot(pivot_column, [values])`**: Specify the column whose unique values will become new columns.
    *   Optionally, you can provide a list of specific `values` to pivot on. This is highly recommended for performance and to ensure consistent schema, especially when the number of unique pivot values is large or dynamic.
3.  **`agg(agg_function)`**: Apply an aggregation function to populate the new pivoted columns.

### Why it's important for Data Engineers:

*   **Report Generation**: Creating summary tables where categories are presented as columns (e.g., sales per product per region).
*   **Data Reshaping**: Transforming data into a format suitable for specific analytical tools or models that expect a "wide" table.
*   **Comparisons**: Easily compare metrics across different categories side-by-side.

### Example (Python):

In [2]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import sum, avg

# 1. Initialize Spark Session
spark = SparkSession.builder.appName("PivotOperations").getOrCreate()

# 2. Create Sample Data
data = [("RegionA", "ProductX", 100),
        ("RegionA", "ProductY", 150),
        ("RegionB", "ProductX", 200),
        ("RegionB", "ProductY", 120),
        ("RegionA", "ProductX", 50),
        ("RegionC", "ProductY", 300)]
columns = ["Region", "Product", "Sales"]
df = spark.createDataFrame(data, columns)
df.show()

print("\n--- Pivot on 'Product' to show sales per region per product ---")
df.groupBy("Region").pivot("Product").agg(sum("Sales")).show()

print("\n--- Pivot with specific product values (recommended for performance/consistency) ---")
# Explicitly specify pivot values to avoid potential issues with high cardinality
# and ensure all desired columns are present even if a product has no sales in a region.
df.groupBy("Region").pivot("Product", ["ProductX", "ProductY", "ProductZ"]).agg(sum("Sales")).show()

print("\n--- Pivoting with another aggregation (e.g., average) ---")
df.groupBy("Region").pivot("Product").agg(avg("Sales")).show()

spark.stop()

# Key Note:
# Multiple aggregations with pivot() are not directly supported in a single call.
# For multiple aggregates (e.g., sum and avg for each pivoted column),
# you would typically perform separate pivots and then join the results,
# or use advanced techniques like `cube` or `rollup` (which are different but related to multi-dimensional aggregations).

+-------+--------+-----+
| Region| Product|Sales|
+-------+--------+-----+
|RegionA|ProductX|  100|
|RegionA|ProductY|  150|
|RegionB|ProductX|  200|
|RegionB|ProductY|  120|
|RegionA|ProductX|   50|
|RegionC|ProductY|  300|
+-------+--------+-----+


--- Pivot on 'Product' to show sales per region per product ---
+-------+--------+--------+
| Region|ProductX|ProductY|
+-------+--------+--------+
|RegionB|     200|     120|
|RegionA|     150|     150|
|RegionC|    NULL|     300|
+-------+--------+--------+


--- Pivot with specific product values (recommended for performance/consistency) ---
+-------+--------+--------+--------+
| Region|ProductX|ProductY|ProductZ|
+-------+--------+--------+--------+
|RegionB|     200|     120|    NULL|
|RegionA|     150|     150|    NULL|
|RegionC|    NULL|     300|    NULL|
+-------+--------+--------+--------+


--- Pivoting with another aggregation (e.g., average) ---
+-------+--------+--------+
| Region|ProductX|ProductY|
+-------+--------+--------