## About

Proper about me can be made later
Classifier that classifies what niche category a certain reddit post falls into.


### To-do/Ideas for the future.
- Need to find/determine a workflow that cleans all the data that we scrape/get from reddit via PRAW.
- Can use LLMs for data-augmentation as well, not just weak supervision. I.e, we can pass our actual existing reddit posts' data into an LLM to give it some ideas and show it some inspiration, and use that to get it to generate more reddit stories that are likely to be viral within a specific chosen niche of our choice.
    - Additionally, instead of just passing good known stories into a general-purpose LLM (like Gemini or GPT-based LLMs) like we are right now, we could train or fine-tune a domain-specific LLM that is dedicated for this task (generating reddit posts within a specific niche that are likely to go viral).

First, we need to collect data.
There aren't many very good datasets, so we need to create our own.
This will be done through data scraping via PRAW and weak supervision via a chosen LLM (I am using Gemini for this).

First, scraping data via PRAW.

In [25]:
# Install all required dependencies

%pip install -r requirements.txt --user # --user flag is needed because one of the dependencies (google-genai) needs to access a script that is hidden in non-administrator environments.

Collecting psaw (from -r requirements.txt (line 2))
  Downloading psaw-0.1.0-py3-none-any.whl.metadata (10 kB)
[31mERROR: Could not find a version that satisfies the requirement distutils (from versions: none)[0m[31m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip3 install --upgrade pip[0m
[31mERROR: No matching distribution found for distutils[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.


In [10]:
#!which pip

In [1]:
# Make your necessary imports
import praw
import pandas as pd
import time
from google import genai
import numpy as np

In [2]:
# Initialize reddit client session

CLIENT_ID = "0xeiOSktNDiHBw"
CLIENT_SECRET = "c-bNB_P5wRjHZmaD1eaJnx0D3mlr8Q"
USER_AGENT = "sestee 1.0"
cli = praw.Reddit(
        client_id=CLIENT_ID,
        client_secret=CLIENT_SECRET,
        user_agent=USER_AGENT
)


### Scraping Data
We need to get some reddit posts.
We can choose between top or hot.
This issue isn't inherently an AI/ML issue, but more of a data issue. The reason I am mentioning this issue is because its extremely relevant to what we're doing (since it directly involves our source of data). 
You are likely to run into data issues like this and need to be ready to make decisions like this yourself too.
#### Hot
- Hot is very good up-to-minute viral content, but due to Reddit API limitations (**not** PRAW limitations), hot can only return upto 1000 posts. And for certain subreddits, such as r/relationships, it returns much fewer posts than 1000 posts sometimes, like 200-400. 
- You can overcome this by scraping from multiple subreddits to maximise total subreddit amount count.
#### Top
- Top gets more sustainably viral content, content that isn't necessarily likely to be up-to-the-minute, but will still hit and is more likely to achieve steady virality.
- It returns a lot more posts in general from single queries (not likely to only get 200-400 posts) due to how Reddit differs in its way of calculating whether a post is considered "top" vs "hot".
- It also allows you to use time_filters (like day, week, month, year), which essentially refers to the date range among which you want to scan/look for top rated posts. 
    - You do not have this option with the hot option.

##### Verdict
- I am going to go with top, as the usage of time_filters allows us to retrieve so many more posts than hot, and also allows us to prioritize subreddits that I would consider more relevant.
- You are free to use either option, as I will write a function for both.


In [44]:

# Function to scrape posts from reddit using "top" option.
def scrape_popular_posts(subreddits, time_filters=["day", "week", "month", "year", "all"], limit=None, sort_by="top"):
    all_posts = []
    for sub_name in subreddits:
        subreddit = cli.subreddit(sub_name)
        for tf in time_filters:
            print(f"Fetching r/{sub_name} - top ({tf})")
            for post in subreddit.top(time_filter=tf, limit=limit):
                post_data = {
                    "title": post.title,
                    "selftext": post.selftext,
                    "subreddit": post.subreddit.display_name,
                    "flair": post.link_flair_text,
                    "score": post.score,
                    "num_comments": post.num_comments,
                    "upvote_ratio": post.upvote_ratio,
                    "created_utc": post.created_utc,
                    "id": post.id,
                    "url": post.url
                }
                all_posts.append(post_data)

    return all_posts
# Tends to fetch ~17k posts.

In [23]:
# Ignore this cell for now.
# Function to scrape posts from reddit using "top" option.
def scrape_popular_posts(subreddits, limit=None, sort_by="top"):
    posts = []
    
    for sub_name in subreddits:
        subreddit = cli.subreddit(sub_name)
        count = 0
        if sort_by == "top":
            submissions = subreddit.top(time_filter="all", limit=limit)
        elif sort_by == "hot":
            submissions = subreddit.hot(limit=limit)
        elif sort_by == "new":
            submissions = subreddit.new(limit=limit)
        else:
            raise ValueError("Invalid sort_by value. Use 'top', 'hot', or 'new'.")
        
        for post in submissions:
            count += 1
            post_data = {
                "title": post.title,
                "selftext": post.selftext, # For reference, selftext is the ACTUAL body text of the post
                "subreddit": post.subreddit.display_name,
                "flair": post.link_flair_text,
                "score": post.score,
                "num_comments": post.num_comments,
                "upvote_ratio": post.upvote_ratio,
                "created_utc": post.created_utc,
                "id": post.id,
                "url": post.url
            }
            posts.append(post_data)
        print(sub_name, count)
    return posts

In [35]:
# Function to scrape posts from reddit using "hot" option.
def scrape_hot_posts(subreddits, limit=1000):
    posts = []
    for sub_name in subreddits:
        subreddit = cli.subreddit(sub_name)
        count = 0
        for post in subreddit.hot(limit=limit):
            count += 1
            post_data = {
                "title": post.title,
                "selftext": post.selftext,
                "subreddit": post.subreddit.display_name,
                "flair": post.link_flair_text,
                "score": post.score,
                "num_comments": post.num_comments,
                "upvote_ratio": post.upvote_ratio,
                "created_utc": post.created_utc,
                "id": post.id,
                "url": post.url
            }
            posts.append(post_data)
        print(f"{sub_name}: Retrieved {count} hot posts")
    return posts
# Tends to fetch ~10K posts.


In [45]:
# Figure out what subreddits you want to scrape from
subreddits = ["AskReddit", "relationships", "AmItheAsshole", "TrueOffMyChest", "TIFU"]
subreddits_for_hot = ["AskReddit", "relationships", "AmItheAsshole", "TrueOffMyChest", "TIFU", "confession", "offmychest", "dating_advice", "TwoHotTakes", "relationship_advice", "UnpopularOpinion", "PettyRevenge", "prorevenge", "AITAH", "cheating_stories", "breakups"] # subreddits to use if scraping reddit posts using the "hot" option.

# Scrape the data from the subreddits
data = scrape_popular_posts(subreddits, limit=None, sort_by="top")
#data = scrape_hot_posts(subreddits_for_hot, limit=1000) 

# Save the data in a pandas dataframe
df = pd.DataFrame(data)

# Can save the dataframe to a CSV file too!
df.to_csv("reddit_posts.csv", index=False)

df["niche"] = None # Adding a new column to the dataframe for the niche

# Display the first few rows of the dataframe
df.head(20)

Fetching r/AskReddit - top (day)
Fetching r/AskReddit - top (week)
Fetching r/AskReddit - top (month)
Fetching r/AskReddit - top (year)
Fetching r/AskReddit - top (all)
Fetching r/relationships - top (day)
Fetching r/relationships - top (week)
Fetching r/relationships - top (month)
Fetching r/relationships - top (year)
Fetching r/relationships - top (all)
Fetching r/AmItheAsshole - top (day)
Fetching r/AmItheAsshole - top (week)
Fetching r/AmItheAsshole - top (month)
Fetching r/AmItheAsshole - top (year)
Fetching r/AmItheAsshole - top (all)
Fetching r/TrueOffMyChest - top (day)
Fetching r/TrueOffMyChest - top (week)
Fetching r/TrueOffMyChest - top (month)
Fetching r/TrueOffMyChest - top (year)
Fetching r/TrueOffMyChest - top (all)
Fetching r/TIFU - top (day)
Fetching r/TIFU - top (week)
Fetching r/TIFU - top (month)
Fetching r/TIFU - top (year)
Fetching r/TIFU - top (all)


Unnamed: 0,title,selftext,subreddit,flair,score,num_comments,upvote_ratio,created_utc,id,url,niche
0,"People over 35, what's something you genuinely...",,AskReddit,,7620,9723,0.94,1747601000.0,1kptz1u,https://www.reddit.com/r/AskReddit/comments/1k...,
1,What’s the worst city you’ve ever visited?,,AskReddit,,5263,7335,0.92,1747580000.0,1kplv0m,https://www.reddit.com/r/AskReddit/comments/1k...,
2,What is the most surreal “this can’t be real” ...,,AskReddit,,4457,2807,0.96,1747594000.0,1kpr4d4,https://www.reddit.com/r/AskReddit/comments/1k...,
3,What's the grossest thing you've seen someone ...,,AskReddit,,3572,2778,0.92,1747575000.0,1kpjupd,https://www.reddit.com/r/AskReddit/comments/1k...,
4,What was a don’t get paid enough for this sh*t...,,AskReddit,,2836,430,0.97,1747592000.0,1kpqi0f,https://www.reddit.com/r/AskReddit/comments/1k...,
5,(SERIOUS) What’s the worst way you know someon...,,AskReddit,Serious Replies Only,2059,2910,0.91,1747616000.0,1kpz8n7,https://www.reddit.com/r/AskReddit/comments/1k...,
6,Forget elephants in the room. What’s a blue wh...,,AskReddit,,1585,335,0.92,1747617000.0,1kpzhtw,https://www.reddit.com/r/AskReddit/comments/1k...,
7,What's your worst cheating story?,,AskReddit,,1202,848,0.88,1747565000.0,1kph4mb,https://www.reddit.com/r/AskReddit/comments/1k...,
8,"what's the best ""its not a bug it's a feature""...",,AskReddit,,942,223,0.95,1747573000.0,1kpjdmj,https://www.reddit.com/r/AskReddit/comments/1k...,
9,What has become so expensive that it's not wor...,,AskReddit,,1009,1198,0.92,1747601000.0,1kptxsd,https://www.reddit.com/r/AskReddit/comments/1k...,


Now we have a good chunk of data that we need, we need to clean it.
- This means removing any bad records (missing values, etc)

In [47]:
# Cleaning the data
# One of the ways we can clean the data is by removing any rows that have empty string values in the 'selftext' or 'title' columns. 
# If you take a look at the dataframe output above, you'll see this is the case for some of them.
original_entries = len(df)
df = df[df["selftext"].str.strip() != ""] # dropping empty 'selftext' rows
df = df[df["title"].str.strip() != ""] # dropping empty 'title' rows
df = df.drop_duplicates()
new_entries = len(df)
print("Original amount of entries: ", original_entries)
print("Amount of entries removed: ", original_entries - new_entries)
print("Amount of entries after cleaning: ", new_entries)
# As you can see, cleaning your data causes a surprising amount of entries to be removed (check the output of this cell).

#df = df.dropna(subset=["selftext", "title"])  # Drop rows with NaN in 'selftext' or 'title'

# Now if we inspect the dataframe you'll see it doesn't have empty strings anymore at all.
df.head(20)

Original amount of entries:  18487
Amount of entries removed:  5420
Amount of entries after cleaning:  13067


Unnamed: 0,title,selftext,subreddit,flair,score,num_comments,upvote_ratio,created_utc,id,url,niche
3992,Stan Lee has passed away at 95 years old,As many of you know today is day that many of ...,AskReddit,Breaking News,175368,27635,0.87,1542052000.0,9whgf4,https://www.reddit.com/r/AskReddit/comments/9w...,
4018,Professor Stephen Hawking has passed away at t...,We have lost one of the greatest minds in hist...,AskReddit,Breaking News,117199,2688,0.84,1521002000.0,84anfy,https://www.reddit.com/r/AskReddit/comments/84...,
4035,Suicide Prevention Megathread,With the news today of the passing of the amaz...,AskReddit,Modpost,104344,15803,0.82,1528472000.0,8pks1u,https://www.reddit.com/r/AskReddit/comments/8p...,
4047,"Ruth Bader Ginsburg, US Supreme Court Justice,...","As many of you know, today [Ruth Bader Ginsbur...",AskReddit,Breaking News,99514,10307,0.82,1600476000.0,ivici8,https://www.reddit.com/r/AskReddit/comments/iv...,
4058,I can’t breathe. Black lives matter.,As the gap of the political divide in our worl...,AskReddit,Modpost,96756,6705,0.79,1591143000.0,gvj9a9,https://www.reddit.com/r/AskReddit/comments/gv...,
4064,How do you feel about ’how would you feel?’ po...,We’ve gotten a lot of feedback on this one. A ...,AskReddit,Modpost,95634,3420,0.82,1617475000.0,mje9y7,https://www.reddit.com/r/AskReddit/comments/mj...,
4069,[Breaking News] Orlando Nightclub mass-shooting.,**Update 3:19PM EST:** Updated links below\n\n...,AskReddit,Breaking News,94451,39241,0.86,1465745000.0,4nqnrm,https://www.reddit.com/r/AskReddit/comments/4n...,
4139,Australian Bushfire Crisis,"In response to breaking and ongoing news, AskR...",AskReddit,Breaking News,84182,5620,0.91,1578685000.0,emvveb,https://www.reddit.com/r/AskReddit/comments/em...,
4145,PSA: You did not win a gift card,"Recently, users in r/AskReddit have received m...",AskReddit,Modpost,83467,0,0.93,1583650000.0,ff8y60,https://www.reddit.com/r/AskReddit/comments/ff...,
4223,Moratorium on questions related to US Politics,"Effective immediately until a further notice, ...",AskReddit,Modpost,77170,2,0.87,1601679000.0,j44ppb,https://www.reddit.com/r/AskReddit/comments/j4...,


Now that we have our data, we will create a pipeline that allows us to label all the data entries and add a "niche" column via weak supervision. All entries will then be classified.
These are the post classification categories we are planning to classify our posts into.

| Label         | Description                                |
|---------------|--------------------------------------------|
| `advice`      | Help-seeking posts, questions, dilemmas    |
| `story`       | Personal anecdotes with a beginning, middle, end |
| `drama`       | High-stakes conflict, betrayal, gossip      |
| `rant`        | Emotional venting or unfiltered frustration |
| `humor`       | Meme-like, comedic, shitpost-style content  |
| `informative` | Tips, how-tos, PSAs, educational content    |
| `confession`  | Vulnerable personal reveals or identity-based confessions |
| `unknown`     | Doesn’t fit confidently into other categories|

Note: We can use the `unknown` category to find the biggest weaknesses of our LLM, and we can then possibly fine-tune our LLM later very efficiently by especially targetting its weaknesses that we've detected here.

In [46]:
# Create an instance of the Google GenAI API client
client = genai.Client(api_key="AIzaSyDSyIBzIJ9yVnXYd6sJaE7oZ0Vqnc4kEPM")
#gemini-2.0-flash is also a really good option, but does have lower RPD and other dimension limits.
model = "gemma-3-27b-it" # There are a LOT of models to choose from. But in my experience, I feel comfortable with AND use 2.0-flash the most. Will look into 2.5 series once they go through stable release.
# the gemma 3 model here can process 10K+ requests a day, which is really good for this 
# specific contex because, as you saw, our dataset has 10K entries, which equates to 10K requests for this dataset.

template_prompt = f"""I want to train a transformer-based classifer that takes in the text of a reddit post and then classifes them into labels [personal advice, story, drama]. I only have a partial dataset for this. Can you help fill the rest for me?
It should JUST classify the post into one niche category. The niche categories I want you to choose from are [advice, story, drama, rant, humor, informative, confession, unknown]. unknown is for when you really are not sure what category the post belongs to.
I don't want anything else in your response aside from the 1-word niche category. I don't want any explanations or anything else. Just the 1-word niche category.
Here is the post's data:

"""

# This above is the main template prompt that will be used with the rest of the reddit post data to create full proper prompts for every single reddit post data entry that we will classify via the API.


In [16]:
# Store the name of the file thats going to contain the full, labelled dataset.
data_filename = "reddit_posts_with_niches_hot_large.csv"

In [52]:
# Pipeline to classify each one of the posts, making a call to the API and using the full prompt we made to get the response that contains the niche category we want.
# We are essentially using an LLM to label every single of our data points (reddit posts) with the correct, relevant niche category.
# This technique is called weak supervision.
# This usage of an LLM is purely for creating and generating the dataset. 
count = 0
for index, row in df.iterrows():
    post_data_prompt = f"Title: {row['title']}\nSelftext: {row['selftext']}\n\n"
    print("Post that will be classified:")
    print(f"Title: {row['title']}")
    print(f"Body text: {row['selftext']}")
    print("Classifying the post...")
    prompt = template_prompt + post_data_prompt
    try:
        response = client.models.generate_content(
            model=model, contents=prompt
        )
    except Exception as e:
        print(e)
        time.sleep(61)
        continue
    model_niche_guess = response.text
    count += 1
    # It is possible that the model will give NO response (so response.text is None) because our prompt may contain NSFW language (outside our control). 
    # In this case we have to either set the niche to "unknown" or skip the post. I prefer to set it to unknown because it is a valid category still.
    if model_niche_guess is None:
        print("Model returned no response. Setting niche to 'unknown'.")
        model_niche_guess = "unknown"
    print("MODEL'S GUESS:", model_niche_guess + "\n")
    print(f"{count} posts have now been classified! {((count)/new_entries)*100:.2f}% done.")
    time.sleep(5)  # Sleep for 5 seconds to avoid hitting googles rpm limit
    # Now we need to add the model's guess to the dataframe
    df.loc[index, "niche"] = model_niche_guess
df.to_csv("reddit_posts_with_niches_large.csv", index=False)  # Save the dataframe with the new column to a CSV file

# Please do NOT run this cell unless you are happy to spend hours on having your data labelled. This process takes a VERY long time, especially with a 10K+ dataset.

Post that will be classified:
Title: Stan Lee has passed away at 95 years old
Body text: As many of you know today is day that many of us have dreaded. Stan Lee has passed away at the age of 95. He leaves behind a legacy of superheroes and stories that have touched many people's lives for decades. We wanted to make this thread to honor and remember this wonderful man, so please use it discuss his life, his work, [his cameos](https://thumbs.gfycat.com/RapidClearDungenesscrab-small.gif), etc and what they meant to you. 

Excelsior!

-The AskReddit mods
Classifying the post...
MODEL'S GUESS: informative

1 posts have now been classified! 0.01% done.
Post that will be classified:
Title: Professor Stephen Hawking has passed away at the age of 76
Body text: We have lost one of the greatest minds in history today as Professor Stephen William Hawking has passed away on March 14, 2018 at the age of 76.

It is a terrible loss and we wanted to create this thread for people to share their thoughts

KeyboardInterrupt: 

And as you can see, models can get overloaded too. Only so optimistic we can be with Google's LLMs models sometimes. (and free AI services in general).

In [15]:
# Save the dataframe with the new column to a CSV file
df.to_csv("reddit_posts_with_niches_hot.csv", index=False) 

In [16]:
# This is an example of a model REFUSING to generate a response, because it detected that the content that was passed in was explicit/NSFW.

prompt_temp = template_prompt + """Title: TIFU by thinking a woman was a boy, and groping her boob. (kind of NSFW, though it happened at work)
Body text: Obligatory this actually happened a little over a year ago, and throwaway because I don't want people on my main account to know what I do for a living.

So, I work for the TSA, and have for a few years now. It's a good job overall. I'm underpaid, but the benefits are nice, and I get overtime when I want it.

A little over a year ago, during the week leading up to Christmas, we had some really bad weather that delayed all the flights. I volunteered to stay late so that my coworkers could go home to their families. Most of the work was done anyway, so it was mostly just standing around waiting for the odd latecomer

I was working the AIT (the space tube thingy), when three passengers came up together, a middle-aged man, a middle-aged woman, and a teenage boy. I figure it's a family traveling together for the holidays, and go about my work.

Mom goes through, all is fine. Dad goes through, all is fine.

Kid comes up, I get a good look at him. Hoodie, sweatpants, shortish hair, smooth face. I figure he's about 13, maybe 14.

I hit the button, direct him to wait with me for a moment, and then gesture to the screen, which lit up on his chest area.

I tell him that I have to pat that area down. He's a little nervous, I figure that because he's so young, this is probably his first time getting a pat down, but he says okay, and I start the patdown.

I do the left side of the chest, and feel some moob, which catches me off guard because he didn't look chubby at all.

I move to the right side of the chest, read what's on the hoodie, and it all clicks at once. The hoodie has the name of the local college on it. This is an adult, not a child. He's not wearing sweatpants, \*she\* is wearing yoga pants. She doesn't even know the couple that just came through.

I look at her face, which is bright red, my hand is still on her boob, and I pull it back like I just got bit by a snake.

I immediately call for my supervisor, who comes over and asks what's wrong, and I explain the situation to her.

My supervisor covers her mouth, and at first I thought she was absolutely mortified, but then I realized she's trying not to laugh.

She takes a minute to pull herself together, tells me to go take a break, and finishes screening the passenger herself.

Once that was done, I apologize to the passenger, she tells me it's fine, that it wasn't the first time she was mistaken for a boy, and she probably should have said something before I started touching her. I leave her alone, and go talk to my supervisor to figure out exactly how fired I am.

She tells me to calm down, that it was just an honest mistake, and that she has my back if the passenger files an official complaint, but that probably won't happen, and I shouldn't be worried.

That reassured me a little, but I still groped a woman and ruined Christmas, so I feel like an absolute monster.

I swallow my shame, and finish my shift, then I go into the airport proper to find some food, because I just finished a twelve hour shift and there's no way I have the energy to cook dinner.

I saw my hapless victim sitting at her gate, waiting for her flight. I went up to her to apologize again, and saw that the flight had been delayed until morning (it was about eleven at night).

I apologize again, she says it's fine, and I ask her if she's planning to stay the whole night. She says she has to, all the hotels in the area are book.

I tell her that I'm getting some dinner, and offer to get her some food as well. After all, I already got to second base, I think it's only fair that I buy her dinner.

She agrees, and we go to one of the restaurants that is open late, get some food, and start eating.

She said she gets mistaken for a boy a lot, and it's not a big deal. I told her about how I had long hair and no beard in college, and at the gym people would frequently walk into the men's bathroom, see me, and do a double take to make sure they didn't walk into the ladies' room.

She laughed, and we ended up talking for a few hours, before I finally told her that I had to get home, and apologized again for the accidental molestation.

She said that all is forgiven, if I promise to take her on a real date when she gets back.

I agreed, she gave me her phone number, and I went home, and immediately started texting her. We kept talking until her flight finally left, and when she got back I picked her up at the airport, and a few days later took her on that date that I promised her.

We just celebrated our one year anniversary.

She has long hair now.

&#x200B;

tl;dr: Thought an adult woman was a teenage boy, touched her on the boob, everything worked out better than expected."""

response = client.models.generate_content(
        model=model, contents=prompt_temp
    )
print(response)  


  prompt_temp = template_prompt + """Title: TIFU by thinking a woman was a boy, and groping her boob. (kind of NSFW, though it happened at work)


candidates=None create_time=None response_id=None model_version='gemini-2.0-flash' prompt_feedback=GenerateContentResponsePromptFeedback(block_reason=<BlockedReason.PROHIBITED_CONTENT: 'PROHIBITED_CONTENT'>, block_reason_message=None, safety_ratings=None) usage_metadata=GenerateContentResponseUsageMetadata(cache_tokens_details=None, cached_content_token_count=None, candidates_token_count=None, candidates_tokens_details=None, prompt_token_count=1339, prompt_tokens_details=[ModalityTokenCount(modality=<MediaModality.TEXT: 'TEXT'>, token_count=1339)], thoughts_token_count=None, tool_use_prompt_token_count=None, tool_use_prompt_tokens_details=None, total_token_count=1339, traffic_type=None) automatic_function_calling_history=[] parsed=None


We now have cleaned, labelled data.
We can now proceed to create and train our model.
First we need to choose a model architecture and create our model, before we actually start training it.

In [15]:
# Import relevant AI/ML libraries
import tensorflow as tf

2025-05-17 22:10:04.938804: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [17]:

# Now we need to create arrays storing our features and possible labels (column names and categories respectively).
COLUMN_NAMES = ["title", "selftext", "subreddit", "flair", "score", "num_comments", "upvote_ratio", "created_utc", "id", "url"]
CATEGORIES = ["advice", "story", "drama", "rant", "humor", "informative", "confession", "unknown"]
# data_filename contains the training and testing data. We will use the first 8000 entries for training and the rest for testing to evaluate our model.


In [None]:
# Now we have to create an input function
# This function is used to create/reorganize our data into a format that can be used by the model for training or testing.
def input_fn(features, labels, training=True, batch_size=256):
    # Convert the inputs to a Dataset
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels)) # This creates the dataset from the features and labels in TF's internal format.
   
    # If we're in training mode, we need to shuffle the data around and repeat it a couple times too, so that the model doesn't just learn the order of the data.
    if training:
        dataset = dataset.shuffle(1000).repeat() # shuffle 1000 means we shuffle the data around in a random order, 1000 times over.
   
    # You now batch the data, which is basically where the dataset gets put into groups of a set size, where each batch is a subset of the dataset.
    # Each batch contains batch_size number of samples/examples.
    # This is done to speed up the training process, because it allows the model to process multiple samples at once.
    # The batch size is a hyperparameter that you can tune to find the best value for your model.
    return dataset.batch(batch_size)
