In [1]:
import time
import os
import sys

In [2]:
os.environ['PYSPARK_PYTHON'] = "/home/group15/.conda/envs/group15/bin/python"
os.environ['PYSPARK_DRIVER_PYTHON'] = "/home/group15/.conda/envs/group15/bin/python"

# Spark Setup

In [3]:
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.ml.feature import VectorAssembler, StringIndexer, StandardScaler
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier, RandomForestClassificationModel
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

import json
from itertools import chain
import pycountry_convert as pc

from twitter.config import LANG_MAP

In [4]:
from pyspark.sql import SparkSession

# config the connector jar file
spark = (SparkSession.builder.appName("SimpleSparkJob").master("spark://34.142.194.212:7077")
         .config("spark.jars","/opt/spark/jars/gcs-connector-latest-hadoop2.jar")
         .config("spark.executor.memory", "1G")  #excutor excute only 2G
         .config("spark.driver.memory","1G") 
         .config("spark.executor.cores","1") # Cluster use only 3 cores to excute as it has 3 server
         .config("spark.python.worker.memory","1G") # each worker use 1G to excute
         .config("spark.driver.maxResultSize","1G") #Maximum size of result is 3G
         .config("spark.kryoserializer.buffer.max","1024M")
         .getOrCreate())

# datetime migration
spark.conf.set("spark.sql.legacy.timeParserPolicy","LEGACY")

# config the credential to identify the google cloud hadoop file 
spark.conf.set("google.cloud.auth.service.account.json.keyfile","/opt/bucket_connector/lucky-wall-393304-3fbad5f3943c.json")
spark._jsc.hadoopConfiguration().set('fs.gs.impl', 'com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem')
spark._jsc.hadoopConfiguration().set('fs.gs.auth.service.account.enable', 'true')

23/12/24 15:46:16 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/12/24 15:46:18 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
23/12/24 15:46:18 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


In [5]:
spark

# Data Cleaning

In [6]:
# load_profile
PROFILE_PATH = "gs://it4043e-it5384/it4043e/it4043e_group15_problem1/output/profiles.jsonl"
def load_profile(fp):
    profile_df = spark.read.json(fp)
    return profile_df
profile_df = load_profile(PROFILE_PATH)
profile_df.show(1)

23/12/24 15:47:27 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

+--------------------+--------------------+--------------------+--------------------+
|            follower|           following|               tweet|                user|
+--------------------+--------------------+--------------------+--------------------+
|[1700886965070209...|[241664456, 35731...|[{snscrape.module...|{snscrape.modules...|
+--------------------+--------------------+--------------------+--------------------+
only showing top 1 row



## User Cleaning

In [39]:
# _extract_users
def _extract_users(profile_df):
    user_df = profile_df.select(F.col("user.*"))
    return user_df

user_df = _extract_users(profile_df)
user_df.show(1)

+--------------------+----+--------+--------------------+--------------------+--------------------+---------------+--------------+------------+----------+----------+-----------+--------+----------+--------------------+--------------------+---------+--------------------+-------------+--------------------+--------------+--------+
|               _type|blue|blueType|             created|    descriptionLinks|         displayname|favouritesCount|followersCount|friendsCount|        id|    id_str|listedCount|location|mediaCount|    profileBannerUrl|     profileImageUrl|protected|      rawDescription|statusesCount|                 url|      username|verified|
+--------------------+----+--------+--------------------+--------------------+--------------------+---------------+--------------+------------+----------+----------+-----------+--------+----------+--------------------+--------------------+---------+--------------------+-------------+--------------------+--------------+--------+
|snscrape.

In [40]:
# _label_users
LOCATION_PATH = "output/location.json"
MISSING_CN_TO_CNT = {
    "Kosovo": "Europe",
    "Palestinian Territories": "Asia",
    "North Pole": "Arctic",
    "Ascension and Tristan da Cunha": "Africa",
    "Gornja Siga": "Europe"
}

def country_to_continent(country_name: str) -> str:
    try:
        country_alpha2 = pc.country_name_to_country_alpha2(country_name)
        country_continent_code = pc.country_alpha2_to_continent_code(country_alpha2)
        country_continent_name = pc.convert_continent_code_to_continent_name(country_continent_code)
        return country_continent_name
    except:
        if country_name in ["Europe", "Africa", "Antarctica"]:
            return country_name
        elif country_name in MISSING_CN_TO_CNT:
            return MISSING_CN_TO_CNT[country_name]
        return None

def _label_users(user_df):
    with open(LOCATION_PATH, "r") as f:
        location_dict = json.load(f)
    for k, v in location_dict.items():
        if isinstance(v , str):
            if v == "None":
                v = None
            else:
                v = v.strip()
        location_dict[k] = v        
    for unk_geo in ["nan", ""]:
        location_dict[unk_geo] = None
        
    country_encode_fn = F.create_map([F.lit(x) for x in chain(*location_dict.items())])
    user_df = user_df.withColumn("country", country_encode_fn[F.col("location")])
        
    for k, v in location_dict.items():
        if v is not None:
            location_dict[k] = country_to_continent(v)
        else:
            location_dict[k] = None
        
    label_encode_fn = F.create_map([F.lit(x) for x in chain(*location_dict.items())])
    user_df = user_df.withColumn("location", label_encode_fn[F.col("location")])
    
    return user_df

user_df = _label_users(user_df)
user_df.where(F.col("location").isNotNull()).show(1)

+--------------------+----+--------+--------------------+--------------------+--------------------+---------------+--------------+------------+-------------------+-------------------+-----------+-------------+----------+--------------------+--------------------+---------+--------------------+-------------+--------------------+-----------+--------+-------------+
|               _type|blue|blueType|             created|    descriptionLinks|         displayname|favouritesCount|followersCount|friendsCount|                 id|             id_str|listedCount|     location|mediaCount|    profileBannerUrl|     profileImageUrl|protected|      rawDescription|statusesCount|                 url|   username|verified|      country|
+--------------------+----+--------+--------------------+--------------------+--------------------+---------------+--------------+------------+-------------------+-------------------+-----------+-------------+----------+--------------------+--------------------+---------+

In [41]:
# clean_users
def clean_users(
    profile_df
):
    # extract users from crawled data
    user_df = _extract_users(profile_df)
    
    # consistent & unique
    user_df = user_df \
                .withColumnRenamed("id", "user_id") \
                .dropDuplicates(['user_id'])
    user_df = user_df.na.drop(subset=["user_id"])
    user_df = user_df.withColumn("created", F.to_timestamp(user_df.created, 'yyyy-MM-dd HH:mm:ss'))
    
    # handle target variable
    user_df = _label_users(user_df)
    
    # drop all-NULL rows
    user_df = user_df.na.drop("all")
    
    return user_df

user_df = clean_users(profile_df)
user_df.show(1)

[Stage 91:>                                                         (0 + 2) / 2]

+--------------------+----+--------+-------------------+--------------------+-----------+---------------+--------------+------------+-------+------+-----------+-------------+----------+--------------------+--------------------+---------+--------------------+-------------+--------------------+----------+--------+-------------+
|               _type|blue|blueType|            created|    descriptionLinks|displayname|favouritesCount|followersCount|friendsCount|user_id|id_str|listedCount|     location|mediaCount|    profileBannerUrl|     profileImageUrl|protected|      rawDescription|statusesCount|                 url|  username|verified|      country|
+--------------------+----+--------+-------------------+--------------------+-----------+---------------+--------------+------------+-------+------+-----------+-------------+----------+--------------------+--------------------+---------+--------------------+-------------+--------------------+----------+--------+-------------+
|snscrape.module

                                                                                

In [26]:
user_df.printSchema()

root
 |-- _type: string (nullable = true)
 |-- blue: boolean (nullable = true)
 |-- blueType: string (nullable = true)
 |-- created: timestamp (nullable = true)
 |-- descriptionLinks: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- tcourl: string (nullable = true)
 |    |    |-- text: string (nullable = true)
 |    |    |-- url: string (nullable = true)
 |-- displayname: string (nullable = true)
 |-- favouritesCount: long (nullable = true)
 |-- followersCount: long (nullable = true)
 |-- friendsCount: long (nullable = true)
 |-- user_id: long (nullable = true)
 |-- id_str: string (nullable = true)
 |-- listedCount: long (nullable = true)
 |-- location: string (nullable = true)
 |-- mediaCount: long (nullable = true)
 |-- profileBannerUrl: string (nullable = true)
 |-- profileImageUrl: string (nullable = true)
 |-- protected: string (nullable = true)
 |-- rawDescription: string (nullable = true)
 |-- statusesCount: long (nullable = true)
 |-- url: 

## Tweet Cleaning

In [27]:
# _extract_tweets
def _extract_tweets(profile_df):
    tweet_df = profile_df \
                .select(F.col("user.id").alias("user_id"), "tweet") \
                .withColumn("tweet", F.explode("tweet")) \
                .select("user_id", F.col("tweet.*"))
    
    return tweet_df

tweet_df = _extract_tweets(profile_df)
tweet_df.show(1)

[Stage 49:>                                                         (0 + 1) / 1]

+----------+--------------------+--------+-------------------+-----------+--------------------+---------+-------------------+-------------------+----------------+-------------+----+---------+-----+--------------------+--------------+-----+----------+-----------+--------------------+----------+------------+--------------+--------------------+-----------+-----------------+--------------------+---------+
|   user_id|               _type|cashtags|     conversationId|coordinates|                date| hashtags|                 id|             id_str|inReplyToTweetId|inReplyToUser|lang|likeCount|links|               media|mentionedUsers|place|quoteCount|quotedTweet|          rawContent|replyCount|retweetCount|retweetedTweet|              source|sourceLabel|        sourceUrl|                 url|viewCount|
+----------+--------------------+--------+-------------------+-----------+--------------------+---------+-------------------+-------------------+----------------+-------------+----+---------

                                                                                

In [28]:
# clean_tweets
def clean_tweets(
    profile_df
):
    # extract tweets from crawled data
    tweet_df = _extract_tweets(profile_df) # <-- id -> user_id here
    
    # consistent & duplicate
    tweet_df = tweet_df.dropDuplicates(["id"])
    tweet_df = tweet_df.na.drop(subset=["id"])
    
    # drop all-NULL rows
    tweet_df = tweet_df.na.drop("all")

    return tweet_df

tweet_df = clean_tweets(profile_df)
tweet_df.show(1)



+---------+--------------------+--------+--------------+-----------+--------------------+--------+----------+----------+----------------+-------------+----+---------+-----+------------+--------------+-----+----------+-----------+---------------+----------+------------+--------------+--------------------+------------------+------------------+--------------------+---------+
|  user_id|               _type|cashtags|conversationId|coordinates|                date|hashtags|        id|    id_str|inReplyToTweetId|inReplyToUser|lang|likeCount|links|       media|mentionedUsers|place|quoteCount|quotedTweet|     rawContent|replyCount|retweetCount|retweetedTweet|              source|       sourceLabel|         sourceUrl|                 url|viewCount|
+---------+--------------------+--------+--------------+-----------+--------------------+--------+----------+----------+----------------+-------------+----+---------+-----+------------+--------------+-----+----------+-----------+---------------+-----

                                                                                

In [29]:
tweet_df.drop(*["retweetedTweet", "quotedTweet"]).printSchema()

root
 |-- user_id: long (nullable = true)
 |-- _type: string (nullable = true)
 |-- cashtags: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- conversationId: long (nullable = true)
 |-- coordinates: string (nullable = true)
 |-- date: string (nullable = true)
 |-- hashtags: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- id: long (nullable = true)
 |-- id_str: string (nullable = true)
 |-- inReplyToTweetId: long (nullable = true)
 |-- inReplyToUser: struct (nullable = true)
 |    |-- _type: string (nullable = true)
 |    |-- displayname: string (nullable = true)
 |    |-- id: long (nullable = true)
 |    |-- username: string (nullable = true)
 |-- lang: string (nullable = true)
 |-- likeCount: long (nullable = true)
 |-- links: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- tcourl: string (nullable = true)
 |    |    |-- text: string (nullable = true)
 |    |    |-- url: string (nullable =

# Data Processing

## Tweet Processing

In [30]:
# preprocess_tweets
def preprocess_tweets(
    tweet_df,
    keep_cols = ["user_id", "replyCount", "retweetCount", 
                 "likeCount", "quoteCount", "viewCount", "lang"]
):
    # drop unecessary fields for ML tasks
    tweet_df = tweet_df.select(*keep_cols)
    
    # drop any-NULL rows
    tweet_df = tweet_df.na.drop()
    
    # agg. by user
    tweet_df = tweet_df.groupBy("user_id").agg(
        F.mean("replyCount").alias("replyCount"),
        F.mean("retweetCount").alias("retweetCount"),
        F.mean("likeCount").alias("likeCount"),
        F.mean("quoteCount").alias("quoteCount"),
        F.mean("viewCount").alias("viewCount"),
        F.mode("lang").alias("lang")
    )
    
    return tweet_df

tweet_df = preprocess_tweets(tweet_df)
tweet_df.show(1)



+-------+------------------+-----------------+------------------+----------+------------------+----+
|user_id|        replyCount|     retweetCount|         likeCount|quoteCount|         viewCount|lang|
+-------+------------------+-----------------+------------------+----------+------------------+----+
| 816653|2.4761904761904763|5.380952380952381|12.666666666666666|       1.0|22106.761904761905|  en|
+-------+------------------+-----------------+------------------+----------+------------------+----+
only showing top 1 row



                                                                                

In [31]:
tweet_df.printSchema()

root
 |-- user_id: long (nullable = true)
 |-- replyCount: double (nullable = true)
 |-- retweetCount: double (nullable = true)
 |-- likeCount: double (nullable = true)
 |-- quoteCount: double (nullable = true)
 |-- viewCount: double (nullable = true)
 |-- lang: string (nullable = true)



## User Processing

In [42]:
# preprocess_users
def preprocess_users(
    user_df,
    neglect_cols = ['profileImageUrl','profileBannerUrl','descriptionLinks','_type', 'verified',
                    'id_str','url','created', 'rawDescription', 'protected', 'blueType', 'displayname', 'country']
):
    # drop unecessary fields for ML tasks
    #user_df = user_df.drop(*drop_cols)
    
    # drop rows with NaN values in columns other than "location"
    process_cols = [col_name for col_name in user_df.columns if col_name not in (["location"] + neglect_cols)]
    user_df = user_df.na.drop(subset=process_cols)
    
    return user_df

user_df = preprocess_users(user_df)
user_df.show(1)



+--------------------+----+--------+-------------------+--------------------+-----------+---------------+--------------+------------+-------+------+-----------+-------------+----------+--------------------+--------------------+---------+--------------------+-------------+--------------------+----------+--------+-------------+
|               _type|blue|blueType|            created|    descriptionLinks|displayname|favouritesCount|followersCount|friendsCount|user_id|id_str|listedCount|     location|mediaCount|    profileBannerUrl|     profileImageUrl|protected|      rawDescription|statusesCount|                 url|  username|verified|      country|
+--------------------+----+--------+-------------------+--------------------+-----------+---------------+--------------+------------+-------+------+-----------+-------------+----------+--------------------+--------------------+---------+--------------------+-------------+--------------------+----------+--------+-------------+
|snscrape.module

                                                                                

In [43]:
user_df.printSchema()

root
 |-- _type: string (nullable = true)
 |-- blue: boolean (nullable = true)
 |-- blueType: string (nullable = true)
 |-- created: timestamp (nullable = true)
 |-- descriptionLinks: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- tcourl: string (nullable = true)
 |    |    |-- text: string (nullable = true)
 |    |    |-- url: string (nullable = true)
 |-- displayname: string (nullable = true)
 |-- favouritesCount: long (nullable = true)
 |-- followersCount: long (nullable = true)
 |-- friendsCount: long (nullable = true)
 |-- user_id: long (nullable = true)
 |-- id_str: string (nullable = true)
 |-- listedCount: long (nullable = true)
 |-- location: string (nullable = true)
 |-- mediaCount: long (nullable = true)
 |-- profileBannerUrl: string (nullable = true)
 |-- profileImageUrl: string (nullable = true)
 |-- protected: string (nullable = true)
 |-- rawDescription: string (nullable = true)
 |-- statusesCount: long (nullable = true)
 |-- url: 

## Data Integration

In [44]:
def merge_data(
    df_user, df_tweet, 
    fillna_num=0, fillna_cat="unk", left_join_on_cols=['user_id'],
):
    # perform a left join on specified columns
    data_spark = df_user.join(df_tweet, on=left_join_on_cols, how='left')

    # replace NaN values in columns for tweets
    num_cols = [col_name for col_name in df_tweet.columns if col_name != "location" and col_name != "lang"]
    cat_cols = ["lang"]
    data_spark = data_spark.fillna(fillna_num, subset=num_cols)
    data_spark = data_spark.fillna(fillna_cat, subset=cat_cols)
                            
    return data_spark

data = merge_data(user_df, tweet_df)
data.show(1)

                                                                                

+-------+--------------------+----+--------+-------------------+--------------------+-----------+---------------+--------------+------------+------+-----------+-------------+----------+--------------------+--------------------+---------+--------------------+-------------+--------------------+----------+--------+-------------+------------------+-----------------+------------------+----------+------------------+----+
|user_id|               _type|blue|blueType|            created|    descriptionLinks|displayname|favouritesCount|followersCount|friendsCount|id_str|listedCount|     location|mediaCount|    profileBannerUrl|     profileImageUrl|protected|      rawDescription|statusesCount|                 url|  username|verified|      country|        replyCount|     retweetCount|         likeCount|quoteCount|         viewCount|lang|
+-------+--------------------+----+--------+-------------------+--------------------+-----------+---------------+--------------+------------+------+-----------+--

In [45]:
data.printSchema()

root
 |-- user_id: long (nullable = true)
 |-- _type: string (nullable = true)
 |-- blue: boolean (nullable = true)
 |-- blueType: string (nullable = true)
 |-- created: timestamp (nullable = true)
 |-- descriptionLinks: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- tcourl: string (nullable = true)
 |    |    |-- text: string (nullable = true)
 |    |    |-- url: string (nullable = true)
 |-- displayname: string (nullable = true)
 |-- favouritesCount: long (nullable = true)
 |-- followersCount: long (nullable = true)
 |-- friendsCount: long (nullable = true)
 |-- id_str: string (nullable = true)
 |-- listedCount: long (nullable = true)
 |-- location: string (nullable = true)
 |-- mediaCount: long (nullable = true)
 |-- profileBannerUrl: string (nullable = true)
 |-- profileImageUrl: string (nullable = true)
 |-- protected: string (nullable = true)
 |-- rawDescription: string (nullable = true)
 |-- statusesCount: long (nullable = true)
 |-- url: 

In [46]:
data.count()

                                                                                

2149

In [47]:
data.where(F.col("location").isNotNull()).count()

                                                                                

1000

# Data Transformation

In [48]:
# load_data
def extract_data(data):
    # drop identity columns
    #data = data.drop("user_id", "username")

    # split data into labeled and unlabeled based on the "location" column
    labeled = data.filter(F.col("location").isNotNull())
    unlabeled = data.filter(F.col("location").isNull())
    
    return labeled, unlabeled

labeled, unlabeled = extract_data(data)
labeled.show(1)

                                                                                

+-------+--------------------+----+--------+-------------------+--------------------+-----------+---------------+--------------+------------+------+-----------+-------------+----------+--------------------+--------------------+---------+--------------------+-------------+--------------------+----------+--------+-------------+------------------+-----------------+------------------+----------+------------------+----+
|user_id|               _type|blue|blueType|            created|    descriptionLinks|displayname|favouritesCount|followersCount|friendsCount|id_str|listedCount|     location|mediaCount|    profileBannerUrl|     profileImageUrl|protected|      rawDescription|statusesCount|                 url|  username|verified|      country|        replyCount|     retweetCount|         likeCount|quoteCount|         viewCount|lang|
+-------+--------------------+----+--------+-------------------+--------------------+-----------+---------------+--------------+------------+------+-----------+--

In [50]:
# encode_data
LABEL_MAP = {'Africa': 0, 'Antarctica': 1, 'Asia': 2, 'Europe': 3, 'North America': 4, 'Oceania': 5, 'South America': 6, 'Arctic': 7}
def transform_data(
    df,
    neglect_cols = ['profileImageUrl','profileBannerUrl','descriptionLinks','_type', 'verified',
                    'id_str','url','created', 'rawDescription', 'protected', 'blueType', 'displayname', 'country']
):
    # encode categorical features
    bin_len = 6
    label_map = F.create_map([F.lit(x) for x in chain(*LABEL_MAP.items())])
    lang_map = F.create_map([F.lit(x) for x in chain(*LANG_MAP.items())])
    df = df \
            .withColumn("location", label_map[F.col("location")]) \
            .withColumn("blue_tmp", F.when(F.col("blue") == True, 1).otherwise(0)) \
            .withColumn("lang_tmp", lang_map[F.col("lang")])
    df = df.withColumn("lang_tmp", F.lpad(F.bin(df["lang_tmp"]), bin_len, '0')) # binary encoding
    
    split_lang = F.split(df['lang_tmp'], "")
    for i in range(bin_len):
        df = df \
                .withColumn(f'lang_{i}', split_lang.getItem(i).cast(T.IntegerType()))
    df = df.drop("lang_tmp")
    
    # enumerate types of features
    num_cols = [name for name, datatype in df.dtypes if (not datatype.startswith('string')) and (not name.startswith('lang')) 
                and (not name.startswith('blue')) and name not in (["location", "user_id"] + neglect_cols)]
    cat_cols = ["blue_tmp"] + [name for name, datatype in df.dtypes if name.startswith('lang') and name not in (neglect_cols + ["lang"])]
    
    # define processing ops
    features = num_cols + cat_cols
    
    num_assembler = VectorAssembler(inputCols=num_cols, outputCol="num_feats")
    scaler = StandardScaler(inputCol="num_feats", outputCol="num_feats_scaled", withStd=True, withMean=True)
    
    cat_assembler = VectorAssembler(inputCols=cat_cols, outputCol="cat_feats", 
                                    handleInvalid ="keep")
    
    all_assembler = VectorAssembler(inputCols= ["num_feats_scaled", "cat_feats"], outputCol="all_feats_scaled", 
                                    handleInvalid ="keep")

    # assemble a pipeline
    pipeline = Pipeline(stages=[num_assembler, scaler, cat_assembler, all_assembler])

    # fit transform on pipeline
    encode_data = pipeline.fit(df).transform(df)

    # select features + targets
    selected_cols = ["all_feats_scaled"] + df.columns
    encode_data = encode_data.select(*selected_cols)

    return encode_data

labeled_prep = transform_data(labeled)
labeled_prep.show(1)

                                                                                

+--------------------+-------+--------------------+----+--------+-------------------+--------------------+-----------+---------------+--------------+------------+------+-----------+--------+----------+--------------------+--------------------+---------+--------------------+-------------+--------------------+----------+--------+-------------+------------------+-----------------+------------------+----------+------------------+----+--------+------+------+------+------+------+------+
|    all_feats_scaled|user_id|               _type|blue|blueType|            created|    descriptionLinks|displayname|favouritesCount|followersCount|friendsCount|id_str|listedCount|location|mediaCount|    profileBannerUrl|     profileImageUrl|protected|      rawDescription|statusesCount|                 url|  username|verified|      country|        replyCount|     retweetCount|         likeCount|quoteCount|         viewCount|lang|blue_tmp|lang_0|lang_1|lang_2|lang_3|lang_4|lang_5|
+--------------------+------

# Modelling

In [36]:
# initialize RandomForestClassifier
rf_clf = RandomForestClassifier(labelCol="location", featuresCol="all_feats_scaled", 
                                numTrees=95, maxDepth=10, seed=42)
scoreboard = {}

# set up cross-validation
for metric in ["weightedFMeasure", "weightedPrecision", "weightedRecall", "accuracy"]:
    print(f"Validating model using {metric}")
    paramGrid = ParamGridBuilder().build()
    cross_validator = CrossValidator(
        estimator=rf_clf,
        estimatorParamMaps=paramGrid, 
        evaluator=MulticlassClassificationEvaluator(labelCol="location", metricName=metric), 
        numFolds=2
    )

    # cross-validate on labeled data
    model = cross_validator.fit(labeled_prep)

    # avg. score from the cross-validated model
    scoreboard[metric] = model.avgMetrics[0]
    
# show results
scoreboard

Validating model using weightedFMeasure


23/12/23 18:03:20 WARN DAGScheduler: Broadcasting large task binary with size 1031.1 KiB
23/12/23 18:03:20 WARN DAGScheduler: Broadcasting large task binary with size 1311.8 KiB
23/12/23 18:03:20 WARN DAGScheduler: Broadcasting large task binary with size 1572.8 KiB
23/12/23 18:03:21 WARN DAGScheduler: Broadcasting large task binary with size 1794.2 KiB
23/12/23 18:03:21 WARN DAGScheduler: Broadcasting large task binary with size 1929.8 KiB
23/12/23 18:03:29 WARN DAGScheduler: Broadcasting large task binary with size 1677.3 KiB
23/12/23 18:03:37 WARN DAGScheduler: Broadcasting large task binary with size 1233.5 KiB
23/12/23 18:03:38 WARN DAGScheduler: Broadcasting large task binary with size 1483.9 KiB
23/12/23 18:03:38 WARN DAGScheduler: Broadcasting large task binary with size 1715.1 KiB
23/12/23 18:03:38 WARN DAGScheduler: Broadcasting large task binary with size 1878.5 KiB
23/12/23 18:03:45 WARN DAGScheduler: Broadcasting large task binary with size 1621.0 KiB
23/12/23 18:04:02 WAR

Validating model using weightedPrecision


23/12/23 18:04:18 WARN DAGScheduler: Broadcasting large task binary with size 1049.4 KiB
23/12/23 18:04:18 WARN DAGScheduler: Broadcasting large task binary with size 1347.5 KiB
23/12/23 18:04:19 WARN DAGScheduler: Broadcasting large task binary with size 1628.1 KiB
23/12/23 18:04:19 WARN DAGScheduler: Broadcasting large task binary with size 1842.1 KiB
23/12/23 18:04:20 WARN DAGScheduler: Broadcasting large task binary with size 1915.3 KiB
23/12/23 18:04:26 WARN DAGScheduler: Broadcasting large task binary with size 1716.0 KiB
23/12/23 18:04:35 WARN DAGScheduler: Broadcasting large task binary with size 1218.1 KiB
23/12/23 18:04:36 WARN DAGScheduler: Broadcasting large task binary with size 1448.7 KiB
23/12/23 18:04:36 WARN DAGScheduler: Broadcasting large task binary with size 1644.5 KiB
23/12/23 18:04:36 WARN DAGScheduler: Broadcasting large task binary with size 1780.9 KiB
23/12/23 18:04:43 WARN DAGScheduler: Broadcasting large task binary with size 1574.1 KiB
23/12/23 18:04:59 WAR

Validating model using weightedRecall


23/12/23 18:05:14 WARN DAGScheduler: Broadcasting large task binary with size 1049.4 KiB
23/12/23 18:05:14 WARN DAGScheduler: Broadcasting large task binary with size 1347.5 KiB
23/12/23 18:05:15 WARN DAGScheduler: Broadcasting large task binary with size 1628.1 KiB
23/12/23 18:05:15 WARN DAGScheduler: Broadcasting large task binary with size 1842.1 KiB
23/12/23 18:05:16 WARN DAGScheduler: Broadcasting large task binary with size 1915.3 KiB
23/12/23 18:05:22 WARN DAGScheduler: Broadcasting large task binary with size 1716.0 KiB
23/12/23 18:05:30 WARN DAGScheduler: Broadcasting large task binary with size 1233.5 KiB
23/12/23 18:05:30 WARN DAGScheduler: Broadcasting large task binary with size 1483.9 KiB
23/12/23 18:05:30 WARN DAGScheduler: Broadcasting large task binary with size 1715.1 KiB
23/12/23 18:05:31 WARN DAGScheduler: Broadcasting large task binary with size 1878.5 KiB
23/12/23 18:05:37 WARN DAGScheduler: Broadcasting large task binary with size 1621.0 KiB
23/12/23 18:05:53 WAR

Validating model using accuracy


23/12/23 18:06:10 WARN DAGScheduler: Broadcasting large task binary with size 1031.1 KiB
23/12/23 18:06:10 WARN DAGScheduler: Broadcasting large task binary with size 1311.8 KiB
23/12/23 18:06:10 WARN DAGScheduler: Broadcasting large task binary with size 1572.8 KiB
23/12/23 18:06:11 WARN DAGScheduler: Broadcasting large task binary with size 1794.2 KiB
23/12/23 18:06:11 WARN DAGScheduler: Broadcasting large task binary with size 1929.8 KiB
23/12/23 18:06:17 WARN DAGScheduler: Broadcasting large task binary with size 1677.3 KiB
23/12/23 18:06:25 WARN DAGScheduler: Broadcasting large task binary with size 1218.1 KiB
23/12/23 18:06:25 WARN DAGScheduler: Broadcasting large task binary with size 1448.7 KiB
23/12/23 18:06:26 WARN DAGScheduler: Broadcasting large task binary with size 1644.5 KiB
23/12/23 18:06:26 WARN DAGScheduler: Broadcasting large task binary with size 1780.9 KiB
23/12/23 18:06:32 WARN DAGScheduler: Broadcasting large task binary with size 1574.1 KiB
23/12/23 18:06:48 WAR

{'weightedFMeasure': 0.9631500001432405,
 'weightedPrecision': 0.9678116168626494,
 'weightedRecall': 0.9701770177017702,
 'accuracy': 0.9661966196619662}

In [26]:
# re-fit model on full data
rf_clf = RandomForestClassifier(labelCol="location", featuresCol="all_feats_scaled", 
                                numTrees=95, maxDepth=10, seed=42)
rf_clf = rf_clf.fit(labeled_prep)

# save the model to bucket
rf_clf.save("gs://it4043e-it5384/it4043e/it4043e_group15_problem1/output/models/spark_rf-n_estimators_95-max_depth_11_noscale")

23/12/24 07:35:39 WARN DAGScheduler: Broadcasting large task binary with size 1396.5 KiB
23/12/24 07:35:40 WARN DAGScheduler: Broadcasting large task binary with size 1951.0 KiB
23/12/24 07:35:40 WARN DAGScheduler: Broadcasting large task binary with size 2.6 MiB
23/12/24 07:35:41 WARN DAGScheduler: Broadcasting large task binary with size 3.3 MiB
23/12/24 07:35:42 WARN DAGScheduler: Broadcasting large task binary with size 4.1 MiB
23/12/24 07:35:56 WARN TaskSetManager: Stage 203 contains a task of very large size (1501 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

# Inference

In [51]:
# load the model
rf_clf = RandomForestClassificationModel.load("gs://it4043e-it5384/it4043e/it4043e_group15_problem1/output/models/spark_rf-n_estimators_95-max_depth_11_noscale")
rf_clf

                                                                                

RandomForestClassificationModel: uid=RandomForestClassifier_a08dd87bace2, numTrees=95, numClasses=8, numFeatures=18

In [52]:
# Infer the unlabeled
unlabeled_prep = transform_data(unlabeled)
unlabeled_pred = rf_clf.transform(unlabeled_prep)
unlabeled_pred.show(1)

23/12/24 16:30:50 WARN DAGScheduler: Broadcasting large task binary with size 3.4 MiB


+--------------------+-------+--------------------+-----+--------+-------------------+--------------------+------------------------+---------------+--------------+------------+-------+-----------+--------+----------+--------------------+--------------------+---------+-----------------------+-------------+--------------------+--------+--------+-------+----------+------------+---------+----------+---------+----+--------+------+------+------+------+------+------+--------------------+--------------------+----------+
|    all_feats_scaled|user_id|               _type| blue|blueType|            created|    descriptionLinks|             displayname|favouritesCount|followersCount|friendsCount| id_str|listedCount|location|mediaCount|    profileBannerUrl|     profileImageUrl|protected|         rawDescription|statusesCount|                 url|username|verified|country|replyCount|retweetCount|likeCount|quoteCount|viewCount|lang|blue_tmp|lang_0|lang_1|lang_2|lang_3|lang_4|lang_5|       rawPredicti

In [53]:
# invert the prediction
INV_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()}
inv_label_map = F.create_map([F.lit(x) for x in chain(*INV_LABEL_MAP.items())])
unlabeled_show = unlabeled_pred \
                    .withColumn("prediction", inv_label_map[F.col("prediction")]) \
                    .select(*unlabeled.columns, "prediction") \
                    .drop("location") \
                    .withColumnRenamed("prediction", "location") \
                    .withColumn("blue", F.when(F.col("blue") == 1, True).otherwise(False))
unlabeled_show.printSchema()

root
 |-- user_id: long (nullable = true)
 |-- _type: string (nullable = true)
 |-- blue: boolean (nullable = false)
 |-- blueType: string (nullable = true)
 |-- created: timestamp (nullable = true)
 |-- descriptionLinks: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- tcourl: string (nullable = true)
 |    |    |-- text: string (nullable = true)
 |    |    |-- url: string (nullable = true)
 |-- displayname: string (nullable = true)
 |-- favouritesCount: long (nullable = true)
 |-- followersCount: long (nullable = true)
 |-- friendsCount: long (nullable = true)
 |-- id_str: string (nullable = true)
 |-- listedCount: long (nullable = true)
 |-- mediaCount: long (nullable = true)
 |-- profileBannerUrl: string (nullable = true)
 |-- profileImageUrl: string (nullable = true)
 |-- protected: string (nullable = true)
 |-- rawDescription: string (nullable = true)
 |-- statusesCount: long (nullable = true)
 |-- url: string (nullable = true)
 |-- username:

In [54]:
labeled.select(*unlabeled_show.columns).printSchema()

root
 |-- user_id: long (nullable = true)
 |-- _type: string (nullable = true)
 |-- blue: boolean (nullable = true)
 |-- blueType: string (nullable = true)
 |-- created: timestamp (nullable = true)
 |-- descriptionLinks: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- tcourl: string (nullable = true)
 |    |    |-- text: string (nullable = true)
 |    |    |-- url: string (nullable = true)
 |-- displayname: string (nullable = true)
 |-- favouritesCount: long (nullable = true)
 |-- followersCount: long (nullable = true)
 |-- friendsCount: long (nullable = true)
 |-- id_str: string (nullable = true)
 |-- listedCount: long (nullable = true)
 |-- mediaCount: long (nullable = true)
 |-- profileBannerUrl: string (nullable = true)
 |-- profileImageUrl: string (nullable = true)
 |-- protected: string (nullable = true)
 |-- rawDescription: string (nullable = true)
 |-- statusesCount: long (nullable = true)
 |-- url: string (nullable = true)
 |-- username: 

In [55]:
# union data for report
processed_data = labeled.select(*unlabeled_show.columns).union(unlabeled_show)
processed_data.count()

                                                                                

2149

# Visualization

In [56]:
from elasticsearch import Elasticsearch
from elasticsearch import helpers

def create_es_index(
    df, 
    idx_name
):
    es = Elasticsearch(['http://34.143.255.36:9200/'], http_auth=('elastic', 'elastic2023'))
    
    def index_data_to_elasticsearch(df):
        actions = []
        for row in df.rdd.toLocalIterator():
            action = {
                "_index": idx_name,
                "_source": {
                    col: row[col] for col in df.columns
                }
            }
            actions.append(action)
        helpers.bulk(es, actions)
        
    index_data_to_elasticsearch(df)

create_es_index(processed_data, "group15-spark-ml")

  es = Elasticsearch(['http://34.143.255.36:9200/'], http_auth=('elastic', 'elastic2023'))
23/12/24 16:32:04 WARN DAGScheduler: Broadcasting large task binary with size 3.5 MiB
23/12/24 16:32:05 WARN DAGScheduler: Broadcasting large task binary with size 3.5 MiB
[Stage 217:>                                                        (0 + 1) / 1]