In [None]:
import logging
import os

import sys
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf

from shared_code.utility.spark.set_environ import set_azure_env

set_azure_env()

from shared_code.utility.scripts.reddit_collector import RedditDataCollector

logging.basicConfig(level=logging.INFO)
logging.getLogger("azure.storage").setLevel(logging.WARNING)
logging.getLogger("diffusers").setLevel(logging.WARNING)
logging.getLogger("azure.core").setLevel(logging.WARNING)

os.environ['PYSPARK_PYTHON'] = sys.executable
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable
spark_session = SparkSession \
	.builder \
	.appName("collect-reddit-data") \
	.master("local[12]") \
	.config("spark.cores.max", 12) \
	.config("spark.executor.memory", "10g") \
	.config("spark.driver.memory", "10g") \
	.config("spark.memory.offHeap.enabled", True) \
	.config("spark.memory.offHeap.size", "10g") \
	.getOrCreate()

In [None]:
table_name = "training"
image_out_dir = "/data/images/"

In [None]:
subreddits = "greentext+selfies+Faces+EarthPorn+CityPorn+sfwpetite+SFWNextDoorGirls+SFWRedheads"
subs = subreddits.split("+")

In [None]:
from pyspark.sql.types import IntegerType

collector: RedditDataCollector = RedditDataCollector(image_out_dir=image_out_dir, table_name=table_name)

def download_subreddit_images(subreddit) -> int:
	try:
		count = collector.download_subreddit_images(subreddit)
		return count
	except Exception as e:
		logging.error(f"Error downloading images for {subreddit}: {e}")
		return 0

download_subreddit_images_udf = udf(download_subreddit_images, IntegerType())

In [None]:
logging.getLogger("azure.storage").setLevel(logging.WARNING)

data = []

for sub in subs:
	try:
		num_images_collected = collector.download_subreddit_images(sub)
		result = {
			"subreddit": sub,
			"count": num_images_collected
		}
		data.append(result)
	except Exception as e:
		logging.error(f"Error downloading images for {sub}: {e}")
		continue

spark_session.createDataFrame(data=data).show()