# Setup

In [1]:
import os
import pandas as pd

from pprint import pprint

from sklearn.metrics import accuracy_score

from skllm import MultiLabelZeroShotGPTClassifier
from skllm.config import SKLLMConfig

In [2]:
#See notes in INSTALL.md for how to set this, DO NOT HARD CODE YOUR API KEY HERE, 
# if your repository is public, then someone will steal your API key and make you pay for their shit 
SKLLMConfig.set_openai_key(os.environ.get('OPENAI_API_KEY'))

# Collect a Dataset

In [16]:
popular_beers = [
    "Budweiser",
    "Bud Light"
]

human_labels = [
    "repub",
    "woke"
]



In [17]:
# Creating a DataFrame
data = pd.DataFrame({
    'Popular Beers': popular_beers,
    'Human Labels': human_labels
})

In [18]:
data

Unnamed: 0,Popular Beers,Human Labels
0,Budweiser,repub
1,Bud Light,woke


In [19]:
X = data["Popular Beers"]  # Get the text data from the DataFrame

In [20]:
X

0    Budweiser
1    Bud Light
Name: Popular Beers, dtype: object

# "Develop" a Model aka just use OpenAI's API

In [21]:
# Define candidate labels
candidate_labels = [
    "woke",
    "repub"
]

# Create and fit the classifier
clf = MultiLabelZeroShotGPTClassifier(max_labels=2) 
clf.fit(None, [candidate_labels])

In [8]:
# Predict the labels
labels = clf.predict(X)

100%|██████████████████████████████████████████████| 11/11 [00:47<00:00,  4.36s/it]


In [9]:
print(labels)

[['Hipster'], ['Hipster'], ['Hipster'], ['Woke'], ['Hipster'], ['Hipster'], ['Hipster'], ['Hipster'], ['Hipster'], ['Hipster'], ['Hipster']]


In [11]:
# Add labels to the dataset and save
data['ChatGPTLabel'] = labels
data.to_csv('../data/classified_beers.csv', index=False)

In [12]:
data

Unnamed: 0,Beer,ChatGPTLabel
0,Budweiser,[Hipster]
1,Bud Light,[Hipster]
2,Heineken,[Hipster]
3,Corona,[Woke]
4,Guinness,[Hipster]
5,Stella Artois,[Hipster]
6,Samuel Adams,[Hipster]
7,Sierra Nevada Pale Ale,[Hipster]
8,Coors Light,[Hipster]
9,Miller Lite,[Hipster]


# Skipped Steps
* Beat a baseline
* Overfit, regularize and tune
* Communicate with stakeholders
* Ship an inference model
* Monitor and maintain