# Setup

In [1]:
import os
import pandas as pd

from pprint import pprint

from sklearn.metrics import accuracy_score

from skllm import MultiLabelZeroShotGPTClassifier
from skllm.config import SKLLMConfig

In [2]:
#See notes in INSTALL.md for how to set this, DO NOT HARD CODE YOUR API KEY HERE, 
# if your repository is public, then someone will steal your API key and make you pay for their shit 
SKLLMConfig.set_openai_key(os.environ.get('OPENAI_API_KEY'))

# Collect a Dataset

In [3]:
# Load the Reddit life tips dataset
data = pd.read_csv('data/helpfulRedditPosts.csv')

# Understand Your Data

In [4]:
data.head()

Unnamed: 0,id,author,isOver18,postUrl,subreddit,postTitle,hasPostBody,postBody,score,numComments,Unnamed: 10,Unnamed: 11,Unnamed: 12,Unnamed: 13,Unnamed: 14,HumanLabel
0,f6jt5e,w2555,False,https://reddit.com/r/LifeProTips/comments/f6jt5e/,LifeProTips,"LPT: keep your mouth shut, and don't volunteer...",True,I had a phone interview scheduled this morning...,147296,4730,,,,,,['Work']
1,lq1jn7,this1tyme,False,https://reddit.com/r/LifeProTips/comments/lq1jn7/,LifeProTips,"LPT: Texans, you are about to experience the w...",False,,134320,4121,,,,,,['Other']
2,j2mm1b,raviji22,False,https://reddit.com/r/LifeProTips/comments/j2mm1b/,LifeProTips,"LPT: When you sign up for anything online, put...",False,,129513,1971,,,,,,['Other']
3,fqkkke,[deleted],False,https://reddit.com/r/LifeProTips/comments/fqkkke/,LifeProTips,"LPT: First rule of family gatherings, always b...",True,[deleted],124219,2762,,,,,,['Family']
4,gmmiah,AlphaSyncz,False,https://reddit.com/r/YouShouldKnow/comments/gm...,YouShouldKnow,"YSK That there is a Youtuber called ""Dad, how ...",True,It's just basic stuff but I know friends of mi...,120038,1532,,,,,,['Family']


In [5]:
data.describe()

Unnamed: 0,score,numComments,Unnamed: 10,Unnamed: 11,Unnamed: 12,Unnamed: 13,Unnamed: 14
count,9.0,9.0,0.0,0.0,0.0,0.0,0.0
mean,125228.444444,2604.222222,,,,,
std,10248.934592,1217.272954,,,,,
min,115511.0,964.0,,,,,
25%,118582.0,1792.0,,,,,
50%,120038.0,2633.0,,,,,
75%,129513.0,2933.0,,,,,
max,147296.0,4730.0,,,,,


In [6]:
X = data['postTitle']

In [7]:
X

0    LPT: keep your mouth shut, and don't volunteer...
1    LPT: Texans, you are about to experience the w...
2    LPT: When you sign up for anything online, put...
3    LPT: First rule of family gatherings, always b...
4    YSK That there is a Youtuber called "Dad, how ...
5    LPT: Try not to be mean or toxic in online gam...
6    LPT: Always tell a child who is wearing a helm...
7    LPT: If you want a smarter kid, teach your chi...
8    LPT: Don't be fooled by the "working for a dre...
Name: postTitle, dtype: object

# "Develop" a Model aka just use OpenAI's API

In [8]:
# Define candidate labels
candidate_labels = [
    "Work",
    "Family",
    "Other"
]

# Create and fit the classifier
clf = MultiLabelZeroShotGPTClassifier(max_labels=2) 
clf.fit(None, [candidate_labels])

In [9]:
# Predict the labels
labels = clf.predict(X)

100%|████████████████████████████████████████████████████████████| 9/9 [00:13<00:00,  1.46s/it]


In [10]:
# Add labels to the dataset and save
data['ChatGPTLabel'] = labels
data.to_csv('data/classified_tips.csv', index=False)

# Choose a measure of success, Choose an evaluation protocol / evaluate

In [11]:
# remove some schmutz from the labels, don't worry about what this does for now
data['HumanLabel'] = data['HumanLabel'].str.extract(r"\['(.*?)'\]")
data['ChatGPTLabel'] = data['ChatGPTLabel'].apply(lambda x: x[0] if x else None)

In [12]:
data[["postTitle","ChatGPTLabel","HumanLabel"]]

Unnamed: 0,postTitle,ChatGPTLabel,HumanLabel
0,"LPT: keep your mouth shut, and don't volunteer...",Work,Work
1,"LPT: Texans, you are about to experience the w...",Other,Other
2,"LPT: When you sign up for anything online, put...",Other,Other
3,"LPT: First rule of family gatherings, always b...",Family,Family
4,"YSK That there is a Youtuber called ""Dad, how ...",Family,Family
5,LPT: Try not to be mean or toxic in online gam...,Other,Family
6,LPT: Always tell a child who is wearing a helm...,Family,Family
7,"LPT: If you want a smarter kid, teach your chi...",Family,Family
8,"LPT: Don't be fooled by the ""working for a dre...",Work,Work


In [13]:
accuracy = accuracy_score(data[["HumanLabel"]], data[["ChatGPTLabel"]])
print(accuracy)

0.8888888888888888


# Skipped Steps
* Beat a baseline
* Overfit, regularize and tune
* Communicate with stakeholders
* Ship an inference model
* Monitor and maintain