# Sampling the Reddit Dataset
The original reddit dataset is hosted on the HuggingFace Dataset Repository, an open-source directory in which any contributor can upload and share datasets for model training.  This dataset in particular is exceptionally large (20+GiB) so this notebook was created to provide transparency and document the steps we took to sample this down to something small enough to include with the Applied Machine Learning Prototype. 

**This notebook is NOT intended to be run from start to finish.** 

In fact, in its current state, that is not even possible as several the initial steps were actually completed on a different Colab notebook and other steps require intermediate files which are stored on a local machine.  However, we hope this provides a starting point for those who wish to explore this dataset in more detail, as well as provide motivation for the sample we ultimately include with the repo. 

The original reddit dataset has over 3.8 million posts from thousands of subreddits, which is sampled down to 600+K posts after several pre-processing steps: 

1. I first identify the most popular subreddits by number of posts. You can see the top 30 most popular subreddits
2. I select two slightly different subsamples: 
    * **Top10**: select only those post from the top 10 subreddits by number of posts
    * **Curated10**: select 10 subreddits by the semantic quality of the subreddit name
3. Finally, I sampled within some subreddits: The top 3 subreddits (AskReddit, relationships, leagueoflegends) have an order of magnitude more posts than any other subreddit. To keep the subset manageable, I only include the first 60K posts from each of these subreddits. 

I keep both the **Top10** and **Curated10** as one dataset. Their is overlap between the subreddits contained in these two "top 10" lists for a total of just under 640K posts within 16 popular subreddits. 

In the cell below contains the original code from the Colab notebook that was used to perform this initial subsampling. 

```python

from datasets import load_dataset
dataset = load_dataset("reddit")

"""
Downloading:
4.38k/? [00:01<00:00, 3.32kB/s]

Downloading:
2.83k/? [00:00<00:00, 3.28kB/s]
Using custom data configuration default

Downloading and preparing dataset reddit/default (download: 2.93 GiB, generated: 17.64 GiB, post-processed: Unknown size, total: 20.57 GiB) to /root/.cache/huggingface/datasets/reddit/default/1.0.0/98ba5abea674d3178f7588aa6518a5510dc0c6fa8176d9653a3546d5afcb3969...
Downloading: 100%
3.14G/3.14G [05:26<00:00, 9.62MB/s]

3848330/0 [06:29<00:00, 10409.42 examples/s]
Dataset reddit downloaded and prepared to /root/.cache/huggingface/datasets/reddit/default/1.0.0/98ba5abea674d3178f7588aa6518a5510dc0c6fa8176d9653a3546d5afcb3969. Subsequent calls will reuse this data.
"""

dataset = dataset['train']

# identify unique subreddits
unique_subreddits, counts = np.unique(dataset['subreddit'], return_counts=True)

# count posts in each subreddit; sort by post counts
top_subreddits = [sr for c, sr in sorted(zip(counts, unique_subreddits), reverse=True)]

top_subreddits[:30]
#['AskReddit', 'relationships', 'leagueoflegends', 'tifu', 'relationship_advice', 'trees', 'gaming', 'atheism', 'AdviceAnimals', 'funny', 'politics', 'pics', 'sex', 'WTF', 'explainlikeimfive', 'todayilearned', 'Fitness', 'IAmA', 'worldnews', 'DotA2', 'TwoXChromosomes', 'videos', 'DestinyTheGame', 'reddit.com', 'offmychest', 'buildapc', 'AskMen', 'personalfinance', 'summonerschool', 'technology']

# top10: select top10 subreddits accoring the number of posts
top10_subreddits = top_subreddits[:10]
# curated10: select 10 popular subreddits (in the top 30) that have semantically meaningful subreddit names
curated_subreddits = ['relationships', 'trees', 'gaming', 'funny', 'politics', 'sex', 'Fitness', 'worldnews', 'personalfinance', 'technology']

relevant_subreddits = list(set(top10_subreddits).union(set(curated_subreddits)))

# next we create a mask that will select out posts from the relevant_subreddits
subreddit_mask = np.zeros(len(dataset))

# The top 3 categories each contain between 100k and 600k posts -- an order of magnitude more 
# than any other popular subreddit. 
# Truncate these 3 categories to a max of 60K posts each
askreddit = 0
relationships = 0
lol = 0
max_examples = 60000

for i, sub in enumerate(dataset['subreddit']):
  if sub in relevant_subreddits:
    if sub == 'AskReddit' and askreddit <= max_examples:
      subreddit_mask[i] = 1
      askreddit += 1
    elif sub == 'relationships' and relationships <= max_examples:
      subreddit_mask[i] = 1
      relationships += 1
    elif sub == 'leagueoflegends' and lol <= max_examples:
      subreddit_mask[i] = 1
      lol += 1
    elif sub not in ['AskReddit', 'relationships', 'leagueoflegends']:    
      subreddit_mask[i] = 1
    else:
      continue

subreddit_mask = subreddit_mask == 1

np.sum(subreddit_mask)
#636695 (reduced from over 3.8M!)

# create the subset
subset = dataset[subreddit_mask]

# this file lives on a local machine
subset = pd.DataFrame(subset)
subset.to_csv("reddit_subset.pd")
```

## Create Categorical columns for classification

The fewshot library expects datasets to contain `category` and `label` columns. The `category` column should contain string descriptions, while the `label` column should contain integers, each unique to their respective category. 

In [None]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

filename = "../my_data/reddit/reddit_subset.csv"

In [None]:
df = pd.read_csv(filename)

In [None]:
df.head()

In [None]:
# map the subreddit names to a standardized format to create category names
df['category'] = df['subreddit'].map({
    'atheism': 'atheism',
    'funny': 'funny',
    'sex': 'sex',
    'Fitness': 'fitness',
    'AdviceAnimals': 'advice animals',
    'trees': 'trees',
    'personalfinance': 'personal finance', 
    'relationships': 'relationships',
    'relationship_advice': 'relationship advice',
    'tifu': 'tifu',
    'politics': 'politics',
    'gaming': 'gaming', 
    'worldnews': 'world news',
    'technology': 'technology',
    'leagueoflegends': 'league of legends',
    'AskReddit': 'ask reddit'
    })

df['category'] = pd.Categorical(df.category)
df['label'] = df.category.cat.codes

In [None]:
df.dropna(inplace=True)
df.head()

## Create train/test splits

In [None]:
from sklearn.model_selection import train_test_split

# here we perform a stratified split on `subreddit` though
# we could just as easily split on `category`
X_train, X_test, y_train, y_test = train_test_split(df, df['subreddit'], 
                                                    test_size=.1, 
                                                    random_state=42, 
                                                    stratify=df['subreddit'])

In [None]:
# We see that the number of examples in the training set is not fixed - the dataset is imbalanced. 
np.unique(X_train.subreddit, return_counts=True)

In [None]:
np.unique(X_test.subreddit, return_counts=True)

In [None]:
len(y_train), len(y_test)

In [None]:
X_train.to_csv("../data/reddit/reddit_subset_train.csv")
X_test.to_csv("../data/reddit/reddit_subset_test.csv")

## Which text should we use? 
You might have noticed earlier that there are three different text fields associated with a reddit post: the **content**, **body**, and **summary** columns. Through exploration I've noticed that the **content** and **body** columns are often very similar to each other. 

Below I create a figure that displays the distribution of character counts for each of these three columns for all reddit posts in the test set. I ultimately used the **summary** column as the data on which to classify for two reasons:
1. fewer number of characters in each post mean faster processing and inference time through SentenceBERT
2. the **summary** column is usually a TL;DR which could better encapsulate the idea behind the post and perhaps be more semantically meaningful. 

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def count_characters(x):
    return len(x)

plt.hist(X_test.body.apply(count_characters).values, alpha=.6, range=(0,6000), label='body');
plt.hist(X_test.content.apply(count_characters).values, alpha=.6, range=(0,6000), label='content');
plt.hist(X_test.summary.apply(count_characters).values, alpha=.6, range=(0,6000), label='summary');
plt.legend()

summary_lengths = X_test.summary.apply(count_characters).values

In [None]:
print(np.mean(summary_lengths))
print(np.median(summary_lengths))
print(np.max(summary_lengths))

## Most frequent words in reddit dataset
In our on-the-fly classification regime, we used the most frequently used words, as measured over the word2vec corpus, to create a mapping between SentenceBERT's embedding space and word2vec's embedding space. Another option is to instead use the most frequently used words as measured over the Reddit dataset! The core idea is that using words that are the most common to our dataset could improve the mapping between embedding spaces. 

In the following blocks, we perform this analysis and save that output. 

In [None]:
import string
from collections import Counter
from nltk import FreqDist, word_tokenize

import nltk
nltk.download('stopwords')
nltk.download('punkt')

from nltk.corpus import stopwords

In [None]:
df.dropna(inplace=True)

In [None]:
corpus = ''

for summary in df.summary:
    try:
        corpus += summary
    except:
        print(summary)

In [None]:
thing = word_tokenize(corpus)

In [None]:
len(thing)

In [None]:
stop1 = list(string.punctuation) + ["``", "''", "..."] #
stop2 = stopwords.words("english") + list(string.punctuation) + ["``", "''", "..."]
words1 = [word for word in thing if word not in stop1]
words2 = [word for word in thing if word not in stop2]

In [None]:
print(len(words1))
print(len(words2))

In [None]:
word_freq1 = Counter(words1).most_common(100000)
most_common_words1, counts = [list(c) for c in zip(*word_freq1)]

word_freq2 = Counter(words2).most_common(100000)
most_common_words2, counts = [list(c) for c in zip(*word_freq2)]

In [None]:
most_common_words = {"no_punc": most_common_words1, "no_punc_no_stop": most_common_words2}

In [None]:
import pickle

pickle.dump(most_common_words, open("../data/reddit/most_common_words.pkl", "wb"))

## Subsample the train set because it's too big
Training set is *too big*?? Bet you've never heard that one before. However, in the few-shot analysis we're trying to explore regimes in which we don't have a lot of labeled examples (if any). So, in this case -- our training set is WAY too big!

In the following cells we load the train set we created earlier, isolate just the **curated10** list of subreddits, and sample two sets of 10000 examples (1000 examples for each of the 10 categories). We save these two sets as the official train set and a validation set. 

In [None]:
df_train = pd.read_csv("../my_data/reddit/reddit_subset_train.csv")

In [None]:
len(df_train)

In [None]:
curated_subreddits = ['relationships', 'trees', 'gaming', 'funny', 'politics', \
        'sex', 'Fitness', 'worldnews', 'personalfinance', 'technology']

In [None]:
sample = (
    df_train[df_train.subreddit.isin(curated_subreddits)]
    .groupby('category', group_keys=False)
    .apply(lambda x: x.sample(min(len(x), 2000), random_state=42))
)

In [None]:
X_train, X_valid, y_train, y_valid = train_test_split(sample, sample['subreddit'], 
                                                      test_size=.5, 
                                                      random_state=42, 
                                                      stratify=sample['subreddit'])

In [None]:
X_train.groupby('subreddit')['subreddit'].count()

In [None]:
len(X_train), len(X_valid)

In [None]:
X_train.to_csv("../data/reddit/reddit_subset_train1000.csv")
X_valid.to_csv("../data/reddit/reddit_subset_valid1000.csv")

## Create a subset of the reddit test set that has exactly 1300 examples per category

Earlier we created a test set which was slightly imbalanced between the 16 classes. We decided to sample this test set down to make error analysis easier. We also focus solely on the **curated10** subreddits as classification performed better with these categories than with the **top10** categories (as expected). 

The final test set produced in this cell is included as an artifact in the fewshot repository. 

In [None]:
df_test = pd.read_csv("../my_data/reddit/reddit_subset_test.csv")

# In our experiments we'll work with just 10 of the 16 most popular subreddits
curated_subreddits = ['relationships', 'trees', 'gaming', 'funny', 'politics', \
      'sex', 'Fitness', 'worldnews', 'personalfinance', 'technology']

df_reddit_test = (
  df_test[df_test.subreddit.isin(curated_subreddits)]
  .groupby('category', group_keys=False)
  .apply(lambda x: x.sample(min(len(x), 1300), random_state=42))
  .assign(
      category = lambda df: pd.Categorical(df.category),
      label = lambda df: df.category.cat.codes
      )
  )

# save the .csv version of this test set
df_reddit_test.to_csv("../data/reddit/reddit_subset_test1300.csv")

Now we can create a Dataset object that will also compute and store SentenceBERT embeddings for each example in the test set. 

In [None]:
# NOTE: only run this cell if you have a GPU or some time on your hands

from fewshot.data.loaders import _create_dataset_from_df
from fewshot.utils import pickle_save

filename = "../data/reddit/reddit_dataset_1300.pkl"

# Cast the pandas df as a Dateset object
reddit_test_data = _create_dataset_from_df(df_reddit_test, 'summary')

# Compute SentenceBERT embeddings for each example
reddit_test_data.calc_sbert_embeddings()
pickle_save(reddit_test_data, filename)

In [None]:
from fewshot.utils import pickle_load, torch_load

filename = "../data/agnews/agnews_dataset.pt"
filename = "../data/agnews/agnews_train_dataset.pkl"

try: 
    torch_load(filename)
    print("loaded with torch")
except:
    pickle_load(filename)
    print("loaded with pickle")

***If this documentation includes code, including but not limited to, code examples, Cloudera makes this available to you under the terms of the Apache License, Version 2.0, including any required notices.  A copy of the Apache License Version 2.0 can be found [here](https://opensource.org/licenses/Apache-2.0).***