1. For each date, get the gics level (sector, industry group, or industry) with the highest returns over the next day.
2. Construct a label for each FOMC statement with the label being the gics level with the highest and lowest returns over the next day.

In [41]:
from collections import defaultdict
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tiktoken

In [67]:
sp = pd.read_pickle('data/sp500_constituents.pkl')
fomc = pd.read_pickle('data/fomc_statements.pkl')
fomc = fomc.drop(columns='statement')
fomc = fomc.rename(columns={'cleaned_statement':'statement'})
fomc = fomc[fomc['date']>='2000-01-03'].reset_index(drop=True)
fomc = fomc[fomc['date'] <= '2024-01-01'].reset_index(drop=True)

In [68]:
gics_level = 'sector'
#gics_level = 'group'

### Get Realized Returns of Next Day for Each Sector

In [69]:
# get return for next day of each stock 
sp['next_ret'] = sp.groupby('gvkey')['ret'].shift(-1)

# for each date, get returns of each gics group
gics_returns = pd.DataFrame(sp.groupby(['date', gics_level])['next_ret'].mean().reset_index())

gics_returns = gics_returns.dropna().reset_index(drop=True)
gics_returns

Unnamed: 0,date,sector,next_ret
0,2000-01-03,Communication Services,-0.037238
1,2000-01-03,Consumer Discretionary,-0.021774
2,2000-01-03,Consumer Staples,-0.019177
3,2000-01-03,Energy,-0.019377
4,2000-01-03,Financials,-0.034273
...,...,...,...
66391,2023-12-28,Industrials,-0.002743
66392,2023-12-28,Information Technology,-0.005385
66393,2023-12-28,Materials,-0.005865
66394,2023-12-28,Real Estate,-0.011880


In [70]:
gics_returns_pivot = gics_returns.pivot(index='date', columns='sector', values='next_ret')
gics_returns_pivot

sector,Communication Services,Consumer Discretionary,Consumer Staples,Energy,Financials,Health Care,Industrials,Information Technology,Materials,Real Estate,Utilities
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
2000-01-03,-0.037238,-0.021774,-0.019177,-0.019377,-0.034273,-0.025955,-0.023178,-0.058937,-0.018511,-0.032652,0.006866
2000-01-04,-0.002932,0.000018,0.004089,0.010810,-0.006724,0.019272,0.000941,-0.023925,0.032018,0.043004,0.033996
2000-01-05,-0.009655,0.002784,0.006458,0.039145,0.029920,0.026320,0.018839,-0.027577,0.023137,0.033019,0.004239
2000-01-06,0.006351,0.019603,0.027310,0.021895,0.017336,0.055065,0.022011,0.039293,0.003805,-0.010203,0.010110
2000-01-07,0.040937,0.018814,-0.010323,-0.011609,-0.019246,-0.004656,0.001914,0.042415,-0.006447,-0.010630,-0.003023
...,...,...,...,...,...,...,...,...,...,...,...
2023-12-21,-0.001633,-0.004706,0.006339,0.001779,0.003671,0.004200,0.004818,0.006176,0.006091,0.003073,0.004074
2023-12-22,0.005353,0.001771,0.004943,0.011957,0.005275,0.004883,0.005826,0.008374,0.006006,0.008687,0.006829
2023-12-26,0.000710,0.001273,0.002847,-0.005782,0.002801,0.002958,0.001465,-0.000784,0.002130,0.004458,-0.001413
2023-12-27,0.005436,0.002040,0.003188,-0.014100,0.003644,0.001494,0.000675,0.000226,-0.004506,0.006977,0.007503


### Reduce Portfolio to Only FOMC Days

In [71]:
# merge realized returns of each sector for each day following FOMC statements.
fomc_sector_returns = pd.merge(fomc, gics_returns_pivot, on='date').drop(columns=['statement'])
fomc_sector_returns = fomc_sector_returns.set_index('date')

In [72]:
fomc_sector_returns.median()

Communication Services   -0.000156
Consumer Discretionary   -0.000347
Consumer Staples         -0.000026
Energy                   -0.001692
Financials               -0.000760
Health Care              -0.000112
Industrials               0.000089
Information Technology    0.002362
Materials                 0.000412
Real Estate              -0.000523
Utilities                 0.002158
dtype: float64

In [73]:
fomc_sector_returns.mean()

Communication Services   -0.000671
Consumer Discretionary   -0.001596
Consumer Staples         -0.000235
Energy                   -0.001714
Financials               -0.003067
Health Care              -0.000318
Industrials              -0.001081
Information Technology   -0.000055
Materials                -0.001801
Real Estate              -0.002627
Utilities                -0.000064
dtype: float64

In [74]:
fomc_sector_returns.describe()

Unnamed: 0,Communication Services,Consumer Discretionary,Consumer Staples,Energy,Financials,Health Care,Industrials,Information Technology,Materials,Real Estate,Utilities
count,159.0,159.0,159.0,159.0,159.0,159.0,159.0,159.0,159.0,159.0,159.0
mean,-0.000671,-0.001596,-0.000235,-0.001714,-0.003067,-0.000318,-0.001081,-5.5e-05,-0.001801,-0.002627,-6.4e-05
std,0.017386,0.018154,0.010477,0.022413,0.02,0.013553,0.016387,0.024546,0.017635,0.019756,0.013842
min,-0.054035,-0.074885,-0.041591,-0.10396,-0.089648,-0.056904,-0.07007,-0.070644,-0.083624,-0.080684,-0.059797
25%,-0.009343,-0.010051,-0.006248,-0.013349,-0.012689,-0.008337,-0.010077,-0.014481,-0.011187,-0.011034,-0.007524
50%,-0.000156,-0.000347,-2.6e-05,-0.001692,-0.00076,-0.000112,8.9e-05,0.002362,0.000412,-0.000523,0.002158
75%,0.008787,0.010143,0.005901,0.010639,0.008936,0.007819,0.008286,0.010221,0.009774,0.007374,0.00749
max,0.068751,0.045034,0.029414,0.06856,0.042505,0.040327,0.047343,0.117393,0.045207,0.049228,0.043914


### Check Efficacy of Strategy: If the correct long and short were selected, how does the strategy perform?

In [75]:
fomc_sector_returns_melted = pd.melt(fomc_sector_returns.reset_index(), id_vars=['date'], var_name='sector', value_name='next_ret')
fomc_sector_returns_melted = fomc_sector_returns_melted.sort_values(by='date').reset_index(drop=True)
fomc_sector_returns_melted

Unnamed: 0,date,sector,next_ret
0,2000-02-02,Communication Services,0.005600
1,2000-02-02,Information Technology,0.039761
2,2000-02-02,Financials,0.000359
3,2000-02-02,Energy,-0.003988
4,2000-02-02,Materials,0.001109
...,...,...,...
1744,2023-12-13,Consumer Staples,-0.011217
1745,2023-12-13,Industrials,0.015539
1746,2023-12-13,Energy,0.029266
1747,2023-12-13,Materials,0.021260


In [76]:
# create decile portfolios
df_sorted = fomc_sector_returns_melted.groupby('date').apply(lambda x: x.sort_values('next_ret', ascending=False)).reset_index(drop=True)

# 9 is best, 0 is worst
df_sorted['decile'] = df_sorted.groupby('date')['next_ret'].transform(
    lambda x: pd.qcut(x, q=10, labels=False)
)

decile_returns = df_sorted.groupby(['date', 'decile'])['next_ret'].mean().reset_index()
decile_returns

  df_sorted = fomc_sector_returns_melted.groupby('date').apply(lambda x: x.sort_values('next_ret', ascending=False)).reset_index(drop=True)


Unnamed: 0,date,decile,next_ret
0,2000-02-02,0,-0.009902
1,2000-02-02,1,-0.000828
2,2000-02-02,2,0.000359
3,2000-02-02,3,0.001109
4,2000-02-02,4,0.002784
...,...,...,...
1585,2023-12-13,5,0.018370
1586,2023-12-13,6,0.021260
1587,2023-12-13,7,0.025120
1588,2023-12-13,8,0.029266


In [77]:
decile_returns_pivot = decile_returns.pivot(index='date', columns='decile', values='next_ret')
decile_returns_pivot

decile,0,1,2,3,4,5,6,7,8,9
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
2000-02-02,-0.009902,-0.000828,0.000359,0.001109,0.002784,0.005600,0.007064,0.017985,0.019954,0.039761
2000-03-21,-0.035456,-0.020109,-0.012721,-0.009040,-0.008748,0.001349,0.003103,0.009054,0.009669,0.024574
2000-05-16,-0.016457,-0.013647,-0.012513,-0.012053,-0.009743,-0.009383,-0.009028,-0.007225,-0.004783,0.007353
2000-06-28,-0.022569,-0.006976,-0.005632,-0.005238,-0.002841,-0.000347,0.001703,0.007033,0.010705,0.016429
2000-08-22,-0.018075,-0.011101,-0.008640,-0.008342,-0.001150,-0.000156,0.005894,0.005927,0.010353,0.030438
...,...,...,...,...,...,...,...,...,...,...
2023-06-14,0.005554,0.008406,0.009056,0.009407,0.010699,0.010954,0.013540,0.014816,0.015211,0.017536
2023-07-26,-0.020516,-0.013514,-0.010225,-0.007912,-0.007699,-0.006888,-0.006503,-0.003813,-0.003108,0.001459
2023-09-20,-0.027796,-0.018342,-0.017214,-0.017000,-0.016930,-0.015914,-0.014279,-0.011863,-0.011823,-0.006527
2023-11-01,0.017088,0.018714,0.019409,0.019517,0.024507,0.024855,0.025770,0.027718,0.028985,0.031288


In [78]:
# Ensure the DataFrame is sorted by date
decile_returns_pivot = decile_returns_pivot.sort_index()

# Calculate cumulative returns
cumulative_returns = (1 + decile_returns_pivot).cumprod() - 1

cumulative_returns

decile,0,1,2,3,4,5,6,7,8,9
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
2000-02-02,-0.009902,-0.000828,0.000359,0.001109,0.002784,0.005600,0.007064,0.017985,0.019954,0.039761
2000-03-21,-0.045007,-0.020921,-0.012367,-0.007941,-0.005988,0.006957,0.010189,0.027202,0.029816,0.065312
2000-05-16,-0.060723,-0.034283,-0.024725,-0.019898,-0.015673,-0.002491,0.001069,0.019781,0.024891,0.073146
2000-06-28,-0.081922,-0.041020,-0.030217,-0.025032,-0.018469,-0.002837,0.002774,0.026953,0.035862,0.090777
2000-08-22,-0.098516,-0.051665,-0.038596,-0.033165,-0.019598,-0.002993,0.008684,0.033040,0.046587,0.123978
...,...,...,...,...,...,...,...,...,...,...
2023-06-14,-0.891485,-0.684532,-0.526541,-0.377781,-0.199041,0.066152,0.442987,1.068769,2.244633,8.003069
2023-07-26,-0.893712,-0.688795,-0.531382,-0.382704,-0.205207,0.058808,0.433604,1.060880,2.234548,8.016207
2023-09-20,-0.896666,-0.694504,-0.539449,-0.393199,-0.218663,0.041958,0.413134,1.036431,2.196308,7.957363
2023-11-01,-0.894900,-0.688787,-0.530510,-0.381355,-0.199515,0.067857,0.449550,1.092877,2.288951,8.237616


In [79]:
decile_returns.groupby('decile')['next_ret'].describe()

Unnamed: 0_level_0,count,mean,std,min,25%,50%,75%,max
decile,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
0,159.0,-0.013974,0.018519,-0.093792,-0.022605,-0.009902,-0.002083,0.017088
1,159.0,-0.007141,0.015572,-0.079769,-0.01409,-0.004633,0.00334,0.030792
2,159.0,-0.004544,0.014976,-0.074885,-0.011204,-0.00272,0.005145,0.032283
3,159.0,-0.002813,0.014526,-0.07007,-0.00893,-0.001735,0.006825,0.032564
4,159.0,-0.001185,0.014541,-0.06757,-0.007346,-0.000426,0.007743,0.03383
5,159.0,0.000632,0.014468,-0.065422,-0.005447,0.001654,0.008427,0.040327
6,159.0,0.002574,0.014419,-0.056904,-0.004237,0.003357,0.01062,0.043914
7,159.0,0.004914,0.014335,-0.050761,-0.003297,0.00523,0.013372,0.046932
8,159.0,0.007816,0.015418,-0.042239,0.000287,0.007285,0.015565,0.066265
9,159.0,0.014467,0.019382,-0.041591,0.003901,0.012098,0.022311,0.117393


### Construct Labels for Train Data

In [80]:
# get the gics sector with the highest and lowest return over the next day

def get_high_low(df):
    df = df.dropna(subset=['next_ret'])
    
    # find the group with the highest and lowest return for each date
    highest_return_gics = df.loc[df['next_ret'].idxmax(), gics_level]
    highest_return = df['next_ret'].max()
    
    lowest_return_gics = df.loc[df['next_ret'].idxmin(), gics_level]
    lowest_return = df['next_ret'].min()


    return pd.Series({
        f'highest_return_{gics_level}': highest_return_gics,
        'highest_return': highest_return,
        f'lowest_return_{gics_level}': lowest_return_gics,
        'lowest_return': lowest_return,
    })

labels = gics_returns.groupby('date').apply(get_high_low, include_groups=False).reset_index()

In [81]:
labels

Unnamed: 0,date,highest_return_sector,highest_return,lowest_return_sector,lowest_return
0,2000-01-03,Utilities,0.006866,Information Technology,-0.058937
1,2000-01-04,Real Estate,0.043004,Information Technology,-0.023925
2,2000-01-05,Energy,0.039145,Information Technology,-0.027577
3,2000-01-06,Health Care,0.055065,Real Estate,-0.010203
4,2000-01-07,Information Technology,0.042415,Financials,-0.019246
...,...,...,...,...,...
6031,2023-12-21,Consumer Staples,0.006339,Consumer Discretionary,-0.004706
6032,2023-12-22,Energy,0.011957,Consumer Discretionary,0.001771
6033,2023-12-26,Real Estate,0.004458,Energy,-0.005782
6034,2023-12-27,Utilities,0.007503,Energy,-0.014100


In [82]:
# format the correct label in terms of how the API wants it
labels['strategy'] = labels.apply(
    lambda row: f"long: {row[f'highest_return_{gics_level}']}, short: {row[f'lowest_return_{gics_level}']}", axis=1
)

labels


Unnamed: 0,date,highest_return_sector,highest_return,lowest_return_sector,lowest_return,strategy
0,2000-01-03,Utilities,0.006866,Information Technology,-0.058937,"long: Utilities, short: Information Technology"
1,2000-01-04,Real Estate,0.043004,Information Technology,-0.023925,"long: Real Estate, short: Information Technology"
2,2000-01-05,Energy,0.039145,Information Technology,-0.027577,"long: Energy, short: Information Technology"
3,2000-01-06,Health Care,0.055065,Real Estate,-0.010203,"long: Health Care, short: Real Estate"
4,2000-01-07,Information Technology,0.042415,Financials,-0.019246,"long: Information Technology, short: Financials"
...,...,...,...,...,...,...
6031,2023-12-21,Consumer Staples,0.006339,Consumer Discretionary,-0.004706,"long: Consumer Staples, short: Consumer Discre..."
6032,2023-12-22,Energy,0.011957,Consumer Discretionary,0.001771,"long: Energy, short: Consumer Discretionary"
6033,2023-12-26,Real Estate,0.004458,Energy,-0.005782,"long: Real Estate, short: Energy"
6034,2023-12-27,Utilities,0.007503,Energy,-0.014100,"long: Utilities, short: Energy"


In [83]:
display(labels['highest_return_sector'].value_counts())
display(labels['lowest_return_sector'].value_counts())

highest_return_sector
Energy                    1368
Information Technology     777
Utilities                  768
Real Estate                718
Consumer Staples           440
Financials                 402
Health Care                396
Communication Services     378
Materials                  340
Consumer Discretionary     324
Industrials                125
Name: count, dtype: int64

lowest_return_sector
Energy                    1331
Information Technology     782
Utilities                  778
Real Estate                741
Consumer Staples           426
Financials                 417
Communication Services     399
Health Care                359
Consumer Discretionary     336
Materials                  324
Industrials                143
Name: count, dtype: int64

In [84]:
# create the labeled dataset (input: FOMC and date of FOMC, output: desired output from GPT)
labeled_fomcs = pd.merge(fomc, labels[['date', 'strategy']], on='date', how='left')

In [85]:
labeled_fomcs.isna().sum()

date         0
statement    0
strategy     0
dtype: int64

In [86]:
labeled_fomcs

Unnamed: 0,date,statement,strategy
0,2000-02-02,immediate release federal open market committe...,"long: Information Technology, short: Real Estate"
1,2000-03-21,immediate release federal open market committe...,"long: Information Technology, short: Real Estate"
2,2000-05-16,immediate release federal open market committe...,"long: Energy, short: Utilities"
3,2000-06-28,immediate release federal open market committe...,"long: Health Care, short: Information Technology"
4,2000-08-22,immediate release federal open market committe...,"long: Energy, short: Real Estate"
...,...,...,...
154,2023-06-14,recent indicators suggest economic activity co...,"long: Health Care, short: Real Estate"
155,2023-07-26,recent indicators suggest economic activity ex...,"long: Information Technology, short: Real Estate"
156,2023-09-20,recent indicators suggest economic activity ex...,"long: Communication Services, short: Real Estate"
157,2023-11-01,recent indicators suggest economic activity ex...,"long: Real Estate, short: Health Care"


In [87]:
# properly format above data for training gpt model

# create list to hold each formatted conversation
chat_data = []

for index, row in labeled_fomcs.iterrows():
    date = row['date']
    statement = row['statement']
    if pd.notna(row['strategy']):
        chat_data.append({
            "messages": [{
            # system message to describe the chatbot
            "role": "system",
            "content": f"""As of {date.strftime('%Y-%m-%d')}, you are a financial analyst specializing in 
            interpreting FOMC statements to predict GICS sector returns in the stock market."""
        },
        {   
            # system to describe what we are asking the
            "role": "user",
            "content": f"""Based on the FOMC statement released on {date.strftime('%Y-%m-%d')}, please identify:

        - The sector that will have the highest returns over the next day.
        - The sector that will have the lowest returns over the next day.

        Provide your answer in the following format:

        'long: sector, short: sector'

        Recall the list of sector to choose from are:
        'Energy', 'Materials', 'Industrials', 'Consumer Discretionary', 'Consumer Staples', 'Health Care',
        'Financials', 'Information Technology', 'Communication Services', 'Utilities', 'Real Estate' 

        Here is the FOMC Statement:
        \"\"\"
        {statement}
        \"\"\"
        """
        },
        {
            "role": "assistant", 
            "content": row['strategy']
        }
        ]})

#  path to save the JSONL file
output_file = "data/fine_tuning_chat_data.jsonl"

# write data to a JSONL file
with open(output_file, 'w') as f:
    for entry in chat_data:
        f.write(json.dumps(entry) + "\n")

print(f"Data successfully saved to {output_file}")


Data successfully saved to data/fine_tuning_chat_data.jsonl


In [88]:
labeled_fomcs.dtypes

date         datetime64[ns]
statement            object
strategy             object
dtype: object

In [89]:
# split dataset into train, val, and test sets
train_df = labeled_fomcs[labeled_fomcs['date'].between('2000-01-01', '2015-12-31')]
val_df = labeled_fomcs[labeled_fomcs['date'].between('2016-01-01', '2018-12-31')]
test_df = labeled_fomcs[labeled_fomcs['date'] >= '2019-01-01']

# format the data into the required structure
def format_chat_data(df):
    # create list to hold each formatted conversation
    chat_data = []

    for index, row in labeled_fomcs.iterrows():
        date = row['date']
        statement = row['statement']
        if pd.notna(row['strategy']):
            chat_data.append({
                "messages": [{
                # system message to describe the chatbot
                "role": "system",
                "content": f"""As of {date.strftime('%Y-%m-%d')}, you are a financial analyst specializing in 
                interpreting FOMC statements to predict GICS sector returns in the stock market."""
            },
            {   
                # system to describe what we are asking the
                "role": "user",
                "content": f"""Based on the FOMC statement released on {date.strftime('%Y-%m-%d')}, please identify:

            - The sector that will have the highest returns over the next day.
            - The sector that will have the lowest returns over the next day.

            Provide your answer in the following format:

            'long: sector, short: sector'


            Recall the list of sector to choose from are:
            'Energy', 'Materials', 'Industrials', 'Consumer Discretionary', 'Consumer Staples', 'Health Care',
            'Financials', 'Information Technology', 'Communication Services', 'Utilities', 'Real Estate' 

            Here is the FOMC Statement:
            \"\"\"
            {statement}
            \"\"\"
            """
            },
            {
                "role": "assistant", 
                "content": row['strategy']
            }
            ]})


    return chat_data


In [90]:
# format each subset of the data
train_data = format_chat_data(train_df)
val_data = format_chat_data(val_df)
test_data = format_chat_data(test_df)

# save data to JSONL
def save_to_jsonl(data, filename):
    with open(filename, 'w') as f:
        for entry in data:
            f.write(json.dumps(entry) + "\n")
    print(f"Data successfully saved to {filename}")

save_to_jsonl(train_data, "data/train_data.jsonl")
save_to_jsonl(val_data, "data/validation_data.jsonl")
save_to_jsonl(test_data, "data/test_data.jsonl")

Data successfully saved to data/train_data.jsonl
Data successfully saved to data/validation_data.jsonl
Data successfully saved to data/test_data.jsonl


### Ensure Data in Proper Format  
[Data Formatting](https://cookbook.openai.com/examples/chat_finetuning_data_prep)

In [91]:
# Format error checks
format_errors = defaultdict(int)

for ex in train_data:
    if not isinstance(ex, dict):
        format_errors["data_type"] += 1
        continue
        
    messages = ex.get("messages", None)
    if not messages:
        format_errors["missing_messages_list"] += 1
        continue
        
    for message in messages:
        if "role" not in message or "content" not in message:
            format_errors["message_missing_key"] += 1
        
        if any(k not in ("role", "content", "name", "function_call", "weight") for k in message):
            format_errors["message_unrecognized_key"] += 1
        
        if message.get("role", None) not in ("system", "user", "assistant", "function"):
            format_errors["unrecognized_role"] += 1
            
        content = message.get("content", None)
        function_call = message.get("function_call", None)
        
        if (not content and not function_call) or not isinstance(content, str):
            format_errors["missing_content"] += 1
    
    if not any(message.get("role", None) == "assistant" for message in messages):
        format_errors["example_missing_assistant_message"] += 1

if format_errors:
    print("Found errors:")
    for k, v in format_errors.items():
        print(f"{k}: {v}")
else:
    print("No errors found")

No errors found


In [92]:
encoding = tiktoken.get_encoding("cl100k_base")

# not exact!
# simplified from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):
    num_tokens = 0
    for message in messages:
        num_tokens += tokens_per_message
        for key, value in message.items():
            num_tokens += len(encoding.encode(value))
            if key == "name":
                num_tokens += tokens_per_name
    num_tokens += 3
    return num_tokens

def num_assistant_tokens_from_messages(messages):
    num_tokens = 0
    for message in messages:
        if message["role"] == "assistant":
            num_tokens += len(encoding.encode(message["content"]))
    return num_tokens

def print_distribution(values, name):
    print(f"\n#### Distribution of {name}:")
    print(f"min / max: {min(values)}, {max(values)}")
    print(f"mean / median: {np.mean(values)}, {np.median(values)}")
    print(f"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}")

In [93]:
# Warnings and tokens counts
n_missing_system = 0
n_missing_user = 0
n_messages = []
convo_lens = []
assistant_message_lens = []

for ex in train_data:
    messages = ex["messages"]
    if not any(message["role"] == "system" for message in messages):
        n_missing_system += 1
    if not any(message["role"] == "user" for message in messages):
        n_missing_user += 1
    n_messages.append(len(messages))
    convo_lens.append(num_tokens_from_messages(messages))
    assistant_message_lens.append(num_assistant_tokens_from_messages(messages))
    
print("Num examples missing system message:", n_missing_system)
print("Num examples missing user message:", n_missing_user)
print_distribution(n_messages, "num_messages_per_example")
print_distribution(convo_lens, "num_total_tokens_per_example")
print_distribution(assistant_message_lens, "num_assistant_tokens_per_example")
n_too_long = sum(l > 16385 for l in convo_lens)
print(f"\n{n_too_long} examples may be over the 16,385 token limit, they will be truncated during fine-tuning")

Num examples missing system message: 0
Num examples missing user message: 0

#### Distribution of num_messages_per_example:
min / max: 3, 3
mean / median: 3.0, 3.0
p5 / p95: 3.0, 3.0

#### Distribution of num_total_tokens_per_example:
min / max: 247, 787
mean / median: 450.0314465408805, 434.0
p5 / p95: 340.8, 571.2

#### Distribution of num_assistant_tokens_per_example:
min / max: 7, 11
mean / median: 8.377358490566039, 8.0
p5 / p95: 8.0, 9.0

0 examples may be over the 16,385 token limit, they will be truncated during fine-tuning


In [94]:
# Pricing and default n_epochs estimate
MAX_TOKENS_PER_EXAMPLE = 16385

TARGET_EPOCHS = 3
MIN_TARGET_EXAMPLES = 100
MAX_TARGET_EXAMPLES = 25000
MIN_DEFAULT_EPOCHS = 1
MAX_DEFAULT_EPOCHS = 25

n_epochs = TARGET_EPOCHS
n_train_examples = len(train_data)
if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:
    n_epochs = min(MAX_DEFAULT_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)
elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:
    n_epochs = max(MIN_DEFAULT_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)

n_billing_tokens_in_dataset = sum(min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens)
print(f"Dataset has ~{n_billing_tokens_in_dataset} tokens that will be charged for during training")
print(f"By default, you'll train for {n_epochs} epochs on this dataset")
print(f"By default, you'll be charged for ~{n_epochs * n_billing_tokens_in_dataset} tokens")

Dataset has ~71555 tokens that will be charged for during training
By default, you'll train for 3 epochs on this dataset
By default, you'll be charged for ~214665 tokens


In [95]:
test_df.to_pickle('data/test_df.pkl')

In [96]:
train_df.to_pickle('data/train_df.pkl')