In [71]:
# Must be included at the beginning of each new notebook. Remember to change the app name.
import findspark
findspark.init('/home/ubuntu/spark-3.2.1-bin-hadoop2.7')
import pyspark
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('basics').getOrCreate()

In [72]:
df = spark.read.csv('data/transformed.csv',inferSchema=True)

In [73]:
df.show()

+---------+----------+----------+---+-----+---------------+----------+------------+----+----+
|      _c0|       _c1|       _c2|_c3|  _c4|            _c5|       _c6|         _c7| _c8| _c9|
+---------+----------+----------+---+-----+---------------+----------+------------+----+----+
|300000000|139.082615| 961000000|169| 4500|         Action|2007-05-19|united_state|high| low|
|245000000|107.376788| 880674609|148| 4466|         Action|2015-10-26|      others|high| low|
|250000000| 112.31295|1084939099|165| 9106|         Action|2012-07-16|united_state|high| low|
|260000000| 43.926995| 284139100|132| 2124|         Action|2012-03-07|united_state|high| low|
|260000000| 48.681969| 591794936|100| 3330|      Animation|2010-11-24|united_state|high| low|
|280000000|134.279229|1405403694|141| 6767|         Action|2015-04-22|united_state|high| low|
|250000000| 98.885637| 933959197|153| 5293|      Advanture|2009-07-07|      others|high| low|
|250000000|155.790452| 873260194|151| 7004|         Action|2

In [74]:
# Optionally, rename columns for better clarity
df = df.withColumnRenamed("_c0", "budget") \
       .withColumnRenamed("_c1", "popularity") \
       .withColumnRenamed("_c2", "revenue") \
       .withColumnRenamed("_c3", "runtime") \
       .withColumnRenamed("_c4", "vote_count") \
       .withColumnRenamed("_c5", "genre") \
       .withColumnRenamed("_c6", "release_date") \
       .withColumnRenamed("_c7", "production_country")\
       .withColumnRenamed("_c8", "popularity_rank")\
       .withColumnRenamed("_c9", "risk")
# Let's get an idea of what the data looks like. 
df.printSchema()
df.show()

root
 |-- budget: integer (nullable = true)
 |-- popularity: double (nullable = true)
 |-- revenue: integer (nullable = true)
 |-- runtime: integer (nullable = true)
 |-- vote_count: integer (nullable = true)
 |-- genre: string (nullable = true)
 |-- release_date: string (nullable = true)
 |-- production_country: string (nullable = true)
 |-- popularity_rank: string (nullable = true)
 |-- risk: string (nullable = true)

+---------+----------+----------+-------+----------+---------------+------------+------------------+---------------+----+
|   budget|popularity|   revenue|runtime|vote_count|          genre|release_date|production_country|popularity_rank|risk|
+---------+----------+----------+-------+----------+---------------+------------+------------------+---------------+----+
|300000000|139.082615| 961000000|    169|      4500|         Action|  2007-05-19|      united_state|           high| low|
|245000000|107.376788| 880674609|    148|      4466|         Action|  2015-10-26|       

In [75]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Initialize Spark Session
spark = SparkSession.builder.appName("DecisionTreeExample").getOrCreate()


# Convert categorical columns into numerical representations
indexer_genre = StringIndexer(inputCol="genre", outputCol="genre_index")
indexer_country = StringIndexer(inputCol="production_country", outputCol="country_index")
indexer_risk = StringIndexer(inputCol="risk", outputCol="risk_index")
indexer_rank = StringIndexer(inputCol="popularity_rank", outputCol="rank_index")

# Apply StringIndexer transformations
df = indexer_genre.fit(df).transform(df)
df = indexer_country.fit(df).transform(df)
df = indexer_risk.fit(df).transform(df)
df = indexer_rank.fit(df).transform(df)

In [76]:
df.show()

+---------+----------+----------+-------+----------+---------------+------------+------------------+---------------+----+-----------+-------------+----------+----------+
|   budget|popularity|   revenue|runtime|vote_count|          genre|release_date|production_country|popularity_rank|risk|genre_index|country_index|risk_index|rank_index|
+---------+----------+----------+-------+----------+---------------+------------+------------------+---------------+----+-----------+-------------+----------+----------+
|300000000|139.082615| 961000000|    169|      4500|         Action|  2007-05-19|      united_state|           high| low|        1.0|          0.0|       0.0|       0.0|
|245000000|107.376788| 880674609|    148|      4466|         Action|  2015-10-26|            others|           high| low|        1.0|          1.0|       0.0|       0.0|
|250000000| 112.31295|1084939099|    165|      9106|         Action|  2012-07-16|      united_state|           high| low|        1.0|          0.0|   

In [77]:
from pyspark.sql.functions import col, when, lit

# Get unique genre numbers
unique_genre_numbers = df.select("genre_index").distinct().count()

# Iterate over each unique genre number and create a new column with 1 for true and 0 for false
for i in range(1, unique_genre_numbers + 1):
    genre_col = f"genre_{i}"
    df = df.withColumn(genre_col, when(col("genre_index") == i, lit(1)).otherwise(lit(0)))

# Show the DataFrame with encoded genres
df.show()

+---------+----------+----------+-------+----------+---------------+------------+------------------+---------------+----+-----------+-------------+----------+----------+-------+-------+-------+-------+-------+-------+-------+-------+-------+--------+--------+
|   budget|popularity|   revenue|runtime|vote_count|          genre|release_date|production_country|popularity_rank|risk|genre_index|country_index|risk_index|rank_index|genre_1|genre_2|genre_3|genre_4|genre_5|genre_6|genre_7|genre_8|genre_9|genre_10|genre_11|
+---------+----------+----------+-------+----------+---------------+------------+------------------+---------------+----+-----------+-------------+----------+----------+-------+-------+-------+-------+-------+-------+-------+-------+-------+--------+--------+
|300000000|139.082615| 961000000|    169|      4500|         Action|  2007-05-19|      united_state|           high| low|        1.0|          0.0|       0.0|       0.0|      1|      0|      0|      0|      0|      0|   

In [78]:
df.printSchema()
# Check unique values in the genre_1 column
df.select('genre_2').distinct().show()

root
 |-- budget: integer (nullable = true)
 |-- popularity: double (nullable = true)
 |-- revenue: integer (nullable = true)
 |-- runtime: integer (nullable = true)
 |-- vote_count: integer (nullable = true)
 |-- genre: string (nullable = true)
 |-- release_date: string (nullable = true)
 |-- production_country: string (nullable = true)
 |-- popularity_rank: string (nullable = true)
 |-- risk: string (nullable = true)
 |-- genre_index: double (nullable = false)
 |-- country_index: double (nullable = false)
 |-- risk_index: double (nullable = false)
 |-- rank_index: double (nullable = false)
 |-- genre_1: integer (nullable = false)
 |-- genre_2: integer (nullable = false)
 |-- genre_3: integer (nullable = false)
 |-- genre_4: integer (nullable = false)
 |-- genre_5: integer (nullable = false)
 |-- genre_6: integer (nullable = false)
 |-- genre_7: integer (nullable = false)
 |-- genre_8: integer (nullable = false)
 |-- genre_9: integer (nullable = false)
 |-- genre_10: integer (nullable

In [80]:

# List of feature column names
all_feature_columns = ['genre_1', 'genre_2', 'genre_3', 'genre_4', 'genre_5','genre_6', 'genre_7', 'genre_8', 'genre_9', 'genre_10', 'genre_11']

assembler = VectorAssembler(inputCols=all_feature_columns, outputCol="features")
data = assembler.transform(df)

# Select features and label
data = data.select("features", "rank_index")
data.show()
# Split the data
train_data, test_data = data.randomSplit([0.8, 0.2], seed=1234)

# Create and train the Decision Tree model
dt = DecisionTreeClassifier(featuresCol="features", labelCol="rank_index", predictionCol="prediction")
model = dt.fit(train_data)

# Make predictions
predictions = model.transform(test_data)

# Get the string representation of the decision tree
tree_str = model.toDebugString

# Print the decision tree
print("Decision Tree Model:")
print(tree_str)


# Evaluate the model
evaluator = MulticlassClassificationEvaluator(labelCol="rank_index", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print(f"Test Accuracy = {accuracy}")

# Stop the Spark session
spark.stop()

+--------------+----------+
|      features|rank_index|
+--------------+----------+
|(11,[0],[1.0])|       0.0|
|(11,[0],[1.0])|       0.0|
|(11,[0],[1.0])|       0.0|
|(11,[0],[1.0])|       0.0|
|(11,[5],[1.0])|       0.0|
|(11,[0],[1.0])|       0.0|
|(11,[2],[1.0])|       0.0|
|(11,[0],[1.0])|       0.0|
|(11,[0],[1.0])|       0.0|
|(11,[2],[1.0])|       0.0|
|(11,[2],[1.0])|       0.0|
|(11,[0],[1.0])|       0.0|
|(11,[0],[1.0])|       0.0|
|(11,[2],[1.0])|       0.0|
|(11,[3],[1.0])|       0.0|
|(11,[2],[1.0])|       0.0|
|(11,[0],[1.0])|       0.0|
|(11,[0],[1.0])|       0.0|
|(11,[0],[1.0])|       0.0|
|(11,[0],[1.0])|       0.0|
+--------------+----------+
only showing top 20 rows

Decision Tree Model:
DecisionTreeClassificationModel: uid=DecisionTreeClassifier_9b887282946d, depth=1, numNodes=3, numClasses=2, numFeatures=11
  If (feature 8 <= 0.5)
   Predict: 0.0
  Else (feature 8 > 0.5)
   Predict: 1.0

Test Accuracy = 0.9953379953379954
