In [1]:
import os
import json

import numpy as np
import pandas as pd

import matplotlib as mp
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors

import utils as ut

In [2]:
os.chdir(os.path.expanduser('~/github/dpi/Data-Provenance-Collection'))

# Prepare data

## List of datasets

In [3]:
dat = pd.read_csv('notebooks/papers.csv') \
    .rename({'collection': 'Collection'}, axis=1) \
    .set_index('Collection')
dat = dat.loc[dat['modality'] == 'finetune'].drop(['modality', 'cites'], axis=1)

In [4]:
files = [y for x in dat['summary_keys'].str.split('|').tolist() for y in x]

assert set(files) <= set([
    os.path.splitext(f)[0]
    for f in os.listdir('data_summaries')
    if not f.startswith('_template')
])

In [5]:
tmp = []
short_names = pd.DataFrame(dat['summary_keys'].str.split('|')).to_records().tolist()
for short, keys in short_names:
    for key in keys:
        tmp += [(short, key)]
short_names = pd.DataFrame(tmp, columns=['short_name', 'summary_key']).set_index('summary_key')

## Dimension data

In [6]:
with open('constants/domain_groups.json', 'rt') as f:
    domain_groups = json.load(f)
    domain_groups = {
        v: k
        for k, vs in domain_groups.items()
        for v in vs
    }

In [7]:
with open('constants/task_groups.json', 'rt') as f:
    task_groups = json.load(f)
    task_groups = {
        v: k
        for k, vs in task_groups.items()
        for v in vs
    }

In [8]:
with open('constants/license_classes.json', 'rt') as f:
    license_classes = json.load(f)
    license_classes = {k : v[0] for k, v in license_classes.items()}
    license_classes['Custom'] = 'Unspecified'

## Load summary files

In [9]:
summaries = {}
for file in os.listdir('data_summaries'):
    if file.startswith('_template') or file in ('audio', 'video'):
        continue
    
    with open(os.path.join('data_summaries', file), 'rt') as f:
        summaries[file.split('.')[0]] = json.load(f)

# Make tables

## Licenses

In [10]:
licenses = {}
for k in summaries.keys():
    for ds in summaries[k].keys():
        for lic in summaries[k][ds]['Licenses']:
            licenses[k] = licenses.get(k, []) + [lic['License']]
licenses = pd.Series({k : list(set(v)) for k, v in licenses.items()})

tmp = short_names.copy()
tmp['licenses'] = licenses
dat['License'] = tmp.groupby('short_name')['licenses'].apply(lambda s: list(set([y for x in s for y in x])))

dat['OAI'] = dat['License'].apply(lambda s: 'OpenAI' in s or 'OANC' in s)
dat['OAI'] = dat['OAI'].replace({True: r'\greencheck', False: '\emojiblank'})
dat['License'] = dat['License'].apply(lambda s: [v for v in s if v != 'OpenAI' and v != 'OANC'])

In [11]:
dat['Use'] = dat['License'].apply(lambda s: list(set([license_classes[v] for v in s])))
dat['Use'] = dat['License'].apply(lambda s: list(set([license_classes[v] for v in s])))
dat['Use'] = dat['Use'].apply(ut.color_license_classes)
dat.drop(['License'], axis=1, inplace=True)

## Property counts and text lens

In [12]:
raw = []
for collection in dat.index:
    for file in dat.loc[collection, 'summary_keys'].split('|'):
        for k in summaries[file].keys():
            if 'Languages' in summaries[file][k].keys():
                langs = summaries[file][k]['Languages']
            else:
                langs = None
            
            metrics = summaries[file][k].get('Text Metrics', None)
            if metrics is None or metrics == '' or metrics == {}:
                num_dialogs = np.nan
                mean_inputs_length = np.nan
                mean_targets_length = np.nan
            else:
                num_dialogs = metrics['Num Dialogs']
                mean_inputs_length = metrics['Mean Inputs Length']
                mean_targets_length = metrics['Mean Targets Length']
                
            if 'Text Sources' not in summaries[file][k].keys():
                domains = None
            elif not isinstance(summaries[file][k]['Text Sources'], (list, tuple)):
                domains = None
            else:
                domains = summaries[file][k]['Text Sources']
                domains = [domain_groups[d] for d in domains]

            if 'Task Categories' not in summaries[file][k].keys():
                tasks = None
            elif not isinstance(summaries[file][k]['Task Categories'], (list, tuple)):
                tasks = None
            else:
                tasks = summaries[file][k]['Task Categories']
                tasks = [task_groups[d] for d in tasks]

            inf_metadata = summaries[file][k].get('Inferred Metadata', None)
            if inf_metadata is None or inf_metadata == '' or inf_metadata == {}:
                topics = None
            else:
                if 'Text Topics' not in inf_metadata.keys():
                    topics = None
                elif not isinstance(inf_metadata['Text Topics'], (list, tuple)):
                    topics = None
                else:
                    topics = inf_metadata['Text Topics']

            raw += [{
                'collection': collection,
                'summary_key': file,
                'sub': k,

                'num_dialogs': num_dialogs,
                'mean_inputs_length': mean_inputs_length,
                'mean_targets_length': mean_targets_length,

                'langs': langs,
                'topics': topics,
                'domains': domains,
                'tasks': tasks,
                'datasets': 1,
            }]
raw = pd.DataFrame(raw)

total_input_length = raw['num_dialogs'] * raw['mean_inputs_length']
total_targets_length = raw['num_dialogs'] * raw['mean_targets_length']

num_dialogs = raw.groupby('collection')['num_dialogs'].sum()
mean_inputs_length = total_input_length.groupby(raw['collection']).sum() / num_dialogs
mean_targets_length = total_targets_length.groupby(raw['collection']).sum() / num_dialogs

num_langs = raw.groupby('collection')['langs'].apply(ut.count_unique_with_none)
num_topics = raw.groupby('collection')['topics'].apply(ut.count_unique_with_none)
num_domains = raw.groupby('collection')['domains'].apply(ut.count_unique_with_none)
num_tasks = raw.groupby('collection')['tasks'].apply(ut.count_unique_with_none)
num_datasets = raw.groupby('collection')['datasets'].sum()

dat['Num Langs'] = num_langs.fillna(0).astype(int)
dat['Num Dialogs'] = num_dialogs.fillna(0).astype(int)
dat['Mean Inputs Length'] = mean_inputs_length.fillna(0).astype(int)
dat['Mean Targets Length'] = mean_targets_length.fillna(0).astype(int)
dat['Num Topics'] = num_topics.fillna(0).astype(int)
dat['Num Datasets'] = num_datasets.fillna(0).astype(int)
dat['Num Tasks'] = num_tasks.fillna(0).astype(int)
dat['Num Domains'] = num_domains.fillna(0).astype(int).apply(lambda s: max(s, 1))

## Source

In [13]:
mgen = {}
for k in summaries.keys():
    for ds in summaries[k].keys():
        models = summaries[k][ds]['Model Generated']
        models = [m for m in models if m != '']
        mgen[k] = mgen.get(k, []) + [len(models) > 0]
mgen = pd.Series({k : list(set(v)) for k, v in mgen.items()})
tmp = short_names.copy()
tmp['mgen'] = mgen

sources = tmp.groupby('short_name')['mgen'] \
    .agg(lambda x: [item for sublist in x for item in sublist]) \
    .apply(lambda s: 1 - sum(s) / len(s))

dat.loc[(sources > 0) & (sources < 1), 'Source'] = r'\emoji{globe-with-meridians}\emoji{robot}'
dat.loc[sources == 1, 'Source'] = r'\emoji{globe-with-meridians}\emojiblank'
dat.loc[sources == 0, 'Source'] = r'\emojiblank\emoji{robot}'

## Format

In [14]:
found_formats = [
    summaries[file][k]['Format']
    for file in summaries.keys()
    for k in summaries[file].keys()
]
found_formats = set([y for x in found_formats for y in x])
assert found_formats <= set(ut.FORMATS_MAP.keys())

fmts = []
for collection in dat.index:
    for file in dat.loc[collection, 'summary_keys'].split('|'):
        for k in summaries[file].keys():
            tmp_fmts = summaries[file][k].get('Format', [])
            tmp_fmts = [ut.FORMATS_MAP[f] for f in tmp_fmts]
            
            fmts += [{
                'collection': collection,
                'summary_key': file,
                'sub': k,
                
                'formats': tmp_fmts,
            }]
fmts = pd.DataFrame(fmts)
fmts = fmts.groupby('collection')['formats'] \
           .apply(lambda s: list(set([y for x in s for y in x]))) \
           .rename('formats')

for fmt in ut.FORMATS_MAP.values():
    dat[fmt] = fmts.apply(lambda s: fmt in s) \
        .replace(True, r'\greencheck') \
        .replace(False, r'\emojiblank')

## Format for LaTeX output

In [15]:
dat = dat[[
    'Num Datasets',
    # 'Num Dialogs',
    'Num Tasks',
    'Num Langs',
    # 'Num Topics',
    'Num Domains',
    
    # 'Mean Inputs Length',
    # 'Mean Targets Length',
    
    'Source',
    
    'ZS',
    'FS',
    'CT',
    'RR',
    'MD',
    
    'Use',
    'OAI',
]]

In [16]:
column_mapping = {
    'Num Datasets': ('Property Counts', 'Datasets'),
    'Num Dialogs': ('Property Counts', 'Dialogs'),
    'Num Tasks': ('Property Counts', 'Tasks'),
    'Num Langs': ('Property Counts', 'Langs'),
    'Num Topics': ('Property Counts', 'Topics'),
    'Num Cites': ('Property Counts', 'Cites'),
    'Num Domains': ('Property Counts', 'Domains'),
    'Mean Inputs Length': ('Text Lens', 'Inpt'),
    'Mean Targets Length': ('Text Lens', 'Tgt'),
    'Source': ('Dataset Types', 'Source'),

    'CT': ('Dataset Types', 'C'),
    'ZS': ('Dataset Types', 'Z'),
    'RR': ('Dataset Types', 'R'),
    'MD': ('Dataset Types', 'M'),
    'FS': ('Dataset Types', 'F'),
    
    'Use': ('Dataset Types', 'Use'),
    'OAI': ('Dataset Types', 'O'),
}

dat.columns = pd.MultiIndex.from_arrays([
    [column_mapping[col][0] for col in dat.columns],
    [column_mapping[col][1] for col in dat.columns]
])

dat.columns = pd.MultiIndex.from_tuples([
    (r'\textsc{' + c[0] + r'}', r'\textsc{\thead{' + c[1] + r'}}')
    for c in dat.columns
])

dat.index.name = r'\textsc{' + dat.index.name + r'}'

def color_map(value, cmap='BrBG', vmin=None, vmax=None):
    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
    colormap = mp.colormaps[cmap]
    color = [int(255*x) for x in colormap(norm(value))[:3]]
    return ','.join(map(str, color))  # Convert the color to a CSV string

color_def = ''
formatters = {}
tmp_val_color = {}
num_cols = [
    c
    for c in dat.columns
    if (
        'Property Counts' in c[0] or
        'Text Lens' in c[0]
    )
]

for col in num_cols:
    tmp_val_color[col] = {}
    
    vmin = np.log(dat[col].min() + 1e-6)
    vmax = np.log(dat[col].max() + 1e-6)
    midpt = (vmax + vmin) / 2
    vmin, vmax = vmin - midpt, vmax - midpt
    
    for row in dat.index:
        value = np.log(dat.loc[row, col] + 1e-6)
        value -= midpt
        
        if pd.notnull(value):
            row_color_name = row.replace(' ', '') \
                .replace(r'\textsc{', '').replace('}', '')
            col_color_name = '_'.join(col).replace(' ', '') \
                .replace(r'\textsc{\thead{', '').replace('}}', '') \
                .replace(r'\textsc{', '').replace('}', '')
            
            color_name = f"color{row_color_name}{col_color_name}"
            
            color_def += f"\\definecolor{{{color_name}}}{{RGB}}{{{color_map(value / 4, vmin=vmin, vmax=vmax)}}}\n"
            
            tmp_val_color[col][dat.loc[row, col]] = color_name

    if 'Dialogs' in col[1]:  # or 'Downs' in col[1]:
        def func(v, col=col):
            color_name = tmp_val_color[col][v]
    
            if col in num_cols and v >= 1000:
                v /= 1000
                return f'\\cellcolor{{{color_name}}}{{{v:,.0f}k}}' if pd.notnull(v) else '-'
            elif col in num_cols and v == 0:
                return '-'
            else:
                return f'\\cellcolor{{{color_name}}}{{<1k}}' if pd.notnull(v) else '-'
    else:
        def func(v, col=col):
            color_name = tmp_val_color[col][v]
    
            if col in num_cols and v >= 1000:
                v /= 1000
                return f'\\cellcolor{{{color_name}}}{{{v:,.0f}k}}' if pd.notnull(v) else '-'
            elif col in num_cols and v == 0:
                return '-'
            else:
                return f'\\cellcolor{{{color_name}}}{{{v:,.0f}}}' if pd.notnull(v) else '-'
    
    formatters[col] = func

In [17]:
kwargs = {
    'environment': 'longtable',
    
    'label': 'tab:collections-text',
    'column_format': 'l|cccc|c|p{0.35cm}p{0.35cm}p{0.35cm}p{0.35cm}p{0.35cm}cp{0.35cm}',
    'multicol_align': 'c',
    
    'caption': (r'''
    \textbf{Alignment tuning (text) collections and properties}. Collection properties include numbers of datasets, tasks, languages, and text domains. The \textsc{Source} column indicates whether a collection contains human-generated web text (\emoji{globe-with-meridians}), language model outputs (\emoji{robot}) or both (\emoji{globe-with-meridians}\emoji{robot}). Several columns indicate the type of dialogs, with some collections having more than one: zero-shot (Z), few-shot (F), response ranking (R), chain-of-thought (C), program-of-thought (PT), single-turn dialog (SD), multi-turn dialog (M), and evaluation (EV). Finally, the \textsc{Use} column indicates whether a collection includes data freely usable even for commercial purposes (\protect\CommercialDataCircle), data usable only for noncommercial purposes or academic research (\protect\NCDataCircle) and data whose license status is not specified precisely enough to allow us to determine commercial use permissions (\protect\UnspecifiedDataCircle). Note that each collection may have different datasets with one, two, or all three of these statuses. The O column indicates collections which include OpenAI model generations.
    '''.strip(), r'\textbf{Alignment tuning (text) collections and properties}'),
    
    'hrules': True,
    'convert_css': True,
}

latex = dat \
    .sort_index() \
    .reset_index() \
    .style \
    .hide() \
    .format(formatter=formatters) \
    .to_latex(**kwargs)

print('\n'.join([
    r'\setlength{\tabcolsep}{1.9pt}',
    color_def,
    latex,
]))

\setlength{\tabcolsep}{1.9pt}
\definecolor{color10kPromptRankedPropertyCounts_Datasets}{RGB}{240,223,178}
\definecolor{colorAgentInstructPropertyCounts_Datasets}{RGB}{245,241,232}
\definecolor{colorAyaPropertyCounts_Datasets}{RGB}{187,229,223}
\definecolor{colorBactrian-XPropertyCounts_Datasets}{RGB}{245,241,232}
\definecolor{colorCOBRAFramesPropertyCounts_Datasets}{RGB}{240,223,178}
\definecolor{colorCOIGPropertyCounts_Datasets}{RGB}{227,240,239}
\definecolor{colorCapybaraPropertyCounts_Datasets}{RGB}{240,243,243}
\definecolor{colorChatDoctorPropertyCounts_Datasets}{RGB}{245,236,212}
\definecolor{colorChatbotArenaPropertyCounts_Datasets}{RGB}{240,223,178}
\definecolor{colorCidarPropertyCounts_Datasets}{RGB}{240,223,178}
\definecolor{colorCollectiveCognitionPropertyCounts_Datasets}{RGB}{240,223,178}
\definecolor{colorConiferPropertyCounts_Datasets}{RGB}{240,223,178}
\definecolor{colorDeita10KPropertyCounts_Datasets}{RGB}{245,234,202}
\definecolor{colorDialogStudioPropertyCounts_Dataset