In [None]:
from transformers import pipeline

import json
import numpy as np

import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots

import torch

In [None]:
### Load sentiment classification model.

sentiment_model_name = "cardiffnlp/twitter-roberta-base-sentiment-latest"
sentiment_task = pipeline("sentiment-analysis", model=sentiment_model_name, tokenizer=sentiment_model_name, device_map="auto")

In [3]:
def get_mean_classifications(generation_file_name):
    with open(generation_file_name) as f:
        generations = json.load(f)
    
    with torch.no_grad():
        classifications = sentiment_task(generations, top_k=None)

        classifications_flattened = [
            x
            for xs in classifications
            for x in xs
        ]

        pos_scores = [c["score"] for c in classifications_flattened if c["label"] == "positive"]
        neutral_scores = [c["score"] for c in classifications_flattened if c["label"] == "neutral"]
        neg_scores = [c["score"] for c in classifications_flattened if c["label"] == "negative"]

        return np.mean(pos_scores), np.mean(neutral_scores), np.mean(neg_scores)
    

In [9]:
model_names = ["meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-13b-chat-hf", "meta-llama/Meta-Llama-3-8B-Instruct"]

generation_modes = ["default", "steered_alpha_0.25", "steered_alpha_-0.25",]

output_path= "generations/"

classifications = dict()

In [None]:
for model_name in model_names:
    classifications[model_name] = dict()
    file_loc = output_path + model_name  + "/"
    for mode in generation_modes:
        classifications[model_name][mode] = get_mean_classifications(file_loc + mode + ".json")

In [11]:
generation_modes = ["default", "steered_alpha_0.25", "steered_alpha_-0.25"]

def draw_bar_diagrams(classifications, model_name, generation_modes):
    categories = ["positive", "neutral", "negative"]


    fig = make_subplots(rows=1, cols=len(generation_modes), subplot_titles=["Default", "+Happiness", "-Happiness"])
    

    colors = ['#fef0d9','#fdcc8a','#fc8d59','#d7301f'][1:]
    colors = ["#1f77b4", "#7f7f7f", "#ff7f0e"] 

    for i, mode in enumerate(generation_modes):
        fig.add_trace(go.Bar(x=categories, y=classifications[model_name][mode], name=mode, marker_color=colors), row=1, col=i+1) 
        fig.update_yaxes(range=[0, 1], row=1, col=i+1)


    fig.update_layout(
        yaxis_title="mean score",
        showlegend=False,
    )

    fig.update_layout(width=450, height=250, 
                  font_family="Serif", font_size=14, 
                  margin_l=5, margin_t=25, margin_b=5, margin_r=5)

    fig.update_layout(template='plotly_white')
    return fig


In [18]:
fig1 = draw_bar_diagrams(classifications, "meta-llama/Meta-Llama-3-8B-Instruct", generation_modes)
fig1.show()