# Install and Import

In [None]:
#%pip install google-api-python-client


In [None]:
import googleapiclient.discovery
import math
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql.functions import lit

# Setup and initialise

In [None]:
api_service_name = "youtube"
api_version = "v3"
DEVELOPER_KEY = "Your API Key here"


In [None]:
youtube = googleapiclient.discovery.build(
    api_service_name, api_version, developerKey = DEVELOPER_KEY)


# API Requests

In [None]:
def convertResponseItemsToDataframe(items):
    
    # Create DataFrame
    schema = StructType([
        StructField("id", StringType(), True),
        StructField("textDisplay", StringType(), True),
        StructField("publishedAt", StringType(), True),
        StructField("likeCount", StringType(), True)
    ])

    df = spark.createDataFrame([
        (
            d['id'], 
            d['snippet']['topLevelComment']['snippet']['textDisplay'],
            d['snippet']['topLevelComment']['snippet']['publishedAt'],
            d['snippet']['topLevelComment']['snippet']['likeCount'],
        )
            for d in items],
        schema=schema
    )
    return df


In [None]:
def getCommentThreads(videoId, commentCount, batchSize=100):
    requestCount = math.ceil((commentCount - batchSize)/batchSize)
    # Get first batch
    requestInit = youtube.commentThreads().list(
        part="id,snippet",
        maxResults=batchSize,
        moderationStatus="published",
        order="relevance",
        videoId=videoId
    )
    responseInit = requestInit.execute()
    df = convertResponseItemsToDataframe(responseInit['items'])
    df = df.withColumn("videoId", lit(videoId))


    #get subsequent requests
    i=1
    if commentCount > batchSize:
        previousResponseToken = responseInit['nextPageToken']
        while i <= requestCount:
            request = youtube.commentThreads().list(
                part="id,snippet",
                maxResults=batchSize,
                moderationStatus="published",
                order="relevance",
                pageToken=previousResponseToken,
                videoId=videoId
            )
            response = request.execute()
            new_df = convertResponseItemsToDataframe(response['items'])
            new_df = new_df.withColumn("videoId", lit(videoId))
            df = df.union(new_df)

            if 'nextPageToken' in response:
                previousResponseToken = response['nextPageToken']
                i=i+1
            else:
                i=requestCount+1

    return df

In [None]:
videoIds = spark.sql("SELECT Id FROM Raw.videos").rdd.flatMap(lambda x: x).collect()

schema = StructType([
        StructField("id", StringType(), True),
        StructField("textDisplay", StringType(), True),
        StructField("publishedAt", StringType(), True),
        StructField("likeCount", StringType(), True),
        StructField("videoId", StringType(), True)
])

df = spark.createDataFrame(
    [],
    schema=schema
)

for videoId in videoIds:
    print(f"Getting comments for video: {videoId}")
    newDf = getCommentThreads(videoId,10000)
    print(f"Obtained {newDf.count()} comments for video {videoId}")
    df = df.union(newDf)


### Convert data types

In [None]:
df = df.withColumn('likeCount',df.likeCount.cast(IntegerType()))

# Write data

In [None]:
df.write.mode("overwrite").format("delta").option("overwriteSchema", "true").saveAsTable("Raw.comments")