# 1. Initial setting

In [1]:
import os
import datetime as dt
import json
import textwrap
import logging

from IPython.display import Markdown
import polars as pl
from praw import Reddit
from pydantic import BaseModel, TypeAdapter, AwareDatetime

# Extract data from Reddit

We use PRAW https://pypi.org/project/praw/, "Python Reddit API Wrapper".

Login to Reddit and create an app at https://www.reddit.com/prefs/apps

![Reddit app](img/reddit-app.png)

(Image from https://www.jcchouinard.com/get-reddit-api-credentials-with-praw/)

In [2]:
%load_ext dotenv

In [3]:
%dotenv

In [4]:
reddit_client_id = os.environ["REDDIT_CLIENT_ID"]
reddit_client_secret = os.environ["REDDIT_SECRET"]
reddit_username = os.environ["REDDIT_USERNAME"]  # Just for assembling the user agent

In [5]:
# Read-only Reddit connection https://praw.readthedocs.io/en/stable/getting_started/quick_start.html#read-only-reddit-instances
reddit = Reddit(
    client_id=reddit_client_id,
    client_secret=reddit_client_secret,
    user_agent=f"a-hole predictor by u/{reddit_username}",
)
reddit.user.me(), reddit.read_only

(None, True)

In [6]:
target_subreddit_name = "r/AmItheAsshole"

In [7]:
subreddit = reddit.subreddit(target_subreddit_name.removeprefix("r/"))
print(subreddit.title)
Markdown(subreddit.description[:1_000] + "...")

Am I the Asshole? 


#Welcome to r/AmITheAsshole!

A catharsis for the frustrated moral philosopher in all of us, and a place to finally find out if you were wrong in a real-world argument that's been bothering you. Tell us about any non-violent conflict you have experienced; give us both sides of the story, and find out if you're right, or you're the asshole.

This is the sub to lay out your  actions and conflicts and get impartial judgment rendered against you.  Were you the asshole in that situation or not? Post should reflect real situations, and abide by the rules below.

After 18 hours, your post will be given a flair representing the final judgment on your matter.  This flair is determined by the subscribers who have both rendered judgment and voted on which judgment is best.  ***The power of the crowd will judge you***.  If your top level comment has the highest number of upvotes in a thread, you will get a flair point. More details are listed in [our FAQ](https://www.reddit.com/r/AmItheAssho...

In [8]:
# https://praw.readthedocs.io/en/v7.7.1/getting_started/logging.html
handler = logging.StreamHandler()
handler.setLevel(logging.DEBUG)
for logger_name in ("praw", "prawcore"):
    logger = logging.getLogger(logger_name)
    logger.setLevel(logging.DEBUG)
    logger.addHandler(handler)

> "To page through a listing, start by fetching the first page without specifying values for after and count.
> The response will contain an after value which you can pass in the next request.
> It is a good idea, but not required, to send an updated value for count which should be the number of items already fetched.

https://old.reddit.com/dev/api#listings

> We enforce rate limits for those eligible for free access usage of our Data API. The limit is:
> - 100 queries per minute (QPM) per OAuth client id
> QPM limits will be an average over a time window (currently 10 minutes) to support bursting requests.

https://support.reddithelp.com/hc/en-us/articles/16160319875092-Reddit-Data-API-Wiki

In [9]:
submissions = []
for submission in subreddit.new(limit=1_000, params=dict(after="t3_1dwignu")):
    submissions.append(
        {
            "title": submission.title,
            "author_name": submission.author.name if submission.author else None,
            "creation_datetime": dt.datetime.fromtimestamp(
                submission.created_utc, tz=dt.timezone.utc,
            ).isoformat(),
            "subreddit_name": submission.subreddit_name_prefixed,
            "num_comments": submission.num_comments,
            "sfw": not submission.over_18,
            "score": submission.score,
            "upvote_ratio": submission.upvote_ratio,
            "is_self": submission.is_self,
            "permalink": submission.permalink,
            "selftext": submission.selftext,
            "flair_text": submission.link_flair_text,  # Target
        }
    )

Fetching: GET https://oauth.reddit.com/r/AmItheAsshole/new at 1720395504.189213
Data: None
Params: {'after': 't3_1dwignu', 'limit': 1000, 'raw_json': 1}
Response: 200 (114184 bytes) (rst-95:rem-981.0:used-19 ratelimit) at 1720395505.4814022
Fetching: GET https://oauth.reddit.com/r/AmItheAsshole/new at 1720395505.5022268
Data: None
Params: {'after': 't3_1dvv09v', 'limit': 1000, 'raw_json': 1}
Response: 200 (109113 bytes) (rst-94:rem-980.0:used-20 ratelimit) at 1720395507.268327
Fetching: GET https://oauth.reddit.com/r/AmItheAsshole/new at 1720395507.283094
Data: None
Params: {'after': 't3_1dv8tz6', 'limit': 1000, 'raw_json': 1}
Response: 200 (112015 bytes) (rst-92:rem-979.0:used-21 ratelimit) at 1720395508.988001
Fetching: GET https://oauth.reddit.com/r/AmItheAsshole/new at 1720395509.003242
Data: None
Params: {'after': 't3_1duoubx', 'limit': 1000, 'raw_json': 1}
Response: 200 (108062 bytes) (rst-90:rem-978.0:used-22 ratelimit) at 1720395510.574493
Fetching: GET https://oauth.reddit.com

In [10]:
len(submissions)

668

In [11]:
print("\n".join(textwrap.wrap((json.dumps(submissions)[:1_000] + "..."), width=70)))

[{"title": "AITA for kicking guests out of a theme park restroom at
12:40 AM", "author_name": "TheHylind", "creation_datetime":
"2024-07-06T05:36:33+00:00", "subreddit_name": "r/AmItheAsshole",
"num_comments": 7, "sfw": true, "score": 1, "upvote_ratio": 1.0,
"is_self": true, "permalink": "/r/AmItheAsshole/comments/1dwif5p/aita_
for_kicking_guests_out_of_a_theme_park/", "selftext": "For context, I
work at a local theme park that, for reasons beyond my comprehension,
closes at midnight. Obviously I can't disclose the name of the park or
anything like that, but it's notably smaller than Disney or even most
Six Flags parks. Most of our guests are either locals that come here
every day, or tourists who came here for another reason and just
wanted to check out what was in the area.\n\nBecause of this, we often
have a lot of stragglers even when we're already closed, which makes
it really frustrating when I'm assigned to clean the bathrooms,
especially because my bathrooms are literally right

Notice the `creation_datetime` is `str`, for easier serialization!

In [12]:
submissions[-1]["creation_datetime"]

'2024-07-01T13:15:44+00:00'

# Model the data

We use Pydantic https://pypi.org/project/pydantic/, a popular Python library for data validation

In [13]:
class RedditSubmission(BaseModel):
    title: str
    author_name: str | None
    creation_datetime: AwareDatetime
    subreddit_name: str
    num_comments: int
    sfw: bool
    score: int
    upvote_ratio: float
    is_self: bool
    permalink: str
    selftext: str | None
    flair_text: str | None

In [14]:
adapter = TypeAdapter(list[RedditSubmission])

In [15]:
objects = adapter.validate_python(submissions)
objects[:5]

[RedditSubmission(title='AITA for kicking guests out of a theme park restroom at 12:40 AM', author_name='TheHylind', creation_datetime=datetime.datetime(2024, 7, 6, 5, 36, 33, tzinfo=TzInfo(UTC)), subreddit_name='r/AmItheAsshole', num_comments=7, sfw=True, score=1, upvote_ratio=1.0, is_self=True, permalink='/r/AmItheAsshole/comments/1dwif5p/aita_for_kicking_guests_out_of_a_theme_park/', selftext='For context, I work at a local theme park that, for reasons beyond my comprehension, closes at midnight. Obviously I can\'t disclose the name of the park or anything like that, but it\'s notably smaller than Disney or even most Six Flags parks. Most of our guests are either locals that come here every day, or tourists who came here for another reason and just wanted to check out what was in the area.\n\nBecause of this, we often have a lot of stragglers even when we\'re already closed, which makes it really frustrating when I\'m assigned to clean the bathrooms, especially because my bathrooms 

Pydantic automatically converted the str datetime to an actual `datetime.datatime` object, as specified in the model:

In [16]:
objects[-1].creation_datetime

datetime.datetime(2024, 7, 1, 13, 15, 44, tzinfo=TzInfo(UTC))

We now use Polars https://pypi.org/project/polars/, a dataframe library with an expressive API and blazing fast performance:

In [17]:
df = (
    pl.from_dicts(objects)
    .with_columns(
        # Otherwise the timezone is lost
        pl.col("creation_datetime").dt.replace_time_zone("UTC"),
    )
)
df.head()

title,author_name,creation_datetime,subreddit_name,num_comments,sfw,score,upvote_ratio,is_self,permalink,selftext,flair_text
str,str,"datetime[μs, UTC]",str,i64,bool,i64,f64,bool,str,str,str
"""AITA for kicking guests out of…","""TheHylind""",2024-07-06 05:36:33 UTC,"""r/AmItheAsshole""",7,True,1,1.0,True,"""/r/AmItheAsshole/comments/1dwi…","""For context, I work at a local…","""Not enough info"""
"""AITA for reporting coworker's …","""Allethiia""",2024-07-06 05:21:10 UTC,"""r/AmItheAsshole""",2,True,3,0.6,True,"""/r/AmItheAsshole/comments/1dwi…","""I (28) have been working at my…","""TL;DR"""
"""AITA for cancelling my birthda…","""Lis_wj""",2024-07-06 05:14:00 UTC,"""r/AmItheAsshole""",6,True,1,0.6,True,"""/r/AmItheAsshole/comments/1dwi…","""I (26F) have been really stres…","""Not the A-hole"""
"""AITA: I told my sister she has…","""dswizzle2""",2024-07-06 05:09:53 UTC,"""r/AmItheAsshole""",14,True,18,0.8,True,"""/r/AmItheAsshole/comments/1dwh…","""For context, I(27F) and my sis…","""Not the A-hole"""
"""WIBTA for calling out my frien…","""gremlinoverlord_420""",2024-07-06 04:57:21 UTC,"""r/AmItheAsshole""",9,True,0,0.33,True,"""/r/AmItheAsshole/comments/1dwh…","""I (f) have gotten fed up with …","""Not the A-hole"""


In [18]:
(
    df["flair_text"]
    .value_counts()
    .sort("count", descending=True)
    .with_columns(
        (pl.col("count") / pl.sum("count") * 100).round(1).alias("rel_pct"),
    )
)

flair_text,count,rel_pct
str,u32,f64
"""Not the A-hole""",480,71.9
"""Asshole""",84,12.6
"""No A-holes here""",32,4.8
"""Everyone Sucks""",26,3.9
"""TL;DR""",18,2.7
…,…,…
,5,0.7
"""Not the A-hole POO Mode""",4,0.6
"""Asshole POO Mode""",3,0.4
"""UPDATE""",1,0.1


Approximately 85 % of the posts have a definitive result.

# Store data

Let's write it to a local Parquet file to avoid having to retrieve the data from the API again:

In [19]:
df.write_parquet("submissions.pq")

# Load raw data and do feature engineering

We load the data again from the same Parquet file:

In [20]:
import polars as pl

df = pl.read_parquet("submissions.pq").sort("creation_datetime", descending=True)
df.head()

title,author_name,creation_datetime,subreddit_name,num_comments,sfw,score,upvote_ratio,is_self,permalink,selftext,flair_text
str,str,"datetime[μs, UTC]",str,i64,bool,i64,f64,bool,str,str,str
"""AITA for kicking guests out of…","""TheHylind""",2024-07-06 05:36:33 UTC,"""r/AmItheAsshole""",7,True,1,1.0,True,"""/r/AmItheAsshole/comments/1dwi…","""For context, I work at a local…","""Not enough info"""
"""AITA for reporting coworker's …","""Allethiia""",2024-07-06 05:21:10 UTC,"""r/AmItheAsshole""",2,True,3,0.6,True,"""/r/AmItheAsshole/comments/1dwi…","""I (28) have been working at my…","""TL;DR"""
"""AITA for cancelling my birthda…","""Lis_wj""",2024-07-06 05:14:00 UTC,"""r/AmItheAsshole""",6,True,1,0.6,True,"""/r/AmItheAsshole/comments/1dwi…","""I (26F) have been really stres…","""Not the A-hole"""
"""AITA: I told my sister she has…","""dswizzle2""",2024-07-06 05:09:53 UTC,"""r/AmItheAsshole""",14,True,18,0.8,True,"""/r/AmItheAsshole/comments/1dwh…","""For context, I(27F) and my sis…","""Not the A-hole"""
"""WIBTA for calling out my frien…","""gremlinoverlord_420""",2024-07-06 04:57:21 UTC,"""r/AmItheAsshole""",9,True,0,0.33,True,"""/r/AmItheAsshole/comments/1dwh…","""I (f) have gotten fed up with …","""Not the A-hole"""


In [21]:
# Two types of posts: AITA and WIBTA https://www.reddit.com/r/AmItheAsshole/wiki/howtopost/
df = (
    df.with_columns(
        pl.col("title").str.extract(r"^(AITA|WIBTA)", 1).alias("post_type"),
        pl.col("selftext").str.len_chars().alias("text_length"),
    )
)
df.head(5)

title,author_name,creation_datetime,subreddit_name,num_comments,sfw,score,upvote_ratio,is_self,permalink,selftext,flair_text,post_type,text_length
str,str,"datetime[μs, UTC]",str,i64,bool,i64,f64,bool,str,str,str,str,u32
"""AITA for kicking guests out of…","""TheHylind""",2024-07-06 05:36:33 UTC,"""r/AmItheAsshole""",7,True,1,1.0,True,"""/r/AmItheAsshole/comments/1dwi…","""For context, I work at a local…","""Not enough info""","""AITA""",1546
"""AITA for reporting coworker's …","""Allethiia""",2024-07-06 05:21:10 UTC,"""r/AmItheAsshole""",2,True,3,0.6,True,"""/r/AmItheAsshole/comments/1dwi…","""I (28) have been working at my…","""TL;DR""","""AITA""",4890
"""AITA for cancelling my birthda…","""Lis_wj""",2024-07-06 05:14:00 UTC,"""r/AmItheAsshole""",6,True,1,0.6,True,"""/r/AmItheAsshole/comments/1dwi…","""I (26F) have been really stres…","""Not the A-hole""","""AITA""",1169
"""AITA: I told my sister she has…","""dswizzle2""",2024-07-06 05:09:53 UTC,"""r/AmItheAsshole""",14,True,18,0.8,True,"""/r/AmItheAsshole/comments/1dwh…","""For context, I(27F) and my sis…","""Not the A-hole""","""AITA""",2996
"""WIBTA for calling out my frien…","""gremlinoverlord_420""",2024-07-06 04:57:21 UTC,"""r/AmItheAsshole""",9,True,0,0.33,True,"""/r/AmItheAsshole/comments/1dwh…","""I (f) have gotten fed up with …","""Not the A-hole""","""WIBTA""",2284


Now let's perform some quick sentiment analysis using NLTK:

In [22]:
import nltk
from nltk.sentiment.vader import SentimentIntensityAnalyzer

nltk.download('vader_lexicon')

[nltk_data] Downloading package vader_lexicon to
[nltk_data]     /Users/juan_cano/nltk_data...
[nltk_data]   Package vader_lexicon is already up-to-date!


True

In [23]:
sia = SentimentIntensityAnalyzer()

In [24]:
for index in range(5):
    print(sia.polarity_scores(df.item(index, "selftext")))

{'neg': 0.076, 'neu': 0.828, 'pos': 0.096, 'compound': 0.7268}
{'neg': 0.087, 'neu': 0.843, 'pos': 0.07, 'compound': -0.9138}
{'neg': 0.071, 'neu': 0.849, 'pos': 0.081, 'compound': 0.2894}
{'neg': 0.136, 'neu': 0.768, 'pos': 0.096, 'compound': -0.9803}
{'neg': 0.153, 'neu': 0.722, 'pos': 0.125, 'compound': -0.9398}


Better to do it by sentence:

In [25]:
sentences = (
    df.with_columns(
        pl.col("selftext").str.split(".").list.eval(pl.element().str.strip_chars()).alias("sentences")
    )
    .select(pl.col("permalink", "sentences"))
    .explode("sentences")
    .with_columns(
        pl.col("sentences").map_elements(
            lambda s: sia.polarity_scores(s),
            return_dtype=pl.Struct({"neg": pl.Float64, "neu": pl.Float64, "pos": pl.Float64, "compound": pl.Float64}),
        ).alias("sentiment_scores"),
    )
)
sentences.head()

permalink,sentences,sentiment_scores
str,str,struct[4]
"""/r/AmItheAsshole/comments/1dwi…","""For context, I work at a local…","{0.0,1.0,0.0,0.0}"
"""/r/AmItheAsshole/comments/1dwi…","""Obviously I can't disclose the…","{0.0,0.929,0.071,0.1901}"
"""/r/AmItheAsshole/comments/1dwi…","""Most of our guests are either …","{0.0,1.0,0.0,0.0}"
"""/r/AmItheAsshole/comments/1dwi…","""Because of this, we often have…","{0.076,0.859,0.064,-0.1263}"
"""/r/AmItheAsshole/comments/1dwi…","""I'll be practically done with …","{0.138,0.769,0.092,-0.4215}"


In [26]:
df = df.join(
    (
        sentences
        .group_by("permalink")
        .agg(
            pl.col("sentiment_scores").struct.field("compound").mean().alias("compound_sentiment"),
        )
    ),
    on="permalink",
    how="left",
)
df.head()

title,author_name,creation_datetime,subreddit_name,num_comments,sfw,score,upvote_ratio,is_self,permalink,selftext,flair_text,post_type,text_length,compound_sentiment
str,str,"datetime[μs, UTC]",str,i64,bool,i64,f64,bool,str,str,str,str,u32,f64
"""AITA for kicking guests out of…","""TheHylind""",2024-07-06 05:36:33 UTC,"""r/AmItheAsshole""",7,True,1,1.0,True,"""/r/AmItheAsshole/comments/1dwi…","""For context, I work at a local…","""Not enough info""","""AITA""",1546,-0.0193
"""AITA for reporting coworker's …","""Allethiia""",2024-07-06 05:21:10 UTC,"""r/AmItheAsshole""",2,True,3,0.6,True,"""/r/AmItheAsshole/comments/1dwi…","""I (28) have been working at my…","""TL;DR""","""AITA""",4890,-0.094389
"""AITA for cancelling my birthda…","""Lis_wj""",2024-07-06 05:14:00 UTC,"""r/AmItheAsshole""",6,True,1,0.6,True,"""/r/AmItheAsshole/comments/1dwi…","""I (26F) have been really stres…","""Not the A-hole""","""AITA""",1169,0.040277
"""AITA: I told my sister she has…","""dswizzle2""",2024-07-06 05:09:53 UTC,"""r/AmItheAsshole""",14,True,18,0.8,True,"""/r/AmItheAsshole/comments/1dwh…","""For context, I(27F) and my sis…","""Not the A-hole""","""AITA""",2996,-0.056525
"""WIBTA for calling out my frien…","""gremlinoverlord_420""",2024-07-06 04:57:21 UTC,"""r/AmItheAsshole""",9,True,0,0.33,True,"""/r/AmItheAsshole/comments/1dwh…","""I (f) have gotten fed up with …","""Not the A-hole""","""WIBTA""",2284,-0.037443


## Predict result

Now, the interesting part: can we create a machine learning model that can predict, _before the vote happens_, who is the a-hole of the story?

Let's start with an extremely simple model and very few features:

In [27]:
from sklearn.pipeline import Pipeline
from sklearn.compose import make_column_transformer
from sklearn.tree import DecisionTreeClassifier
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import GridSearchCV, StratifiedKFold
from sklearn.preprocessing import MinMaxScaler

In [28]:
classifier = Pipeline([
    ("transformer", make_column_transformer(
        ("passthrough", ["text_length"]),
        (MinMaxScaler((-1, 1)), ["compound_sentiment"]),
        (OneHotEncoder(sparse_output=False, handle_unknown="ignore"), ["sfw", "post_type"]),
        remainder="drop",  # Important to avoid accidental data leakage
    ).set_output(transform="polars")),
    ("classifier", DecisionTreeClassifier(class_weight="balanced")),
])
classifier

We will only show the model entries with a definitive result:

In [29]:
df_train = (
    df
    .with_columns(
        pl.when(pl.col("author_name").is_null())
        .then(pl.lit("<UNKNOWN>"))
        .otherwise(pl.col("author_name"))
        .alias("author_name"),
    )
    .filter(pl.col("flair_text").is_in(["Not the A-hole", "Asshole"]))
)
len(df_train)

564

In [30]:
classifier[:-1].fit_transform(df_train).head()

passthrough__text_length,minmaxscaler__compound_sentiment,onehotencoder__sfw_False,onehotencoder__sfw_True,onehotencoder__post_type_AITA,onehotencoder__post_type_WIBTA
u32,f64,f64,f64,f64,f64
1169,0.019746,0.0,1.0,1.0,0.0
2996,-0.081782,0.0,1.0,1.0,0.0
2284,-0.061768,0.0,1.0,0.0,1.0
2710,-0.292353,0.0,1.0,1.0,0.0
1651,-0.470861,0.0,1.0,1.0,0.0


In [31]:
X_train, y_train = df_train.drop("flair_text"), df_train["flair_text"]

cv = GridSearchCV(
    classifier,
    param_grid={
        "classifier__max_depth": range(1, 6),
        "transformer__minmaxscaler": [MinMaxScaler((-1, 1)), "drop"],
    },
    cv=StratifiedKFold(5)
)
cv.fit(X_train, y_train)

In [32]:
results = (
    df_train
    .select(pl.col("upvote_ratio", "post_type", "flair_text"))
    .with_columns(
        pl.Series(name="flair_text_predicted", values=cv.predict(df_train))
    )
)
results

upvote_ratio,post_type,flair_text,flair_text_predicted
f64,str,str,str
0.6,"""AITA""","""Not the A-hole""","""Not the A-hole"""
0.8,"""AITA""","""Not the A-hole""","""Not the A-hole"""
0.33,"""WIBTA""","""Not the A-hole""","""Not the A-hole"""
0.91,"""AITA""","""Not the A-hole""","""Not the A-hole"""
0.75,"""AITA""","""Not the A-hole""","""Not the A-hole"""
…,…,…,…
0.29,"""AITA""","""Not the A-hole""","""Asshole"""
0.92,"""AITA""","""Not the A-hole""","""Not the A-hole"""
0.95,"""AITA""","""Not the A-hole""","""Not the A-hole"""
0.95,"""AITA""","""Not the A-hole""","""Not the A-hole"""


In [33]:
results["flair_text_predicted"].value_counts().sort("count", descending=True)

flair_text_predicted,count
str,u32
"""Not the A-hole""",450
"""Asshole""",114


In [34]:
cv.best_params_

{'classifier__max_depth': 4, 'transformer__minmaxscaler': 'drop'}

In [35]:
importances = pl.DataFrame({
    "feature_names_out": cv.best_estimator_[:-1].get_feature_names_out(),
    "importances": cv.best_estimator_[-1].feature_importances_,
})
importances

feature_names_out,importances
str,f64
"""passthrough__text_length""",0.945209
"""onehotencoder__sfw_False""",0.0
"""onehotencoder__sfw_True""",0.0
"""onehotencoder__post_type_AITA""",0.054791
"""onehotencoder__post_type_WIBTA""",0.0


In [36]:
cv.score(df_train.drop("flair_text"), df_train["flair_text"])

0.7553191489361702

Well, we have a prediction... but there's lots of work to do!

- The score is okay, but probably can be improved
- We only used a handful of features
- We did very little feature engineering
- We would like to try out other models
- What if we have more data?
- How do we keep track of experiments?

We could complicate this notebook more and more, but it's already getting out of hand...

Let's start taking steps towards a more structured approach! ✨