In [17]:
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification 
import torch.nn.functional as F
import pandas as pd

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pretrained tokenizer and our trained model
model_checkpoint = "EleutherAI/gpt-neo-125m"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
tokenizer.pad_token = tokenizer.eos_token

# Load the trained weights
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=5)
model_weights_path = "model_weights/conversation_classifier.pth"
model.load_state_dict(torch.load(model_weights_path, map_location=DEVICE))
model.to(DEVICE)  # Move the model to the appropriate device
model.eval()

Some weights of GPTNeoForTokenClassification were not initialized from the model checkpoint at EleutherAI/gpt-neo-125m and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  model.load_state_dict(torch.load(model_weights_path, map_location=DEVICE))


GPTNeoForTokenClassification(
  (transformer): GPTNeoModel(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(2048, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPTNeoBlock(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoSelfAttention(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=768, out_features=768, bias=False)
            (v_proj): Linear(in_features=768, out_features=768, bias=False)
            (q_proj): Linear(in_features=768, out_features=768, bias=False)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTNeoMLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (c_proj): L

In [18]:
MAX_SEQ_LENGTH = 512

# These weights are the inverse of the class frequencies in the training data
# Using to try to correct for the imbalance in the training data
def calculate_class_weights():
    # Calculate class weights
    class_frequencies = {
        0: 0.2546,  # Very Negative
        1: 0.3903,  # Negative
        2: 0.2477,  # Neutral
        3: 0.0456,  # Positive
        4: 0.0620   # Very Positive
    }
    # Calculate inverse of frequencies
    inverse_frequencies = {class_id: 1/freq for class_id, freq in class_frequencies.items()}
    
    # Normalize weights so they sum to 1
    total = sum(inverse_frequencies.values())
    normalized_weights = {class_id: inv_freq/total for class_id, inv_freq in inverse_frequencies.items()}
    
    # Convert to tensor
    class_weights = torch.tensor([normalized_weights[i] for i in range(len(class_frequencies))]).float()
    
    return class_weights

CLASS_WEIGHTS = calculate_class_weights().to(DEVICE)


def create_confidence_histogram(probabilities):
    labels = ["Very Negative", "Negative", "Neutral", "Positive", "Very Positive"]
    colors = ['red', 'orange', 'gray', 'lightblue', 'blue']
    
    # Get the probabilities for the last token
    last_token_probs = probabilities[0, -1, :].cpu().numpy()
    
    histogram_html = "<div style='display: flex; align-items: center; font-size: 12px; margin-bottom: 5px;'>"
    for i, (label, prob, color) in enumerate(zip(labels, last_token_probs, colors)):
        bar_width = int(prob * 100)
        histogram_html += f"""
        <div style='margin-right: 5px; text-align: center;'>
            <div style='background-color: {color}; width: {bar_width}px; height: 15px;'></div>
            <div>{label}: {prob:.2f}</div>
        </div>
        """
    histogram_html += "</div>"

    return histogram_html

# Load and sample the dataset
def load_and_sample_data(file_path, n_dialogues_per_rating=5):
    df = pd.read_csv(file_path)
    df.dropna()

    # Get unique ratings
    unique_ratings = df['rating'].unique()

    sampled_dialogues = []
    for rating in unique_ratings:
        if rating == 'nan':
            # Skip NaN ratings
            continue
        # Get dialogues for this rating
        rating_dialogues = df[df['rating'] == rating]['dialogueID'].drop_duplicates()
        
        # Determine how many dialogues to sample
        n_available = len(rating_dialogues)
        n_to_sample = min(n_dialogues_per_rating, n_available)
        
        # Sample dialogues, or take all if fewer than requested
        if n_to_sample > 0:
            sampled = rating_dialogues.sample(n=n_to_sample, replace=True)
            sampled_dialogues.extend(sampled)
    
    return df[df['dialogueID'].isin(sampled_dialogues)].sort_values(['dialogueID', 'date'])

def highlight_text(text, probabilities, predicted_classes):
    tokens = tokenizer.convert_ids_to_tokens(tokenizer(text, return_tensors="pt", max_length=MAX_SEQ_LENGTH).input_ids[0])
    sentiment_map = {0: "Very Negative", 1: "Negative", 2: "Neutral", 3: "Positive", 4: "Very Positive"}
    
    highlighted_text = ""
    for i, token in enumerate(tokens):
        if token in [tokenizer.pad_token, '<|endoftext|>']:
            continue
        
        sentiment_class = predicted_classes[0, i].item()
        sentiment = sentiment_map[sentiment_class]
        confidence_score = probabilities[0, i, sentiment_class].item()
        
        if sentiment in ["Very Positive", "Positive"]:
            color = f"rgba(0, 0, 255, {confidence_score})"  # Blue with varying opacity
        elif sentiment in ["Very Negative", "Negative"]:
            color = f"rgba(255, 165, 0, {confidence_score})"  # Orange with varying opacity
        else:
            color = "transparent"
        
        highlighted_text += f'<span style="background-color:{color}">{token.replace("Ġ", " ").replace("Ċ", "")}</span>'
    
    return highlighted_text


# Map the predicted classes to sentiment labels
sentiment_map = {0: "Very Negative", 1: "Negative", 2: "Neutral", 3: "Positive", 4: "Very Positive"}
def predict_rating(text):
    # Tokenize the input text
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    inputs = inputs.to(DEVICE)

    # Set the model to evaluation mode
    model.eval()

    # Disable gradient calculations
    with torch.no_grad():
        # Get the model outputs
        outputs = model(**inputs)
        logits = outputs.logits

    # Apply class weights to logits
    weighted_logits = logits * CLASS_WEIGHTS.unsqueeze(0).unsqueeze(0)

    # Convert logits to probabilities
    probabilities = F.softmax(weighted_logits, dim=-1)
    last_token_probs = probabilities[0, -1, :]

    # Get the predicted classes for each token
    predicted_classes = torch.argmax(weighted_logits, dim=-1)

    # Get the overall sentiment (sentiment of the last token)
    overall_sentiment_class = predicted_classes[0, -1].item()
    overall_sentiment = sentiment_map[overall_sentiment_class]

    highlighted_text = highlight_text(text, probabilities, predicted_classes)

    # Add this at the end of the function
    histogram_html = create_confidence_histogram(last_token_probs.unsqueeze(0).unsqueeze(0))

    return overall_sentiment, highlighted_text, histogram_html

# Process the sampled data
def process_samples(samples):
    results = []
    for dialogue_id, dialogue in samples.groupby('dialogueID'):
        dialogue_text = ""
        dialogue_highlighted = ""
        for _, turn in dialogue.iterrows():
            turn_text = f"{turn['from']}: {turn['text']}\n"
            dialogue_text += turn_text
            
            sentiment, highlighted_text, histogram_html = predict_rating(turn_text)
            dialogue_highlighted += f"{highlighted_text}<br>{histogram_html}<br>"

        overall_sentiment, _, overall_histogram_html = predict_rating(dialogue_text)
        results.append({
            'DialogueID': dialogue_id,
            'Dialogue': dialogue_highlighted,
            'Predicted Rating': overall_sentiment,
            'Overall Histogram': overall_histogram_html,
            'GPT Labeled Rating': sentiment_map[dialogue['rating'].mode().values[0]]  # Most common rating in the dialogue
        })
    return pd.DataFrame(results)

# Handled pasted in conversation
def process_conversation(conversation):
    turns = conversation.split('\n')
    processed_turns = []
    
    for turn in turns:
        if ':' in turn:
            name, text = turn.split(':', 1)
            sentiment, highlighted_text, histogram_html = predict_rating(text)
            processed_turns.append(f"<strong>{name.strip()}:</strong> {highlighted_text}<br>{histogram_html}<br>")
    
    return "".join(processed_turns)

# Create the Gradio interface
def create_interface():
    # Load and process the samples
    samples = load_and_sample_data('data/10k-dialogues-sample-labeled.csv')
    results = process_samples(samples)

    # Create HTML table for results
    table_html = "<table style='width:100%; border-collapse: collapse;'>"
    table_html += "<tr><th>Dialogue</th><th>Predicted Rating</th><th>GPT Labeled Rating</th></tr>"
    
    for _, row in results.iterrows():
        predicted_sentiment = row['Predicted Rating']
        gpt_sentiment = row['GPT Labeled Rating']
        
        def get_sentiment_color(sentiment):
            if sentiment in ["Positive", "Very Positive"]:
                return "rgba(0, 0, 255, 1)"  # Blue
            elif sentiment in ["Negative", "Very Negative"]:
                return "rgba(255, 165, 0, 1)"  # Orange
            else:
                return "transparent"
        
        predicted_color = get_sentiment_color(predicted_sentiment)
        gpt_color = get_sentiment_color(gpt_sentiment)
        
        table_html += f"<tr><td style='border: 1px solid #ddd; padding: 8px; word-wrap: break-word;'>{row['Dialogue']}</td>"
        table_html += f"<td style='border: 1px solid #ddd; padding: 8px; background-color: {predicted_color}; color: black;'>{predicted_sentiment}</td>"
        table_html += f"<td style='border: 1px solid #ddd; padding: 8px; background-color: {gpt_color}; color: black;'>{gpt_sentiment}</td></tr>"
    
    table_html += "</table>"

    # Add the results table
    with gr.Blocks() as demo:
        gr.Markdown("# Conversation Turn Prediction")
        gr.Markdown(f"""
        ## How to Read the Colors
        - Blue: Positive sentiment (darker blue indicates higher confidence)
        - Orange: Negative sentiment (darker orange indicates higher confidence)
        - No color: Neutral sentiment
        You'll see the model change its opinions as it reads the text
        """)
        gr.Interface(
            fn=process_conversation,
            inputs=gr.Textbox(lines=5, label="Enter conversation (format: 'name: text' for each turn)"),
            outputs=gr.HTML(label="Processed Conversation"),
            title="Conversation Turn Prediction",
            description="Enter a conversation with the format 'name: text' for each turn. The model will highlight the text based on its predicted rating.",
        )
        gr.Markdown("## Sample Results")
        gr.HTML(table_html)

    return demo

# Launch the interface
demo = create_interface()
demo.launch(share=True)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


* Running on local URL:  http://127.0.0.1:7870
* Running on public URL: https://74ffdb736a8b26cec4.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


