In [8]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, udf, when, isnan, isnull
from pyspark.sql.types import IntegerType, StringType, FloatType
from pyspark.ml.feature import Tokenizer, StopWordsRemover, HashingTF, IDF, StringIndexer
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [9]:
import os
os.environ['PYSPARK_PYTHON'] = r'C:\miniconda3\envs\v\python.exe'  # 将其替换为Python 3.10的路径
os.environ['PYSPARK_DRIVER_PYTHON'] = r'C:\miniconda3\envs\v\python.exe'  # 将其替换为Python 3.10的路径

spark = SparkSession.builder.appName(
            "SentimentAnalysisBaseline"
        ).config(
            "spark.executor.memory", "8g"
        ).config(
            "spark.driver.memory", "16g"
        ).getOrCreate()

In [10]:
df = spark.read.csv("../ggg_sg.csv", header=True, inferSchema=True, multiLine=True, escape='"')

In [11]:
df.show()

+-------------------+--------------------+--------------------+--------------------+--------+-----------------+-----------------+---------+------+-----+-----------+--------+--------+-------+--------------------+-------------------+
|           DateTime|                 URL|               Title|        SharingImage|LangCode|          DocTone|DomainCountryCode| Location|   Lat|  Lon|CountryCode|Adm1Code|Adm2Code|GeoType|      ContextualText|           the_geom|
+-------------------+--------------------+--------------------+--------------------+--------+-----------------+-----------------+---------+------+-----+-----------+--------+--------+-------+--------------------+-------------------+
|2022-08-18 19:15:00|https://www.goal....|How to watch Manc...|https://assets.go...|     eng|0.764331210191083|               SP|Singapore|1.3667|103.8|         SN|      SN|    NULL|      1|nunez added to kl...|POINT(103.8 1.3667)|
|2021-03-09 01:30:00|https://www.labma...|Vetter Announces ...|https://w

In [12]:
df = df.filter(df.ContextualText.isNotNull())
df = df.filter(df.DocTone.isNotNull())

In [13]:
df = df.withColumn("DocTone", df["DocTone"].cast(FloatType()))

In [14]:
df.show(10)

+-------------------+--------------------+--------------------+--------------------+--------+-----------+-----------------+---------+------+-----+-----------+--------+--------+-------+--------------------+-------------------+
|           DateTime|                 URL|               Title|        SharingImage|LangCode|    DocTone|DomainCountryCode| Location|   Lat|  Lon|CountryCode|Adm1Code|Adm2Code|GeoType|      ContextualText|           the_geom|
+-------------------+--------------------+--------------------+--------------------+--------+-----------+-----------------+---------+------+-----+-----------+--------+--------+-------+--------------------+-------------------+
|2022-08-18 19:15:00|https://www.goal....|How to watch Manc...|https://assets.go...|     eng|  0.7643312|               SP|Singapore|1.3667|103.8|         SN|      SN|    NULL|      1|nunez added to kl...|POINT(103.8 1.3667)|
|2021-03-09 01:30:00|https://www.labma...|Vetter Announces ...|https://www.labma...|     eng|   

# Exploratory Data Analysis (DocTone)

In [15]:
from pyspark.sql.functions import col, min, max, avg, stddev, percentile_approx

stats = df.select(
    min(col("DocTone")).alias("min_DocTone"),
    max(col("DocTone")).alias("max_DocTone"),
    avg(col("DocTone")).alias("avg_DocTone"),
    stddev(col("DocTone")).alias("stddev_DocTone")
).collect()[0]

print("DocTone Statistics:")
print(f"Min: {stats['min_DocTone']}")
print(f"Max: {stats['max_DocTone']}")
print(f"Average: {stats['avg_DocTone']}")
print(f"Standard Deviation: {stats['stddev_DocTone']}")


DocTone Statistics:
Min: -28.947368621826172
Max: 34.78260803222656
Average: -0.03639324925228623
Standard Deviation: 3.1345220942204715


In [16]:
# Calculate quantiles
quantiles = [0.0, 0.25, 0.5, 0.75, 1.0]

# Calculate quantiles using approxQuantile
quantile_values = df.approxQuantile("DocTone", quantiles, relativeError=0.001)

# Print quantiles
for q, value in zip(quantiles, quantile_values):
    print(f"{int(q*100)}th percentile: {value}")


0th percentile: -28.947368621826172
25th percentile: -2.0202019214630127
50th percentile: 0.0
75th percentile: 1.9910084009170532
100th percentile: 34.78260803222656


In [17]:
# Sample data for visualization (100,000 samples)
sample_df = df.select("DocTone").sample(False, 0.01, seed=42).limit(100000)

# Collect DocTone values
doc_tone_values = sample_df.rdd.map(lambda row: row['DocTone']).collect()

import matplotlib.pyplot as plt
import seaborn as sns

# Plot distribution of DocTone scores
plt.figure(figsize=(10, 6))
sns.histplot(doc_tone_values, bins=50, kde=True)
plt.title('Distribution of DocTone Scores (Sampled Data)')
plt.xlabel('DocTone Score')
plt.ylabel('Frequency')
plt.show()


Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 8.0 failed 1 times, most recent failure: Lost task 0.0 in stage 8.0 (TID 7) (DESKTOP-LVRFT6Q executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "C:\miniconda3\envs\v\Lib\site-packages\pyspark\python\lib\pyspark.zip\pyspark\worker.py", line 1100, in main
    raise PySparkRuntimeError(
pyspark.errors.exceptions.base.PySparkRuntimeError: [PYTHON_VERSION_MISMATCH] Python in worker has different version (3, 11) than that in driver 3.10, PySpark cannot run with different minor versions.
Please check environment variables PYSPARK_PYTHON and PYSPARK_DRIVER_PYTHON are correctly set.

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:572)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:784)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:766)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:525)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:366)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)
	at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)
	at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)
	at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
	at org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1049)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2433)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at java.base/java.lang.Thread.run(Thread.java:829)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2856)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2792)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2791)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2791)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1247)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3060)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2994)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2983)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:989)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2393)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2414)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2433)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2458)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1049)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:410)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1048)
	at org.apache.spark.api.python.PythonRDD$.collectAndServe(PythonRDD.scala:195)
	at org.apache.spark.api.python.PythonRDD.collectAndServe(PythonRDD.scala)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:829)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "C:\miniconda3\envs\v\Lib\site-packages\pyspark\python\lib\pyspark.zip\pyspark\worker.py", line 1100, in main
    raise PySparkRuntimeError(
pyspark.errors.exceptions.base.PySparkRuntimeError: [PYTHON_VERSION_MISMATCH] Python in worker has different version (3, 11) than that in driver 3.10, PySpark cannot run with different minor versions.
Please check environment variables PYSPARK_PYTHON and PYSPARK_DRIVER_PYTHON are correctly set.

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:572)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:784)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:766)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:525)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:366)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)
	at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)
	at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)
	at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
	at org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1049)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2433)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	... 1 more


DocTone Statistics:

Min: -28.947368621826172

Max: 34.78260803222656

Average: -0.03639324925228623

Standard Deviation: 3.1345220942204715

0th percentile: -28.947368621826172

25th percentile: -2.0202019214630127

50th percentile: 0.0

75th percentile: 1.9910084009170532

100th percentile: 34.78260803222656

# TF-IDF + Logistic Regression

In [28]:
# Create sentiment label: Positive (2), Neutral (1), Negative (0)
def sentiment_label(score):
    if score > 1.9910:
        return 2
    elif score < -2.0202:
        return 0
    else:
        return 1
    
sentiment_udf = udf(sentiment_label, IntegerType())

df = df.withColumn("label", sentiment_udf(col("DocTone")))

label_counts = df.groupBy("label").count().orderBy("label")
label_counts.show()

[Stage 14:>                 (0 + 1) / 1][Stage 24:>                 (0 + 1) / 1]

+-----+-------+
|label|  count|
+-----+-------+
|    0|2297409|
|    1|4586529|
|    2|2301367|
+-----+-------+



                                                                                

In [29]:
# Tokenize
tokenizer = Tokenizer(inputCol="ContextualText", outputCol="words")

# Remove stop words
remover = StopWordsRemover(inputCol="words", outputCol="filtered_words")

# Feature extraction (TF-IDF)
hashingTF = HashingTF(inputCol="filtered_words", outputCol="rawFeatures", numFeatures=10000)
idf = IDF(inputCol="rawFeatures", outputCol="features")

pipeline = Pipeline(stages=[tokenizer, remover, hashingTF, idf])

In [30]:
(trainingData, testData) = df.randomSplit([0.8, 0.2], seed=42)

pipelineModel = pipeline.fit(trainingData)
trainingData = pipelineModel.transform(trainingData)
testData = pipelineModel.transform(testData)

# Train logistic regression model
lr = LogisticRegression(featuresCol='features', labelCol='label', maxIter=10, family='multinomial')
lrModel = lr.fit(trainingData)


24/09/27 01:17:56 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
24/09/27 01:18:50 WARN MemoryStore: Not enough space to cache rdd_164_0 in memory! (computed 17.0 MiB so far)
24/09/27 01:18:50 WARN BlockManager: Persisting block rdd_164_0 to disk instead.
24/09/27 01:26:14 WARN MemoryStore: Not enough space to cache rdd_164_0 in memory! (computed 419.2 MiB so far)
24/09/27 01:26:21 WARN MemoryStore: Not enough space to cache rdd_164_0 in memory! (computed 419.2 MiB so far)
24/09/27 01:26:31 WARN MemoryStore: Not enough space to cache rdd_164_0 in memory! (computed 419.2 MiB so far)
24/09/27 01:26:40 WARN MemoryStore: Not enough space to cache rdd_164_0 in memory! (computed 419.2 MiB so far)
24/09/27 01:26:52 WARN MemoryStore: Not enough space to cache rdd_164_0 in memory! (computed 419.2 MiB so far)
24/09/27 01:27:03 WARN MemoryStore: Not enough space to cache rdd_164_0 in memory! (computed 419.2 MiB so far)
24/09/27 01:27:15 WARN MemoryStore: 

In [32]:
# Make predictions
predictions = lrModel.transform(testData)

In [33]:
# Evaluate model
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Test Accuracy = %g " % accuracy)

[Stage 40:>                                                         (0 + 1) / 1]

Test Accuracy = 0.729166 


                                                                                

In [35]:
# Detailed evaluation
evaluator_precision = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="weightedPrecision")
precision = evaluator_precision.evaluate(predictions)

evaluator_recall = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="weightedRecall")
recall = evaluator_recall.evaluate(predictions)

evaluator_f1 = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="f1")
f1 = evaluator_f1.evaluate(predictions)

print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1 Score: {f1}")

[Stage 52:>                                                         (0 + 1) / 1]

Precision: 0.7310097965622986
Recall: 0.7291663150352512
F1 Score: 0.728252326514687


                                                                                