In [None]:
#|default_exp core

# API

In [None]:
#| export
from fastcore.utils import *
from fastcore.meta import delegates
from ast_grep_py import SgRoot

import ast

In [None]:
from fastcore.test import *

In [None]:
#| export
def get_docstring(node, lines):
    "Get docstring from source lines if present"
    if not (node.body and isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Constant)): return None
    doc_node = node.body[0]
    return '\n'.join(lines[doc_node.lineno-1:doc_node.end_lineno])

def _node_sig(node, lines):
    body_start = max(node.body[0].lineno - 1, node.lineno)
    sig = '\n'.join(lines[node.lineno-1:body_start])
    doc = get_docstring(node, lines)
    return (f"{sig}\n{doc}" if doc else sig).strip('\r\n') + ' ...'

def py_sigs(src):
    "Extract class/function/method signatures from Python source"
    tree,lines = ast.parse(src),src.splitlines()
    def _collect(nodes):
        sigs = []
        for n in nodes:
            if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
                sigs.append(_node_sig(n, lines))
                sigs.extend(_collect(n.body))
        return sigs
    return _collect(tree.body)

In [None]:
#| export
def _get_sigs(src, lang, kinds, name_kind, params_kind, fmt):
    root = SgRoot(src, lang).root()
    sigs = []
    for kind in kinds:
        for n in root.find_all(kind=kind):
            name,params = n.find(kind=name_kind),n.find(kind=params_kind)
            if name and params: sigs.append(fmt(n, name.text(), params.text()))
    return sigs

In [None]:
test_py = """
def greet(name, age=10):
    "Say hello"
    return f"Hello {name}"

class Foo:
    def __init__(self, x): self.x = x
    def bar(self, y, z): return y + z
"""
py_sigs(test_py)

['def greet(name, age=10):\n    "Say hello" ...',
 'class Foo: ...',
 '    def __init__(self, x): self.x = x ...',
 '    def bar(self, y, z): return y + z ...']

In [None]:
#| export
def js_sigs(src, lang="javascript"):
    "Extract function signatures from JS/TS source"
    sigs = _get_sigs(src, lang, ["function_declaration"], "identifier", "formal_parameters",
                     lambda n,nm,ps: f"function {nm}{ps} {{...}}")
    sigs += _get_sigs(src, lang, ["method_definition"], "property_identifier", "formal_parameters",
                      lambda n,nm,ps: f"{nm}{ps} {{...}}")
    return sigs

In [None]:
test_js = """
function greet(name, age) { return `Hello ${name}`; }
const add = (a, b) => a + b;
class Foo {
    constructor(x) { this.x = x; }
    bar(y, z) { return y + z; }
}
"""
js_sigs(test_js)

['function greet(name, age) {...}', 'constructor(x) {...}', 'bar(y, z) {...}']

In [None]:
#| export
def java_sigs(src):
    "Extract method signatures from Java source"
    def fmt(n, nm, ps):
        typ = n.find(kind="type_identifier") or n.find(kind="void_type") or n.find(kind="integral_type")
        return f"{typ.text() if typ else 'void'} {nm}{ps};"
    return _get_sigs(src, "java", ["method_declaration"], "identifier", "formal_parameters", fmt)

In [None]:
test_java = """
public class Calculator {
    public int add(int a, int b) { return a + b; }
    private void reset() { this.value = 0; }
    public String format(String template, Object... args) { return String.format(template, args); }
}
"""
java_sigs(test_java)

['int add(int a, int b);',
 'void reset();',
 'String format(String template, Object... args);']

In [None]:
#| export
def rust_sigs(src):
    "Extract function signatures from Rust source"
    return _get_sigs(src, "rust", ["function_item"], "identifier", "parameters",
                     lambda n,nm,ps: f"fn {nm}{ps} {{...}}")

In [None]:
test_rust = """
fn greet(name: &str) -> String { format!("Hello {}", name) }
fn add(a: i32, b: i32) -> i32 { a + b }
pub fn process(items: Vec<Item>, filter: impl Fn(&Item) -> bool) -> Vec<Item> { items.into_iter().filter(filter).collect() }
"""
rust_sigs(test_rust)

['fn greet(name: &str) {...}',
 'fn add(a: i32, b: i32) {...}',
 'fn process(items: Vec<Item>, filter: impl Fn(&Item) -> bool) {...}']

In [None]:
#| export
def csharp_sigs(src):
    "Extract method signatures from C# source"
    def fmt(n, nm, ps):
        typ = n.find(kind="predefined_type") or n.find(kind="identifier")
        return f"{typ.text() if typ else 'void'} {nm}{ps};"
    return _get_sigs(src, "csharp", ["method_declaration"], "identifier", "parameter_list", fmt)

In [None]:
test_csharp = """
public class Service {
    public string GetName(int id) { return "test"; }
    private void Initialize() { }
    public async Task<List<Item>> FetchItems(string query, int limit) { return new List<Item>(); }
}
"""
csharp_sigs(test_csharp)

['string GetName(int id);',
 'void Initialize();',
 'string Task(string query, int limit);']

In [None]:
#| export
def css_selectors(src):
    "Extract CSS selectors from source"
    root = SgRoot(src, "css").root()
    return [f"{n.text()} {{...}}" for n in root.find_all(kind="selectors")]

In [None]:
test_css = """
.container { margin: 0; }
#header, .nav { display: flex; }
body > main p { color: red; }
"""
css_selectors(test_css)

['.container {...}', '#header, .nav {...}', 'body > main p {...}']

In [None]:
#| export
def go_sigs(src):
    "Extract function signatures from Go source"
    root = SgRoot(src, "go").root()
    sigs = []
    for n in root.find_all(kind="function_declaration"):
        name,params = n.find(kind="identifier"),n.find(kind="parameter_list")
        if name and params: sigs.append(f"func {name.text()}{params.text()} {{...}}")
    for n in root.find_all(kind="method_declaration"):
        recv,name = n.find(kind="parameter_list"),n.find(kind="field_identifier")
        params = n.find_all(kind="parameter_list")
        if name and len(params) > 1: sigs.append(f"func {recv.text()} {name.text()}{params[1].text()} {{...}}")
    return sigs

In [None]:
test_go = """
func greet(name string) string { return "Hello " + name }
func add(a, b int) int { return a + b }
func (s *Server) Start(port int) error { return nil }
func (c Client) Get(url string, timeout time.Duration) (*Response, error) { return nil, nil }
"""
go_sigs(test_go)

['func greet(name string) {...}',
 'func add(a, b int) {...}',
 'func (s *Server) Start(port int) {...}',
 'func (c Client) Get(url string, timeout time.Duration) {...}']

In [None]:
#| export
def kotlin_sigs(src):
    "Extract function signatures from Kotlin source"
    return _get_sigs(src, "kotlin", ["function_declaration"], "simple_identifier", "function_value_parameters",
                     lambda n,nm,ps: f"fun {nm}{ps} {{...}}")

In [None]:
test_kotlin = """
fun greet(name: String, age: Int = 10): String { return "Hello $name" }

class Foo(val x: Int) {
    fun bar(y: Int, z: Int): Int { return y + z }
}
"""
kotlin_sigs(test_kotlin)

['fun greet(name: String, age: Int = 10) {...}',
 'fun bar(y: Int, z: Int) {...}']

In [None]:
#| export
def swift_sigs(src):
    "Extract function signatures from Swift source"
    return _get_sigs(src, "swift", ["function_declaration"], "simple_identifier", "parameter",
                     lambda n,nm,ps: f"func {nm}({ps}) {{...}}")

In [None]:
test_swift = """
func greet(name: String, age: Int = 10) -> String { return "Hello \\(name)" }

class Foo {
    var x: Int
    init(x: Int) { self.x = x }
    func bar(y: Int, z: Int) -> Int { return y + z }
}
"""
swift_sigs(test_swift)

['func greet(name: String) {...}', 'func bar(y: Int) {...}']

In [None]:
#| export
def lua_sigs(src):
    "Extract function signatures from Lua source"
    return _get_sigs(src, "lua", ["function_declaration"], "identifier", "parameters",
                     lambda n,nm,ps: f"function {nm}{ps} ... end")

In [None]:
test_lua = """
function greet(name, age)
    return "Hello " .. name
end

function add(a, b) return a + b end

local function helper(x) return x * 2 end
"""
lua_sigs(test_lua)

['function greet(name, age) ... end',
 'function add(a, b) ... end',
 'function helper(x) ... end']

In [None]:
#| export
def php_sigs(src):
    "Extract function signatures from PHP source"
    return _get_sigs(src, "php", ["function_definition", "method_declaration"], "name", "formal_parameters",
                     lambda n,nm,ps: f"function {nm}{ps} {{...}}")

In [None]:
test_php = """<?php
function greet($name, $age = 10) { return "Hello $name"; }

class Foo {
    public function __construct($x) { $this->x = $x; }
    public function bar($y, $z) { return $y + $z; }
}
"""
php_sigs(test_php)

['function greet($name, $age = 10) {...}',
 'function __construct($x) {...}',
 'function bar($y, $z) {...}']

In [None]:
#| export
def ruby_sigs(src):
    "Extract method signatures from Ruby source"
    return _get_sigs(src, "ruby", ["method"], "identifier", "method_parameters",
                     lambda n,nm,ps: f"def {nm}{ps} ... end")

In [None]:
test_ruby = """
def greet(name, age = 10)
  "Hello #{name}"
end

class Foo
  def initialize(x)
    @x = x
  end
  def bar(y, z) = y + z
end
"""
ruby_sigs(test_ruby)

['def greet(name, age = 10) ... end',
 'def initialize(x) ... end',
 'def bar(y, z) ... end']

In [None]:
#| export
_sigs_fns = {'.py': py_sigs, '.js': js_sigs, '.ts': lambda s: js_sigs(s, "typescript"), '.jsx': js_sigs, 
             '.tsx': lambda s: js_sigs(s, "typescript"), '.java': java_sigs, '.rs': rust_sigs, 
             '.cs': csharp_sigs, '.css': css_selectors, '.go': go_sigs, '.rb': ruby_sigs,
             '.php': php_sigs, '.kt': kotlin_sigs, '.kts': kotlin_sigs, '.swift': swift_sigs, '.lua': lua_sigs}

In [None]:
#| export
def ext_sigs(src, ext):
    "Read retrieve signatures for `src` based on suitable langage for `ext`"
    if not ext.startswith('.'): ext = '.'+ext
    fn = _sigs_fns.get(ext)
    return [] if fn is None else fn(src)

In [None]:
ext_sigs(test_ruby, 'rb')

['def greet(name, age = 10) ... end',
 'def initialize(x) ... end',
 'def bar(y, z) ... end']

In [None]:
#| export
def file_sigs(fname):
    "Read file content and retrieve signatures"
    fname = Path(fname).expanduser()
    try: s = fname.read_text()
    except UnicodeDecodeError: return []
    return ext_sigs(s, fname.suffix)

In [None]:
for o in file_sigs('../codesigs/core.py'): print(o)

def get_docstring(node, lines):
    "Get docstring from source lines if present" ...
def _node_sig(node, lines): ...
def py_sigs(src):
    "Extract class/function/method signatures from Python source" ...
    def _collect(nodes): ...
def _get_sigs(src, lang, kinds, name_kind, params_kind, fmt): ...
def js_sigs(src, lang="javascript"):
    "Extract function signatures from JS/TS source" ...
def java_sigs(src):
    "Extract method signatures from Java source" ...
    def fmt(n, nm, ps): ...
def rust_sigs(src):
    "Extract function signatures from Rust source" ...
def csharp_sigs(src):
    "Extract method signatures from C# source" ...
    def fmt(n, nm, ps): ...
def css_selectors(src):
    "Extract CSS selectors from source" ...
def go_sigs(src):
    "Extract function signatures from Go source" ...
def kotlin_sigs(src):
    "Extract function signatures from Kotlin source" ...
def swift_sigs(src):
    "Extract function signatures from Swift source" ...
def lua_sigs(src):
    "Extract fun

## Export -

In [None]:
#|hide
#|eval: false
from nbdev.doclinks import nbdev_export
nbdev_export()