In [113]:
import datetime as dt
import json
import os
import sys
import re
import random
import csv

In [114]:
def clean_comment(comment):
    return re.sub('\s+', ' ', comment)

def sample_valid(sample):
    return not sample["text"] == "[removed]"

In [115]:
downloads_dir = "downloads/"

kw_files = os.listdir(downloads_dir)

In [116]:
data = {}
for fname in kw_files:
    with open(os.path.join(downloads_dir, fname), "r") as f:
        keyword = fname.replace(".json", "")
        
        data[keyword] = json.load(f)

In [117]:
for key, comments in data.items():
    print(key, len(comments))

hydrocodone 9390
vicodin 3950
isotonitazene 506
heroin 84379
oxycodone 15299
codeine 21046
brorphine 487
purple heroin 175
238 675
carfentanil 950
percocet 5088
fentanyl 28187
morphine 22763
eutylone 408
oxymorphone 2600


In [118]:
def make_sample(comment, key, label):
    return {
        "text": clean_comment(comment['body']),
        "subreddit_label": comment['subreddit'],
        "keyword_label": key,
        "label": label
    }

In [119]:
datasets = {
    "train": {
        "keys": ["morphine", "fentanyl", "codeine", "heroin", "hydrocodone", "oxycodone", "oxymorphone"],
        "samples": []
    },
    "test": {
        "keys": ["percocet", "vicodin", "eutylone", "carfentanil", "brorphine", "isotonitazene"],
        "samples": []
    }
}

for name, dataset in datasets.items():
    for key in dataset["keys"]:
        for comment in data[key]:
            sample = make_sample(comment, key, label=1)
            if not sample_valid(sample):
                continue
            dataset["samples"].append(sample)

In [120]:
scrapes_dir = "scrapes/"

for scrape in os.listdir(scrapes_dir):
    print(scrape)

askreddit_1603615242371.json
askreddit_1603632069744.json


In [121]:
scrape_files = ["askreddit_1603632069744.json"]

In [122]:
scrape_samples = []
for scrape_file in scrape_files:
    path = os.path.join(scrapes_dir, scrape_file)
    with open(path, "r") as f:
        for comment in json.load(f):
            sample = make_sample(comment, "", label=0)
            if not sample_valid(sample):
                continue
            scrape_samples.append(sample)

In [123]:
dataset_lengths = [len(datasets[name]["samples"]) for name in datasets]
dataset_count = sum(dataset_lengths)
dataset_ratios = [length / dataset_count for length in dataset_lengths]
scrape_lengths = list(map(lambda x: int(x * len(scrape_samples)), dataset_ratios))
scrape_lengths[-1] = len(scrape_samples) - sum(scrape_lengths[:-1])

print(dataset_lengths, scrape_lengths)

[181281, 11144] [186242, 11450]


In [124]:
random.shuffle(scrape_samples)
sample_idx = 0
for length, (name, dataset) in zip(scrape_lengths, datasets.items()):
    dataset["samples"].extend(scrape_samples[sample_idx:sample_idx+length])

In [125]:
for name, dataset in datasets.items():
    random.shuffle(dataset["samples"])

In [126]:
out_folder = "datasets/"
dataset_name = "classify-drug-discussion"

for name, dataset in datasets.items():
    out_file = os.path.join(out_folder, dataset_name + "_" + name + ".tsv")
    with open(out_file, "w") as f:
        writer = csv.writer(f, delimiter='\t')

        header = dataset["samples"][0].keys()
        writer.writerow(header)

        for sample in dataset["samples"]:
            writer.writerow(sample.values())