# Finding differential test cases

The goal of this notebook is to identify and extract all differential test cases and store the functions that are used as oracles. These functions can then be checked for changes whenever developers make a new commit that the tool's user might want to know about.

## Imports

In [100]:
import ast
import sys
import numpy as np
import os

## Import source code and create Abstract Syntax Tree


An abstract syntax tree represents source code as a tree, where each node is an element of the source code, e.g. a function definition, function call, if-else statement, while-loop, variable assignment etc. It abstracts away e.g. grouping parentheses, because these groupings are implicitly defined in the tree's structure.

<img style="float: right;" src="images/abstract_syntax_tree.png" width="450">

This code produces the following tree:
```
while b ≠ 0
  if a > b
    a := a − b
  else
    b := b − a
return a
```

<br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br>
<div style="text-align: right"> <i> Source https://en.wikipedia.org/wiki/Abstract_syntax_tree </i> </div>

### Setup: Edit this code cell and replace the library root with the equivalent file path on your machine
Comment in the library you want to extract the test cases from.  
The specific python files here are debugging on a single file.  

In [101]:
# Set library root folder:

# For TensorFlow
#library_root = "A:/BachelorThesis/DLL_Testing_Tool/DL_Libraries/Tensorflow/tensorflow-master/tensorflow/python/" 
#tests_root = "kernel_tests"
#save_data_to = "extracted_data/tensorflow_data.csv"

# For Pytorch
library_root = "A:/BachelorThesis/DLL_Testing_Tool/DL_Libraries/PyTorch/pytorch-master/" 
tests_root = "test"
save_data_to = "extracted_data/pytorch_data.csv"

# set python source file (for single file debugging):

# TensorFlow:
#source = open(library_root + "kernel_tests/distributions/gamma_test.py")
#source = open(library_root + "kernel_tests/matrix_solve_op_test.py")
#source = open(library_root + "kernel_tests/distributions/student_t_test.py")
#source = open(library_root + "kernel_tests/reduction_ops_test.py")

# PyTorch:
source = open(library_root + tests_root + "/distributions/test_distributions.py")

# generate abstract syntax tree
tree = ast.parse(source.read())

## Identify assert function calls and check if the arguments are differential test oracles

Here we define which assert statements are relevant for us. 


* TODO: Explain why we look for only these specific statements 
* TODO: Check for other asserts too? 



In [102]:
# Define which assert functions to look for
approximation_asserts = ["assertAlmostEqual", "assertAlmostEquals", "assertAllClose", "assertAllLessEqual",
                         "assertAllCloseAccordingToType", "assertArrayNear", "assert_list_pairwise",
                         "assertNear", "assertLess", "assertAllLess", "assertLessEqual", 
                         "assertNDArrayNear", "assert_allclose", "assert_array_almost_equal",
                         "assert_almost_equal","assert_array_less", 
                         "isclose", "allclose", "gradcheck", "gradgradcheck"]

# TODO Check for these too?
bool_asserts = ["assertTrue", "assertFalse","assertIs","assertIsNot"]

other_asserts = ["assertAllEqual","assertEquals", "assertEqual", "assertAllGreater",
                                                          "assertAllGreaterEqual", "assertAllInRange", "assertAllInSet",
                                                          "assertCountEqual", "assertDTypeEqual", "assertDictEqual",
                                                          "assertSequenceEqual", "assertShapeEqual",
                                                          "assertTupleEqual", "assert_array_equal"]

### Create a custom printer/logger for later debugging

In [103]:
LOG_ALL = 0
LOG_FINAL = 1
LOG_NONE = 2

class CustomLogger():
    def __init__(self, print_mode=LOG_FINAL):
        """Open Log and set print mode. LOG_FINAL prints only output for users. Other log modes are for debugging."""
        self.print_mode = print_mode
        
        # for creating a csv with the collected data
        self.file_path = ""
        #self.data = np.array(["File Path", "Line Number", "Found in Function", "Function Definition Line Number" "Assert Statement Type", "Oracle Argument Position", "Differential Function Line Number", "Differential Test Function"])
        self.data = np.array(["File_Path", "Line_Number", "Found_in_Function", "Function_Definition_Line_Number", "Assert_Statement_Type", "Oracle_Argument_ Position", "Differential_Function_Line_Number", "Differential_Test_Function"])
    
    def add(self, string, mode=LOG_ALL):
        """Add text to the log and print if it matches the chosen print_mode"""
        #TODO Write log to txt file (for the full tool later) 
        
        # print the string if print mode matches
        if mode >= self.print_mode:
            print(string)
    
    def createEntry(self, line_no, found_in_function, function_def_line_no, assert_statement_type, oracle_arg_pos, diff_func_line_no, diff_test_function_name):
        """Add an entry to the data."""
        self.data = np.vstack((self.data, [self.file_path, line_no, found_in_function, function_def_line_no, assert_statement_type, oracle_arg_pos, diff_func_line_no, diff_test_function_name]))
    
    def set_file_path_variable(self, file_path):
        self.file_path = file_path
        
    def get_data(self):
        return self.data
    
    def save_data_to_csv(self, path="extracted_data/data.csv"):
        np.savetxt(path, self.data, fmt='%s', delimiter=",")
        print("Data saved to " + path)

### Traversing the tree via NodeVisitor

Since we are only interested in functions calls within the library's testing code, we use `visit_Call` to check each function call node if it is (1) an assert call and (2) if it uses a differential test oracle, i.e. another function call for comparison. 

As an example we will use the `kernel_tests/distributions/gamma_test.py` file of TensorFlow version [2.5.0](https://github.com/tensorflow/tensorflow/releases/tag/v2.5.0), which tests the gamma distribution functions.
In this file there are many differential test cases. For example the assert statements in lines 77 and 143 are both differntial tests.  

The `gamma_test` file uses the scipy library for differential testing. Many test cases in this file, for example the function `testGammaLogPDF` compares some part of TensorFlows own `tensorflow.python.ops.distributions.gamma` to the corresponding part in `scipy.stats.gamma`, e.g. comparing both log probability density functions at multiple points.

TODO Explain how we can detect a differential oracle in the AST (i.e. by finding the stats function used as comparison)  
TODO Trail back the oracle to a definition outside the function call, because the oracle might be an argument in the function definition

*Resource ast https://greentreesnakes.readthedocs.io/en/latest/manipulating.html*  
*Resource astor https://astor.readthedocs.io/en/latest/*

In [104]:
class TreeTraverser(ast.NodeVisitor):
    def __init__(self):
        
        # always stores the last function definition node that was visited
        self.last_function_definition_node = None
        
    def visit_Call(self, node):
        
        # (1) Identify if it is an assert call that we are looking for:
        
        # for nodes that call a function directly (statically), e.g. assert(...)
        # this is usually not the case for asserts
        if isinstance(node.func, ast.Name):
            
             # if the name of the called function contains "assert"
            if 'assert' in node.func.id:
                log.add("NEW FOUND: " + node.func.id)
            
        
        # for nodes that call a function of an object, e.g. self.assert(...)
        elif isinstance(node.func, ast.Attribute):
            
            # if the name of the called function is one we are searching for
            if node.func.attr in approximation_asserts:
                
                # print the assert function name and line number
                log.add("______________" + node.func.attr + "__(line " + str(node.lineno) + ")______________", mode=LOG_FINAL)
                
                # print the node and its structure
                log.add(ast.dump(node, indent='\t') + '\n')
                log.add("Found in: " + self.last_function_definition_node.name + "\n")
                
                # (2) Identify if the assert call uses a differential test oracle:
                
                # Store arguments of the assert call
                assert_argument_position1 = node.args[0]
                assert_argument_position2 = node.args[1]
                
                # return the definitions of the arguments to check if they are differential testing functions
                definitions_arg1 = self._getDefinition(assert_argument_position1)
                definitions_arg2 = self._getDefinition(assert_argument_position2)
                
                # check which node is the oracle
                if definitions_arg1 is not None:
                    for definition in definitions_arg1:
                        # print the ast structure of the definitions
                        log.add("\n(line " + str(definition.lineno) + ")\narg1 = " + ast.dump(definition, indent='\t') + '\n')
                        
                        # extract name
                        diff_test_function_name = self._getDifferentialOracleName(definition)
                        log.add("Extracted name arg1: " + diff_test_function_name + "\n", mode=LOG_FINAL)
                        
                        # create data entry
                        log.createEntry(node.lineno, self.last_function_definition_node.name, self.last_function_definition_node.lineno, node.func.attr, 1, definition.lineno, diff_test_function_name)
                        
                        
                if definitions_arg2 is not None:
                    for definition in definitions_arg2:
                        # print the ast structure of the definitions
                        log.add("\n(line " + str(definition.lineno) + ")\narg1 = " + ast.dump(definition, indent='\t') + '\n')
                        
                        # extract name
                        diff_test_function_name = self._getDifferentialOracleName(definition)
                        log.add("Extracted name arg2: " + diff_test_function_name + "\n", mode=LOG_FINAL)
                        
                        # create data entry
                        log.createEntry(node.lineno, self.last_function_definition_node.name, self.last_function_definition_node.lineno, node.func.attr, 2, definition.lineno, diff_test_function_name)
       
    
        # visit child nodes (neccessary?)
        self.generic_visit(node)
    
    def _getDefinition(self, node):
        """Check if the argument node is the oracle and return its definitions, 
        i.e. a list of each value that was assigned to it.
        
        Returns None if the node is not an oracle.
        """
        
        definitions = []
            
        # if the argument uses self.evaluate, we can be sure that this is the argument to test
        # therefore the other argument is the oracle and we can stop analyzing this one
        if self._isSelf_Evaluate(node):
            return None
        
        # if the argument is a named variable we need to trace it back to its definition to see if it uses another
        # function or library for differential testing
        if isinstance(node, ast.Name):
            log.add(node.id + " is a named variable!")
            
            # store name of the variable to search for
            variable_name = node.id
            
            # iterate through each variable assignment in the function that the assert is called from 
            # and return the first assigned value of our assert argument variable
            for child in ast.walk(self.last_function_definition_node):
                if isinstance(child, ast.Assign):
                    for target in child.targets:
                        
                        # for list assignments of the form e.g. [a,b] = func([a,b])
                        if isinstance(target, ast.List):
                            for list_target in target.elts:
                                if list_target.id == variable_name:
                                    
                                    # check if the list uses self.evaluate as a function, i.e. self.evaluate([a,b])
                                    if self._isSelf_Evaluate(child.value):
                                        return None
                                    
                                    definitions.append(child.value)
                        
                        # for regular variable assignments
                        if isinstance(target, ast.Name) and target.id == variable_name:
                            
                            # if the assignment includes a self.evaluate statement, return None
                            if self._isSelf_Evaluate(child.value):
                                return None
                            
                            # otherwise return the value of the variable 
                            definitions.append(child.value)

            
        # if the argument is a function call, extract the function name and trace its arguments back to their definitions
        if isinstance(node, ast.Call):
            log.add("Argument is a function call!")
            log.add(ast.dump(node, indent='\t'))
            # Check variable that the function is called on, e.g. check a when oracle is a.mean():
            
            # for regular calls e.g. a.mean()
            if hasattr(node, 'func') and hasattr(node.func, 'value'):
                # get all definitions of the variable
                variable_definition = self._getDefinition(node.func.value)
                
            # for calls to defined functions
            elif hasattr(node, 'func') and hasattr(node.func, 'id'):
                # node.func will contain a ast.Name, where the Name.id is the name of the defined function
                variable_definition = [node.func]
                
            # for double function calls, e.g. func(a)(b)
            elif hasattr(node, 'func') and hasattr(node.func, 'func'):
                variable_definition = self._getDefinition(node.func.func)
            
            # for subscripts, e.g. a[:2].mean()
            elif hasattr(node, 'value'):
                variable_definition = self._getDefinition(node.value)
                
            
                
            # check if it is the oracle
            if variable_definition is not None:
                # append definitions 
                [definitions.append(definition) for definition in variable_definition]
            
            # if it is not the oracle
            else:
                return None
            
            # TODO Check arguments of the function?       
        
             
        return definitions
            
        
    def _getDifferentialOracleName(self, node):
        """Get the name of the function used in the differential test case."""
        
        oracle_attributes = []
        oracle_name = ""
        
        # if it is a function call
        if isinstance(node, ast.Call):
            
            # move into the function of the call node (ignoring the arguments of the function call)
            _node = node.func
            
            # if the function call has multiple attributes, i.e. np.linalg.solve()
            while isinstance(_node, ast.Attribute):
                # extract the name of the function
                oracle_attributes.append(_node.attr)
                
                # move another layer deeper into the node
                _node = _node.value
            
            # if this is the last attribute, i.e. we are at the deepest level of the call node
            if isinstance(_node, ast.Name):
                oracle_attributes.append(_node.id)
            
            # construct the oracle name by going through the reversed attributes array
            for attr in oracle_attributes[::-1]:
                oracle_name += attr + "."
            
            # remove the dot at the end
            oracle_name = oracle_name[:-1]
        
        
        # if it is a list comprehension, find func when node is e.g. [func(a) for a in b]
        if isinstance(node, ast.ListComp):
            # recursively call this function to extract the function name
            oracle_name = self._getDifferentialOracleName(node.elt)               
                
                
        # CURRENTLY UNSUPPORTED TYPES: 
        
        # TODO Check the variables of binary operations? e.g. arg1 = a/b
        if isinstance(node, ast.BinOp):
            oracle_name = "UNSUPPORTED Binary Operation"
            
        # TODO Can constants be derived from oracles?
        if isinstance(node, ast.Constant):
            oracle_name = "UNSUPPORTED Constant"
            
        if isinstance(node, ast.Name):
            oracle_name = "UNSUPPORTED Name (named variable or defined function: " + node.id + ")"
            
        
        # return the oracle function name
        return oracle_name
                
        
    def _isSelf_Evaluate(self, node):
        """Check if the node is a self.evaluate call. If so, the node is not an oracle."""
        
        if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute) and node.func.attr == 'evaluate':
            log.add("Argument is an evaluate function call!")
            return True
        return False
        
    def visit_FunctionDef(self, node):
        
        # TODO if it is a function definition on the class level, i.e. it contains self as an argument
        #print(ast.dump(node.args, indent='\t'))
        #if hasattr(node, 'args') and isinstance(node.args, ast.arguments) and hasattr(node.args, 'arg') and isinstance(node.args.args, list) and node.args.args[0].arg == 'self':
            #breakpoint()
        
        # check if the function definition contains self
        if 'self' in ast.dump(node.args, indent='\t'):
            # remember the  last class-level function definition seen
            self.last_function_definition_node = node
            
        # visit child nodes
        self.generic_visit(node)


# For debugging on a single file:

# Initialize custom logger
# Use the print mode LOG_FINAL for final output and LOG_ALL for debugging
log = CustomLogger(LOG_FINAL)

# start the tree traversal
TreeTraverser().visit(tree)

______________assertLess__(line 792)______________
______________assertLess__(line 793)______________
Extracted name arg2: UNSUPPORTED Binary Operation

______________assertLess__(line 1496)______________
Extracted name arg1: UNSUPPORTED Name (named variable or defined function: abs)

______________assertLess__(line 2346)______________
Extracted name arg1: max

Extracted name arg1: max

______________assertLess__(line 2352)______________
Extracted name arg1: max

Extracted name arg1: max

______________allclose__(line 2609)______________
Extracted name arg1: vec_to_tril_matrix

Extracted name arg1: UNSUPPORTED Binary Operation

______________assertLess__(line 2671)______________
Extracted name arg1: UNSUPPORTED Binary Operation

______________assertLess__(line 2897)______________
______________assertLess__(line 2925)______________
______________assertLess__(line 2955)______________
______________assertLess__(line 2985)______________
______________assertLess__(line 3015)______________
_

## Applying the function extraction to all TensorFlow testing files

Now that our function extraction via tree traversal works on single source files, we can now create a list of potential differential testing functions for all TensorFlow testing files. This will be in the form of a data table with the following entries:

| File | Assert line number | Found in Function | Assert Statement Type | Oracle argument position | Differential Function line number | Extracted Differential Test Function | 
|:----:|:-----------:|:-----------------:|:---------------------:|:-------------------:|:-------------------------------:|:----:|
|kernel_tests/distributions/student_t_test.py| 75 |testStudentPDFAndLogPDF|assertAllClose|1|70|stats.t.logpdf|

To do this we will first write code that will go through each subdirectory and file of `tests_root` and return the file paths of all `.py` files:

In [105]:
# fill test_files with the file paths to all python files in tests_root relative to library_root
test_files = []
for subdir, _, files in os.walk(library_root + tests_root):
    #print(subdir)
    for file in files:
        if file.endswith(".py"):
            relative_dir = subdir.replace(library_root, '')
            filepath = relative_dir + os.sep + file
            test_files.append(filepath) 
            #print("\t"+filepath)

In [106]:
# Initialize custom logger
log = CustomLogger(LOG_NONE)

error_list = []
# go through each file in 'tests_root' and extract data
for file in test_files:
    # set file path that appears in the data entries for this file 
    log.set_file_path_variable(file)
    log.add("<<<<<<<<<<<<" + file + ">>>>>>>>>>>>", LOG_FINAL)
    
    source = open(library_root + file)
    
    try:
        # generate abstract syntax tree and start the tree traversal as before
        tree = ast.parse(source.read())
        TreeTraverser().visit(tree)
    
    except:
        error_list.append(file)

# print collected data and save to csv
#print(log.get_data())
print("Remaining errors in files: " + str(len(error_list)) + " " + str(error_list))
log.save_data_to_csv(save_data_to)

Remaining errors in files: 18 ['test\\test_fx_experimental.py', 'test\\test_jit.py', 'test\\custom_backend\\test_custom_backend.py', 'test\\custom_operator\\test_custom_ops.py', 'test\\distributed\\test_c10d_spawn_gloo.py', 'test\\distributed\\test_data_parallel.py', 'test\\distributed\\optim\\test_zero_redundancy_optimizer.py', 'test\\distributed\\pipeline\\sync\\test_checkpoint.py', 'test\\distributed\\pipeline\\sync\\test_copy.py', 'test\\distributed\\pipeline\\sync\\test_deferred_batch_norm.py', 'test\\distributed\\pipeline\\sync\\test_transparency.py', 'test\\distributed\\pipeline\\sync\\skip\\test_portal.py', 'test\\distributions\\test_constraints.py', 'test\\distributions\\test_transforms.py', 'test\\distributions\\test_utils.py', 'test\\onnx\\test_caffe2_common.py', 'test\\onnx\\test_pytorch_onnx_onnxruntime.py', 'test\\package\\test_model.py']
Data saved to extracted_data/pytorch_data.csv
