# Finding differential test cases

The goal of this notebook is to identify and extract all differential test cases and store the functions that are used as oracles. These functions can then be checked for changes whenever developers make a new commit that the tool's user might want to know about.

## Imports

In [1]:
import ast
import astor
import sys

## Import source code and create Abstract Syntax Tree


An abstract syntax tree represents source code as a tree, where each node is an element of the source code, e.g. a function definition, function call, if-else statement, while-loop, variable assignment etc. It abstracts away e.g. grouping parentheses, because these groupings are implicitly defined in the tree's structure.

<img style="float: right;" src="images/abstract_syntax_tree.png" width="450">

This code produces the following tree:
```
while b ≠ 0
  if a > b
    a := a − b
  else
    b := b − a
return a
```

<br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br>
<div style="text-align: right"> <i> Source https://en.wikipedia.org/wiki/Abstract_syntax_tree </i> </div>

In [27]:
# Set tensorflow root folder
tensorflow_root = "A:/BachelorThesis/DLL_Testing_Tool/DL_Libraries/Tensorflow/tensorflow-master/tensorflow/python/" 

# set python source file
source = open(tensorflow_root + "kernel_tests/distributions/gamma_test.py")
#source = open(tensorflow_root + "kernel_tests/matrix_solve_op_test.py")

# generate abstract syntax tree
tree = ast.parse(source.read())

## Identify assert function calls and check if the arguments are differential test oracles

TODO:

* Check arguments of assert calls for differential testing oracles

*Resource ast https://greentreesnakes.readthedocs.io/en/latest/manipulating.html*  
*Resource astor https://astor.readthedocs.io/en/latest/*

In [3]:
# Define which assert functions to look for
approximation_asserts = ["assertAlmostEqual", "assertAlmostEquals", "assertAllClose", "assertAllLessEqual",
                         "assertAllCloseAccordingToType", "assertArrayNear", "assert_list_pairwise",
                         "assertNear", "assertLess", "assertAllLess", "assertLessEqual", 
                         "assertNDArrayNear", "assert_allclose", "assert_array_almost_equal",
                         "assert_almost_equal","assert_array_less", 
                         "isclose", "allclose", "gradcheck", "gradgradcheck"]

# TODO Check for these too?
bool_asserts = ["assertTrue", "assertFalse","assertIs","assertIsNot"]

other_asserts = ["assertAllEqual","assertEquals", "assertEqual", "assertAllGreater",
                                                          "assertAllGreaterEqual", "assertAllInRange", "assertAllInSet",
                                                          "assertCountEqual", "assertDTypeEqual", "assertDictEqual",
                                                          "assertSequenceEqual", "assertShapeEqual",
                                                          "assertTupleEqual", "assert_array_equal"]

### Traversing the tree via NodeVisitor

Since we are only interested in functions calls within the library's testing code, we use `visit_Call` to check each function call node if it is (1) an assert call and (2) if it uses a differential test oracle, i.e. another function call for comparison. 

TODO Change back to current version.
As an example we will use the `kernel_tests/distributions/gamma_test.py` file of TensorFlow version 1.12.0 (https://github.com/tensorflow/tensorflow/releases/tag/v1.12.0), which tests the gamma distribution functions.
In this file there are many differential test cases, two of which were picked out here: The corresponding assert statements are in lines 118 and 219.  

The `gamma_test` file uses the scipy library for differential testing. Many test cases in this file, for example the function `testGammaLogPDF` compares some part of TensorFlows own `tensorflow.python.ops.distributions.gamma` to the corresponding part in `scipy.stats.gamma`, e.g. comparing both log probability density functions at multiple points.

TODO Use self.evaluate to identify which argument is the one to test. The other one is the oracle then.  
TODO Explain how we can detect a differential oracle in the AST (i.e. by finding the stats function used as comparison)  
TODO Figure out how to get the definition of expected_log_pdf, then identify stats as scipy.stats there  
TODO Trail back the oracle to its definition (this might be outside the function call, because the oracle might be an argument in the function definition)

Other examples: matrix_solve_op_test line 73 uses np.linalg.solve

In [43]:
class TreeTraverser(ast.NodeVisitor):
    def __init__(self):
        
        # always stores the last function definition node that was visited
        self.last_function_definition_node = None
        
    def visit_Call(self, node):
        
        # (1) Identify if it is an assert call that we are looking for:
        
        # for nodes that call a function directly (statically), e.g. assert(...)
        # this is usually not the case for asserts
        if isinstance(node.func, ast.Name):
            
             # if the name of the called function contains "assert"
            if 'assert' in node.func.id:
                print("NEW FOUND: " + node.func.id)
            
        
        # for nodes that call a function of an object, e.g. self.assert(...)
        elif isinstance(node.func, ast.Attribute):
            
            # if the name of the called function is one we are searching for
            if node.func.attr in approximation_asserts:
                
                # print the assert function name and line number
                print("______________" + node.func.attr + "__(line " + str(node.lineno) + ")______________")
                
                # print the node and its structure
                print(ast.dump(node, indent='\t') + '\n')
                print("Found in: " + self.last_function_definition_node.name + "\n")
                
                # (2) Identify if the assert call uses a differential test oracle:
                
                # Store arguments of the assert call
                assert_argument_position1 = node.args[0]
                assert_argument_position2 = node.args[1]
                
                # return the definitions of the arguments to check if they are differential testing functions
                definition_arg1 = self._getDefinition(assert_argument_position1)
                definition_arg2 = self._getDefinition(assert_argument_position2)
                
                # check which node is the oracle
                if definition_arg1 is not None:
                    # print the ast structure of the definitions
                    print("\n(line " + str(definition_arg1.lineno) + ")\narg1 = " + ast.dump(definition_arg1, indent='\t') + '\n')
                    print("Extracted name arg1: " + self._getDifferentialOracleName(definition_arg1))
                
                if definition_arg2 is not None:
                    # print the ast structure of the definitions
                    print("\n(line " + str(definition_arg2.lineno) + ")\narg2 = " + ast.dump(definition_arg2, indent='\t') + '\n')
                    print("Extracted name arg2: " + self._getDifferentialOracleName(definition_arg2))
                
        
        # visit child nodes (neccessary?)
        self.generic_visit(node)
    
    def _getDefinition(self, node):
        """Check if the argument node is the oracle and get its definition.
        
        Returns None if the node is not an oracle.
        """
            
        # if the argument uses self.evaluate, we can be sure that this is the argument to test
        # therefore the other argument is the oracle and we can stop analyzing this one
        if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute) and node.func.attr == 'evaluate':
            print("Argument is an evaluate function call!")
            return None
        
        # if the argument is a named variable we need to trace it back to its definition to see if it uses another
        # function or library for differential testing
        if isinstance(node, ast.Name):
            print(node.id + " is a named variable!")
            
            # store name of the variable to search for
            variable_name = node.id
            
            # iterate through each variable assignment in the function that the assert is called from 
            # and return the first assigned value of our assert argument variable
            for child in ast.walk(self.last_function_definition_node):
                if isinstance(child, ast.Assign):
                    for target in child.targets:
                        
                        # for list assignments of the form e.g. [a,b] = func([a,b])
                        if isinstance(target, ast.List):
                            for list_target in target.elts:
                                if list_target.id == variable_name:
                                    return child.value
                        
                        # for regular variable assignments
                        if isinstance(target, ast.Name) and target.id == variable_name:
                            return child.value

            
        # if the argument is a function call, extract the function name and trace its arguments back to their definitions
        if isinstance(node, ast.Call):
            print("Call extracted name: " + self._getDifferentialOracleName(node) + "\n")
            
    def _getDifferentialOracleName(self, node):
        """Get the name of the function used in the differential test case."""
        
        oracle_attributes = []
        oracle_name = ""
        
        # If it is a function call
        if isinstance(node, ast.Call):
            
            # move into the function of the call node (ignoring the arguments of the function call)
            node = node.func
            
            # if the function call has multiple attributes, i.e. np.linalg.solve()
            while isinstance(node, ast.Attribute):
                # extract the name of the function
                oracle_attributes.append(node.attr)
                
                # move another layer deeper into the node
                node = node.value
            
            # if this is the last attribute, i.e. we are at the deepest level of the call node
            if isinstance(node, ast.Name):
                oracle_attributes.append(node.id)
            
            # construct the oracle name by going through the reversed attributes array
            for attr in oracle_attributes[::-1]:
                oracle_name += attr + "."
            
        # return the oracle name (without the dot at the end)
        return oracle_name[:-1]
                
        
        
    def visit_FunctionDef(self, node):
        self.last_function_definition_node = node
        
        # visit child nodes
        self.generic_visit(node)

TreeTraverser().visit(tree)

______________assertAllClose__(line 77)______________
Call(
	func=Attribute(
		value=Name(id='self', ctx=Load()),
		attr='assertAllClose',
		ctx=Load()),
	args=[
		Call(
			func=Attribute(
				value=Name(id='self', ctx=Load()),
				attr='evaluate',
				ctx=Load()),
			args=[
				Name(id='log_pdf', ctx=Load())],
			keywords=[]),
		Name(id='expected_log_pdf', ctx=Load())],
	keywords=[])

Found in: testGammaLogPDF

Argument is an evaluate function call!
expected_log_pdf is a named variable!

(line 76)
arg2 = Call(
	func=Attribute(
		value=Attribute(
			value=Name(id='stats', ctx=Load()),
			attr='gamma',
			ctx=Load()),
		attr='logpdf',
		ctx=Load()),
	args=[
		Name(id='x', ctx=Load()),
		Name(id='alpha_v', ctx=Load())],
	keywords=[
		keyword(
			arg='scale',
			value=BinOp(
				left=Constant(value=1),
				op=Div(),
				right=Name(id='beta_v', ctx=Load())))])

Extracted name arg2: stats.gamma.logpdf
______________assertAllClose__(line 78)______________
Call(
	func=Attribute(
		value=Name(