In [None]:
import polars as pl
import src.settings as settings
from typing import Optional
from enum import Enum

In [None]:
COMMENTS_NDJSON = settings.DATA_DIR / 'amitheasshole_comments.ndjson'
COMMENTS_PARQUET = settings.DATA_DIR / 'amitheasshole_comments.parquet'
SUBMISSIONS_NDJSON = settings.DATA_DIR / 'amitheasshole_submissions.ndjson'
SUBMISSIONS_PARQUET = settings.DATA_DIR / 'amitheasshole_submissions.parquet'

class Flairs(Enum):
    YTA = 'Asshole'
    NTA = 'Not the A-hole'
    ESH = 'Everyone Sucks'
    NAH = 'No A-holes here'

SHOULD_PRINT_INFO = False
SHOULD_PRINT_HEAD = False
N_HEAD = 10
USE_TOP_LEVEL_COMMENTS_ONLY = True
N_BIGGEST_SUBMISSIONS: Optional[int] = 2
SHOULD_MAP_OTHER_FLAIRS = True # Map YTA_LIKE to YTA and NTA_LIKE to NTA
YTA_LIKE = ('YTA', 'YWBTA', 'ESH')
NTA_LIKE = ('NTA', 'YWNBTA', 'NAH')

# (Convert to) and load from parquet into lazy frames

In [None]:
if SUBMISSIONS_PARQUET.exists():
    submissions_lf = pl.scan_parquet(SUBMISSIONS_PARQUET)
else:
    submissions_lf = (
        pl.scan_ndjson(
            SUBMISSIONS_NDJSON,
            schema_overrides={
                'edited': pl.Utf8,
            },
        )
        # TODO:
        .with_columns(
            # TODO:
            pl.when(pl.col('edited').is_null() | (pl.col('edited') == 'false'))
            .then(False)
            .otherwise(True)
            .alias('edited')
        )
        .select([
            'author',
            'edited',
            'link_flair_text',
            'name',
            'num_comments',
            'over_18',
            'score',
            'selftext',
            'title',
            'upvote_ratio',
        ])
    )

    submissions_lf.sink_parquet(SUBMISSIONS_PARQUET, engine='streaming')

In [None]:
if COMMENTS_PARQUET.exists():
    comments_lf = pl.scan_parquet(COMMENTS_PARQUET)
else:
    comments_lf = (
        pl.scan_ndjson(
            COMMENTS_NDJSON,
            schema_overrides={
                'edited': pl.Utf8,
            },
        )
        # TODO:
        .with_columns(
            pl.when(pl.col('edited').is_null() | (pl.col('edited') == 'false'))
            .then(False)
            .otherwise(True)
            .alias('edited')
        )
        .select([
            'author',
            'body',
            'edited',
            'is_submitter',
            'link_id',
            'parent_id',
            'score',
        ]) # TODO: controversiality?
    )

    comments_lf.sink_parquet(COMMENTS_PARQUET, engine='streaming')

# Submissions filtering

In [None]:
# FIXME: extract these values to global constant?
filtered_submissions_lf = submissions_lf.filter(
    pl.col('over_18').is_not_null()
    & ~pl.col('over_18')
    & pl.col('selftext').is_not_null()
    & ~pl.col('selftext').is_in(('[deleted]', '[removed]'))
    & pl.col('author').is_not_null()
    & ~pl.col('author').is_in(('Judgement_Bot_AITA', 'AutoModerator')) # TODO: often, deleted authors still have post up and not deleted, so don't filter on [deleted]?
    & pl.col('link_flair_text').is_in((Flairs.NTA.value, Flairs.YTA.value))
    & pl.col('edited').is_not_null()
    & ~pl.col('edited') # TODO: perhaps too aggresive, you might lose interesting posts
    & pl.all_horizontal(
        pl.col('num_comments', 'score', 'title', 'upvote_ratio').is_not_null(),
    )
)

if SHOULD_PRINT_INFO:
    print(f'row count submissions original: {submissions_lf.select(pl.len()).collect(engine='streaming').item()}')
    print(f'row count submissions filtered: {filtered_submissions_lf.select(pl.len()).collect(engine='streaming').item()}')

In [None]:
sorted_submissions_lf = filtered_submissions_lf.sort(pl.col('num_comments'), descending=True) 
selected_submissions_lf = filtered_submissions_lf if N_BIGGEST_SUBMISSIONS is None else sorted_submissions_lf.limit(N_BIGGEST_SUBMISSIONS)

In [None]:
if SHOULD_PRINT_HEAD:
    _ = display(selected_submissions_lf.collect(engine='streaming').head(N_HEAD))

# Comments filtering

In [None]:
# FIXME: extract these values to global constant?
# TODO: null checks
filtered_comments_lf = comments_lf.filter(
    ~pl.col('body').is_in(('[deleted]', '[removed]'))
    & ~pl.col('author').is_in(('Judgement_Bot_AITA', 'AutoModerator')) # TODO: is this the same case as for submissions?
    & ~pl.col('edited') # TODO: perhaps too aggresive, you might lose interesting posts
    & ~pl.col('is_submitter')
)

if USE_TOP_LEVEL_COMMENTS_ONLY:
    filtered_comments_lf = filtered_comments_lf.filter(
        pl.col('parent_id').str.starts_with('t3_') # Comments always start with t1_, submissions with t3_, and subreddits with t5_ #TODO: dit klopt toch?
    )

if SHOULD_PRINT_INFO:
    print(f'row count comments original: {comments_lf.select(pl.len()).collect(engine='streaming').item()}')
    print(f'row count comments filtered: {filtered_comments_lf.select(pl.len()).collect(engine='streaming').item()}')

In [None]:
# TODO: necessary if we take all submissions?
matched_comments_lf = filtered_comments_lf.join(
    selected_submissions_lf,
    left_on='parent_id',
    right_on='name',
    how='semi',
)

# Calculate submission features

In [None]:
# TODO: debug
matched_comments_lf = matched_comments_lf.cache()

In [None]:
submission_features_lf = (
    selected_submissions_lf
    .with_columns(
        ground_truth_majority_vote=(
            pl.when(pl.col('link_flair_text') == 'Asshole').then(pl.lit('YTA'))
            .when(pl.col('link_flair_text') == Flairs.ESH.value).then(pl.lit('ESH'))
            .when(pl.col('link_flair_text') == Flairs.NTA.value).then(pl.lit('NTA'))
            .when(pl.col('link_flair_text') == Flairs.NAH.value).then(pl.lit('NAH'))
            .otherwise(None)
        ),
        text_length=(
            pl.col('title').str.len_chars()
            + pl.col('selftext').str.len_chars()
        ),
    )
    .filter(pl.col('ground_truth_majority_vote').is_not_null())
    .select(
        'name',
        'ground_truth_majority_vote',
        'text_length',
        'link_flair_text', # TODO: DEBUG
        'score',
        'upvote_ratio',
    )
)

if SHOULD_MAP_OTHER_FLAIRS:
    submission_features_lf = (
        submission_features_lf
        .with_columns(
            pl.when(pl.col('ground_truth_majority_vote') == 'ESH').then(pl.lit('YTA'))
            .when(pl.col('ground_truth_majority_vote') == 'NAH').then(pl.lit('NTA'))
            .otherwise(pl.col('ground_truth_majority_vote'))
            .alias('ground_truth_majority_vote')
        )
    )

if SHOULD_PRINT_HEAD:
    _ = display(submission_features_lf.collect(engine='streaming').head(N_HEAD))

# Extract vote from each comment

In [None]:
body_lowercase = pl.col('body').str.to_lowercase()
comment_votes_lf = (
    matched_comments_lf
    .filter(~pl.col('body').str.contains(r'\bINFO\b', literal=True))
    .with_columns(
        extracted_vote=(
            pl.when(body_lowercase.str.contains(r'\bywbta\b', literal=False)).then(pl.lit('YWBTA'))
            .when(body_lowercase.str.contains(r'\bywnbta\b', literal=False)).then(pl.lit('YWNBTA'))
            .when(body_lowercase.str.contains(r'\byta\b', literal=False)).then(pl.lit('YTA'))
            .when(body_lowercase.str.contains(r'\bnta\b', literal=False)).then(pl.lit('NTA'))
            .when(body_lowercase.str.contains(r'\besh\b', literal=False)).then(pl.lit('ESH'))
            .when(body_lowercase.str.contains(r'\bnah\b', literal=False)).then(pl.lit('NAH'))
            .otherwise(None)
        )
    )
    .filter(pl.col('extracted_vote').is_not_null())
    .select(['link_id', 'score', 'extracted_vote', 'body']) # TODO: remove body
)

if SHOULD_MAP_OTHER_FLAIRS:
    comment_votes_lf = comment_votes_lf.with_columns(
        pl.when(pl.col('extracted_vote').is_in(YTA_LIKE)).then(pl.lit('YTA'))
        .when(pl.col('extracted_vote').is_in(NTA_LIKE)).then(pl.lit('NTA'))
        .alias('extracted_vote')
    )

if SHOULD_PRINT_HEAD:
    _ = display(comment_votes_lf.sort('score').collect(engine='streaming').head(N_HEAD))

# Calculate (weighted) majority vote based on extracted comment votes

In [None]:
per_submission_extracted_majority_votes_lf = (
    comment_votes_lf
    .group_by('link_id')
    .agg(
        n_yta=(pl.col('extracted_vote') == 'YTA').sum(),
        n_nta=(pl.col('extracted_vote') == 'NTA').sum(),
        weighted_yta=pl.when(pl.col('extracted_vote') == 'YTA').then(pl.col('score')).otherwise(0).sum(),
        weighted_nta=pl.when(pl.col('extracted_vote') == 'NTA').then(pl.col('score')).otherwise(0).sum(),
    )
    .with_columns(
        delta=pl.col('n_yta') - pl.col('n_nta'),
        weighted_delta=pl.col('weighted_yta') - pl.col('weighted_nta'),
        n_comments=pl.col('n_yta') + pl.col('n_nta')
    )
    .with_columns(
        extracted_majority_vote=(
            pl.when(pl.col('delta') > 0).then(pl.lit('YTA'))
             .when(pl.col('delta') < 0).then(pl.lit('NTA'))
             .otherwise(pl.lit('UNDECIDED'))
        ),
        extracted_weighted_majority_vote=(
            pl.when(pl.col('weighted_delta') > 0).then(pl.lit('YTA'))
             .when(pl.col('weighted_delta') < 0).then(pl.lit('NTA'))
             .otherwise(pl.lit('UNDECIDED'))
        ),
        polarity=(
            pl.when(pl.col('n_comments') > 0)
            .then(
                (pl.col('n_yta') - pl.col('n_nta'))
                / pl.col('n_comments')
            )
            .otherwise(None) # TODO: filter out all nulls?
        ),
    )
    .select(
        'link_id',
        'n_yta',
        'n_nta',
        'n_comments',
        'extracted_majority_vote',
        'extracted_weighted_majority_vote',
        'polarity',
    )
)

if SHOULD_PRINT_HEAD:
    _ = display(per_submission_extracted_majority_votes_lf.collect(engine='streaming').head(N_HEAD))

In [None]:
S = pl.col('score')
R = pl.col('upvote_ratio')
denominator = 2 * R - 1
n_downvotes = S * (1 - R) / denominator

joined_lf = (
    submission_features_lf
    .join(
        per_submission_extracted_majority_votes_lf,
        left_on='name',
        right_on='link_id',
        how='inner',
    )
    .with_columns(
        n_downvotes=(
            pl.when(denominator != 0)
            .then(n_downvotes.round(0).cast(pl.UInt32))
            .otherwise(None)
        ),
    )
    .with_columns(
        n_upvotes=(
            pl.when(denominator != 0)
            .then((S + n_downvotes).round(0).cast(pl.UInt32))
            .otherwise(None)
        ),
    )
)

SUBMISSION_PATTERN_DATA_PARQUET = settings.DATA_DIR / 'submission_pattern_data.parquet'
if not SUBMISSION_PATTERN_DATA_PARQUET.exists():
    joined_lf.sink_parquet(SUBMISSION_PATTERN_DATA_PARQUET, engine='streaming')
joined_lf = pl.scan_parquet(SUBMISSION_PATTERN_DATA_PARQUET)

_ = display(joined_lf.collect(engine='streaming'))

# if SHOULD_PRINT_HEAD:
#     _ = display(joined_lf.collect(engine='streaming').head(N_HEAD))