# Finding differential test cases

The goal of this notebook is to identify and extract all differential test cases and store the functions that are used as oracles. These functions can then be checked for changes whenever developers make a new commit that the tool's user might want to know about.

## Imports

In [2]:
import ast
import sys

## Import source code and create Abstract Syntax Tree


An abstract syntax tree represents source code as a tree, where each node is an element of the source code, e.g. a function definition, function call, if-else statement, while-loop, variable assignment etc. It abstracts away e.g. grouping parentheses, because these groupings are implicitly defined in the tree's structure.

<img style="float: right;" src="images/abstract_syntax_tree.png" width="450">

This code produces the following tree:
```
while b ≠ 0
  if a > b
    a := a − b
  else
    b := b − a
return a
```

<br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br>
<div style="text-align: right"> <i> Source https://en.wikipedia.org/wiki/Abstract_syntax_tree </i> </div>

In [42]:
# Set tensorflow root folder
tensorflow_root = "A:/BachelorThesis/DLL_Testing_Tool/DL_Libraries/Tensorflow/tensorflow-master/tensorflow/python/" 

# set python source file
#source = open(tensorflow_root + "kernel_tests/distributions/gamma_test.py")
#source = open(tensorflow_root + "kernel_tests/matrix_solve_op_test.py")
source = open(tensorflow_root + "kernel_tests/distributions/student_t_test.py")
#source = open(tensorflow_root + "kernel_tests/reduction_ops_test.py")

# generate abstract syntax tree
tree = ast.parse(source.read())

## Identify assert function calls and check if the arguments are differential test oracles

Here we define which assert statements are relevant for us. 


* TODO: Explain why we look for only these specific statements 
* TODO: Check for other asserts too? 



In [4]:
# Define which assert functions to look for
approximation_asserts = ["assertAlmostEqual", "assertAlmostEquals", "assertAllClose", "assertAllLessEqual",
                         "assertAllCloseAccordingToType", "assertArrayNear", "assert_list_pairwise",
                         "assertNear", "assertLess", "assertAllLess", "assertLessEqual", 
                         "assertNDArrayNear", "assert_allclose", "assert_array_almost_equal",
                         "assert_almost_equal","assert_array_less", 
                         "isclose", "allclose", "gradcheck", "gradgradcheck"]

# TODO Check for these too?
bool_asserts = ["assertTrue", "assertFalse","assertIs","assertIsNot"]

other_asserts = ["assertAllEqual","assertEquals", "assertEqual", "assertAllGreater",
                                                          "assertAllGreaterEqual", "assertAllInRange", "assertAllInSet",
                                                          "assertCountEqual", "assertDTypeEqual", "assertDictEqual",
                                                          "assertSequenceEqual", "assertShapeEqual",
                                                          "assertTupleEqual", "assert_array_equal"]

### Create a custom printer/logger for later debugging

In [5]:
LOG_ALL = 0
LOG_FINAL = 1

class CustomLogger():
    def __init__(self, print_mode=LOG_FINAL):
        """Open Log and set print mode. LOG_FINAL prints only output for users. Other log modes are for debugging."""
        self.print_mode = print_mode
    
    def add(self, string, mode=LOG_ALL):
        """Add text to the log and print if it matches the chosen print_mode"""
        #TODO Write string to txt file 
        
        # print the string if print mode matches
        if mode >= self.print_mode:
            print(string)

### Traversing the tree via NodeVisitor

Since we are only interested in functions calls within the library's testing code, we use `visit_Call` to check each function call node if it is (1) an assert call and (2) if it uses a differential test oracle, i.e. another function call for comparison. 

As an example we will use the `kernel_tests/distributions/gamma_test.py` file of TensorFlow version [2.5.0](https://github.com/tensorflow/tensorflow/releases/tag/v2.5.0), which tests the gamma distribution functions.
In this file there are many differential test cases. For example the assert statements in lines 77 and 143 are both differntial tests.  

The `gamma_test` file uses the scipy library for differential testing. Many test cases in this file, for example the function `testGammaLogPDF` compares some part of TensorFlows own `tensorflow.python.ops.distributions.gamma` to the corresponding part in `scipy.stats.gamma`, e.g. comparing both log probability density functions at multiple points.

TODO Explain how we can detect a differential oracle in the AST (i.e. by finding the stats function used as comparison)  
TODO Trail back the oracle to a definition outside the function call, because the oracle might be an argument in the function definition

*Resource ast https://greentreesnakes.readthedocs.io/en/latest/manipulating.html*  
*Resource astor https://astor.readthedocs.io/en/latest/*

In [43]:
class TreeTraverser(ast.NodeVisitor):
    def __init__(self):
        
        # always stores the last function definition node that was visited
        self.last_function_definition_node = None
        
    def visit_Call(self, node):
        
        # (1) Identify if it is an assert call that we are looking for:
        
        # for nodes that call a function directly (statically), e.g. assert(...)
        # this is usually not the case for asserts
        if isinstance(node.func, ast.Name):
            
             # if the name of the called function contains "assert"
            if 'assert' in node.func.id:
                log.add("NEW FOUND: " + node.func.id)
            
        
        # for nodes that call a function of an object, e.g. self.assert(...)
        elif isinstance(node.func, ast.Attribute):
            
            # if the name of the called function is one we are searching for
            if node.func.attr in approximation_asserts:
                
                # print the assert function name and line number
                log.add("______________" + node.func.attr + "__(line " + str(node.lineno) + ")______________", mode=LOG_FINAL)
                
                # print the node and its structure
                log.add(ast.dump(node, indent='\t') + '\n')
                log.add("Found in: " + self.last_function_definition_node.name + "\n")
                
                # (2) Identify if the assert call uses a differential test oracle:
                
                # Store arguments of the assert call
                assert_argument_position1 = node.args[0]
                assert_argument_position2 = node.args[1]
                
                # return the definitions of the arguments to check if they are differential testing functions
                definitions_arg1 = self._getDefinition(assert_argument_position1)
                definitions_arg2 = self._getDefinition(assert_argument_position2)
                
                # check which node is the oracle
                if definitions_arg1 is not None:
                    for definition in definitions_arg1:
                        # print the ast structure of the definitions
                        log.add("\n(line " + str(definition.lineno) + ")\narg1 = " + ast.dump(definition, indent='\t') + '\n')
                        log.add("Extracted name arg1: " + self._getDifferentialOracleName(definition) + "\n", mode=LOG_FINAL)

                if definitions_arg2 is not None:
                    for definition in definitions_arg2:
                        # print the ast structure of the definitions
                        log.add("\n(line " + str(definition.lineno) + ")\narg1 = " + ast.dump(definition, indent='\t') + '\n')
                        log.add("Extracted name arg2: " + self._getDifferentialOracleName(definition) + "\n", mode=LOG_FINAL)
        
        # visit child nodes (neccessary?)
        self.generic_visit(node)
    
    def _getDefinition(self, node):
        """Check if the argument node is the oracle and return its definitions, 
        i.e. a list of each value that was assigned to it.
        
        Returns None if the node is not an oracle.
        """
        
        definitions = []
            
        # if the argument uses self.evaluate, we can be sure that this is the argument to test
        # therefore the other argument is the oracle and we can stop analyzing this one
        if self._isSelf_Evaluate(node):
            return None
        
        # if the argument is a named variable we need to trace it back to its definition to see if it uses another
        # function or library for differential testing
        if isinstance(node, ast.Name):
            log.add(node.id + " is a named variable!")
            
            # store name of the variable to search for
            variable_name = node.id
            
            # iterate through each variable assignment in the function that the assert is called from 
            # and return the first assigned value of our assert argument variable
            for child in ast.walk(self.last_function_definition_node):
                if isinstance(child, ast.Assign):
                    for target in child.targets:
                        
                        # for list assignments of the form e.g. [a,b] = func([a,b])
                        if isinstance(target, ast.List):
                            for list_target in target.elts:
                                if list_target.id == variable_name:
                                    
                                    # check if the list uses self.evaluate as a function, i.e. self.evaluate([a,b])
                                    if self._isSelf_Evaluate(child.value):
                                        return None
                                    
                                    definitions.append(child.value)
                        
                        # for regular variable assignments
                        if isinstance(target, ast.Name) and target.id == variable_name:
                            
                            # if the assignment includes a self.evaluate statement, return None
                            if self._isSelf_Evaluate(child.value):
                                return None
                            
                            # otherwise return the value of the variable 
                            definitions.append(child.value)

            
        # if the argument is a function call, extract the function name and trace its arguments back to their definitions
        if isinstance(node, ast.Call):
            log.add("Argument is a function call!")
            
            # Check variable that the function is called on, e.g. check a when oracle is a.mean():
            
            # for regular calls e.g. a.mean()
            if hasattr(node, 'func'):
                # get all definitions of the variable
                variable_definition = self._getDefinition(node.func.value)
            
            # for subscripts, e.g. a[:2].mean()
            elif hasattr(node, 'value'):
                variable_definition = self._getDefinition(node.value)
                
            # check if it is the oracle
            if variable_definition is not None:
                # append definitions 
                [definitions.append(definition) for definition in variable_definition]
            
            # if it is not the oracle
            else:
                return None
            
            
            # TODO Check arguments of the function?       
        
             
        return definitions
            
        
    def _getDifferentialOracleName(self, node):
        """Get the name of the function used in the differential test case."""
        
        oracle_attributes = []
        oracle_name = ""
        
        # if it is a function call
        if isinstance(node, ast.Call):
            
            # move into the function of the call node (ignoring the arguments of the function call)
            _node = node.func
            
            # if the function call has multiple attributes, i.e. np.linalg.solve()
            while isinstance(_node, ast.Attribute):
                # extract the name of the function
                oracle_attributes.append(_node.attr)
                
                # move another layer deeper into the node
                _node = _node.value
            
            # if this is the last attribute, i.e. we are at the deepest level of the call node
            if isinstance(_node, ast.Name):
                oracle_attributes.append(_node.id)
            
            # construct the oracle name by going through the reversed attributes array
            for attr in oracle_attributes[::-1]:
                oracle_name += attr + "."
            
            # remove the dot at the end
            oracle_name = oracle_name[:-1]
        
        
        # if it is a list comprehension, find func when node is e.g. [func(a) for a in b]
        if isinstance(node, ast.ListComp):
            # recursively call this function to extract the function name
            oracle_name = self._getDifferentialOracleName(node.elt)               
                
                
        # CURRENTLY UNSUPPORTED TYPES: 
        
        # TODO Check the variables of binary operations? e.g. arg1 = a/b
        if isinstance(node, ast.BinOp):
            oracle_name = "UNSUPPORTED Binary Operation"
            
        # TODO Can constants be derived from oracles?
        if isinstance(node, ast.Constant):
            oracle_name = "UNSUPPORTED Constant"
            
        if isinstance(node, ast.Name):
            oracle_name = "UNSUPPORTED Name (named variable: " + node.id + ")"
            
        
        # return the oracle function name
        return oracle_name
                
        
    def _isSelf_Evaluate(self, node):
        """Check if the node is a self.evaluate call. If so, the node is not an oracle."""
        
        if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute) and node.func.attr == 'evaluate':
            log.add("Argument is an evaluate function call!")
            return True
        return False
        
    def visit_FunctionDef(self, node):
        self.last_function_definition_node = node
        
        # visit child nodes
        self.generic_visit(node)


# Initialize custom logger
# Use the print mode LOG_FINAL for final output and LOG_ALL for debugging
log = CustomLogger(LOG_FINAL)

# start the tree traversal
TreeTraverser().visit(tree)

______________assertAllClose__(line 75)______________
Extracted name arg1: stats.t.logpdf

______________assertAllClose__(line 76)______________
______________assertAllClose__(line 77)______________
Extracted name arg1: stats.t.pdf

______________assertAllClose__(line 78)______________
______________assertAllClose__(line 102)______________
Extracted name arg1: stats.t.logpdf

______________assertAllClose__(line 103)______________
______________assertAllClose__(line 104)______________
Extracted name arg1: stats.t.pdf

______________assertAllClose__(line 105)______________
______________assertAllClose__(line 129)______________
Extracted name arg1: stats.t.logcdf

______________assertAllClose__(line 130)______________
______________assertAllClose__(line 132)______________
Extracted name arg1: stats.t.cdf

______________assertAllClose__(line 133)______________
______________assertAllClose__(line 156)______________
Extracted name arg1: stats.t.entropy

Extracted name arg1: np.reshape

_____