In [None]:
from datasets import load_dataset
import json
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
SBERT_MODEL = "all-MiniLM-L6-v2"
from collections import Counter
import nltk
import re
from nltk import sent_tokenize

In [None]:
nsfw_dataset = load_dataset("jjmachan/NSFW-questions",split="train")
pro_social_dataset = load_dataset("allenai/prosocial-dialog",split="train")

In [None]:
def match_rot_safetylabels(dataset):
    rots = [item["rots"] for item in dataset]
    safety_annotations = [item["safety_label"] for item in dataset]
    results = {}
    for rots, sfty in zip(rots, safety_annotations):
        for rot in rots:
            if rot not in results.keys():
                results[rot] = sfty
    return results

In [None]:
rot_sfty = match_rot_safetylabels(pro_social_dataset)
all_rots = list(set(rot_sfty.keys()))

In [None]:
def load_vectorizer(model=SBERT_MODEL):
    return SentenceTransformer(model)


def vectorize_text(model, texts):
    return model.encode(texts, show_progress_bar=True)


In [None]:
model = load_vectorizer()


In [None]:
rot_vector = vectorize_text(model,all_rots)


In [None]:
import scipy.spatial as sp
from collections import defaultdict
from tqdm import tqdm

THRESHOLD = 0.65


def match_query_rot(q,m):
   
    cosine_sim = 1 - sp.distance.cdist(q, m, "cosine")
    sim_indices = np.argwhere(cosine_sim >= THRESHOLD)
    return sim_indices
        


In [None]:
BATCH_SIZE = 100
def match_rot_post(dataset):
    
    dic = {}
    posts = [item["title"] for item in dataset]
    post_vector = vectorize_text(model,posts)
    for idx in tqdm(range(0,len(post_vector),BATCH_SIZE)):
        sim_indices = match_query_rot(post_vector[idx:idx+BATCH_SIZE],rot_vector)
        for post_idx,rot_idx in sim_indices:
            rot = all_rots[rot_idx]
            dic.update({dataset[int(post_idx)+idx]['post_id']:{"rots":[rot],
                                             "safety_label":rot_sfty.get(rot)}})
    return dic
           
    
    

In [None]:
print("Turaround perc",len(result_dict)/len(nsfw_dataset) * 100)

In [None]:
def filter_stopwords(example):
    stopwords = ["Ladies","Women","Gals","Men","guys"]
    regex = "".join([f'{word}(,)?|' for word in stopwords])
    example['title'] = re.sub(regex,'',example['title'],flags=re.IGNORECASE)            
    return example

In [None]:
def add_rot_label(example):
    
    post_id = example['post_id']
    if post_id in result_dict.keys():
        example["rots"] = result_dict.get(post_id)['rots']
        example['safety_label'] = result_dict.get(post_id)['safety_label']
        
    return example

In [None]:
def select_response(example):
    
    comments = [example[key] for key in ["C1","C2"] if example[key] is not None]
    comments = [comment for comment in comments if (len(sent_tokenize(comment))>1) and (len(sent_tokenize(comment))<3) ]
    print(comments)
    if comments:
        example["response"] = np.random.choice(comments,1)[0]
    else:
        pass
    
    return example
        
        
    

In [None]:
new_column = [[]] * len(nsfw_dataset)
nsfw_dataset = nsfw_dataset.add_column("rots", new_column)
new_column = [None] * len(nsfw_dataset)
nsfw_dataset = nsfw_dataset.add_column("safety_label", new_column)
nsfw_dataset = nsfw_dataset.add_column("response", new_column)



In [None]:
nsfw_dataset = nsfw_dataset.map(filter_stopwords)
nsfw_dataset = nsfw_dataset.map(add_rot_label)