In [2]:
from ai_fix_helper import ask_openai_fix

In [4]:
import traceback


def get_concise_error_message(e):
    """
    Extract the most relevant single-line error message for display and AI use.
    """
    tb_str = traceback.format_exception_only(type(e), e)
    return tb_str[-1].strip() if tb_str else str(e)


In [6]:
!pip install findspark



In [7]:
from spark_init import start_spark

spark = start_spark("Spark Deep Learning Example")

data = [("Apple", 10), ("Banana", 5), ("Orange", 8)]
df = spark.createDataFrame(data, ["Fruit", "Quantity"])
df.show()



✅ Spark session started: Spark Deep Learning Example
+------+--------+
| Fruit|Quantity|
+------+--------+
| Apple|      10|
|Banana|       5|
|Orange|       8|
+------+--------+



In [8]:
# ✅ Sample data
data = [
    ("Alice", "USA", 34),
    ("Bob", "USA", 45),
    ("Catherine", "UK", 29),
    ("David", "India", 34),
    ("Emily", "India", 21),
    ("Frank", "UK", 45)
]

columns = ["Name", "Country", "Age"]
df = spark.createDataFrame(data, columns)
# ✅ Show the data

print(type(df))

# ✅ Show the data
df.show()
df.printSchema()

<class 'pyspark.sql.dataframe.DataFrame'>
+---------+-------+---+
|     Name|Country|Age|
+---------+-------+---+
|    Alice|    USA| 34|
|      Bob|    USA| 45|
|Catherine|     UK| 29|
|    David|  India| 34|
|    Emily|  India| 21|
|    Frank|     UK| 45|
+---------+-------+---+

root
 |-- Name: string (nullable = true)
 |-- Country: string (nullable = true)
 |-- Age: long (nullable = true)



In [9]:
from pyspark.sql.functions import col, upper, lower, concat_ws

df.select(
    col("Name"),
    upper(col("Country")).alias("Country_UPPER"),
    concat_ws(" - ", col("Name"), col("Country")).alias("Combined")
).show()


+---------+-------------+--------------+
|     Name|Country_UPPER|      Combined|
+---------+-------------+--------------+
|    Alice|          USA|   Alice - USA|
|      Bob|          USA|     Bob - USA|
|Catherine|           UK|Catherine - UK|
|    David|        INDIA| David - India|
|    Emily|        INDIA| Emily - India|
|    Frank|           UK|    Frank - UK|
+---------+-------------+--------------+



In [10]:
df.show()

+---------+-------+---+
|     Name|Country|Age|
+---------+-------+---+
|    Alice|    USA| 34|
|      Bob|    USA| 45|
|Catherine|     UK| 29|
|    David|  India| 34|
|    Emily|  India| 21|
|    Frank|     UK| 45|
+---------+-------+---+



In [11]:
df.filter(col("Age") >= 45).show()
df.filter((col("Age") >= 45) & (col("Country") == "USA")).show()


+-----+-------+---+
| Name|Country|Age|
+-----+-------+---+
|  Bob|    USA| 45|
|Frank|     UK| 45|
+-----+-------+---+

+----+-------+---+
|Name|Country|Age|
+----+-------+---+
| Bob|    USA| 45|
+----+-------+---+



In [12]:
# Grouping and Aggregation

df.groupBy("Country").count().show()
df.groupBy("Country").avg("Age").show()




+-------+-----+
|Country|count|
+-------+-----+
|    USA|    2|
|     UK|    2|
|  India|    2|
+-------+-----+

+-------+--------+
|Country|avg(Age)|
+-------+--------+
|    USA|    39.5|
|     UK|    37.0|
|  India|    27.5|
+-------+--------+



In [13]:
# show the top 2 oldest people in each country
from pyspark.sql import Window
windowSpec = Window.partitionBy("Country").orderBy(col("Age").desc())

#display(type(windowSpec))

from pyspark.sql.functions import row_number
df.withColumn("Rank", row_number().over(windowSpec)) \
  .filter(col("Rank") <= 2) \
  .select("Name", "Country", "Age") \
  .show()

+---------+-------+---+
|     Name|Country|Age|
+---------+-------+---+
|    David|  India| 34|
|    Emily|  India| 21|
|    Frank|     UK| 45|
|Catherine|     UK| 29|
|      Bob|    USA| 45|
|    Alice|    USA| 34|
+---------+-------+---+



In [14]:
# Grouping and Aggregation
df.groupBy("Country").agg(
    {"Age": "avg", "Name": "count"}
).show()

+-------+--------+-----------+
|Country|avg(Age)|count(Name)|
+-------+--------+-----------+
|    USA|    39.5|          2|
|     UK|    37.0|          2|
|  India|    27.5|          2|
+-------+--------+-----------+



In [15]:
from pyspark.sql.window import Window
from pyspark.sql.functions import rank, dense_rank

windowSpec = Window.partitionBy("Country").orderBy(col("Age").desc())

df.withColumn("rank", rank().over(windowSpec)) \
  .withColumn("dense_rank", dense_rank().over(windowSpec)) \
  .show()


+---------+-------+---+----+----------+
|     Name|Country|Age|rank|dense_rank|
+---------+-------+---+----+----------+
|    David|  India| 34|   1|         1|
|    Emily|  India| 21|   2|         2|
|    Frank|     UK| 45|   1|         1|
|Catherine|     UK| 29|   2|         2|
|      Bob|    USA| 45|   1|         1|
|    Alice|    USA| 34|   2|         2|
+---------+-------+---+----+----------+



In [16]:
# adding, removing, and renaming columns

df = df.withColumn("AgePlus5", col("Age") + 5)
df.show()

df = df.drop("AgePlus5")
df = df.withColumnRenamed("Country", "Nation")
df.show()


+---------+-------+---+--------+
|     Name|Country|Age|AgePlus5|
+---------+-------+---+--------+
|    Alice|    USA| 34|      39|
|      Bob|    USA| 45|      50|
|Catherine|     UK| 29|      34|
|    David|  India| 34|      39|
|    Emily|  India| 21|      26|
|    Frank|     UK| 45|      50|
+---------+-------+---+--------+

+---------+------+---+
|     Name|Nation|Age|
+---------+------+---+
|    Alice|   USA| 34|
|      Bob|   USA| 45|
|Catherine|    UK| 29|
|    David| India| 34|
|    Emily| India| 21|
|    Frank|    UK| 45|
+---------+------+---+



In [17]:

# ordering the data and showing the top 3 oldest people
df.orderBy(col("Age").desc()).show(3)  # Top 3 oldest


+-----+------+---+
| Name|Nation|Age|
+-----+------+---+
|Frank|    UK| 45|
|  Bob|   USA| 45|
|David| India| 34|
+-----+------+---+
only showing top 3 rows



In [18]:
from pyspark.sql import Row
from pyspark.sql.functions import col

# Create initial DataFrame
data = [("Alice", 30), ("Bob", 25), ("Charlie", 40)]
df = spark.createDataFrame(data, ["Name", "Age"])

# Add 'Salary' column based on 'Age'
df_with_null = df.withColumn("Salary", col("Age") * 1000)
df_with_null = df_with_null.withColumn("Salary", col("Salary").cast("double"))

# ✅ Create new row with matching data types
new_row = [("George", None, None)]  # Ensure correct types: str, None (for int), None (for float)

# ✅ Use schema explicitly to cast the new row
df_new_row = spark.createDataFrame(new_row, schema=df_with_null.schema)

# ✅ Union the DataFrames
df_null = df_with_null.union(df_new_row)

df_null.show()

# # Fill missing salary values with 0
df_null.fillna({"Salary": 0}).show()

# # Drop rows with any nulls
df_null.dropna().show()


+-------+----+-------+
|   Name| Age| Salary|
+-------+----+-------+
|  Alice|  30|30000.0|
|    Bob|  25|25000.0|
|Charlie|  40|40000.0|
| George|NULL|   NULL|
+-------+----+-------+

+-------+----+-------+
|   Name| Age| Salary|
+-------+----+-------+
|  Alice|  30|30000.0|
|    Bob|  25|25000.0|
|Charlie|  40|40000.0|
| George|NULL|    0.0|
+-------+----+-------+

+-------+---+-------+
|   Name|Age| Salary|
+-------+---+-------+
|  Alice| 30|30000.0|
|    Bob| 25|25000.0|
|Charlie| 40|40000.0|
+-------+---+-------+



In [19]:

# deduping data

#Explain below code
# Create a DataFrame with duplicate rows
df_dup = df.union(df)
df_dup.show()

#quit()

# explain the deduplication process
# Deduplication in Spark can be done using the `dropDuplicates()` method.
# This method removes duplicate rows based on all columns by default.
# If you want to deduplicate based on specific columns, you can pass those column names as arguments.
# Example of deduplication

df_dup.dropDuplicates().show()


+-------+---+
|   Name|Age|
+-------+---+
|  Alice| 30|
|    Bob| 25|
|Charlie| 40|
|  Alice| 30|
|    Bob| 25|
|Charlie| 40|
+-------+---+

+-------+---+
|   Name|Age|
+-------+---+
|  Alice| 30|
|    Bob| 25|
|Charlie| 40|
+-------+---+



In [None]:
def run_country_count_query():
    """
    Creates a Spark DataFrame with people data, registers it as a temporary SQL view,
    and runs an SQL query to count the number of people per country.

    Purpose:
    - To demonstrate using SQL on Spark DataFrames via temporary views.
    - Helpful for learners who are comfortable with SQL syntax.
    """
    # Create sample DataFrame
    data = [
        ("Alice", "USA", 34),
        ("Bob", "USA", 45),
        ("Catherine", "UK", 29),
        ("David", "India", 34),
        ("Emily", "India", 21),
        ("Frank", "UK", 45)
    ]
    columns = ["Name", "Country", "Age"]
    df = spark.createDataFrame(data, columns)

    # Create a temporary SQL view
    df.createOrReplaceTempView("people")

    # Run SQL query to count by country
    spark.sql("SELECT Country, COUNT(*) as total FROM people GROUP BY Country").show()


In [24]:
try:
    run_country_count_query()
except Exception as e:
    concise = get_concise_error_message(e)
    #print(f"\n❌ Error: {concise}\n")
    print("🤖 Asking AI to help fix it...\n")
    ask_openai_fix(concise)  # 👈 This will print directly

🤖 Asking AI to help fix it...

🧠 AI Suggestion:

- 🛠️ Root cause: The SQL query is missing a comma between the `Country` column and the `COUNT(*)` function.  
- ✅ Fix: Add a comma after `Country`. The corrected query is:  
  ```sql
  SELECT Country, COUNT(*) as total FROM people GROUP BY Country
  ```  
- 💡 Tip: Always check for commas between selected columns in SQL queries to avoid syntax errors.


In [22]:
#spark.stop()