In [14]:
import ast

import datasets
import jsonlines as jsonl
import numpy as np
import rich
import transformers
import torch
import wget


In [10]:
# download dataset
url = "https://raw.githubusercontent.com/google-research/google-research/master/mbpp/mbpp.jsonl"
output_file = "mbpp.jsonl"
wget.download(url, output_file)
with open(output_file, "r") as f:
    reader = jsonl.Reader(f)
    data = list(reader)


In [79]:
def prep():
    samples = []
    lengths = []
    for entry in data:
        if entry["test_setup_code"]:
            rich.print("[red]Skipping because there is setup code")
            continue

        assert len(entry["test_list"]) == 3, len(entry["test_list"])
        
        for code in entry["test_list"]:
            parsed = ast.parse(code)
            assert len(parsed.body) == 1, len(parsed.body)
            assert_obj = parsed.body[0]
            assert isinstance(assert_obj, ast.Assert), type(assert_obj)
            assert isinstance(assert_obj.test.left, ast.Call), type(assert_obj.test.left)
            call_obj = assert_obj.test.left
            call_args = call_obj.args
            output_constant_obj = assert_obj.test.comparators
            assert len(output_constant_obj) == 1, len(output_constant_obj)
            
            node = output_constant_obj[0]
            if isinstance(node, ast.Call):
                print(code)
                rich.print("[red] Skipping because Call in comparator")
                continue

            if isinstance(node, ast.Name):
                print(code)
                rich.print("[red] Skipping because Name in comparator")
                continue

            output_constant_val = ast.literal_eval(node)
            args_vals = []
            ignore_case = False
            for arg in call_args:
                result = None

                if isinstance(arg, ast.Name):
                    print(code)
                    rich.print("[red]Ignored because there was a name in the args of the fn")
                    ignore_case = True 
                    break

                if isinstance(arg, ast.Call):
                    print(code)
                    rich.print("[red]Ignored because there was a call in the args of the fn")
                    ignore_case = True 
                    break

                if isinstance(arg, ast.BinOp):
                    left = ast.literal_eval(arg.left)
                    right = ast.literal_eval(arg.right)
                    if isinstance(arg.op, ast.Add):
                        result = left + right
                    elif isinstance(arg.op, ast.Sub):
                        result = left - right
                    elif isinstance(arg.op, ast.Mult):
                        result = left * right
                    elif isinstance(arg.op, ast.Div):
                        result = left / right
                    else:
                        raise NotImplementedError(type(arg.op))

                if result is None:
                    try:                    
                        parsed_arg = ast.literal_eval(arg)
                    except ValueError as err:
                        print(code)
                        rich.print("[red]Ignored because an arg wasn't a litteral")
                        ignore_case = True 
                        break
                else:
                    parsed_arg = result

                args_vals.append(parsed_arg)

            if not ignore_case:
                lengths.append(len(entry["code"].strip().split("\n")))
                samples.append((args_vals, output_constant_val))

    return samples, lengths

samples, lengths = prep() 
print()
print(len(samples))
print(len(data) * 3)

assert remove_datatype((4, 5, 4, 7.7, 1.2), int) == [7.7, 1.2]


assert remove_datatype((7, 8, 9, "SR"), str) == [7, 8, 9]


assert remove_datatype((7, 1.1, 2, 2.2), float) == [7, 2]


assert tuple_size(("A", 1, "B", 2, "C", 3) ) == sys.getsizeof(("A", 1, "B", 2, "C", 3))


assert tuple_size((1, "Raju", 2, "Nikhil", 3, "Deepanshu") ) == sys.getsizeof((1, "Raju", 2, "Nikhil", 3, "Deepanshu"))


assert tuple_size(((1, "Lion"), ( 2, "Tiger"), (3, "Fox"), (4, "Wolf"))  ) == sys.getsizeof(((1, "Lion"), ( 2, "Tiger"), (3, "Fox"), (4, "Wolf")))


assert max_chain_length([Pair(5, 24), Pair(15, 25),Pair(27, 40), Pair(50, 60)], 4) == 3


assert max_chain_length([Pair(1, 2), Pair(3, 4),Pair(5, 6), Pair(7, 8)], 4) == 4


assert max_chain_length([Pair(19, 10), Pair(11, 12),Pair(13, 14), Pair(15, 16), Pair(31, 54)], 5) == 5


assert int(lobb_num(5, 3)) == 35


assert int(lobb_num(3, 2)) == 5


assert int(lobb_num(4, 2)) == 20


2904
2922
