In [None]:
import openai
import os
import traceback
from dotenv import load_dotenv

# Load API key from .env file (recommended)
load_dotenv()
openai.api_key = os.getenv("REMOVED_API_KEY")

In [None]:
def get_concise_error_message(e):
    """
    Extract the most relevant single-line error message for display and AI use.
    """
    tb_str = traceback.format_exception_only(type(e), e)
    return tb_str[-1].strip() if tb_str else str(e)


In [99]:
from openai import OpenAI

client = OpenAI(api_key="REMOVED_API_KEY")  # Replace with your actual key


def ask_openai_fix(error_message):
    """
    Ask OpenAI (gpt-4o-mini-2024-07-18) to explain and fix a PySpark error.
    Prints the suggestion directly. Uses openai>=1.0.0 syntax.
    """
    prompt = f"""
I encountered this PySpark error in a notebook:

Error:
{error_message}

Please explain what went wrong and how to fix it clearly and concisely. Use a point-wise format.

Respond only with:
- 🛠️ Root cause (1 short sentence)
- ✅ Fix (1–2 steps max, include code if needed)
- 💡 Tip (1 short sentence to avoid it in future)

Keep your answer under 6 lines.
"""

    try:
        response = client.chat.completions.create(
            model="gpt-4o-mini-2024-07-18",
            messages=[{"role": "user", "content": prompt}],
            temperature=0.3
        )
        print("🧠 AI Suggestion:\n")
        print(response.choices[0].message.content)

    except Exception as err:
        print(f"⚠️ OpenAI API call failed: {err}")


In [None]:
!pip install findspark

In [None]:
from spark_init import start_spark

spark = start_spark("Spark Deep Learning Example")

data = [("Apple", 10), ("Banana", 5), ("Orange", 8)]
df = spark.createDataFrame(data, ["Fruit", "Quantity"])
df.show()



In [None]:
# ✅ Sample data
data = [
    ("Alice", "USA", 34),
    ("Bob", "USA", 45),
    ("Catherine", "UK", 29),
    ("David", "India", 34),
    ("Emily", "India", 21),
    ("Frank", "UK", 45)
]

columns = ["Name", "Country", "Age"]
df = spark.createDataFrame(data, columns)
# ✅ Show the data

print(type(df))

# ✅ Show the data
df.show()
df.printSchema()

In [None]:
from pyspark.sql.functions import col, upper, lower, concat_ws

df.select(
    col("Name"),
    upper(col("Country")).alias("Country_UPPER"),
    concat_ws(" - ", col("Name"), col("Country")).alias("Combined")
).show()


In [None]:
df.show()

In [None]:
df.filter(col("Age") >= 45).show()
df.filter((col("Age") >= 45) & (col("Country") == "USA")).show()


In [None]:
# Grouping and Aggregation

df.groupBy("Country").count().show()
df.groupBy("Country").avg("Age").show()




In [None]:
# show the top 2 oldest people in each country
from pyspark.sql import Window
windowSpec = Window.partitionBy("Country").orderBy(col("Age").desc())

#display(type(windowSpec))

from pyspark.sql.functions import row_number
df.withColumn("Rank", row_number().over(windowSpec)) \
  .filter(col("Rank") <= 2) \
  .select("Name", "Country", "Age") \
  .show()

In [None]:
# Grouping and Aggregation
df.groupBy("Country").agg(
    {"Age": "avg", "Name": "count"}
).show()

In [None]:
from pyspark.sql.window import Window
from pyspark.sql.functions import rank, dense_rank

windowSpec = Window.partitionBy("Country").orderBy(col("Age").desc())

df.withColumn("rank", rank().over(windowSpec)) \
  .withColumn("dense_rank", dense_rank().over(windowSpec)) \
  .show()


In [None]:
# adding, removing, and renaming columns

df = df.withColumn("AgePlus5", col("Age") + 5)
df.show()

df = df.drop("AgePlus5")
df = df.withColumnRenamed("Country", "Nation")
df.show()


In [None]:

# ordering the data and showing the top 3 oldest people
df.orderBy(col("Age").desc()).show(3)  # Top 3 oldest


In [None]:
from pyspark.sql import Row
from pyspark.sql.functions import col

# Create initial DataFrame
data = [("Alice", 30), ("Bob", 25), ("Charlie", 40)]
df = spark.createDataFrame(data, ["Name", "Age"])

# Add 'Salary' column based on 'Age'
df_with_null = df.withColumn("Salary", col("Age") * 1000)
df_with_null = df_with_null.withColumn("Salary", col("Salary").cast("double"))

# ✅ Create new row with matching data types
new_row = [("George", None, None)]  # Ensure correct types: str, None (for int), None (for float)

# ✅ Use schema explicitly to cast the new row
df_new_row = spark.createDataFrame(new_row, schema=df_with_null.schema)

# ✅ Union the DataFrames
df_null = df_with_null.union(df_new_row)

df_null.show()

# # Fill missing salary values with 0
df_null.fillna({"Salary": 0}).show()

# # Drop rows with any nulls
df_null.dropna().show()


In [None]:

# deduping data

#Explain below code
# Create a DataFrame with duplicate rows
df_dup = df.union(df)
df_dup.show()

#quit()

# explain the deduplication process
# Deduplication in Spark can be done using the `dropDuplicates()` method.
# This method removes duplicate rows based on all columns by default.
# If you want to deduplicate based on specific columns, you can pass those column names as arguments.
# Example of deduplication

df_dup.dropDuplicates().show()


In [103]:
def run_country_count_query():
    """
    Creates a Spark DataFrame with people data, registers it as a temporary SQL view,
    and runs an SQL query to count the number of people per country.

    Purpose:
    - To demonstrate using SQL on Spark DataFrames via temporary views.
    - Helpful for learners who are comfortable with SQL syntax.
    """
    # Create sample DataFrame
    data = [
        ("Alice", "USA", 34),
        ("Bob", "USA", 45),
        ("Catherine", "UK", 29),
        ("David", "India", 34),
        ("Emily", "India", 21),
        ("Frank", "UK", 45)
    ]
    columns = ["Name", "Country", "Age"]
    df = spark.createDataFrame(data, columns)

    # Create a temporary SQL view
    df.createOrReplaceTempView("people")

    # Run SQL query to count by country
    spark.sql("SELECT Country, COUNT(*) as total FROM people GROUP BY Country").show()


In [104]:
try:
    run_country_count_query()
except Exception as e:
    concise = get_concise_error_message(e)
    #print(f"\n❌ Error: {concise}\n")
    print("🤖 Asking AI to help fix it...\n")
    ask_openai_fix(concise)  # 👈 This will print directly

+-------+-----+
|Country|total|
+-------+-----+
|    USA|    2|
|     UK|    2|
|  India|    2|
+-------+-----+



In [None]:
#spark.stop()