# Finding differential test cases

The goal of this notebook is to identify and extract all differential test cases and store the functions that are used as oracles. These functions can then be checked for changes whenever developers make a new commit that the tool's user might want to know about.

## Imports

In [1]:
import ast
import sys

## Import source code and create Abstract Syntax Tree


An abstract syntax tree represents source code as a tree, where each node is an element of the source code, e.g. a function definition, function call, if-else statement, while-loop, variable assignment etc. It abstracts away e.g. grouping parentheses, because these groupings are implicitly defined in the tree's structure.

<img style="float: right;" src="images/abstract_syntax_tree.png" width="450">

This code produces the following tree:
```
while b ≠ 0
  if a > b
    a := a − b
  else
    b := b − a
return a
```

<br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br>
<div style="text-align: right"> <i> Source https://en.wikipedia.org/wiki/Abstract_syntax_tree </i> </div>

In [26]:
# Set tensorflow root folder
tensorflow_root = "A:/BachelorThesis/DLL_Testing_Tool/DL_Libraries/Tensorflow/tensorflow-master/tensorflow/python/" 

# set python source file
source = open(tensorflow_root + "kernel_tests/array_ops/gather_op_test.py")

# generate abstract syntax tree
tree = ast.parse(source.read())

## Identify assert function calls and check if the arguments are differential test oracles

TODO:

* Check arguments of assert calls for differential testing oracles

*Resource https://greentreesnakes.readthedocs.io/en/latest/manipulating.html*

In [28]:
# Define which assert functions to look for
approximation_asserts = ["assertAlmostEqual", "assertAlmostEquals", "assertAllClose", "assertAllLessEqual",
                         "assertAllCloseAccordingToType", "assertArrayNear", "assert_list_pairwise",
                         "assertNear", "assertLess", "assertAllLess", "assertLessEqual", 
                         "assertNDArrayNear", "assert_allclose", "assert_array_almost_equal",
                         "assert_almost_equal","assert_array_less", 
                         "isclose", "allclose", "gradcheck", "gradgradcheck"]

# TODO Check for these too?
bool_asserts = ["assertTrue", "assertFalse","assertIs","assertIsNot"]

other_asserts = ["assertAllEqual","assertEquals", "assertEqual", "assertAllGreater",
                                                          "assertAllGreaterEqual", "assertAllInRange", "assertAllInSet",
                                                          "assertCountEqual", "assertDTypeEqual", "assertDictEqual",
                                                          "assertSequenceEqual", "assertShapeEqual",
                                                          "assertTupleEqual", "assert_array_equal"]

### Traversing the tree via NodeVisitor

Since we are only interested in functions calls within the library's testing code, we use `visit_Call` to check each function call node if it is (1) an assert call and (2) if it uses a differential test oracle, i.e. another function call for comparison. 

In [39]:
class TreeTraverser(ast.NodeVisitor):
    def visit_Call(self, node):
        
        # (1) Identify if it is an assert call that we are looking for:
        
        # for nodes that call a function directly (statically), e.g. assert(...)
        # this is usually not the case for asserts
        if isinstance(node.func, ast.Name):
            
             # if the name of the called function contains "assert"
            if 'assert' in node.func.id:
                print("NEW FOUND: " + node.func.id)
            
        
        # for nodes that call a function of an object, e.g. self.assert(...)
        elif isinstance(node.func, ast.Attribute):
            
            # if the name of the called function is one we are searching for
            if node.func.attr in approximation_asserts:
                
                # print the function name
                print("______________" + node.func.attr + "______________")
                
                # print the node and its structure
                print(ast.dump(node, indent='\t') + '\n')
                
                # TODO (2) Identify if the assert call uses a differential test oracle:
                
                
                
        
        # visit child nodes (neccessary?)
        self.generic_visit(node)

TreeTraverser().visit(tree)

______________assertAllClose______________
Call(
	func=Attribute(
		value=Name(id='self', ctx=Load()),
		attr='assertAllClose',
		ctx=Load()),
	args=[
		Name(id='correct_params_grad', ctx=Load()),
		Call(
			func=Attribute(
				value=Name(id='self', ctx=Load()),
				attr='evaluate',
				ctx=Load()),
			args=[
				Name(id='params_grad', ctx=Load())],
			keywords=[])],
	keywords=[
		keyword(
			arg='atol',
			value=Constant(value=2e-06)),
		keyword(
			arg='rtol',
			value=Constant(value=2e-06))])

______________assertAllClose______________
Call(
	func=Attribute(
		value=Name(id='self', ctx=Load()),
		attr='assertAllClose',
		ctx=Load()),
	args=[
		Name(id='correct_params_grad', ctx=Load()),
		Call(
			func=Attribute(
				value=Name(id='self', ctx=Load()),
				attr='evaluate',
				ctx=Load()),
			args=[
				Name(id='params_grad', ctx=Load())],
			keywords=[])],
	keywords=[
		keyword(
			arg='atol',
			value=Constant(value=2e-06)),
		keyword(
			arg='rtol',
			value=Constant(value=2e-06))