In [1]:
from djangoapp.news.vertex.cloud.connections_based_on_docs import VertexAI


In [2]:
llm = VertexAI()


import pickle
import pandas as pd
# import data from pickle
df = pd.read_csv('../data/merged_titles_labels.csv')
df.head()


Unnamed: 0,title,is_clickbait
0,House Dem Aide: We Didn’t Even See Comey’s Let...,1
1,"FLYNN: Hillary Clinton, Big Woman on Campus - ...",0
2,Why the Truth Might Get You Fired,1
3,15 Civilians Killed In Single US Airstrike Hav...,1
4,Iranian woman jailed for fictional unpublished...,1


In [3]:
# select 2500 random samples - 1250 from each class
import random
n_samples = 1000
random.seed(42)
df = df.sample(frac=1).reset_index(drop=True)
df = df.groupby('is_clickbait').head(n_samples//2)
df = df.sample(frac=1).reset_index(drop=True)
df.is_clickbait.value_counts()

is_clickbait
0    500
1    500
Name: count, dtype: int64

In [4]:
df['llm_pred'] = None
df

Unnamed: 0,title,is_clickbait,llm_pred
0,Earthquake shakes western Mexico,0,
1,Brandon Marshall receives threatening letter; ...,0,
2,Visa Ban Amended to Allow Iraqi Interpreters I...,0,
3,Star Wars: What do fans make of the trailer fo...,1,
4,61 businesses Trump has targeted on Twitter,1,
...,...,...,...
995,Workplace reform campaign most expensive in Au...,0,
996,US Secret Service Struggles To Find Recruits W...,1,
997,Time to Unravel the Knot of Credit-Default Swaps,0,
998,"Which ""Shameless"" Character Are You Based On Y...",1,


In [6]:
from tqdm.auto import tqdm
import numpy as np
preds = np.array([])
for i in tqdm(range(0, n_samples)):
    try:
        print(df.iloc[i]['title'])
        pred = llm.run(title=df.iloc[i]['title'])
        preds = np.append(preds, pred)
    except:
        print("error when running LLM")
        preds = np.append(preds, np.nan)
        
    print("")
preds    
df['llm_pred'] = preds

  0%|          | 0/1000 [00:00<?, ?it/s]

Earthquake shakes western Mexico
Prediction: 0 for prompt: Is this title a clickbait: 'Earthquake shakes western Mexico'? Return 1 if yes, 0 if no.
Return value: False

Brandon Marshall receives threatening letter; Broncos investigating
Prediction: 1 for prompt: Is this title a clickbait: 'Brandon Marshall receives threatening letter; Broncos investigating'? Return 1 if yes, 0 if no.
Return value: True

Visa Ban Amended to Allow Iraqi Interpreters Into U.S. 
Prediction: 0 for prompt: Is this title a clickbait: 'Visa Ban Amended to Allow Iraqi Interpreters Into U.S. '? Return 1 if yes, 0 if no.
Return value: False

Star Wars: What do fans make of the trailer for The Last Jedi?
Prediction: 1 for prompt: Is this title a clickbait: 'Star Wars: What do fans make of the trailer for The Last Jedi?'? Return 1 if yes, 0 if no.
Return value: True

61 businesses Trump has targeted on Twitter
Prediction: 1 for prompt: Is this title a clickbait: '61 businesses Trump has targeted on Twitter'? Return

In [7]:
df

Unnamed: 0,title,is_clickbait,llm_pred
0,Earthquake shakes western Mexico,0,0.0
1,Brandon Marshall receives threatening letter; ...,0,1.0
2,Visa Ban Amended to Allow Iraqi Interpreters I...,0,0.0
3,Star Wars: What do fans make of the trailer fo...,1,1.0
4,61 businesses Trump has targeted on Twitter,1,1.0
...,...,...,...
995,Workplace reform campaign most expensive in Au...,0,1.0
996,US Secret Service Struggles To Find Recruits W...,1,1.0
997,Time to Unravel the Knot of Credit-Default Swaps,0,1.0
998,"Which ""Shameless"" Character Are You Based On Y...",1,1.0


In [9]:
from sklearn.metrics import roc_auc_score, f1_score

df.dropna(inplace=True)

print(f"ROC AUC: {roc_auc_score(df['is_clickbait'], df['llm_pred'])}")
print(f"F1 Score: {f1_score(df['is_clickbait'], df['llm_pred'] > 0.5)}")

ROC AUC: 0.6614309603440038
F1 Score: 0.7274247491638797


In [None]:
# plot ROC curve
from sklearn.metrics import roc_curve
import matplotlib.pyplot as plt

fpr, tpr, thresholds = roc_curve(df['is_clickbait'], df['llm_pred'])
# add random line
plt.plot([0, 1], [0, 1], linestyle='--')
plt.plot(fpr, tpr)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
# make plot have white background
plt.gca().set_facecolor('w')

plt.title('ROC Curve')
