<a href="https://colab.research.google.com/github/Sidhtang/implementation-of-research-papers/blob/main/constitutional_Ai_with_human_feedback.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
from typing import List, Dict, Tuple, Optional, Union
import numpy as np
from dataclasses import dataclass
import json
from tqdm import tqdm
import wandb
import re
from scipy.special import softmax

@dataclass
class ConversationTurn:
    """
    Represents a single turn in a conversation, including context from previous turns.
    """
    prompt: str
    response: str
    context: List[Tuple[str, str]]

class AdvancedConstitutionalRule:
    """
    Represents a constitutional rule for AI behavior with evaluation logic.
    """
    def __init__(self, rule_type: str, threshold: float = 0.8):
        self.rule_type = rule_type
        self.threshold = threshold

        # Define rule-specific criteria and violations
        self.criteria = {
            "safety": [
                "Contains no harmful content",
                "Respects user safety",
                "Avoids dangerous advice",
                "Maintains appropriate boundaries"
            ],
            "ethics": [
                "Respects ethical principles",
                "Promotes fairness and equality",
                "Avoids bias and discrimination",
                "Considers moral implications"
            ]
        }

        self.violations = {
            "safety": [
                "harmful content",
                "dangerous advice",
                "personal information request",
                "inappropriate content"
            ],
            "ethics": [
                "biased language",
                "discriminatory content",
                "unethical suggestions",
                "manipulation tactics"
            ]
        }

    def evaluate(self, text: str, evaluator_model: nn.Module, tokenizer) -> Tuple[bool, float, List[str]]:
        """
        Evaluate text against the constitutional rule.
        """
        # Prepare criteria checking prompts
        prompts = [
            f"Does the following text violate {criterion}? Text: {text}"
            for criterion in self.criteria[self.rule_type]
        ]

        # Tokenize prompts
        inputs = tokenizer(
            prompts,
            padding=True,
            truncation=True,
            return_tensors="pt",
            max_length=512
        ).to(evaluator_model.device)

        # Get model predictions
        with torch.no_grad():
            outputs = evaluator_model(**inputs)
            scores = torch.sigmoid(outputs.logits).mean(dim=1)

        # Check for violations
        passes_rule = scores.mean().item() >= self.threshold
        confidence_score = scores.mean().item()

        # Identify specific violations
        violation_reasons = []
        for violation in self.violations[self.rule_type]:
            if re.search(violation, text.lower()):
                violation_reasons.append(f"Contains {violation}")

        return passes_rule, confidence_score, violation_reasons

class ConversationDataset(Dataset):
    """
    Dataset class for handling conversation data.
    """
    def __init__(self, conversations: List[ConversationTurn], tokenizer):
        self.conversations = conversations
        self.tokenizer = tokenizer

        # Set up padding token if it doesn't exist
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def __len__(self):
        return len(self.conversations)

    def __getitem__(self, idx):
        turn = self.conversations[idx]
        context_text = " ".join([f"{p} {r}" for p, r in turn.context])
        full_prompt = f"{context_text} {turn.prompt}" if context_text else turn.prompt

        inputs = self.tokenizer(
            full_prompt,
            return_tensors="pt",
            truncation=True,
            padding="max_length",
            max_length=512
        )
        labels = self.tokenizer(
            turn.response,
            return_tensors="pt",
            truncation=True,
            padding="max_length",
            max_length=512
        )

        return {
            "input_ids": inputs.input_ids[0],
            "attention_mask": inputs.attention_mask[0],
            "labels": labels.input_ids[0]
        }

class AdvancedConstitutionalAI:
    """
    Advanced Constitutional AI model with rule enforcement and distributed training capability.
    """
    def __init__(
        self,
        base_model_name: str,
        evaluator_model_name: str,
        rules: List[AdvancedConstitutionalRule],
        distributed: bool = False,
        device: str = "cuda" if torch.cuda.is_available() else "cpu"
    ):
        self.device = device
        self.distributed = distributed

        # Initialize base model and tokenizer
        self.base_model = AutoModelForCausalLM.from_pretrained(base_model_name).to(device)
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)

        # Set up padding token for tokenizer
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.base_model.config.pad_token_id = self.tokenizer.pad_token_id

        if distributed:
            self.base_model = DDP(self.base_model)

        # Initialize evaluator model and tokenizer
        self.evaluator_model = AutoModelForSequenceClassification.from_pretrained(
            evaluator_model_name,
            num_labels=1  # Binary classification
        ).to(device)
        self.evaluator_tokenizer = AutoTokenizer.from_pretrained(evaluator_model_name)

        # Set up padding token for evaluator tokenizer
        if self.evaluator_tokenizer.pad_token is None:
            self.evaluator_tokenizer.pad_token = self.evaluator_tokenizer.eos_token

        # Store rules
        self.rules = rules

        # Initialize optimizer
        self.optimizer = Adam(self.base_model.parameters(), lr=1e-5)

    def evaluate_response(self, response: str) -> Tuple[bool, List[Dict]]:
        """
        Evaluate a response against all constitutional rules.
        """
        rule_results = []
        passes_all_rules = True

        for rule in self.rules:
            passes_rule, confidence, violations = rule.evaluate(
                response,
                self.evaluator_model,
                self.evaluator_tokenizer
            )

            rule_results.append({
                "rule_type": rule.rule_type,
                "passes": passes_rule,
                "confidence": confidence,
                "violations": violations
            })

            passes_all_rules = passes_all_rules and passes_rule

        return passes_all_rules, rule_results

    def generate_response(
        self,
        prompt: str,
        max_length: int = 100,
        num_attempts: int = 3
    ) -> Tuple[str, List[Dict]]:
        """
        Generate a response that complies with constitutional rules.
        """
        for _ in range(num_attempts):
            # Generate response
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            outputs = self.base_model.generate(
                **inputs,
                max_length=max_length,
                num_return_sequences=1,
                pad_token_id=self.tokenizer.pad_token_id
            )

            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

            # Evaluate response
            passes_rules, rule_results = self.evaluate_response(response)

            if passes_rules:
                return response, rule_results

        # If all attempts fail, return the last response with its evaluation results
        return response, rule_results

    def save_checkpoint(self, path: str, epoch: int, metrics: Dict):
        """Save model checkpoint and training state."""
        if not self.distributed or (self.distributed and dist.get_rank() == 0):
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': self.base_model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'metrics': metrics
            }
            torch.save(checkpoint, path)

    def load_checkpoint(self, path: str) -> Dict:
        """Load model checkpoint and training state."""
        checkpoint = torch.load(path, map_location=self.device)
        self.base_model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        return checkpoint['metrics']

def create_sample_dataset():
    """Create sample conversation data for testing"""
    sample_conversations = [
        ConversationTurn(
            prompt="What is artificial intelligence?",
            response="Artificial intelligence is the simulation of human intelligence by machines. It involves creating systems that can learn, reason, and solve problems.",
            context=[]
        ),
        ConversationTurn(
            prompt="How do neural networks work?",
            response="Neural networks are computing systems inspired by biological neural networks. They learn from examples to recognize patterns in data.",
            context=[("What is artificial intelligence?", "Artificial intelligence is the simulation of human intelligence by machines.")]
        ),
        ConversationTurn(
            prompt="Explain machine learning.",
            response="Machine learning is a subset of AI that enables systems to learn and improve from experience without being explicitly programmed.",
            context=[
                ("What is artificial intelligence?", "Artificial intelligence is the simulation of human intelligence by machines."),
                ("How do neural networks work?", "Neural networks are computing systems inspired by biological neural networks.")
            ]
        )
    ]
    return sample_conversations

def train_epoch(model: AdvancedConstitutionalAI,
                dataloader: DataLoader,
                rank: Optional[int] = None) -> float:
    """Train for one epoch"""
    model.base_model.train()
    total_loss = 0

    progress_bar = tqdm(dataloader, disable=rank is not None and rank != 0)
    for batch in progress_bar:
        model.optimizer.zero_grad()

        # Move batch to device
        batch = {k: v.to(model.device) for k, v in batch.items()}

        # Forward pass
        outputs = model.base_model(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )

        loss = outputs.loss
        total_loss += loss.item()

        # Backward pass
        loss.backward()
        model.optimizer.step()

        # Update progress bar
        progress_bar.set_description(f"Loss: {loss.item():.4f}")

        # Log to wandb
        if rank is None or rank == 0:
            wandb.log({"batch_loss": loss.item()})

    avg_loss = total_loss / len(dataloader)
    return avg_loss

def train_distributed(rank: int,
                     world_size: int,
                     model_config: Dict,
                     training_config: Dict,
                     dataset: ConversationDataset):
    """Handle distributed training across multiple GPUs"""
    # Initialize process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    # Initialize model with DDP
    model = AdvancedConstitutionalAI(**model_config, distributed=True, device=f"cuda:{rank}")

    # Create distributed sampler and dataloader
    sampler = torch.utils.data.DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank
    )

    dataloader = DataLoader(
        dataset,
        batch_size=training_config.get('batch_size', 4),
        sampler=sampler
    )

    # Training loop
    for epoch in range(training_config['num_epochs']):
        sampler.set_epoch(epoch)  # Important for proper shuffling
        avg_loss = train_epoch(model, dataloader, rank)

        if rank == 0:  # Save checkpoint only on main process
            model.save_checkpoint(
                f"checkpoint_epoch_{epoch}.pt",
                epoch,
                {"loss": avg_loss}
            )

    dist.destroy_process_group()

def main():
    """Main entry point for training"""
    # Initialize wandb
    wandb.init(project="constitutional-ai")

    # Initialize rules
    rules = [
        AdvancedConstitutionalRule("safety"),
        AdvancedConstitutionalRule("ethics")
    ]

    # Model configuration
    model_config = {
        "base_model_name": "gpt2-medium",
        "evaluator_model_name": "bert-base-uncased",
        "rules": rules
    }

    # Training configuration
    training_config = {
        "num_epochs": 10,
        "batch_size": 4,
        "learning_rate": 1e-5
    }

    # Initialize dataset with sample data
    conversations = create_sample_dataset()
    tokenizer = AutoTokenizer.from_pretrained(model_config["base_model_name"])
    dataset = ConversationDataset(conversations, tokenizer)

    # Launch distributed training if multiple GPUs available
    world_size = torch.cuda.device_count()
    if world_size > 1:
        mp.spawn(
            train_distributed,
            args=(world_size, model_config, training_config, dataset),
            nprocs=world_size,
            join=True
        )
    else:
        # Single GPU training
        model = AdvancedConstitutionalAI(**model_config)
        dataloader = DataLoader(
            dataset,
            batch_size=training_config['batch_size'],
            shuffle=True
        )

        # Training loop
        for epoch in range(training_config['num_epochs']):
            avg_loss = train_epoch(model, dataloader)
            print(f"Epoch {epoch + 1}/{training_config['num_epochs']}, Loss: {avg_loss:.4f}")

            # Save checkpoint
            model.save_checkpoint(
                f"checkpoint_epoch_{epoch}.pt",
                epoch,
                {"loss": avg_loss}
            )

if __name__ == "__main__":
    main()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/718 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/1.52G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Loss: 11.4402: 100%|██████████| 1/1 [01:39<00:00, 99.97s/it]


Epoch 1/10, Loss: 11.4402


  0%|          | 0/1 [00:00<?, ?it/s]

In [2]:
!pip install gradio

Collecting gradio
  Downloading gradio-5.4.0-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl.metadata (9.7 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.4-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.4.0-py3-none-any.whl.metadata (2.9 kB)
Collecting gradio-client==1.4.2 (from gradio)
  Downloading gradio_client-1.4.2-py3-none-any.whl.metadata (7.1 kB)
Collecting httpx>=0.24.1 (from gradio)
  Downloading httpx-0.27.2-py3-none-any.whl.metadata (7.1 kB)
Collecting huggingface-hub>=0.25.1 (from gradio)
  Downloading huggingface_hub-0.26.1-py3-none-any.whl.metadata (13 kB)
Collecting markupsafe~=2.0 (from gradio)
  Downloading MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)
Collecting orjson~=3.0 (from gradio)
  Downloading orjson-3.10.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.w

In [4]:
import gradio as gr
import torch
from typing import List, Dict, Tuple
import json
from dataclasses import dataclass
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
import re

class AdvancedConstitutionalRule:
    """
    Represents a constitutional rule for AI behavior with evaluation logic.
    """
    def __init__(self, rule_type: str, threshold: float = 0.8):
        self.rule_type = rule_type
        self.threshold = threshold

        # Define rule-specific criteria and violations
        self.criteria = {
            "safety": [
                "Contains no harmful content",
                "Respects user safety",
                "Avoids dangerous advice",
                "Maintains appropriate boundaries"
            ],
            "ethics": [
                "Respects ethical principles",
                "Promotes fairness and equality",
                "Avoids bias and discrimination",
                "Considers moral implications"
            ]
        }

        self.violations = {
            "safety": [
                "harmful content",
                "dangerous advice",
                "personal information request",
                "inappropriate content"
            ],
            "ethics": [
                "biased language",
                "discriminatory content",
                "unethical suggestions",
                "manipulation tactics"
            ]
        }

    def evaluate(self, text: str, model: AutoModelForSequenceClassification,
                tokenizer: AutoTokenizer) -> Tuple[bool, float, List[str]]:
        """
        Evaluates text against the rule using the provided model.

        Args:
            text: Text to evaluate
            model: Pre-trained model for evaluation
            tokenizer: Tokenizer for the model

        Returns:
            Tuple containing:
            - Boolean indicating if text passes the rule
            - Confidence score
            - List of detected violations
        """
        # Tokenize input text
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

        # Get model prediction
        with torch.no_grad():
            outputs = model(**inputs)
            scores = torch.sigmoid(outputs.logits).squeeze().cpu().numpy()

        # Calculate confidence score (average of relevant criteria scores)
        confidence = float(np.mean(scores))

        # Check for violations
        detected_violations = []
        for violation in self.violations[self.rule_type]:
            # Simple keyword checking - could be enhanced with more sophisticated detection
            if violation in text.lower():
                detected_violations.append(violation)

        # Determine if text passes the rule
        passes_rule = confidence >= self.threshold and len(detected_violations) == 0

        return passes_rule, confidence, detected_violations

class ConstitutionalAIInterface:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # Initialize models
        self.base_model_name = "gpt2-medium"
        self.evaluator_model_name = "bert-base-uncased"

        self.base_model = AutoModelForCausalLM.from_pretrained(self.base_model_name).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name)

        self.evaluator_model = AutoModelForSequenceClassification.from_pretrained(
            self.evaluator_model_name,
            num_labels=1
        ).to(self.device)
        self.evaluator_tokenizer = AutoTokenizer.from_pretrained(self.evaluator_model_name)

        # Initialize rules
        self.rules = [
            AdvancedConstitutionalRule("safety"),
            AdvancedConstitutionalRule("ethics")
        ]

        # Set up padding tokens
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.base_model.config.pad_token_id = self.tokenizer.pad_token_id

        if self.evaluator_tokenizer.pad_token is None:
            self.evaluator_tokenizer.pad_token = self.evaluator_tokenizer.eos_token

        # Initialize conversation history
        self.conversation_history = []

        # Define forbidden patterns and safe responses
        self.forbidden_patterns = [
            r"how to .*?(steal|hack|fraud|scam)",
            r"(steal|hack|fraud|scam).*?money",
            r"illegal.*?(activity|way|method)",
            # Add more patterns as needed
        ]

        self.safe_responses = {
            "harmful_intent": "I apologize, but I cannot assist with harmful or illegal activities. If you need financial assistance, I'd be happy to suggest legal resources and legitimate ways to earn or manage money.",
            "unsafe_query": "I cannot provide information about harmful or dangerous activities. Instead, I'd be happy to suggest safe and legal alternatives.",
            "default": "I aim to be helpful while ensuring safety and ethics. Could you please rephrase your request in a way that doesn't involve potential harm?"
        }

    def is_safe_query(self, text: str) -> bool:
        """Check if the query is safe based on patterns"""
        text = text.lower()
        return not any(re.search(pattern, text) for pattern in self.forbidden_patterns)

    def get_safe_response(self, query: str) -> str:
        """Return appropriate safe response based on query type"""
        if any(word in query.lower() for word in ["steal", "hack", "fraud", "scam"]):
            return self.safe_responses["harmful_intent"]
        elif not self.is_safe_query(query):
            return self.safe_responses["unsafe_query"]
        return None

    def evaluate_response(self, response: str) -> Tuple[bool, List[Dict]]:
        rule_results = []
        passes_all_rules = True

        for rule in self.rules:
            passes_rule, confidence, violations = rule.evaluate(
                response,
                self.evaluator_model,
                self.evaluator_tokenizer
            )

            rule_results.append({
                "rule_type": rule.rule_type,
                "passes": passes_rule,
                "confidence": confidence,
                "violations": violations
            })

            passes_all_rules = passes_all_rules and passes_rule

        return passes_all_rules, rule_results

    def generate_and_evaluate(self, prompt: str, max_length: int = 100) -> Tuple[str, str, str]:
        # First, check if the prompt is safe
        safe_response = self.get_safe_response(prompt)
        if safe_response:
            response = safe_response
            passes_rules, rule_results = True, [
                {
                    "rule_type": "safety",
                    "passes": True,
                    "confidence": 1.0,
                    "violations": []
                },
                {
                    "rule_type": "ethics",
                    "passes": True,
                    "confidence": 1.0,
                    "violations": []
                }
            ]
        else:
            # Generate response for safe queries
            context_text = " ".join([f"{p} {r}" for p, r in self.conversation_history[-3:]])
            full_prompt = f"{context_text} {prompt}" if context_text else prompt

            inputs = self.tokenizer(full_prompt, return_tensors="pt").to(self.device)
            outputs = self.base_model.generate(
                **inputs,
                max_length=max_length,
                num_return_sequences=1,
                pad_token_id=self.tokenizer.pad_token_id,
                do_sample=True,
                temperature=0.7,
                top_p=0.9
            )
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

            # Evaluate response
            passes_rules, rule_results = self.evaluate_response(response)

            # If response doesn't pass rules, use safe response
            if not passes_rules:
                response = self.safe_responses["default"]
                passes_rules, rule_results = self.evaluate_response(response)

        # Format evaluation results
        evaluation_text = ""
        for result in rule_results:
            evaluation_text += f"\n{result['rule_type'].upper()} Evaluation:\n"
            evaluation_text += f"Passes: {result['passes']}\n"
            evaluation_text += f"Confidence: {result['confidence']:.2f}\n"
            if result['violations']:
                evaluation_text += f"Violations: {', '.join(result['violations'])}\n"

        # Update conversation history
        self.conversation_history.append((prompt, response))

        # Format conversation history
        history_text = "\nConversation History:\n"
        for i, (p, r) in enumerate(self.conversation_history[-3:], 1):
            history_text += f"\nTurn {i}:\nUser: {p}\nAI: {r}\n"

        return response, evaluation_text, history_text

    def clear_history(self) -> Tuple[str, str, str]:
        self.conversation_history = []
        return "", "Conversation history cleared.", ""

def create_gradio_interface():
    ai_interface = ConstitutionalAIInterface()

    with gr.Blocks(title="Constitutional AI Interface") as demo:
        gr.Markdown("# Constitutional AI Demo")
        gr.Markdown("This interface demonstrates an AI model with constitutional rules for safety and ethics.")

        with gr.Row():
            with gr.Column(scale=2):
                prompt_input = gr.Textbox(
                    label="Enter your prompt",
                    placeholder="Type your message here...",
                    lines=3
                )
                with gr.Row():
                    submit_btn = gr.Button("Submit", variant="primary")
                    clear_btn = gr.Button("Clear History", variant="secondary")

            with gr.Column(scale=3):
                response_output = gr.Textbox(
                    label="AI Response",
                    lines=4,
                    interactive=False
                )
                evaluation_output = gr.Textbox(
                    label="Rule Evaluation Results",
                    lines=6,
                    interactive=False
                )
                history_output = gr.Textbox(
                    label="Conversation History",
                    lines=10,
                    interactive=False
                )

        submit_btn.click(
            ai_interface.generate_and_evaluate,
            inputs=[prompt_input],
            outputs=[response_output, evaluation_output, history_output]
        )

        clear_btn.click(
            ai_interface.clear_history,
            inputs=[],
            outputs=[prompt_input, evaluation_output, history_output]
        )

    return demo

if __name__ == "__main__":
    demo = create_gradio_interface()
    demo.launch()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://699ab819fc32585ce6.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
