In [1]:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
import numpy as np
import altair as alt

In [2]:
# Load a pre-trained model
model = SentenceTransformer('all-MiniLM-L6-v2')

In [24]:
data = pd.read_json('base_input_output_pairs.json')
data.rename(columns={'output': 'output_base'}, inplace=True)

In [25]:
data2 = pd.read_json('input_output_pairs.json')
data2.rename(columns={'output': 'output_fine_tuned'}, inplace=True)

In [26]:
s_class = ["Zero-shot",
           "Zero-shot",
           "Zero-shot",
           "One-shot",
           "One-shot",
           "One-shot",
           "Few-shot",
           "Few-shot",
           "Few-shot",
           "Negative",
           "Negative",
           "Negative",
           "COT",
           "COT",
           "COT"]
s_class = pd.Series(s_class, name="prompt_class")

In [27]:
data_final = pd.merge(data, data2)

In [28]:
data_final['prompt_type'] = s_class

In [29]:
data_final.head()

Unnamed: 0,input,output_base,output_fine_tuned,prompt_type
0,List the most pressing topics regarding regula...,List the most pressing topics regarding regula...,List the most pressing topics regarding regula...,Zero-shot
1,Who owns material generated by a company’s lar...,Who owns material generated by a company’s lar...,Who owns material generated by a company’s lar...,Zero-shot
2,Describe how China and the United States are a...,Describe how China and the United States are a...,Describe how China and the United States are a...,Zero-shot
3,Task: You are a chat bot called AI PolicyChat....,Task: You are a chat bot called AI PolicyChat....,Task: You are a chat bot called AI PolicyChat....,One-shot
4,Task: You are a chat bot called AI PolicyChat....,Task: You are a chat bot called AI PolicyChat....,Task: You are a chat bot called AI PolicyChat....,One-shot


### Calculate Similarity

In [30]:
def similarity_score(row):
    embedding1 = model.encode([row['output_base']])
    embedding2 = model.encode([row['output_fine_tuned']])
    #print("Embedding 1: ", embedding1)
    #print("Embedding 2: ", embedding2)
    similarity = cosine_similarity([embedding1[0]], [embedding2[0]])
    return similarity[0][0]

In [31]:
data_final['similarity'] = data_final.apply(similarity_score, axis=1)
data_final = data_final.reset_index()

In [32]:
data_final.head()

Unnamed: 0,index,input,output_base,output_fine_tuned,prompt_type,similarity
0,0,List the most pressing topics regarding regula...,List the most pressing topics regarding regula...,List the most pressing topics regarding regula...,Zero-shot,0.892532
1,1,Who owns material generated by a company’s lar...,Who owns material generated by a company’s lar...,Who owns material generated by a company’s lar...,Zero-shot,0.68693
2,2,Describe how China and the United States are a...,Describe how China and the United States are a...,Describe how China and the United States are a...,Zero-shot,0.84477
3,3,Task: You are a chat bot called AI PolicyChat....,Task: You are a chat bot called AI PolicyChat....,Task: You are a chat bot called AI PolicyChat....,One-shot,0.955902
4,4,Task: You are a chat bot called AI PolicyChat....,Task: You are a chat bot called AI PolicyChat....,Task: You are a chat bot called AI PolicyChat....,One-shot,0.960199


In [33]:
data_final.describe()

Unnamed: 0,index,similarity
count,15.0,15.0
mean,7.0,0.910794
std,4.472136,0.086604
min,0.0,0.68693
25%,3.5,0.860675
50%,7.0,0.92605
75%,10.5,0.978024
max,14.0,1.0


In [34]:
alt.Chart(data_final).mark_bar().encode(
    y=alt.X("index:N").title("Prompt Number"),
    x=alt.Y("similarity:Q").title("Cosine Similarity"),
    color="prompt_type"
).properties(
    title="Model Output Cosine Similarity by Input Prompt"
).interactive()