In [31]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, year, when, udf, from_json, max, rand
from pyspark.sql.types import ArrayType, FloatType, StructType, StructField, StringType, IntegerType, DoubleType
import json
import boto3

In [32]:
# test the client
#client = boto3.client('sagemaker')
#response = client.list_endpoints()
#print(response)

In [33]:
spark = SparkSession.builder \
.appName("training_data_prep") \
.getOrCreate()

In [34]:
df = spark.read.json("datasets/movies_9858.json", multiLine=True)

                                                                                

In [35]:
df.printSchema()

root
 |-- adult: boolean (nullable = true)
 |-- backdrop_path: string (nullable = true)
 |-- genre_ids: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- id: long (nullable = true)
 |-- original_language: string (nullable = true)
 |-- original_title: string (nullable = true)
 |-- overview: string (nullable = true)
 |-- popularity: double (nullable = true)
 |-- poster_path: string (nullable = true)
 |-- release_date: string (nullable = true)
 |-- title: string (nullable = true)
 |-- video: boolean (nullable = true)
 |-- vote_average: double (nullable = true)
 |-- vote_count: long (nullable = true)



In [36]:
df.count()

9858

## Drop Unnecessary Columns

In [37]:
drop_cols = ['adult', 'backdrop_path', 'original_language', 'original_title', 'poster_path', 'video']

df_dropped_cols = df.drop(*drop_cols)

## Drop duplicates and records where any of the ['vote_average', 'release_date', 'overview'] is NA

In [38]:
df_na_drop = df_dropped_cols.dropna(subset=['vote_average', 'release_date', 'overview'])
#df_na_drop.dropDuplicates()
df_na_drop.count()

9858

## Fix column order

In [39]:
df = df_na_drop.select('id', 'title', 'overview', 'release_date', 'genre_ids', 'popularity', 'vote_average', 'vote_count')
df.printSchema()

root
 |-- id: long (nullable = true)
 |-- title: string (nullable = true)
 |-- overview: string (nullable = true)
 |-- release_date: string (nullable = true)
 |-- genre_ids: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- popularity: double (nullable = true)
 |-- vote_average: double (nullable = true)
 |-- vote_count: long (nullable = true)



## Fix column types

In [40]:
# Define the schema for genre_ids
'''
schema = ArrayType(IntegerType())

df = df.withColumn('id', col('id').cast(IntegerType()))
df = df.withColumn('popularity', col('popularity').cast(DoubleType()))
df = df.withColumn('vote_average', col('vote_average').cast(DoubleType()))
df = df.withColumn('vote_count', col('vote_count').cast(IntegerType()))
df = df.withColumn('genre_ids', from_json(col('genre_ids'), schema))  # Parse JSON string to array
df.printSchema()
'''

"\nschema = ArrayType(IntegerType())\n\ndf = df.withColumn('id', col('id').cast(IntegerType()))\ndf = df.withColumn('popularity', col('popularity').cast(DoubleType()))\ndf = df.withColumn('vote_average', col('vote_average').cast(DoubleType()))\ndf = df.withColumn('vote_count', col('vote_count').cast(IntegerType()))\ndf = df.withColumn('genre_ids', from_json(col('genre_ids'), schema))  # Parse JSON string to array\ndf.printSchema()\n"

## Extract year from relase date, add 'is_sci_fi' and 'sci_fi_appeal' columns

In [41]:
df = df.withColumn('release_year', year(col('release_date')))
df = df.withColumn('is_sci_fi', when(col('genre_ids').cast('string').contains('878'), 1).otherwise(0))
df = df.withColumn('sci_fi_appeal', when((col('vote_average') > 6) & (col('is_sci_fi') == 1), 1).otherwise(0))

df = df.drop(col('genre_ids')) # we do not need this anymore for training
                 
df.printSchema()

root
 |-- id: long (nullable = true)
 |-- title: string (nullable = true)
 |-- overview: string (nullable = true)
 |-- release_date: string (nullable = true)
 |-- popularity: double (nullable = true)
 |-- vote_average: double (nullable = true)
 |-- vote_count: long (nullable = true)
 |-- release_year: integer (nullable = true)
 |-- is_sci_fi: integer (nullable = false)
 |-- sci_fi_appeal: integer (nullable = false)



## Filter out low-quality data and empty overview records 

In [42]:
df = df.filter(col('vote_count') >= 10)
df = df.filter(col('overview') != '')  # Filter out empty overview strings, sagemaker endpoint requires non-empty values
df.count()
# from 9858 to 9016

9016

## Check the distribution of 'sci_fi_appeal' column

In [43]:
df.groupBy('sci_fi_appeal').count().show()

+-------------+-----+
|sci_fi_appeal|count|
+-------------+-----+
|            1|  833|
|            0| 8183|
+-------------+-----+



In [44]:
# Create balanced dataset with 1000 records
df_positive = df.filter(col('sci_fi_appeal') == 1).limit(500)
df_negative = df.filter(col('sci_fi_appeal') == 0).limit(500)

# Combine to get 1,000 records
df_balanced = df_positive.union(df_negative)

# Verify the new distribution
df_balanced.groupBy('sci_fi_appeal').count().show()
# Should show: 500 for sci_fi_appeal = 1, 500 for sci_fi_appeal = 0

+-------------+-----+
|sci_fi_appeal|count|
+-------------+-----+
|            1|  500|
|            0|  500|
+-------------+-----+



In [45]:
df_balanced.printSchema()

root
 |-- id: long (nullable = true)
 |-- title: string (nullable = true)
 |-- overview: string (nullable = true)
 |-- release_date: string (nullable = true)
 |-- popularity: double (nullable = true)
 |-- vote_average: double (nullable = true)
 |-- vote_count: long (nullable = true)
 |-- release_year: integer (nullable = true)
 |-- is_sci_fi: integer (nullable = false)
 |-- sci_fi_appeal: integer (nullable = false)



In [46]:
df_balanced.count()

1000

## Get embeddings from sentence-transformers/all-MiniLM-L6-v2 model hosted on SageMaker 

In [47]:
# Function to call SageMaker endpoint
def get_embeddings(overview):
    client = boto3.client('sagemaker-runtime')
    response = client.invoke_endpoint(
        EndpointName='sentence-embedding-endpoint',
        ContentType='application/json',
        Body=json.dumps({'inputs': overview})
    )
    embeddings = json.loads(response['Body'].read().decode())[0]  # Shape: (384,)
    return embeddings

# Register UDF
get_embeddings_udf = udf(get_embeddings, ArrayType(FloatType()))

# Apply to DataFrame
df_balanced = df_balanced.withColumn('overview_embedding', get_embeddings_udf(col('overview')))

In [48]:
df_balanced.printSchema()

root
 |-- id: long (nullable = true)
 |-- title: string (nullable = true)
 |-- overview: string (nullable = true)
 |-- release_date: string (nullable = true)
 |-- popularity: double (nullable = true)
 |-- vote_average: double (nullable = true)
 |-- vote_count: long (nullable = true)
 |-- release_year: integer (nullable = true)
 |-- is_sci_fi: integer (nullable = false)
 |-- sci_fi_appeal: integer (nullable = false)
 |-- overview_embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)



In [49]:
df_balanced.limit(5).toPandas()

                                                                                

Unnamed: 0,id,title,overview,release_date,popularity,vote_average,vote_count,release_year,is_sci_fi,sci_fi_appeal,overview_embedding
0,822119,Captain America: Brave New World,After meeting with newly elected U.S. Presiden...,2025-02-12,285.4807,6.116,1275,2025,1,1,"[-0.07376226782798767, -0.04669279232621193, -..."
1,696506,Mickey 17,Unlikely hero Mickey Barnes finds himself in t...,2025-02-28,236.5855,6.94,1441,2025,1,1,"[-0.014793258160352707, 0.007633930072188377, ..."
2,1165067,Cosmic Chaos,"Battles in virtual reality, survival in a post...",2023-08-03,170.0328,6.022,45,2023,1,1,"[0.03736811503767967, 0.026995297521352768, -0..."
3,939243,Sonic the Hedgehog 3,"Sonic, Knuckles, and Tails reunite against a p...",2024-12-19,154.1998,7.74,2394,2024,1,1,"[-0.028525710105895996, -0.04687212035059929, ..."
4,950396,The Gorge,Two highly trained operatives grow close from ...,2025-02-13,92.0002,7.727,2326,2025,1,1,"[-0.045688461512327194, 0.004364991094917059, ..."


## Select features

In [65]:
df_balanced = df_balanced.select('sci_fi_appeal', 'popularity', 'vote_count', 'vote_average', 'release_year', 'overview_embedding')
df_balanced.printSchema()

embedding_cols = [col('overview_embedding')[i].alias(f'embedding_{i}') for i in range(384)]

# Select the columns, including the flattened embeddings
df_balanced = df_balanced.select(
    'sci_fi_appeal', 'popularity', 'vote_count', 'vote_average', 'release_year', *embedding_cols
)

root
 |-- sci_fi_appeal: integer (nullable = false)
 |-- popularity: double (nullable = true)
 |-- vote_count: long (nullable = true)
 |-- vote_average: double (nullable = true)
 |-- release_year: integer (nullable = true)
 |-- overview_embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)



In [66]:
df_balanced.printSchema()

root
 |-- sci_fi_appeal: integer (nullable = false)
 |-- popularity: double (nullable = true)
 |-- vote_count: long (nullable = true)
 |-- vote_average: double (nullable = true)
 |-- release_year: integer (nullable = true)
 |-- embedding_0: float (nullable = true)
 |-- embedding_1: float (nullable = true)
 |-- embedding_2: float (nullable = true)
 |-- embedding_3: float (nullable = true)
 |-- embedding_4: float (nullable = true)
 |-- embedding_5: float (nullable = true)
 |-- embedding_6: float (nullable = true)
 |-- embedding_7: float (nullable = true)
 |-- embedding_8: float (nullable = true)
 |-- embedding_9: float (nullable = true)
 |-- embedding_10: float (nullable = true)
 |-- embedding_11: float (nullable = true)
 |-- embedding_12: float (nullable = true)
 |-- embedding_13: float (nullable = true)
 |-- embedding_14: float (nullable = true)
 |-- embedding_15: float (nullable = true)
 |-- embedding_16: float (nullable = true)
 |-- embedding_17: float (nullable = true)
 |-- embeddin

## Prepare train and validation sets

In [67]:
count = df_balanced.count()
train_count = int(count * 0.8)

df_balanced = df_balanced.orderBy(rand())  # Shuffle the data

train = df_balanced.limit(train_count)
validation = df_balanced.subtract(train)

                                                                                

## Take a glance at each

In [68]:
train.count()

800

In [69]:
#train.limit(5).toPandas()

In [70]:
train.groupBy('sci_fi_appeal').count().show()

+-------------+-----+
|sci_fi_appeal|count|
+-------------+-----+
|            0|  401|
|            1|  399|
+-------------+-----+



In [71]:
validation.count()

                                                                                

200

In [72]:
#validation.limit(5).toPandas()

In [73]:
validation.groupBy('sci_fi_appeal').count().show()

[Stage 252:>                                                        (0 + 1) / 1]

+-------------+-----+
|sci_fi_appeal|count|
+-------------+-----+
|            1|  101|
|            0|   99|
+-------------+-----+



                                                                                

In [74]:
train.coalesce(1).write.csv('train.csv', header=False, mode='overwrite')

                                                                                

In [75]:
validation.coalesce(1).write.csv('validation.csv', header=False, mode='overwrite')

25/04/13 20:58:54 WARN DAGScheduler: Broadcasting large task binary with size 1150.9 KiB


# Final Training Data Format

The training data (train.csv, validation.csv) will look like this (CSV format, no header, target as first column for SageMaker XGBoost):

1,72.5,1500,8.2,2023,0.12,...,0.45 <br />
0,45.2,300,6.5,2019,-0.23,...,0.67

* First column: sci_fi_appeal (0 or 1).
* Next 4 columns: popularity, vote_count, vote_average, release_year.
* Last 384 columns: overview_embedding (384-dimensional vector).

## What Each Column Does for Training
* sci_fi_appeal: The target variable—XGBoost learns to predict this based on the features.
* popularity: Helps the model understand if a movie’s buzz correlates with Sci-Fi appeal (e.g., popular Sci-Fi movies might have higher appeal).
* vote_count: Acts as a weighting factor for vote_average—movies with more votes have more reliable ratings, which might influence appeal.
* vote_average: A key feature since it’s part of the target definition (sci_fi_appeal = 1 if vote_average > 7). The model learns how this interacts with other features.
* release_year: Captures temporal patterns (e.g., modern Sci-Fi movies might appeal more due to better effects).
* overview_embedding: Provides semantic context—movies with overview texts containing Sci-Fi themes (e.g., “space”, “aliens”) might have higher appeal.