In [1]:
import datasets

In [28]:
ds = datasets.load_from_disk("/mnt/disks/persist/user_comments_text_filtered/", keep_in_memory=True)

In [16]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google/bigbird-roberta-base")
max_seq_length = 4096

TO_BINARY = {
    0: 1, # depressed
    1: 0, # control 1 (not depressed)
    2: 0, # control 2 (random)
}

def preprocess_function(examples):
    # Tokenize the texts
    result = tokenizer(examples["text"], max_length=max_seq_length, return_overflowing_tokens=True)
    labels = []
    for input_ids, attention_mask, segment in zip(result["input_ids"], result["attention_mask"], result["overflow_to_sample_mapping"]):
        labels.append(TO_BINARY[examples["depressed_label"][segment]])
        if len(input_ids) < max_seq_length:
            additional_needed = max_seq_length - len(input_ids)
            input_ids.extend([tokenizer.pad_token_id]*additional_needed)
            attention_mask.extend([0]*additional_needed)

    result["labels"] = labels
    del result["overflow_to_sample_mapping"]
    return result

In [26]:
l = []
for i in ds["train"].select(range(0,3)).map(preprocess_function, batched=True, remove_columns=ds["train"].column_names):
    l.append((i["labels"], tokenizer.decode(i["input_ids"])[:200]))

Map:   0%|          | 0/3 [00:00<?, ? examples/s]

[(1,
  '[CLS] Post from /r/bookbinding:<unk>you might have better luck posting this in r/bujo, r/bulletjournal, r/notebooks, or r/stationery<unk>Post from /r/Anticonsumption:<unk>my family got the ones from t'),
 (1,
  '[CLS] have fun and actually express yourself).<unk>Post from /r/AskWomen:<unk>[https://www.youtube.com/watch?v=PE66HEZBZYE](https://www.youtube.com/watch?v=PE66HEZBZYE) a handy dandy adam ruins everyt'),
 (1,
  "[CLS] whale's trying not to move its tail too much and hurt the dolphins on porpoise <unk>Post from /r/learnart:<unk>you could scan and print a copy or two, and then experiment on the copies with more"),
 (1,
  '[CLS] color](https://gd.image-gmkt.com/CUTE-PINK-ORANGE-BAG-FLAT-COMB-TEETH-SQUARE-PLATE-COMB-HAIR-HAIRBRUSH/li/610/936/819936610.g_400-w_g.jpg). less than 10$) so that it gently convinces the tangles'),
 (1,
  '[CLS] how active unions are there and what the society treats as normal amounts of work. for example: a lot of nations have a lot more maternity

In [2]:
ds = datasets.load_from_disk("/mnt/disks/persist/user_comments", keep_in_memory=True)
ds = ds.filter(lambda x: len(x["posts"]) > 0, num_proc=64)
#ds.save_to_disk("/mnt/disks/persist/user_comments_filtered")

Filter (num_proc=64):   0%|          | 0/256852 [00:00<?, ? examples/s]

In [4]:
ds

Dataset({
    features: ['user', 'posts', 'depressed_label'],
    num_rows: 251210
})

In [6]:
def format_post_as_text(posts):
    return "\n\n".join(f"Post from /r/{post['subreddit']}:\n{post['body']}" for post in posts)

def add_text_format(batch):
    batch["text"] = [format_post_as_text(sample) for sample in batch["posts"]]
    return batch

new_ds = ds.map(add_text_format, remove_columns=["user", "posts"], batched=True, num_proc=64)

In [None]:
split = new_ds.train_test_split(test_size=0.15, seed=42)
test_validation = split["test"].train_test_split(test_size=0.5, seed=42)
ds_dict = datasets.DatasetDict({
    "train": split["train"],
    "validation": test_validation["train"],
    "test": test_validation["test"]
})