In [1]:
!pip install pyspark

Collecting pyspark
  Downloading pyspark-3.5.1.tar.gz (317.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.0/317.0 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l- \ done
Building wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l- \ | done
[?25h  Created wheel for pyspark: filename=pyspark-3.5.1-py2.py3-none-any.whl size=317488493 sha256=1e2f8ace6d862240b3b42c2bc2c8168880fef011d1d26723e082edd9bcf37698
  Stored in directory: /root/.cache/pip/wheels/80/1d/60/2c256ed38dddce2fdd93be545214a63e02fbd8d74fb0b7f3a6
Successfully built pyspark
Installing collected packages: pyspark
Successfully installed pyspark-3.5.1


In [2]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import GBTRegressor
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.sql.types import DoubleType

In [3]:
spark = SparkSession.builder \
    .appName("Twitch Streamers ML") \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/06/25 12:05:40 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [4]:
data_path = "/kaggle/input/top-1000-twitch-streamers-data-may-2024"
df = spark.read.option("header", "true").csv(data_path)

In [5]:
# Convert data types for numeric columns

numeric_columns = ['RANK', 'AVERAGE_STREAM_DURATION', 'FOLLOWERS_GAINED_PER_STREAM',
                   'AVG_GAMES_PER_STREAM', 'TOTAL_TIME_STREAMED',
                   'TOTAL_FOLLOWERS', 'TOTAL_VIEWS', 'TOTAL_GAMES_STREAMED',
                   'AVG_VIEWERS_PER_STREAM']  # Include the column AVG_VIEWERS_PER_STREAM in the list

for column in numeric_columns:
    df = df.withColumn(column, df[column].cast(DoubleType()))  # Convert all numeric columns to Double type

In [6]:
# Drop rows with empty values
df = df.dropna()

In [7]:
# Create a feature vector for the MLlib model

feature_columns = ['RANK', 'AVERAGE_STREAM_DURATION', 'FOLLOWERS_GAINED_PER_STREAM',
                   'AVG_GAMES_PER_STREAM', 'TOTAL_TIME_STREAMED',
                   'TOTAL_FOLLOWERS', 'TOTAL_VIEWS', 'TOTAL_GAMES_STREAMED']

assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")

df = assembler.transform(df)

In [8]:
# Use the Gradient-Boosted Trees (GBT) model with cross-validation
gbt = GBTRegressor(labelCol="AVG_VIEWERS_PER_STREAM", featuresCol="features")

In [9]:
# Configure the parameter grid for cross-validation
paramGrid = ParamGridBuilder() \
    .addGrid(gbt.maxDepth, [3, 5, 7]) \
    .addGrid(gbt.maxIter, [10, 20, 30]) \
    .build()

In [10]:
# Evaluate the root mean square error in cross-validation
evaluator = RegressionEvaluator(labelCol="AVG_VIEWERS_PER_STREAM", predictionCol="prediction", metricName="rmse")

In [11]:
# Initialize the cross-validator

crossval = CrossValidator(estimator=gbt,
                          estimatorParamMaps=paramGrid,
                          evaluator=evaluator,
                          numFolds=3)  # Number of folds for cross-validation

In [12]:
# Train the model using cross-validation
cvModel = crossval.fit(df)

24/06/25 12:06:01 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS


In [13]:
# Get the best model from cross-validation
bestModel = cvModel.bestModel

# Output the best model parameters
print(f"Best maxDepth: {bestModel._java_obj.getMaxDepth()}")
print(f"Best maxIter: {bestModel._java_obj.getMaxIter()}")

Best maxDepth: 3
Best maxIter: 20


In [14]:
# Evaluate the model's performance on the test data
predictions = cvModel.transform(df)
rmse = evaluator.evaluate(predictions)
r2 = evaluator.evaluate(predictions, {evaluator.metricName: "r2"})

# Output the results
print(f"Root Mean Squared Error (RMSE): {rmse}")
print(f"R-squared (R2): {r2}")

Root Mean Squared Error (RMSE): 17057.320144510668
R-squared (R2): 0.8291146886945882


In [15]:
# Release SparkSession resources
spark.stop()