# 2 - AST Experiments

The aim here is to understand if AST is useful for finding e.g. which type of block we're in (function/class/params...).

In [1]:
!pip install git+https://github.com/neelnanda-io/Easy-Transformer ipython

# Clear output
from IPython.display import clear_output
clear_output()

In [2]:
import ast
from transformers import AutoTokenizer, GPT2Tokenizer
from typing import List

# Get the "inside-function" code blocks

In [17]:
simple_source = '''
def hello():
    a = 10
    return "hi"
'''

def split_source_by_tokens(source: str, tokenizer = GPT2Tokenizer.from_pretrained("gpt2")) -> List[str]:
    encoded = tokenizer.encode(source)
    decoded = []
    for token in encoded:
        decoded_token = tokenizer.decode(token)
        decoded.append(decoded_token)
    return decoded


def token_index_from_line_col(source_tokens: List[str], line: int, col: int) -> int:
    # Find indices of all lines
    line_indices = [0] # Start with first line index
    for idx, token in enumerate(source_tokens):
        if '\n' in token:
            # Add every token index after a newline
            line_indices.append(idx + 1)
    
    line_tokens = source_tokens[line_indices[line - 1]:line_indices[line]-2]
    character_increment = 0
    for idx, token in enumerate(source_tokens):
        for character in token:
            if col == character_increment:
                return line_indices[line] + idx
            else:    
                character_increment += 1


class FuncLister(ast.NodeVisitor):
    function_locations: List[List[int]] = []
    
    def visit(self, *args) -> List[List[int]]:
        super().visit(*args)
        return self.function_locations
    
    def visit_FunctionDef(self, node):
        start_line = node.lineno
        start_offset = node.col_offset
        end_line = node.end_lineno
        end_offset = node.end_col_offset
        
        self.function_locations = [[start_line, start_offset, end_line, end_offset]]
        
        # inside_function = node.body[0]
        
        # Generic visitor
        self.generic_visit(node)

def get_function_token_indices(source: str) -> List[str]:
    source_tokens = split_source_by_tokens(source)
    
    tree = ast.parse(source)
    function_locations = FuncLister().visit(tree)
    
    function_token_indices = []
    
    for location in function_locations:
        [start_line, start_offset, end_line, end_offset] = location
        print(end_line, end_offset)
        start = token_index_from_line_col(source_tokens, start_line, start_offset)
        end = token_index_from_line_col(source_tokens, end_line, end_offset)
        
        for token_index in range(start, end):
            function_token_indices.append(token_index)
    
    return function_token_indices

# function_tokens = get_function_token_indices(simple_source)

source_tokens = split_source_by_tokens(simple_source)
# source_tokens

token_index_from_line_col(source_tokens, 4, 15)


26