# Constitutional AI: Exploring Self-Critique in Uncensored Models

1. Install the required packages

In [None]:
!pip install -r requirements.txt

2. Run the experiments

In [None]:
import json
from tqdm import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the JSON files
with open('Constitutional_AI_Principles.json', 'r') as f:
    principles = json.load(f)

with open('prompts.json', 'r') as f:
    prompts = json.load(f)

# Load the model and tokenizer from Hugging Face
model_name = "georgesung/llama2_7b_chat_uncensored"
# model_name = "georgesung/llama3_8b_chat_uncensored"

device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

# Define the function to get a response
def get_response(prompt, context=None, max_length=128):
    # Format context as a natural conversation
    if context:
        conversation = ""
        if "context" in context:
            conversation += f"{context['context']}\n"
        if "initial_response" in context:
            conversation += f"Assistant: {context['initial_response']}\n"
        if "critique" in context:
            conversation += f"{context['critique']}\n"
            
        formatted_prompt = f"{conversation}{prompt}"
    else:
        formatted_prompt = prompt

    # Tokenize input
    input_ids = tokenizer.encode(formatted_prompt, return_tensors="pt").to(device)

    # Generate response
    output_ids = model.generate(
        input_ids,
        max_length=max_length,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )

    # Decode and return the output
    response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return response[len(formatted_prompt):].strip()  # Remove the prompt from the output

# Modified version with context tracking
interactions = {"neutral": {}, "negative": {}}

for critique_type, categories in tqdm(prompts.items(), desc="Critique Types"):
    for category, prompts in tqdm(categories.items(), desc="Categories", leave=False):
        interactions[critique_type][category] = []
        
        for prompt in tqdm(prompts, desc="Prompts", leave=False):
            # Initialize conversation context
            conversation_history = {}
            print(f"📝 Generating interactions for {category} with {critique_type} critique")

            print(f"📝 Prompt: {prompt}")
            # Get initial response with context
            initial_response = get_response(prompt, context=conversation_history)
            conversation_history = {"context": prompt}
            conversation_history["initial_response"] = initial_response['response']
            print(f"🤖 Initial response: {initial_response['response']}")
            
            # Get critique with updated context
            critique_prompt = principles[0]["CritiqueRequest"] # We currently only use the first critique principle
            critique_response = get_response(critique_prompt, context=conversation_history)
            print(f"🤖 Critique: {critique_response['response']}")
            
            # Update context with critique
            conversation_history["critique"] = critique_response['response']
            
            # Get revision with full context
            revision_prompt = principles[0]["RevisionRequest"] # We currently only use the first revision principle
            revised_response = get_response(revision_prompt, context=conversation_history)
            print(f"🤖 Revised response: {revised_response['response']}")
            
            # Save the complete interaction with context
            interaction = {
                "prompt": prompt,
                "initial_response": initial_response['response'],
                "critique": critique_response['response'],
                "revised_response": revised_response['response']
            }
            
            interactions[critique_type][category].append(interaction)
            # Write to file
            with open("interactions.json", 'w') as f:
                json.dump(interactions, f, indent=4)
            print(f"✅ Interaction saved\n")