In [69]:
import sys
from importlib import reload
import numpy as np

import polars as pl
from datatools import plotting as dtplot
from datatools import tabular as dttab

sys.path.append("..")

import plotting
import src.util as util
from src import text_process
from src.data_functions import make_example_groups, data_split

reload(util)
reload(plotting)
reload(text_process)
dtplot.set_plotly_template()

In [70]:
examples = util.load_examples()
examples.head(5)

Loaded 187 examples


difficulty,tokens,tags,name,lang,length
str,list[str],list[str],str,str,u32
"""easy""","[""x"", ""="", ""1""]","[""va"", ""opas"", ""nu""]","""shrt_pseudo""","""pseudo""",3
"""easy""","[""["", ""2"", … ""]""]","[""brop"", ""nu"", … ""brcl""]","""smplrr_json""","""json""",9
"""normal""","[""say"", "" "", """"Hello world""""]","[""kwio"", ""ws"", ""st""]","""hllwrld_natural""","""natural""",3
"""normal""","[""puts"", "" "", """"Hello World""""]","[""kwio"", ""ws"", ""st""]","""hllwrld_ruby""","""ruby""",3
"""normal""","[""if"", "" "", … "")""]","[""kwfl"", ""ws"", … ""brcl""]","""smplndnttb_pseudo""","""pseudo""",8


## tokens


In [71]:
token_counts = dttab.value_counts(examples["tokens"].explode())
token_counts

tokens,tokens_count
str,u32
""" """,1558
""" """,461
""")""",371
"""(""",371
""",""",289
…,…
"""""$x + $y = $z\n""""",1
"""""$a, $b""""",1
"""""# questions loaded""""",1
"""!""",1


## which tags can each token have?


In [72]:
token_to_tags = (
    examples.select(pl.col("tokens", "tags").explode())
    .group_by("tokens")
    .agg(pl.col("tags").explode().unique())  # what tags can the token have?
).join(token_counts, on="tokens")

single_tagged = (
    token_to_tags.filter(pl.col("tags").list.len() == 1)
    .sort("tokens_count", descending=True)
    .with_columns(pl.col("tags").list[0])
    .filter(pl.col("tags").is_in(text_process.DET_TAGS).not_())
)  # exclude DET TAGS

multi_tagged = token_to_tags.filter(pl.col("tags").list.len() != 1).sort(
    "tokens_count", descending=True
)
print(f"found {len(single_tagged)} tokens with a single tag")
print(f"found {len(multi_tagged)} tokens with multiple tags")

display(single_tagged.head(10))
# display(multi_tagged.head(10))


found 732 tokens with a single tag
found 62 tokens with multiple tags


tokens,tags,tokens_count
str,str,u32
"""=""","""opas""",247
""".""","""sy""",187
""":""","""sy""",96
"""i""","""va""",45
"""if""","""kwfl""",38
"""return""","""kwfl""",35
"""for""","""kwfl""",31
"""==""","""opcm""",30
"""import""","""kwim""",29
"""in""","""kwop""",27


## which tokens can each tag have?


In [73]:
tag_counts = dttab.value_counts(examples["tags"].explode())

tags_to_tokens = (
    (
        examples.select(pl.col("tokens", "tags").explode())
        .group_by("tags")
        .agg(pl.col("tokens").explode().unique())  # what tags can the token have?
    )
    .join(tag_counts, on="tags")
    .filter(pl.col("tags").is_in(text_process.DET_TAGS).not_())
    .sort(pl.col("tokens").list.len())
)
display(tags_to_tokens.head(10))

tags,tokens,tags_count
str,list[str],u32
"""coml""","[""/**  * Get thse thing  */""]",1
"""coil""","[""# array([2, 5], dtype=int32)"", ""# array([[1, 9, 3], [4, 5, 6]], dtype=int32)"", ""# array([9, 5], dtype=int32)""]",3
"""kwva""","[""const"", ""let"", … ""val""]",32
"""bo""","[""true"", ""false"", … ""False""]",28
"""kwfn""","[""func"", ""function"", … ""end""]",23
"""li""","[""all"", ""xy"", … ""tight""]",9
"""opas""","[""/="", ""+="", … ""*=""]",265
"""kwio""","[""puts"", ""clear"", … ""echo""]",14
"""kwde""","[""namespace"", ""class"", … ""extends""]",13
"""kwim""","[""crate"", ""as"", … ""use""]",61


In [74]:
token_to_tags.filter(pl.col("tokens") == "System")

tokens,tags,tokens_count
str,list[str],u32
"""System""","[""cl"", ""mo""]",3


## Language specific rules!


In [75]:
# in java, "System" is a class

## Split: train/val/test


In [76]:
# Stratify by lang and approx length

import json


examples = make_example_groups(examples, min_group_count=6)

print(dttab.value_counts(examples["group"], as_dict=True))

ratios = [0.6, 0.25, 0.15]
splits = data_split(examples, ratios, seed=None)
result_splits = [len(df) / len(examples) for df in splits]

print("sMAPE=", util.MAPE(result_splits, ratios, symmetric=True))
print("group counts:")
for s in splits:
    print("  ", dttab.value_counts(s["group"], as_dict=True, sort_by="value"))
# display(*splits)

splitnames = ["train", "val", "test"]
split_index = {}
for split, splitname in zip(splits, splitnames):
    split_index.update(dict.fromkeys(split["name"].to_list(), splitname))

split_index
with open("../data/split_index.json", "w") as f:
    json.dump(split_index, f)

{'other': 104, 'long_python_ambiguous': 19, 'long_python_normal': 14, 'short_php_normal': 12, 'medium_python_normal': 12, 'medium_pseudo_normal': 8, 'short_python_normal': 6, 'medium_matlab_normal': 6, 'long_dart_normal': 6}
sMAPE= 0.02740794820565061
group counts:
   {'long_dart_normal': 3, 'long_python_ambiguous': 11, 'long_python_normal': 8, 'medium_matlab_normal': 3, 'medium_pseudo_normal': 4, 'medium_python_normal': 7, 'other': 62, 'short_php_normal': 7, 'short_python_normal': 3}
   {'long_dart_normal': 2, 'long_python_ambiguous': 5, 'long_python_normal': 3, 'medium_matlab_normal': 2, 'medium_pseudo_normal': 2, 'medium_python_normal': 3, 'other': 26, 'short_php_normal': 3, 'short_python_normal': 2}
   {'long_dart_normal': 1, 'long_python_ambiguous': 3, 'long_python_normal': 3, 'medium_matlab_normal': 1, 'medium_pseudo_normal': 2, 'medium_python_normal': 2, 'other': 16, 'short_php_normal': 2, 'short_python_normal': 1}
