In [None]:
#installing the corret pyspark
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q http://archive.apache.org/dist/spark/spark-3.4.1/spark-3.4.1-bin-hadoop3.tgz
!tar xf spark-3.4.1-bin-hadoop3.tgz
!pip install -q findspark

In [None]:
!pip install pyspark-ai langchain

Set Environment Variables:

In [None]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.4.1-bin-hadoop3"
os.environ["OPENAI_API_KEY"] = "fill in" 

In [None]:
!ls

In [None]:
import findspark
findspark.init()
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()
spark.conf.set("spark.sql.repl.eagerEval.enabled", True) # Property used to format output tables better
spark

In [None]:
from langchain.chat_models import ChatOpenAI
from pyspark_ai import SparkAI

# If 'gpt-4' is unavailable, use 'gpt-3.5-turbo' (might lower output quality)
llm = ChatOpenAI(model_name='gpt-4', temperature=0)
# Initialize SparkAI with the ChatOpenAI model
spark_ai = SparkAI(llm=llm, verbose=True)

spark_ai.activate()

## Read Dataframe

In [None]:
from pyspark.sql.types import StructType,StructField, StringType, DoubleType

schema = StructType([ \
    StructField('longitude',StringType(),True), \
    StructField('latitude',StringType(),True), \
    StructField('housing_median_age',StringType(),True), \
    StructField('total_rooms', StringType(), True), \
    StructField('total_bedrooms', StringType(), True), \
    StructField('population', DoubleType(), True), \
    StructField('households', StringType(), True),\
    StructField('median_income', StringType(), True),\
    StructField('median_house_value', StringType(), True)
  ])

In [None]:
df = spark.read.option("header","true").csv('sample_data/california_housing_test.csv',schema=schema)

In [None]:
df.schema

In [None]:
df.head()

In [None]:
df.ai.verify("expect housing median age to be above 0")

In [None]:
from pyspark.sql.functions import col

def check_housing_median_age(df) -> bool:

    # Check if all values in 'housing_median_age' column are above 0
    if df.filter(col('housing_median_age') <= 0).count() > 0:
        return False
    else:
        return True

## Create UDFs

In [None]:
@spark_ai.udf
def convert_population(population: float) -> str:
    """Convert the population to a three bucket tiers"""
    ...

In [None]:
def convert_population(population) -> str:
    if population is not None:
        if population < 100:
            return 'small'
        elif 100 <= population < 500:
            return 'medium'
        else:
            return 'large'

In [None]:
from pyspark.sql.functions import col, udf
from pyspark.sql.types import StringType

convert_populationUDF = udf(lambda z:convert_population(z),StringType())

df.withColumn("population tier", convert_populationUDF(col("population"))) \
  .show(truncate=False)

## Transformations

In [None]:
top_10_house_value = df.ai.transform("find me the top 10 location with the highest median house value")

In [None]:
top_10_house_value.ai.explain()

In [None]:
top_10_house_value.ai.plot("bar with no grid and background white; x value is location; y value is median home value")

## Plotting

In [None]:
df.ai.plot("histogram of the median house value")

In [None]:
df.ai.plot("histogram of the median house value into 20 buckets with no grid and background white")