Dependencies

In [81]:
import pandas as pd
import re

data = pd.read_csv('diffusion_db_unaltered.csv')
nsfw_words = pd.read_csv('nsfw_words.csv')
sfw_words = pd.read_csv('sfw_words.csv')

This cell handles preliminary variables such as p(NSFW) and p(SFW) which determine the probability a prompt is nsfw or sfw using the image_nsfw score

In [None]:
prompt_count = len(data)
nsfw_count = len(data[data['image_nsfw'] >= .5])
sfw_count = len(data[data['image_nsfw'] < .5])

p_nsfw = nsfw_count / prompt_count
p_sfw = sfw_count / prompt_count

379
91189
0.0850015
0.9149985


In [None]:
commonwords = [
    "the", "a", "an", "is", "are", "and", "to", "in", "that", "have", "has", 
    "with", "this", "these", "those", "it", "for", "on", "be", "was", "were", 
    "can", "will", "would", "should", "may", "might", "do", "does", "did", "of", "its", "their"
]

In [128]:
def cleanPrompt(prompt):
    prompt_list = prompt.lower().split(' ')
    prompt_list = [word for word in prompt_list if word not in commonwords]
    prompt = ' '.join(prompt_list)
    return re.sub(r'[^a-zA-Z0-9 ]', '', prompt)



def classifyMessage(message, alpha):
    safe = p_sfw
    unsafe = p_nsfw
    
    # This sections gets the p(word|N) and p(word|S) for each word in the message.
    word_probs_nsfw = {}
    word_probs_sfw = {}

    for word in message.split():
        if word not in word_probs_nsfw.keys():
            word_probs_nsfw[word] = [alpha, 1] # First index is the probability, the second index is the amount of times that the word appears in the prompt
            word_probs_sfw[word] = [alpha, 1]

            # check if word is in nsfw_words dataframe
            if word in nsfw_words.values:
                word_probs_nsfw[word] = [(nsfw_words[nsfw_words['word'] == word]['count'].values[0]) / len(nsfw_words), 1]
            else:
                word_probs_nsfw[word] = [alpha / len(nsfw_words), 1]

            if word in sfw_words.values:
                word_probs_sfw[word] = [(sfw_words[sfw_words['word'] == word]['count'].values[0])  / len(sfw_words), 1]
            else:
                word_probs_sfw[word] = [alpha / len(sfw_words), 1]
        else:
            word_probs_nsfw[word][1]+=1
            word_probs_sfw[word][1]+=1

    for word in word_probs_nsfw.keys():
        # multiply safe and unsafe probabilities by the word probabilities but each probability is to the exponent of the amount of occurances
        safe *= word_probs_sfw[word][0] ** word_probs_sfw[word][1]
        unsafe *= word_probs_nsfw[word][0] ** word_probs_nsfw[word][1]

    # for word, score in word_probs_nsfw.items():
    #     print(f"{word}: {score[0]}")

    # print("\nSAFE FOR WORK WORDS:")

    # for word, score in word_probs_sfw.items():
    #     print(f"{word}: {score[0]}")

    if safe > unsafe:
        return "SFW"
    else:
        return "NSFW"



In [129]:
examples = pd.read_csv('examples.csv')['Prompt']

results = {
    "Prompt": [],
    "Results": []
}

for example in examples:
    result = classifyMessage(cleanPrompt(example), 1)
    results['Prompt'].append(example)
    if result == "SFW":
        results['Results'].append(1)
    else:
        results['Results'].append(0)

results = pd.DataFrame(results)
results.to_csv('results.csv', index=False)

In [130]:
results = pd.read_csv('results.csv')['Results']
examples = pd.read_csv('examples.csv')['Results']

score = 0
for i in range(len(results)):
    if results[i] == examples[i]:
        score += 1

print(f"Accuracy: {score / len(results)}")

Accuracy: 0.4878048780487805
