In [1]:
import os
import json
import pandas as pd

from matplotlib import pyplot as plt

In [9]:
def analyze_tokens(df, group_columns):
    """
    Analyze the #tokens for each combination of specified group columns,
    including the count of samples with #tokens < 15000.

    Args:
        df (pd.DataFrame): The input DataFrame containing the data.
        group_columns (list): List of columns to group by (e.g., ['dataset', 'graph_format']).

    Returns:
        pd.DataFrame: A DataFrame containing aggregated statistics for #tokens.
    """
    # Group by the specified columns and calculate stats for #tokens
    result = df.groupby(group_columns)['#tokens'].agg(
        count='count',
        min_tokens='min',
        max_tokens='max',
        mean_tokens='mean',
        median_tokens=lambda x: x.median(),
        std_dev_tokens='std',
        less_than_15000=lambda x: (x < 15000).sum()  # Count of #tokens < 15000
    ).reset_index()

    # Sort the result for better readability
    result = result.sort_values(by=group_columns)
    
    return result


In [14]:
DATASETS = ['bace',
 'chebi20',
 'connectivity',
 'cycle_checking',
 'degree_counting',
 'edge_attribute_retrieval',
 'edge_counting',
 'edge_existence',
 'esol',
 'explagraphs',
 'fb15k237',
 'fingerprint',
 'graph_automorphic',
 'graph_structure_detection',
 'hamilton_path',
 'movielens1m',
 'node_attribute_retrieval',
 'node_counting',
 'oag_scholar_interest',
 'ogbl_vessel',
 'ogbn_arxiv',
 're_europe',
 'shortest_path',
 'stack_elec',
 'twitch',
 'wikics',
 'yelp_review']

In [10]:
instruction_data = pd.read_json('../experiments/training_v1/instruction_dataset.json', orient='records')

In [11]:
stats_by_dataset_format = analyze_tokens(instruction_data, ['dataset','graph_format'])

In [21]:
stats_by_dataset_format[stats_by_dataset_format['graph_format'] == 'graphml'].sort_values(by='mean_tokens', ascending=False)

Unnamed: 0,dataset,graph_format,count,min_tokens,max_tokens,mean_tokens,median_tokens,std_dev_tokens,less_than_15000
105,yelp_review,graphml,1995,3072,68288,25500.580952,24729.0,10828.278299,258
1,bace,graphml,2000,3601,16206,10802.376,10615.0,2252.530241,1930
57,hamilton_path,graphml,500,1188,31206,10395.84,8369.0,7296.788991,365
5,chebi20,graphml,2000,1492,128317,10092.0045,8090.5,8843.180461,1678
61,movielens1m,graphml,2000,5651,12725,9771.949,9724.5,1188.824723,2000
41,fb15k237,graphml,2000,1638,23209,8181.199,7906.0,2491.822983,1985
89,shortest_path,graphml,500,1139,22407,8118.804,7458.0,4683.568645,454
93,stack_elec,graphml,2000,718,22830,7237.868,7370.5,3147.472305,1967
17,degree_counting,graphml,500,1170,15618,7079.024,6665.0,3249.62669,496
69,node_counting,graphml,500,1308,15298,6920.332,6727.0,3265.810398,498


In [20]:
# stats_by_dataset_format.dataset.unique().tolist()
for dataset in DATASETS:
    dataset = 'node_counting'
    instruction_data = pd.read_json(f'../experiments/langgfm_i/{dataset}/train/instruction_dataset.json', orient='records')
    stats_by_dataset_format = analyze_tokens(instruction_data, ['dataset','graph_format'])
    # stats_by_dataset_format[stats_by_dataset_format['graph_format'] == 'graphml'].sort_values(by='mean_tokens', ascending=False)
    print(stats_by_dataset_format)
    break

         dataset graph_format  count  min_tokens  max_tokens  mean_tokens  \
0  node_counting        table    800         526        3741    1833.5475   

   median_tokens  std_dev_tokens  less_than_15000  
0         1762.0      721.068134              800  


In [18]:
stats_by_dataset_format

Unnamed: 0,dataset,graph_format,count,min_tokens,max_tokens,mean_tokens,median_tokens,std_dev_tokens,less_than_15000
0,node_counting,table,800,526,3741,1833.5475,1762.0,721.068134,800
