In [237]:
import polars as pl
import src.settings as settings
from typing import Optional
from enum import Enum

In [238]:
COMMENTS_NDJSON = settings.DATA_DIR / 'amitheasshole_comments.ndjson'
COMMENTS_PARQUET = settings.DATA_DIR / 'amitheasshole_comments.parquet'
SUBMISSIONS_NDJSON = settings.DATA_DIR / 'amitheasshole_submissions.ndjson'
SUBMISSIONS_PARQUET = settings.DATA_DIR / 'amitheasshole_submissions.parquet'

class Flairs(Enum):
    YTA = 'Asshole'
    NTA = 'Not the A-hole'
    ESH = 'Everyone Sucks'
    NAH = 'No A-holes here'

SHOULD_PRINT_INFO = True
SHOULD_PRINT_HEAD = True
N_HEAD = 10
USE_TOP_LEVEL_COMMENTS_ONLY = True
SHOULD_SORT_SUBMISSIONS = False # Sort submissions in descending order on the scores. Combined with N_SUBMISSIONS will give only the top N submissions with the highest scores
N_SUBMISSIONS: Optional[int] = None # Limit the number of submissions if not None
SHOULD_MAP_OTHER_FLAIRS = True # Map YTA_LIKE to YTA and NTA_LIKE to NTA
YTA_LIKE = ('YTA', 'YWBTA', 'ESH')
NTA_LIKE = ('NTA', 'YWNBTA', 'NAH')

# (Convert to) and load from parquet into lazy frames

In [239]:
if not SUBMISSIONS_PARQUET.exists():
    submissions_lf = (
        pl.scan_ndjson(
            SUBMISSIONS_NDJSON,
            schema_overrides={
                'edited': pl.Utf8,
            },
        )
        .with_columns(
            pl.when(pl.col('edited').is_null() | (pl.col('edited') == 'false'))
            .then(False)
            .otherwise(True)
            .alias('edited')
        )
        .select(
            'author',
            'edited',
            'link_flair_text',
            'name',
            'num_comments',
            'over_18',
            'score',
            'selftext',
            'title',
            'upvote_ratio',
        )
        .drop_nulls()
        .drop_nans()
    )

    submissions_lf.sink_parquet(SUBMISSIONS_PARQUET, engine='streaming')

submissions_lf = pl.scan_parquet(SUBMISSIONS_PARQUET)

In [240]:
if not COMMENTS_PARQUET.exists():
    comments_lf = (
        pl.scan_ndjson(
            COMMENTS_NDJSON,
            schema_overrides={
                'edited': pl.Utf8,
            },
        )
        .with_columns(
            pl.when(pl.col('edited').is_null() | (pl.col('edited') == 'false'))
            .then(False)
            .otherwise(True)
            .alias('edited')
        )
        .select(
            'author',
            'body',
            'edited',
            'is_submitter',
            'link_id',
            'parent_id',
            'score',
        ) # TODO: controversiality?
        .drop_nulls()
        .drop_nans()
    )

    comments_lf.sink_parquet(COMMENTS_PARQUET, engine='streaming')
    
comments_lf = pl.scan_parquet(COMMENTS_PARQUET)

# Submissions filtering

In [241]:
# FIXME: extract these values to global constant?
filtered_submissions_lf = submissions_lf.filter(
    ~pl.col('over_18')
    & ~pl.col('selftext').is_in(('[deleted]', '[removed]'))
    & ~pl.col('author').is_in(('Judgement_Bot_AITA', 'AutoModerator')) # TODO: often, deleted authors still have post up and not deleted, so don't filter on [deleted]?
    & pl.col('link_flair_text').is_in((Flairs.NTA.value, Flairs.YTA.value))
    & ~pl.col('edited') # TODO: perhaps too aggresive, you might lose interesting posts
)

if SHOULD_PRINT_INFO:
    print(f'row count submissions original: {submissions_lf.select(pl.len()).collect(engine='streaming').item()}')
    print(f'row count submissions filtered: {filtered_submissions_lf.select(pl.len()).collect(engine='streaming').item()}')

row count submissions original: 79223
row count submissions filtered: 43000


In [242]:
selected_submissions_lf = filtered_submissions_lf.sort(pl.col('num_comments'), descending=True) if SHOULD_SORT_SUBMISSIONS else filtered_submissions_lf
selected_submissions_lf = filtered_submissions_lf if N_SUBMISSIONS is None else selected_submissions_lf.limit(N_SUBMISSIONS)

In [243]:
if SHOULD_PRINT_HEAD:
    _ = display(selected_submissions_lf.collect(engine='streaming').head(N_HEAD))

author,edited,link_flair_text,name,num_comments,over_18,score,selftext,title,upvote_ratio
str,bool,str,str,i64,bool,i64,str,str,f64
"""Stock_Pizza_1721""",False,"""Asshole""","""t3_v2fhu4""",32,False,20,"""AITA for telling my mom someth…","""AITA for telling something abo…",0.77
"""NotDobbythrwy123""",False,"""Not the A-hole""","""t3_v2foac""",165,False,933,"""15m here, USA. I'm on the spec…","""AITA For Telling My Mom Chores…",0.94
"""Relevant_Hair_9259""",False,"""Not the A-hole""","""t3_v2fs2w""",11,False,1,"""Apologizes for any spelling mi…","""AITA for not letting my friend…",0.56
"""weddingcakeaita""",False,"""Not the A-hole""","""t3_v2g05m""",344,False,2046,"""Some context: my (26f) dad has…","""AITA for refusing to make my c…",0.97
"""throwaway778548""",False,"""Not the A-hole""","""t3_v2gl84""",754,False,7833,"""I 29f have 1 baby girl 3f and …","""AITA for telling my mil it’s n…",0.96
"""ScrantonStranger""",False,"""Not the A-hole""","""t3_v2gm82""",64,False,120,"""My boyfriend (30) is moving ne…","""AITA for refusing to pack my b…",0.89
"""1234random789""",False,"""Not the A-hole""","""t3_v2gq1z""",103,False,295,"""I’ll be very upfront, I knew w…","""AITA for giving a group of ele…",0.94
"""OutlandishnessNo7623""",False,"""Not the A-hole""","""t3_v2gtfj""",33,False,21,"""My (16) mother (43) has a bf w…","""AITA for asking my mum to brin…",0.84
"""Queenstantinople""",False,"""Not the A-hole""","""t3_v2h8g7""",12,False,4,"""I 25f just rejected my moms re…","""AITA for rejecting my moms req…",0.83
"""joonjooniejoon""",False,"""Not the A-hole""","""t3_v2h9ez""",9,False,1,"""I had this guy that I was talk…","""AITA for not contacting my ""fr…",0.67


# Comments filtering

In [244]:
# FIXME: extract these values to global constant?
filtered_comments_lf = comments_lf.filter(
    ~pl.col('body').is_in(('[deleted]', '[removed]'))
    & ~pl.col('author').is_in(('Judgement_Bot_AITA', 'AutoModerator')) # TODO: is this the same case as for submissions?
    & ~pl.col('edited') # TODO: perhaps too aggresive, you might lose interesting posts
    & ~pl.col('is_submitter')
)

if USE_TOP_LEVEL_COMMENTS_ONLY:
    filtered_comments_lf = filtered_comments_lf.filter(
        pl.col('parent_id').str.starts_with('t3_') # Comments always start with t1_, submissions with t3_, and subreddits with t5_
    )

if SHOULD_PRINT_INFO:
    print(f'row count comments original: {comments_lf.select(pl.len()).collect(engine='streaming').item()}')
    print(f'row count comments filtered: {filtered_comments_lf.select(pl.len()).collect(engine='streaming').item()}')

row count comments original: 14560130
row count comments filtered: 7389891


In [245]:
# TODO: necessary if we take all submissions?
matched_comments_lf = filtered_comments_lf.join(
    selected_submissions_lf,
    left_on='parent_id',
    right_on='name',
    how='semi',
)

# Calculate submission features

In [246]:
submission_features_lf = (
    selected_submissions_lf
    .with_columns(
        ground_truth_majority_vote=(
            pl.when(pl.col('link_flair_text') == 'Asshole').then(pl.lit('YTA'))
            .when(pl.col('link_flair_text') == Flairs.ESH.value).then(pl.lit('ESH'))
            .when(pl.col('link_flair_text') == Flairs.NTA.value).then(pl.lit('NTA'))
            .when(pl.col('link_flair_text') == Flairs.NAH.value).then(pl.lit('NAH'))
            .otherwise(None)
        ),
        text_length=(
            pl.col('title').str.len_chars()
            + pl.col('selftext').str.len_chars()
        ),
    )
    .filter(pl.col('ground_truth_majority_vote').is_not_null())
    .select(
        'name',
        'ground_truth_majority_vote',
        'text_length',
        'link_flair_text',
        'score',
        'upvote_ratio',
    )
)

if SHOULD_MAP_OTHER_FLAIRS:
    submission_features_lf = (
        submission_features_lf
        .with_columns(
            pl.when(pl.col('ground_truth_majority_vote') == 'ESH').then(pl.lit('YTA'))
            .when(pl.col('ground_truth_majority_vote') == 'NAH').then(pl.lit('NTA'))
            .otherwise(pl.col('ground_truth_majority_vote'))
            .alias('ground_truth_majority_vote')
        )
    )

if SHOULD_PRINT_HEAD:
    _ = display(submission_features_lf.collect(engine='streaming').head(N_HEAD))

name,ground_truth_majority_vote,text_length,link_flair_text,score,upvote_ratio
str,str,u32,str,i64,f64
"""t3_v2fhu4""","""YTA""",471,"""Asshole""",20,0.77
"""t3_v2foac""","""NTA""",1607,"""Not the A-hole""",933,0.94
"""t3_v2fs2w""","""NTA""",1545,"""Not the A-hole""",1,0.56
"""t3_v2g05m""","""NTA""",2050,"""Not the A-hole""",2046,0.97
"""t3_v2gl84""","""NTA""",1934,"""Not the A-hole""",7833,0.96
"""t3_v2gm82""","""NTA""",1446,"""Not the A-hole""",120,0.89
"""t3_v2gq1z""","""NTA""",1940,"""Not the A-hole""",295,0.94
"""t3_v2gtfj""","""NTA""",1289,"""Not the A-hole""",21,0.84
"""t3_v2h8g7""","""NTA""",2248,"""Not the A-hole""",4,0.83
"""t3_v2h9ez""","""NTA""",713,"""Not the A-hole""",1,0.67


# Extract vote from each comment

In [247]:
body_lowercase = pl.col('body').str.to_lowercase()
comment_votes_lf = (
    matched_comments_lf
    .filter(~pl.col('body').str.contains(r'\bINFO\b'))
    .with_columns(
        extracted_vote=(
            pl.when(body_lowercase.str.contains(r'\bywbta\b')).then(pl.lit('YWBTA'))
            .when(body_lowercase.str.contains(r'\bywnbta\b')).then(pl.lit('YWNBTA'))
            .when(body_lowercase.str.contains(r'\byta\b')).then(pl.lit('YTA'))
            .when(body_lowercase.str.contains(r'\bnta\b')).then(pl.lit('NTA'))
            .when(body_lowercase.str.contains(r'\besh\b')).then(pl.lit('ESH'))
            .when(body_lowercase.str.contains(r'\bnah\b')).then(pl.lit('NAH'))
            .otherwise(None)
        )
    )
    .filter(pl.col('extracted_vote').is_not_null())
    .select(['link_id', 'score', 'extracted_vote', 'body'])
    .drop_nulls()
    .drop_nans()
)

if SHOULD_MAP_OTHER_FLAIRS:
    comment_votes_lf = comment_votes_lf.with_columns(
        pl.when(pl.col('extracted_vote').is_in(YTA_LIKE)).then(pl.lit('YTA'))
        .when(pl.col('extracted_vote').is_in(NTA_LIKE)).then(pl.lit('NTA'))
        .alias('extracted_vote')
    )

if SHOULD_PRINT_HEAD:
    _ = display(comment_votes_lf.sort('score').collect(engine='streaming').head(N_HEAD))

link_id,score,extracted_vote,body
str,i64,str,str
"""t3_zc6sf7""",-2670,"""YTA""","""I'm going to say ESH Your sis…"
"""t3_zy5buk""",-2309,"""NTA""","""NAH - pain is pain. Suffering…"
"""t3_xmundm""",-1886,"""NTA""","""Do you guys usually start a me…"
"""t3_vq0d0d""",-1712,"""NTA""","""NTA. She lied about it not bei…"
"""t3_wcnh0v""",-1543,"""NTA""","""You're NTA for including your …"
"""t3_yu8tm5""",-1471,"""YTA""","""YTA because once you caught on…"
"""t3_yk6gn2""",-1462,"""NTA""","""NTA because of the fact that i…"
"""t3_vo8n7y""",-1341,"""NTA""","""NTA, but I think that this is …"
"""t3_vq0d0d""",-1184,"""NTA""","""NTA...you don't owe anyone aff…"
"""t3_vq0d0d""",-1069,"""YTA""","""ESH. She's TA for lying but yo…"


# Calculate (weighted) majority vote based on extracted comment votes

In [248]:
per_submission_extracted_majority_votes_lf = (
    comment_votes_lf
    .group_by('link_id')
    .agg(
        # Just use signed ints instead of unsigned for simplicity so that when you subtract these values, you can have negative values
        n_yta=(pl.col('extracted_vote') == 'YTA').cast(pl.Int32).sum(),
        n_nta=(pl.col('extracted_vote') == 'NTA').cast(pl.Int32).sum(),
        weighted_yta=pl.when(pl.col('extracted_vote') == 'YTA').then(pl.col('score')).otherwise(0).cast(pl.Int64, strict=False).sum(),
        weighted_nta=pl.when(pl.col('extracted_vote') == 'NTA').then(pl.col('score')).otherwise(0).cast(pl.Int64, strict=False).sum(),
    )
    .with_columns(
        delta=pl.col('n_yta') - pl.col('n_nta'),
        weighted_delta=pl.col('weighted_yta') - pl.col('weighted_nta'),
        n_comments=pl.col('n_yta') + pl.col('n_nta')
    )
    .with_columns(
        extracted_majority_vote=(
            pl.when(pl.col('delta') > 0).then(pl.lit('YTA'))
             .when(pl.col('delta') < 0).then(pl.lit('NTA'))
             .otherwise(pl.lit('UNDECIDED'))
        ),
        extracted_weighted_majority_vote=(
            pl.when(pl.col('weighted_delta') > 0).then(pl.lit('YTA'))
             .when(pl.col('weighted_delta') < 0).then(pl.lit('NTA'))
             .otherwise(pl.lit('UNDECIDED'))
        ),
        polarity=(
            pl.when(pl.col('n_comments') > 0)
            .then(
                (pl.col('n_yta') - pl.col('n_nta'))
                / pl.col('n_comments')
            )
            .otherwise(None) # TODO: filter out all nulls?
        ),
    )
    .select(
        'link_id',
        'n_yta',
        'n_nta',
        'n_comments',
        'extracted_majority_vote',
        'extracted_weighted_majority_vote',
        'polarity',
    )
    .drop_nulls()
    .drop_nans()
)

if SHOULD_PRINT_HEAD:
    _ = display(per_submission_extracted_majority_votes_lf.collect(engine='streaming').head(N_HEAD))

link_id,n_yta,n_nta,n_comments,extracted_majority_vote,extracted_weighted_majority_vote,polarity
str,i32,i32,i32,str,str,f64
"""t3_vcehxn""",0,12,12,"""NTA""","""NTA""",-1.0
"""t3_vci8lj""",1,6,7,"""NTA""","""YTA""",-0.714286
"""t3_vcqpfl""",0,4,4,"""NTA""","""NTA""",-1.0
"""t3_vbf5mr""",0,1,1,"""NTA""","""NTA""",-1.0
"""t3_vcunhd""",6,6,12,"""UNDECIDED""","""NTA""",0.0
"""t3_vdcdrr""",13,3,16,"""YTA""","""YTA""",0.625
"""t3_vdg538""",0,12,12,"""NTA""","""NTA""",-1.0
"""t3_vdk5xo""",10,916,926,"""NTA""","""NTA""",-0.978402
"""t3_vdw56d""",19,1,20,"""YTA""","""YTA""",0.9
"""t3_vmh01y""",0,244,244,"""NTA""","""NTA""",-1.0


In [249]:
S = pl.col('score')
R = pl.col('upvote_ratio')
denominator = pl.when((2 * R - 1).abs() > 0).then(2 * R - 1).otherwise(None)

joined_lf = (
    submission_features_lf
    .join(
        per_submission_extracted_majority_votes_lf,
        left_on='name',
        right_on='link_id',
        how='inner',
    )
    .with_columns(
        n_downvotes=(
            (S * (1 - R) / denominator)
            .round(0)
            .cast(pl.Int64, strict=False)
        ),
    )
    .with_columns(
        n_upvotes=(
            pl.when(pl.col('n_downvotes').is_not_null())
            .then(S + pl.col('n_downvotes'))
            .otherwise(None)
            .round(0)
            .cast(pl.Int64, strict=False)
        ),
    )
)

SUBMISSION_PATTERN_DATA_PARQUET = settings.DATA_DIR / 'submission_pattern_data.parquet'
if not SUBMISSION_PATTERN_DATA_PARQUET.exists():
    joined_lf.sink_parquet(SUBMISSION_PATTERN_DATA_PARQUET, engine='streaming')
joined_lf = pl.scan_parquet(SUBMISSION_PATTERN_DATA_PARQUET)

submission_patterns_df = joined_lf.collect(engine='streaming')
_ = display(submission_patterns_df.sort('score', descending=True))

name,ground_truth_majority_vote,text_length,link_flair_text,score,upvote_ratio,n_yta,n_nta,n_comments,extracted_majority_vote,extracted_weighted_majority_vote,polarity,n_downvotes,n_upvotes
str,str,u32,str,i64,f64,i32,i32,i32,str,str,f64,i64,i64
"""t3_zvmflw""","""NTA""",2919,"""Not the A-hole""",53700,0.91,43,1457,1500,"""NTA""","""NTA""",-0.942667,5894,59594
"""t3_wyjbjs""","""NTA""",2398,"""Not the A-hole""",47773,0.97,32,2713,2745,"""NTA""","""NTA""",-0.976685,1525,49298
"""t3_xe62pq""","""NTA""",2192,"""Not the A-hole""",36730,0.96,91,1358,1449,"""NTA""","""NTA""",-0.874396,1597,38327
"""t3_v5r6pf""","""NTA""",1891,"""Not the A-hole""",35429,0.97,6,1023,1029,"""NTA""","""NTA""",-0.988338,1131,36560
"""t3_wcnh0v""","""NTA""",1831,"""Not the A-hole""",33462,0.97,9,1892,1901,"""NTA""","""NTA""",-0.990531,1068,34530
…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""t3_wvylh6""","""NTA""",2809,"""Not the A-hole""",0,0.5,0,2,2,"""NTA""","""NTA""",-1.0,,
"""t3_wvzn06""","""NTA""",2115,"""Not the A-hole""",0,0.25,0,3,3,"""NTA""","""NTA""",-1.0,0,0
"""t3_ww3xpx""","""NTA""",3024,"""Not the A-hole""",0,0.5,1,8,9,"""NTA""","""NTA""",-0.777778,,
"""t3_wwgrwh""","""NTA""",842,"""Not the A-hole""",0,0.5,1,10,11,"""NTA""","""NTA""",-0.818182,,
