In [None]:
!pip install transformers torch huggingface_hub accelerate -q
import transformers
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch

In [None]:
!pip install datasets

In [None]:
from datasets import load_dataset

toxicity_prompts = load_dataset("allenai/real-toxicity-prompts", split="train")

In [None]:
toxicity_prompts

In [None]:
import random

def get_prompts(dataset, num=200):
    # Filter for prompts with toxicity > 0.5
    toxic_beginnings = [p for p in dataset if p['prompt']['toxicity'] is not None and p['prompt']['toxicity'] > 0.5]
    non_toxic_beginnings = [p for p in dataset if p['prompt']['toxicity'] is not None and p['prompt']['toxicity'] < 0.5]

    random.seed(42)
    return random.sample(toxic_beginnings, num), random.sample(non_toxic_beginnings, num)

# toxic_samples, non_toxic_samples = get_prompts(toxicity_prompts)
# for sample in non_toxic_samples:
#     print(sample['prompt']['text'])

In [None]:
toxic_sample, non_toxic_sample = get_prompts(toxicity_prompts)
toxic_prompts = [p['prompt']['text'] for p in toxic_sample]
non_toxic_prompts = [p['prompt']['text'] for p in non_toxic_sample]

In [None]:
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto"
)

terminators = [
    pipeline.tokenizer.eos_token_id,
    pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

In [None]:
def get_model_continuations(toxic_prompts):
    model_continuations = []
    i = 0
    for prompt in toxic_prompts:
        print(i)
        i += 1

        generation = pipeline([prompt], max_new_tokens=20, do_sample=False, pad_token_id=50256, eos_token_id=terminators)
        #print(generation[0][0]['generated_text'])
        continuation = generation[0][0]['generated_text'].replace(prompt,'')

        model_continuations.append(continuation)

    return model_continuations



In [None]:
llama_continuations_non_toxic = get_model_continuations(df['non_toxic_prompts'].to_list())
llama_continuations_toxic = get_model_continuations(df['toxic_prompts'].to_list())

In [None]:
import pandas as pd
df = pd.DataFrame({'toxic_prompts':toxic_prompts, 'non_toxic_prompts':non_toxic_prompts, 'llama_continuations_toxic':llama_continuations_toxic, 'gpt_continuations_non_toxic':llama_continuations_non_toxic})

In [None]:
df.to_csv('toxic_continuations_llama.csv')