In [307]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.config("spark.driver.host","localhost").appName("Demo").getOrCreate()

In [308]:
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import StringIndexer
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [309]:
dataset_df = spark.read.csv('youtubestat.csv', header = True, inferSchema = True)

In [310]:
spark.version

'3.4.1'

In [311]:
dataset_df.printSchema()

root
 |-- rank: integer (nullable = true)
 |-- Youtuber: string (nullable = true)
 |-- subscribers: integer (nullable = true)
 |-- video views: double (nullable = true)
 |-- category: string (nullable = true)
 |-- Title: string (nullable = true)
 |-- uploads: integer (nullable = true)
 |-- Country: string (nullable = true)
 |-- Abbreviation: string (nullable = true)
 |-- channel_type: string (nullable = true)
 |-- video_views_rank: integer (nullable = true)
 |-- country_rank: string (nullable = true)
 |-- channel_type_rank: string (nullable = true)
 |-- video_views_for_the_last_30_days: string (nullable = true)
 |-- lowest_monthly_earnings: double (nullable = true)
 |-- highest_monthly_earnings: double (nullable = true)
 |-- lowest_yearly_earnings: double (nullable = true)
 |-- highest_yearly_earnings: double (nullable = true)
 |-- subscribers_for_last_30_days: string (nullable = true)
 |-- created_year: string (nullable = true)
 |-- created_month: string (nullable = true)
 |-- created

In [312]:
dataset_df.columns

['rank',
 'Youtuber',
 'subscribers',
 'video views',
 'category',
 'Title',
 'uploads',
 'Country',
 'Abbreviation',
 'channel_type',
 'video_views_rank',
 'country_rank',
 'channel_type_rank',
 'video_views_for_the_last_30_days',
 'lowest_monthly_earnings',
 'highest_monthly_earnings',
 'lowest_yearly_earnings',
 'highest_yearly_earnings',
 'subscribers_for_last_30_days',
 'created_year',
 'created_month',
 'created_date',
 'Gross tertiary education enrollment (%)',
 'Population',
 'Unemployment rate',
 'Urban_population',
 'Latitude',
 'Longitude']

In [313]:
dataset_df = dataset_df.drop('Abbreviation','channel_type_rank','video_views_for_the_last_30_days','created_year','created_date','Population','video_views_rank','video_views_rank','lowest_monthly_earnings','highest_monthly_earnings','lowest_yearly_earnings','highest_yearly_earnings','subscribers_for_last_30_days','created_month','created_month','Gross tertiary education enrollment (%)','Unemployment rate','Urban_population','Latitude','Longitude')
dataset_df.show(10,False)

+----+--------------------------+-----------+---------------+----------------+--------------------------+-------+-------------+-------------+------------+
|rank|Youtuber                  |subscribers|video views    |category        |Title                     |uploads|Country      |channel_type |country_rank|
+----+--------------------------+-----------+---------------+----------------+--------------------------+-------+-------------+-------------+------------+
|1   |T-Series                  |245000000  |2.28E11        |Music           |T-Series                  |20082  |India        |Music        |1           |
|2   |YouTube Movies            |170000000  |0.0            |Film & Animation|youtubemovies             |1      |United States|Games        |7670        |
|3   |MrBeast                   |166000000  |2.836884187E10 |Entertainment   |MrBeast                   |741    |United States|Entertainment|1           |
|4   |Cocomelon - Nursery Rhymes|162000000  |1.64E11        |Education

In [314]:

assemblerInputs = ['subscribers','video views','uploads']
vector_assembler = VectorAssembler(inputCols = assemblerInputs , outputCol = 'stats')
assembler_temp = vector_assembler.transform(dataset_df)
assembler_temp.show(10,False)

+----+--------------------------+-----------+---------------+----------------+--------------------------+-------+-------------+-------------+------------+-------------------------------+
|rank|Youtuber                  |subscribers|video views    |category        |Title                     |uploads|Country      |channel_type |country_rank|stats                          |
+----+--------------------------+-----------+---------------+----------------+--------------------------+-------+-------------+-------------+------------+-------------------------------+
|1   |T-Series                  |245000000  |2.28E11        |Music           |T-Series                  |20082  |India        |Music        |1           |[2.45E8,2.28E11,20082.0]       |
|2   |YouTube Movies            |170000000  |0.0            |Film & Animation|youtubemovies             |1      |United States|Games        |7670        |[1.7E8,0.0,1.0]                |
|3   |MrBeast                   |166000000  |2.836884187E10 |Ente

In [315]:
assembler = assembler_temp.drop('subscribers','video views','uploads')

In [316]:
assembler.select("Country").distinct().show()

+-------------+
|      Country|
+-------------+
|       Russia|
|       Sweden|
|  Philippines|
|    Singapore|
|     Malaysia|
|       Turkey|
|         Iraq|
|      Germany|
|  Afghanistan|
|       Jordan|
|       France|
|    Argentina|
|      Ecuador|
|      Finland|
|         Peru|
|        India|
|United States|
|        China|
|       Kuwait|
|        Chile|
+-------------+
only showing top 20 rows



In [317]:
country_index = StringIndexer(inputCol = "Country" , outputCol = "country_index")
newdataset = country_index.fit(assembler).transform(assembler)
newdataset.show(10,False)

+----+--------------------------+----------------+--------------------------+-------------+-------------+------------+-------------------------------+-------------+
|rank|Youtuber                  |category        |Title                     |Country      |channel_type |country_rank|stats                          |country_index|
+----+--------------------------+----------------+--------------------------+-------------+-------------+------------+-------------------------------+-------------+
|1   |T-Series                  |Music           |T-Series                  |India        |Music        |1           |[2.45E8,2.28E11,20082.0]       |1.0          |
|2   |YouTube Movies            |Film & Animation|youtubemovies             |United States|Games        |7670        |[1.7E8,0.0,1.0]                |0.0          |
|3   |MrBeast                   |Entertainment   |MrBeast                   |United States|Entertainment|1           |[1.66E8,2.836884187E10,741.0]  |0.0          |
|4   |Coco

In [318]:
(training_data, testing_data) = newdataset.randomSplit([0.8, 0.2], seed=42)
dt_classifier = DecisionTreeClassifier(featuresCol="stats", labelCol="country_index")
dt_model = dt_classifier.fit(training_data)
predictions = dt_model.transform(testing_data)

In [319]:
predictions.select("prediction","country_index").show(25,False)

+----------+-------------+
|prediction|country_index|
+----------+-------------+
|0.0       |0.0          |
|0.0       |0.0          |
|0.0       |10.0         |
|0.0       |9.0          |
|0.0       |11.0         |
|0.0       |3.0          |
|0.0       |0.0          |
|0.0       |0.0          |
|0.0       |0.0          |
|0.0       |10.0         |
|0.0       |1.0          |
|0.0       |0.0          |
|2.0       |2.0          |
|0.0       |3.0          |
|2.0       |2.0          |
|0.0       |37.0         |
|2.0       |2.0          |
|0.0       |2.0          |
|0.0       |0.0          |
|0.0       |14.0         |
|0.0       |6.0          |
|0.0       |1.0          |
|0.0       |6.0          |
|0.0       |1.0          |
|0.0       |0.0          |
+----------+-------------+
only showing top 25 rows



In [320]:
evaluator = MulticlassClassificationEvaluator(labelCol="country_index", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)

print("Accuracy:", accuracy)


Accuracy: 0.38509316770186336


In [321]:
print("Test Error = %g " % (1.0 - accuracy))

Test Error = 0.614907 
