In [105]:
import os

import dotenv
import openai
import pandas as pd
import numpy as np
from datasets import load_dataset
from utils.prompt import get_smcdel_prompt
from utils.preprocess import preprocess
from utils.persistent import save_fine_tuning_data
import warnings

warnings.simplefilter(action='ignore', category=DeprecationWarning)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [106]:
np.random.seed(1)

In [107]:
dotenv.load_dotenv(dotenv.find_dotenv())
client = openai.Client(api_key=os.environ["OPENAI_KEY"])

# Data

## 1. Load Data

In [108]:
dataset = load_dataset("sileod/mindgames", cache_dir='./data')
data = dataset['train']
data = pd.DataFrame(data)
examples = data.groupby('setup').apply(lambda x: x.sample(100)).reset_index(drop=True)
data = data.drop(examples.index)
data = data.dropna()
print(f"Number of data examples: {len(data)}")
print(f"Number of examples: {len(examples)}")
data.head(5)

Number of data examples: 10774
Number of examples: 400


Unnamed: 0,premise,smcdel_problem,n_announcements,pbcheck,hypothesis,setup,hypothesis_depth,n_agents,label,names,index,s-l,deberta_pred,deberta_confidence,difficulty
400,There are three persons. Everyone is visible t...,"VARS 1,2,3 LAW Top OBS Agenta:2,3 Agentb:1,3 A...",2,"VARS 0,1,2,3 LAW Top OBS Agenta:2,3 Agentb:1,3...",Nancy can now know whether or not nobody's for...,forehead,0,3,entailment,"[Jillian, Angel, Nancy]",30175,forehead-1,1,0.999686,0.000314
401,There are three persons. Everyone is visible t...,"VARS 1,2,3 LAW Top OBS Agenta:2,3 Agentb:1,3 A...",2,"VARS 0,1,2,3 LAW Top OBS Agenta:2,3 Agentb:1,3...",Carrie can now know that Carrie's forehead is ...,forehead,0,3,not_entailment,"[Richard, Carrie, Danielle]",4555,forehead-0,0,0.847983,0.152017
402,There are four persons. Everyone is visible to...,"VARS 1,2,3,4 LAW Top OBS Agenta:1 Agentb:2 Age...",2,"VARS 0,1,2,3,4 LAW Top OBS Agenta:1 Agentb:2 A...",Christina can now know whether Terry is thirsty.,internal,0,4,not_entailment,"[Christina, Daniel, Michael, Terry]",55696,internal-0,0,0.993097,0.006903
403,There are four persons. Everyone is visible to...,"VARS 1,2,3,4 LAW Top OBS Agenta:3 Agentb:4 Age...",0,"VARS 0,1,2,3,4 LAW Top OBS Agenta:3 Agentb:4 A...",Dan can now know that Samantha can know that J...,explicit,1,4,not_entailment,"[Dan, Janet, Samantha, John]",23424,explicit-0,0,0.998091,0.001909
404,There are four persons. Everyone is visible to...,"VARS 1,2,3,4 LAW Top OBS Agenta:1,2,3,4 Agentb...",1,"VARS 0,1,2,3,4 LAW Top OBS Agenta:1,2,3,4 Agen...",Harold can now know that Leslie can know wheth...,forehead_mirror,1,4,entailment,"[Tanya, Leslie, Simone, Harold]",7808,forehead_mirror-1,1,0.999613,0.000387


## 2. Data Preprocessing

In [109]:
data = preprocess(data)
examples = preprocess(examples)
data.head(1)

Unnamed: 0,setup,context,hypothesis,target_sf,target_label,n_announcements,n_agents,hypothesis_depth
400,forehead,There are three persons. Their names are Jilli...,Nancy can now know whether or not nobody's for...,"VARS 1,2,3 LAW Top OBS Agenta:2,3 Agentb:1,3 A...",1,2,3,0


## 3. Prepare Examples

In [110]:
def choose_example(setup: str) -> pd.Series:
    """
    Choose an example from the train set based on the setup
    :param setup: 
    :return: 
    """
    if setup == 'internal':
        return examples[examples['setup'] == 'internal'].sample(1).iloc[0]
    elif setup == 'forehead':
        return examples[examples['setup'] == 'forehead'].sample(1).iloc[0]
    elif setup == 'explicit':
        return examples[examples['setup'] == 'explicit'].sample(1).iloc[0]
    elif setup == 'forehead_mirror':
        return examples[examples['setup'] == 'forehead'].sample(1).iloc[0]
    else:
        raise ValueError(f"Invalid setup: {setup}")

## 3. Prepare Data

In [111]:
fine_tune_data = []
fine_tune_samples = data.sample(300)
for item in fine_tune_samples.iterrows():
    example = choose_example(item[1]['setup'])
    prompt = get_smcdel_prompt(
        example_context=example['context'],
        example_hypothesis=example['hypothesis'],
        example_sf=example['target_sf'],
        problem_context=item[1]['context'],
        problem_hypothesis=item[1]['hypothesis'],
    )
    fine_tune_item = {
        "messages": [
            {
                "role": "user",
                "content": prompt
            },
            {
                "role": "assistant",
                "content": item[1]['target_sf']
            }
        ]
    }
    fine_tune_data.append(fine_tune_item)

print(f"Number of fine-tune examples: {len(fine_tune_data)}")

Number of fine-tune examples: 300


In [112]:
save_fine_tuning_data(fine_tune_data)

Fine tuning data saved to /Users/weizhitang/Local/Research/ToM-LM/results/fine_tuning_data.jsonl


## 4. Upload Data

In [113]:
client.files.create(
    file=open("results/fine_tuning_data.jsonl", "rb"),
    purpose="fine-tune"
)

FileObject(id='file-VCmbbiIjof6FG3Z63RD6gc2P', bytes=501006, created_at=1712160482, filename='fine_tuning_data.jsonl', object='file', purpose='fine-tune', status='processed', status_details=None)

# Fine Tuning

In [116]:
client.fine_tuning.jobs.create(
    training_file="file-VCmbbiIjof6FG3Z63RD6gc2P",
    model="gpt-3.5-turbo"
)

FineTuningJob(id='ftjob-6FsPU08BIwJICK8ajmUPG6Z2', created_at=1712160584, error=Error(code=None, message=None, param=None, error=None), fine_tuned_model=None, finished_at=None, hyperparameters=Hyperparameters(n_epochs='auto', batch_size='auto', learning_rate_multiplier='auto'), model='gpt-3.5-turbo-0125', object='fine_tuning.job', organization_id='org-4TmdiuDx8xWi3uHVqMbkkztQ', result_files=[], status='validating_files', trained_tokens=None, training_file='file-VCmbbiIjof6FG3Z63RD6gc2P', validation_file=None, user_provided_suffix=None, seed=1139575617)