## About

Proper about me can be made later
Classifier that classifies what niche category a certain reddit post falls into.


### To-do/Ideas for the future.
- Need to find/determine a workflow that cleans all the data that we scrape/get from reddit via PRAW.
- Can use LLMs for data-augmentation as well, not just weak supervision. I.e, we can pass our actual existing reddit posts' data into an LLM to give it some ideas and show it some inspiration, and use that to get it to generate more reddit stories that are likely to be viral within a specific chosen niche of our choice.
    - Additionally, instead of just passing good known stories into a general-purpose LLM (like Gemini or GPT-based LLMs) like we are right now, we could train or fine-tune a domain-specific LLM that is dedicated for this task (generating reddit posts within a specific niche that are likely to go viral).

First, we need to collect data.
There aren't many very good datasets, so we need to create our own.
This will be done through data scraping via PRAW and weak supervision via a chosen LLM (I am using Gemini for this).

First, scraping data via PRAW.

In [None]:
# Install all required dependencies

%pip install -r requirements.txt --user # --user flag is needed because one of the dependencies (google-genai) needs to access a script that is hidden in non-administrator environments.

In [20]:
# Make your necessary imports
import praw
import pandas as pd
import time
from google import genai
import numpy as np

In [None]:
# Initialize reddit client session

CLIENT_ID = "0xeiOSktNDiHBw"
CLIENT_SECRET = "c-bNB_P5wRjHZmaD1eaJnx0D3mlr8Q"
USER_AGENT = "sestee 1.0"
cli = praw.Reddit(
        client_id=CLIENT_ID,
        client_secret=CLIENT_SECRET,
        user_agent=USER_AGENT
)


In [None]:
# Declare a way for you to scrape posts from a subreddit of your choice.
def scrape_popular_posts(subreddits, limit=100, sort_by="top"):
    posts = []

    for sub_name in subreddits:
        subreddit = cli.subreddit(sub_name)

        if sort_by == "top":
            submissions = subreddit.top(limit=limit)
        elif sort_by == "hot":
            submissions = subreddit.hot(limit=limit)
        elif sort_by == "new":
            submissions = subreddit.new(limit=limit)
        else:
            raise ValueError("Invalid sort_by value. Use 'top', 'hot', or 'new'.")

        for post in submissions:
            post_data = {
                "title": post.title,
                "selftext": post.selftext, # For reference, selftext is the ACTUAL body text of the post
                "subreddit": post.subreddit.display_name,
                "flair": post.link_flair_text,
                "score": post.score,
                "num_comments": post.num_comments,
                "upvote_ratio": post.upvote_ratio,
                "created_utc": post.created_utc,
                "id": post.id,
                "url": post.url
            }
            posts.append(post_data)

    return posts

In [27]:
# Figure out what subreddits you want to scrape from
subreddits = ["AskReddit", "relationships", "AmItheAsshole", "TrueOffMyChest", "TIFU"]
# Scrape the data from the subreddits
data = scrape_popular_posts(subreddits, limit=50, sort_by="top")
# Save the data in a pandas dataframe
df = pd.DataFrame(data)

# Can save the dataframe to a CSV file too!
#df.to_csv("reddit_posts.csv", index=False)

# Display the first few rows of the dataframe
df.head(200)

Unnamed: 0,title,selftext,subreddit,flair,score,num_comments,upvote_ratio,created_utc,id,url
0,"People who haven't pooped in 2019 yet, why are...",,AskReddit,,221991,7925,0.91,1.546377e+09,ablzuq,https://www.reddit.com/r/AskReddit/comments/ab...
1,How would you feel about Reddit adding 3 NSFW ...,,AskReddit,,217920,2886,0.87,1.611860e+09,l7530r,https://www.reddit.com/r/AskReddit/comments/l7...
2,Would you watch a show where a billionaire CEO...,,AskReddit,,197595,13327,0.90,1.581069e+09,f08dxb,https://www.reddit.com/r/AskReddit/comments/f0...
3,"What if God came down one day and said ""It's p...",,AskReddit,,195912,10227,0.92,1.600611e+09,iwedc5,https://www.reddit.com/r/AskReddit/comments/iw...
4,How would you feel about a feature where if so...,,AskReddit,,186435,2772,0.90,1.572833e+09,draola,https://www.reddit.com/r/AskReddit/comments/dr...
...,...,...,...,...,...,...,...,...,...,...
195,Everytime somebody messes up at a fast food jo...,,TrueOffMyChest,,40517,1845,0.88,1.600715e+09,ix6r5j,https://www.reddit.com/r/TrueOffMyChest/commen...
196,I've been lying to my husband for the past 8 y...,"I met my husband on April 1, 2011. When his bi...",TrueOffMyChest,,40387,1271,0.89,1.583271e+09,fd1sej,https://www.reddit.com/r/TrueOffMyChest/commen...
197,If you've declawed your cat or debarked your d...,Declawing a cat is equivalant to cutting off y...,TrueOffMyChest,,40130,3982,0.79,1.600521e+09,ivs91v,https://www.reddit.com/r/TrueOffMyChest/commen...
198,I’m not getting my kid anything for Christmas.,UPDATE- I had several one on one talks with hi...,TrueOffMyChest,,39933,7538,0.89,1.639715e+09,ri8z2y,https://www.reddit.com/r/TrueOffMyChest/commen...


Now we have a good chunk of all the data that we need. We need to clean it.

In [None]:
# Cleaning the data
# Check and do later
#df = df.dropna(subset=["selftext", "title"])  # Drop rows with NaN in 'selftext' or 'title'
#df = df[df["selftext"].str.strip() != ""]  # Drop empty 'selftext' rows

Now that we have our data, we will create a pipeline that allows us to label all the data entries and add a "niche" column via weak supervision. All entries will then be classified.
These are the post classification categories we are planning to classify our posts into.

| Label         | Description                                |
|---------------|--------------------------------------------|
| `advice`      | Help-seeking posts, questions, dilemmas    |
| `story`       | Personal anecdotes with a beginning, middle, end |
| `drama`       | High-stakes conflict, betrayal, gossip      |
| `rant`        | Emotional venting or unfiltered frustration |
| `humor`       | Meme-like, comedic, shitpost-style content  |
| `informative` | Tips, how-tos, PSAs, educational content    |
| `confession`  | Vulnerable personal reveals or identity-based confessions |
| `unknown`     | Doesn’t fit confidently into other categories|

Note: We can use the `unknown` category to find the biggest weaknesses of our LLM, and we can then possibly fine-tune our LLM later very efficiently by especially targetting its weaknesses that we've detected here.

In [22]:
# Create an instance of the Google GenAI API client
client = genai.Client(api_key="AIzaSyDSyIBzIJ9yVnXYd6sJaE7oZ0Vqnc4kEPM")
model = "gemini-2.0-flash" # There are a LOT of models to choose from. But in my experience, I feel comfortable with AND use 2.0-flash the most. Will look into 2.5 series once they go through stable release.
template_prompt = f"""I want to train a transformer-based classifer that takes in the text of a reddit post and then classifes them into labels [personal advice, story, drama]. I only have a partial dataset for this. Can you help fill the rest for me?
It should JUST classify the post into one niche category. The niche categories I want you to choose from are [advice, story, drama, rant, humor, informative, confession, unknown]. unknown is for when you really are not sure what category the post belongs to.
I don't want anything else in your response aside from the 1-word niche category. I don't want any explanations or anything else. Just the 1-word niche category.
Here is the post's data:

"""

# This above is the main template prompt that will be used with the rest of the reddit post data to create full proper prompts for every single reddit post data entry that we will classify via the API.


In [31]:
# Pipeline to classify each one of the posts, making a call to the API and using the full prompt we made to get the response that contains the niche category we want.
for index, row in df.iterrows():
    if row["selftext"] == "":
        #print("Skipping empty selftext post.")
        continue
    post_data_prompt = f"Title: {row['title']}\nSelftext: {row['selftext']}\n\n"
    #print("Post that will be classified:")
    print(f"Title: {row['title']}")
    print(f"Body text: {row['selftext']}")
    #print("Classifying the post...")
    prompt = template_prompt + post_data_prompt
    #print(prompt)
    response = client.models.generate_content(
        model=model, contents=prompt
    )
    model_niche_guess = response.text
    print(model_niche_guess + "\n")
    time.sleep(5)  # Sleep for 5 seconds to avoid hitting googles rpm limit
    # Now we need to add the model's guess to the dataframe

Title: Stan Lee has passed away at 95 years old
Body text: As many of you know today is day that many of us have dreaded. Stan Lee has passed away at the age of 95. He leaves behind a legacy of superheroes and stories that have touched many people's lives for decades. We wanted to make this thread to honor and remember this wonderful man, so please use it discuss his life, his work, [his cameos](https://thumbs.gfycat.com/RapidClearDungenesscrab-small.gif), etc and what they meant to you. 

Excelsior!

-The AskReddit mods
informative


Title: Without saying what the category is, what are your top five?
Body text:  
advice


Title: Professor Stephen Hawking has passed away at the age of 76
Body text: We have lost one of the greatest minds in history today as Professor Stephen William Hawking has passed away on March 14, 2018 at the age of 76.

It is a terrible loss and we wanted to create this thread for people to share their thoughts about Professor Hawking, from favorite quotes, to the

KeyboardInterrupt: 

In [None]:
# Add the model's guess to the dataframe