In [1]:
from datasets import load_dataset
import json
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
SBERT_MODEL = "all-MiniLM-L6-v2"
from collections import Counter
import nltk
import re
from nltk import sent_tokenize

In [2]:
nsfw_dataset = load_dataset("jjmachan/NSFW-questions",split="train")
pro_social_dataset = load_dataset("allenai/prosocial-dialog",split="train")

Downloading readme:   0%|          | 0.00/663 [00:00<?, ?B/s]

Using custom data configuration jjmachan--NSFW-questions-d386c5d23f1fe7f9


Downloading and preparing dataset None/None to /home/shahul/.cache/huggingface/datasets/jjmachan___parquet/jjmachan--NSFW-questions-d386c5d23f1fe7f9/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/886k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/12858 [00:00<?, ? examples/s]

Dataset parquet downloaded and prepared to /home/shahul/.cache/huggingface/datasets/jjmachan___parquet/jjmachan--NSFW-questions-d386c5d23f1fe7f9/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.


Using custom data configuration allenai--prosocial-dialog-ebbad39ca08b6d44
Found cached dataset json (/home/shahul/.cache/huggingface/datasets/allenai___json/allenai--prosocial-dialog-ebbad39ca08b6d44/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)


In [3]:
nsfw_dataset

Dataset({
    features: ['title', 'subreddit', 'post_id', 'score', 'link_flair_text', 'is_self', 'over_18', 'upvote_ratio', 'is_question', 'C1', 'C2', 'C3', 'C4', 'C5'],
    num_rows: 12858
})

In [4]:
def match_rot_safetylabels(dataset):
    rots = [item["rots"] for item in dataset]
    safety_annotations = [item["safety_label"] for item in dataset]
    results = {}
    for rots, sfty in zip(rots, safety_annotations):
        for rot in rots:
            if rot not in results.keys():
                results[rot] = sfty
    return results

In [5]:
rot_sfty = match_rot_safetylabels(pro_social_dataset)

In [6]:
all_rots = list(set(rot_sfty.keys()))

In [7]:
def load_vectorizer(model=SBERT_MODEL):
    return SentenceTransformer(model)


def vectorize_text(model, texts):
    return model.encode(texts, show_progress_bar=True)


In [8]:
model = load_vectorizer()


In [9]:
rot_vector = vectorize_text(model,all_rots)


Batches:   0%|          | 0/3630 [00:00<?, ?it/s]

In [10]:
import scipy.spatial as sp
from collections import defaultdict
from tqdm import tqdm

THRESHOLD = 0.65


def match_query_rot(q,m):
   
    cosine_sim = 1 - sp.distance.cdist(q, m, "cosine")
    sim_indices = np.argwhere(cosine_sim >= THRESHOLD)
    return sim_indices
        


In [11]:
BATCH_SIZE = 100
def match_rot_post(dataset):
    
    dic = {}
    posts = [item["title"] for item in dataset]
    post_vector = vectorize_text(model,posts)
    for idx in tqdm(range(0,len(post_vector),BATCH_SIZE)):
        sim_indices = match_query_rot(post_vector[idx:idx+BATCH_SIZE],rot_vector)
        for post_idx,rot_idx in sim_indices:
            rot = all_rots[rot_idx]
            dic.update({dataset[int(post_idx)+idx]['post_id']:{"rots":[rot],
                                             "safety_label":rot_sfty.get(rot)}})
    return dic
           
    
    

In [15]:
result_dict = match_rot_post(nsfw_dataset)

Batches:   0%|          | 0/402 [00:00<?, ?it/s]

100%|█████████████████████████| 129/129 [04:03<00:00,  1.89s/it]


In [16]:
print("Turaround perc",len(result_dict)/len(nsfw_dataset) * 100)

Turaround perc 11.214807901695442


In [17]:
new_column = [[]] * len(nsfw_dataset)
nsfw_dataset = nsfw_dataset.add_column("rots", new_column)
new_column = [None] * len(nsfw_dataset)
nsfw_dataset = nsfw_dataset.add_column("safety_label", new_column)
nsfw_dataset = nsfw_dataset.add_column("response", new_column)



In [18]:
nsfw_dataset

Dataset({
    features: ['title', 'subreddit', 'post_id', 'score', 'link_flair_text', 'is_self', 'over_18', 'upvote_ratio', 'is_question', 'C1', 'C2', 'C3', 'C4', 'C5', 'rots', 'safety_label', 'response'],
    num_rows: 12858
})

In [19]:
def filter_stopwords(example):
    stopwords = ["Ladies","Women","Gals","Men","guys"]
    regex = "".join([f'{word}(,)?|' for word in stopwords])
    example['title'] = re.sub(regex,'',example['title'],flags=re.IGNORECASE)            
    return example

In [20]:
def add_rot_label(example):
    
    post_id = example['post_id']
    if post_id in result_dict.keys():
        example["rots"] = result_dict.get(post_id)['rots']
        example['safety_label'] = result_dict.get(post_id)['safety_label']
        
    return example

In [21]:
def select_response(example):
    
    comments = [example[key] for key in ["C1","C2"] if example[key] is not None]
    comments = [comment for comment in comments if (len(sent_tokenize(comment))>1) and (len(sent_tokenize(comment))<3) ]
    print(comments)
    if comments:
        example["response"] = np.random.choice(comments,1)[0]
    else:
        pass
    
    return example
        
        
    

In [22]:
nsfw_dataset = nsfw_dataset.map(filter_stopwords)

  0%|          | 0/12858 [00:00<?, ?ex/s]

In [23]:
nsfw_dataset = nsfw_dataset.map(add_rot_label)

  0%|          | 0/12858 [00:00<?, ?ex/s]

In [28]:
nsfw_dataset = nsfw_dataset.filter(lambda example : example['safety_label'])

Loading cached processed dataset at /home/shahul/.cache/huggingface/datasets/jjmachan___parquet/jjmachan--NSFW-questions-d386c5d23f1fe7f9/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-3198ff58649f4e1a.arrow


In [32]:
nsfw_dataset=nsfw_dataset.rename_columns({"title":"user"})
nsfw_dataset=nsfw_dataset.remove_columns(["score","is_self","upvote_ratio",'C1', 'C2', 'C3', 'C4', 'C5','response'])

In [39]:
nsfw_dataset = nsfw_dataset.shuffle()

In [47]:
nsfw_dataset.push_to_hub("shahules786/prosocial-nsfw-reddit",token="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")

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
nsfw_dataset = load_dataset("shahules786/prosocial-nsfw-reddit")
[item["post_id"] for item in nsfw_dataset ]