In [1]:
import argparse
from datetime import datetime
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
from pyspark.sql import SparkSession
import re
from pyspark.conf import SparkConf

In [2]:


# Load data from json file and return RDD
def load_data_set(spark, path):
    print("**Loading data from json file**")
    df = spark.read.json(path)
    print("Schema written to file\n")
    print(df.schema, file=open("schema", "a"))
    return df.rdd


def get_most_frequent_annotations(rdd, threshold, max_length= 9):
    print("**Getting most frequent annotations** \n")
    annotations = (
        rdd.flatMap(lambda x: x["data"])
        .map(lambda x: x["context_annotations"])
        .filter(lambda x: x is not None)
        .flatMap(lambda x: list(set([y["entity"]["name"] for y in x])))
        .map(lambda x: (x, 1))
        .reduceByKey(lambda x, y: x + y)
        .sortBy(lambda x: x[1], ascending=False)
    )

    most_frequent_annotations = annotations.filter(lambda x: x[1] > threshold).collect()

    # remove the first one which is 'Politics' (present in nearly all tweets)
    # TODO: check if this is also the case for celebrities
    most_frequent_annotations = list(map(lambda x: x[0], most_frequent_annotations))[1:max_length]

    annotation_dict = {
        annotation: index for index, annotation in enumerate(most_frequent_annotations)
    }

    return most_frequent_annotations, annotation_dict


def extract_relevant_fields(rdd):
    return (
        rdd.filter(lambda x: x["data"])
        .flatMap(lambda x: x["data"])
        .filter(lambda x: x["entities"])
        .filter(lambda x: x["context_annotations"] is not None)
        .map(
            lambda x: {
                "tweet_text": x["text"],
                "tweet_date": x["created_at"],
                "tweet_hashtags": x["entities"]["hashtags"],
                "tweet_mentions": x["entities"]["mentions"],
                "tweet_urls": x["entities"]["urls"],
                "user_id": x["author_id"],
                "tweet_id": x["id"],
                "context_annotations": x["context_annotations"],
                "impression_count": x["public_metrics"]["impression_count"],
            }
        )
    )



In [3]:
from pyspark.sql.types import StringType, IntegerType, ArrayType, StructType, StructField, TimestampType

from pyspark.sql.functions import udf

from transformers import pipeline



In [4]:

def apply_processing_pipeline(
    json_rdd,
    json_rdd_data_fields,
    most_frequent_annotations,
    annotation_dict,
    output_name,
    zero_shot_classification
):
    

    

    def group_context_annotations(x, most_frequent_annotations, annotation_dict):
        # first we see if the annotations of our tweet and the most frequent annotations overlap
        annotation_set = set([y["entity"]["name"] for y in x["context_annotations"]])

        

        intersection = annotation_set.intersection(most_frequent_annotations)

        if len(intersection) > 0:
            x['context_annotations'] = list(intersection)

        else:
            # zero shot here 

            prepared_text = " ".join(annotation_set)

            zero_shot_labels = zero_shot_classification(prepared_text, most_frequent_annotations)

            # print(zero_shot_labels)

            x['context_annotations'] = zero_shot_labels

                    
        return x 
   

    json_rdd_data_fields = json_rdd_data_fields.map(lambda x: group_context_annotations(x, most_frequent_annotations, annotation_dict))

    return json_rdd_data_fields


In [5]:
def pre_processing_pipeline(path, output_name, workers, annotation_threshold):
    spark = (
        SparkSession.builder.appName("tweet_loader")
        .master(f"local[{workers}]")
        .config("spark.driver.memory", "15g")
        .getOrCreate()
    )

    print("**SparkContext created**")
    print(f"GUI: {spark.sparkContext.uiWebUrl}")
    print(f"AppName: {spark.sparkContext.appName}\n")

    rdd = load_data_set(spark, path)

    most_frequent_annotations, annotation_dict = get_most_frequent_annotations(
        rdd, annotation_threshold
    )

    print(f"Most frequent annotations: {most_frequent_annotations}\n")

    print(f"Annotation dictionary: {annotation_dict}\n")

    rdd_subset = extract_relevant_fields(rdd)

    pipe =  spark.sparkContext.broadcast(pipeline(model="valhalla/distilbart-mnli-12-9"))

    def zero_shot_classification(text, labels):
        resp =  pipe.value(text, labels, multi_label=False)

        labels_scores = zip(resp['labels'],resp['scores'])

        labels_scores = filter(lambda x: x[1] > 0.3, labels_scores)

        predicted_labels = list(map(lambda x: x[0], labels_scores))

        if len(predicted_labels) > 0:
            return predicted_labels
        else:
            return [resp['labels'][0]]

    regression_df = apply_processing_pipeline(
        rdd, rdd_subset, most_frequent_annotations, annotation_dict, output_name, zero_shot_classification
    )

    return regression_df



In [6]:
pre_processing_pipeline("../data/american_celebrities/tweets.jsonl", "test", 8, 1000).take(500)

# starts ~25s

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


23/05/25 11:30:19 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
**SparkContext created**
GUI: http://tsf-428-wpa-1-215.epfl.ch:4040
AppName: tweet_loader

**Loading data from json file**


                                                                                

Schema written to file

**Getting most frequent annotations** 



                                                                                

Most frequent annotations: ['Sports', 'Entertainment & Leisure Business', 'Basketball', 'NBA', 'NBA Basketball', 'Music', 'NBA players', 'Politics']

Annotation dictionary: {'Sports': 0, 'Entertainment & Leisure Business': 1, 'Basketball': 2, 'NBA': 3, 'NBA Basketball': 4, 'Music': 5, 'NBA players': 6, 'Politics': 7}

