In [0]:
from pyspark.sql import functions
from pyspark.sql.functions import regexp_replace
from pyspark.sql.types import IntegerType
import pandas as pd
import numpy as np


# 1. Replace empty entries and entries with no relevant data in each column with Nones.
def clean_struct_fields(df, struct_col, fields):
    return df.withColumn(
        struct_col,
        struct(
            *[
                when(trim(col(f"{struct_col}.{field}")) == "", lit(None))
                .otherwise(col(f"{struct_col}.{field}"))
                .alias(field)
                for field in fields
            ]
        )
    )

fields_to_clean = [
    "index",
    "unique_id",
    "title",
    "description",
    "poster_name",
    "follower_count",
    "tag_list",
    "is_image_or_video",
    "image_src",
    "downloaded",
    "save_location",
    "category"
]

df_pin = clean_struct_fields(df_pin, "data", fields_to_clean)

# 2. Ensure every entry for follower_count is a number. Make sure the data type of this column is an int.
def process_follower_count(follower_count):
    replaced_col = regexp_replace(follower_count, "k", "000")
    replaced_col = regexp_replace(replaced_col, "M", "000000")
    return when(replaced_col.rlike('^[0-9]+$'), replaced_col.cast(IntegerType())).otherwise(None)

# 3. Ensure that each column containing numeric data has a numeric data type.
def process_downloaded(downloaded):
    return when(downloaded.rlike('^[0-9]+$'), downloaded.cast(IntegerType())).otherwise(None)

df_pin = df_pin.withColumn(
    "data",
    struct(
        # 5. Rename the index column to ind.
        col("data.index").alias("ind"),
        col("data.unique_id").alias("unique_id"),
        col("data.title").alias("title"),
        col("data.description").alias("description"),
        col("data.poster_name").alias("poster_name"),
        process_follower_count(col("data.follower_count")).alias("follower_count"),
        col("data.tag_list").alias("tag_list"),
        col("data.is_image_or_video").alias("is_image_or_video"),
        col("data.image_src").alias("image_src"),
        process_downloaded(col("data.downloaded")).alias("downloaded"),
        # 4. Clean the data in the save_location column to include only the save location path.
        regexp_replace(col("data.save_location"), "Local save in ", "").alias("save_location"),
        col("data.category").alias("category")
    )
)

# 6. Reorder the DataFrame columns.
df_pin = df_pin.select(
    col("data.ind").alias("ind"),
    col("data.unique_id").alias("unique_id"),
    col("data.title").alias("title"),
    col("data.description").alias("description"),
    col("data.follower_count").alias("follower_count"),
    col("data.poster_name").alias("poster_name"),
    col("data.tag_list").alias("tag_list"),
    col("data.is_image_or_video").alias("is_image_or_video"),
    col("data.image_src").alias("image_src"),
    col("data.save_location").alias("save_location"),
    col("data.category").alias("category"),
    col("data.downloaded").alias("downloaded")
)

display(df_pin)