In [10]:

from itertools import compress
import dataclasses
import re
import sys
import os
current_dir = os.path.dirname(os.path.abspath('./'))
if not current_dir in sys.path:
    sys.path.append(current_dir)
from cross_runs_TF_planes import CrossRunsTFScorer
import console as csl
import utils.data_management as dm
import inspect
import numpy as np
import numpy as nmp
from scipy import exp, tanh
import time
from utils.console.colored import ColoredText, clean_styles, bold
import utils.console.colored as ucc
from utils.data_management import dict2str
from numpy import ones, arange
import random as r
import itertools as it
from typing import *
import scipy as sp
from scipy.signal import bessel
from numpy import cos
from utils.structures import Linked
import ast, astunparse
from abc import ABC


class EntitiesContainer(object):
    
    @staticmethod
    def __check_mergeability(names: Union[str, list[str]], entities: Union[object, list[object]]) -> tuple[list[str], list[object]]:
        mergeable = True
        
        if isinstance(names, list) or isinstance(names, tuple):
            if not isinstance(entities, list) and not isinstance(entities, tuple) or len(names) != len(entities):
                mergeable = False
        else:
            if isinstance(entities, list):
                mergeable = False
            else:
                names = [names]
                entities = [entities]
        
        if not mergeable:
            raise ValueError('Impossible to merge given names and entities (maybe because of different types or lengths)')
        
        return names, entities
    
    def __init__(self, names: Union[str, list[str]], entities: Union[object, list[object]]):
        
        names, entities = self.__check_mergeability(names, entities)
        
        self._names = names
        self._entities = entities
    
    def __bool__(self):
        return bool(self.names) and bool(self.entities)
    
    def __getitem__(self, i: int) -> tuple[str, object]:
        return self.names[i], self.entities[i]
    
    def __contains__(self, item: Union[str, object]) -> bool:
        
        if isinstance(item, str):
            return item in self.names
        elif isinstance(item, object):
            return item in self.entities
        else:
            return False
    
    def __iter__(self):
        
        for name, entity in zip(self.names, self.entities):
            yield name, entity
    
    @property
    def names(self):
        return self._names
    @names.setter
    def names(self, value):
        raise AttributeError('Impossible to set new values for names not setting corresponding entities (use add method)')
    
    @property
    def entities(self):
        return self._entities
    @entities.setter
    def entities(self, value):
        raise AttributeError('Impossible to set new values for entities not setting corresponding names (use add method)')
    
    def add(self, names: Union[str, list[str]], entities: Union[object, list[object]]):
        names, entities = self.__check_mergeability(names, entities)
        self._names += names
        self._names = list(dict.fromkeys(self.names))
        self._entities += entities
        self._entities = list(dict.fromkeys(self.entities))
    
    def remove(self, item: Union[str, object]):
        
        if isinstance(item, str):
            try:
                index = self.names.index(item)
                self.names.pop(index)
                self.entities.pop(index)
            except ValueError:
                pass
        else:
            try:
                index = self.entities.index(item)
                self.names.pop(index)
                self.entities.pop(index)
            except ValueError:
                pass



class TypeHintRemover(ast.NodeTransformer):
    
    def visit_FunctionDef(self, node):
        node.returns = None
        
        if node.args.args:
            
            for arg in node.args.args:
                arg.annotation = None
        
        if node.args.kwonlyargs:
            
            for arg in node.args.kwonlyargs:
                arg.annotation = None
        
        return node


class CodeInspectorHelper(ABC):
    
    @staticmethod
    def __protect_source_module(source_module: object, member: object) -> object:
        
        if source_module is None:
            try:
                return sys.modules[re.findall( r'\'.*\'',str(type(member)))[0][1:-1].split('.')[0]]
            except KeyError:
                raise ModuleNotFoundError(f'Can not determine module of {member.__name__}')
        elif '._' in source_module.__name__:
            try:
                return sys.modules[source_module.__name__.split('._')[0]]
            except KeyError:
                raise ModuleNotFoundError(f'Can not determine correct name of {source_module.__name__}')
        else:
            return source_module
    
    @staticmethod
    def check_manual_imports(imports):
            
            if imports is None:
                return []
            elif isinstance(imports, str):
                return [imports]
            elif isinstance(imports, tuple):
                return list(imports)
            elif isinstance(imports, list):
                return imports
    
    
    @staticmethod
    def clean_locals_related(code: str, imports: list[str]) -> tuple[str, list[str]]:
        
        names_for_locals = list()
        new_imports = imports.copy()
        
        for import_ in imports:
            
            if 'import locals as ' in import_:
                names_for_locals.append(import_.replace('import locals as ', ''))
                new_imports.remove(import_)
            elif 'locals' in import_ or '__main__' in import_:
                new_imports.remove(import_)
        
        for locals_name in names_for_locals:
            code = code.replace(f'{locals_name}.', '')
        
        return code, new_imports
    
    
    def get_module_type(self, member: object) -> str:
        
        module = self.__protect_source_module(inspect.getmodule(member), member)
        
        if module.__spec__ is None:
            return '__main__'
        elif module.__spec__.origin == 'built-in':
            return 'built-in'
        elif 'site-packages' in module.__spec__.origin:
            return 'pip'
        elif '/lib/' in module.__spec__.origin or '\\lib\\' in module.__spec__.origin:
            return 'python'
        else:
            return 'custom'
    
    
    def import_member(self, member_name, member: object, source_module: object, imports: list[str], packages: list[str], locals: list[str], entities: list[object]) -> NoReturn:
        
        source_module_name = source_module.__name__ if self.get_module_type(source_module) not in ['__main__', 'custom'] else 'locals'
        raw_module_name = source_module_name.split('.')[0]
        
        if source_module_name == 'locals' and member_name not in locals:
            locals.append(member_name)
            entities.append(member)
        elif source_module_name != 'locals' and raw_module_name not in packages:
            packages.append(raw_module_name)
        
        same_source = [i for i, import_ in enumerate(imports) if f'from {source_module_name}' in import_]
        
        if same_source:
            imports[same_source[0]] += f', {member_name}'
        else:
            imports.append(f'from {source_module_name} import {member_name}')
    
    
    def import_module(self, member: object, member_name: str, imports: list[str], packages: list[str]) -> NoReturn:
        
        module_type = self.get_module_type(member)
        module_name = member.__name__ if module_type not in ['custom', '__main__'] else 'locals'
        raw_module_name = module_name.split('.')[0]
        
        if module_name != 'locals' and module_type == 'pip' and raw_module_name not in packages:
            packages.append(raw_module_name)
        
        if module_name == member_name:
            imports.append(f'import {module_name}')
        else:
            imports.append(f'import {module_name} as {member_name}')
    
    
    @staticmethod
    def concatenate_code(*args: str, sep: Optional[str] = '\n') -> str:
        code = ''
        
        for arg in args:
            
            if arg not in code:
                code = f'{code}{sep}{arg}'
        
        return code
    
    
    @staticmethod
    def concatenate_members(*args: list[Any]) -> list[Any]:
        args = tuple(list(args).copy())
        args = tuple(map(
            lambda item: item if isinstance(item, list) else [item],
            args
        ))
        
        def unique(list_: list[Any]) -> list:
            
            uniq = list()
            
            for x in list_:
                if x not in uniq:
                    uniq.append(x)
            
            return uniq
        
        all_members = list()
        
        for arg in args:
            all_members += arg
        
        return unique(all_members)
    
    
    def concatenate_imports(self, *args: Union[list[str], str]) -> list[str]:
        args = tuple(list(args).copy())
        
        for i, arg in enumerate(args):
            if isinstance(arg, str):
                args[i] = [arg]
        
        raw_imports = self.concatenate_members(*[
            list(filter(lambda item: 'from' not in item, arg))
            for arg in args
        ])
        
        complex_imports = [
            list(filter(lambda item: 'from' in item, arg))
            for arg in args
        ]
        
        complex_imports = self.concatenate_members(*[
            list(map(lambda item: list(item[5:].split(' import ')), arg))
            for arg in complex_imports
        ])
        
        complex_imports = list(map(
            lambda item: [item[0], item[1].split(', ')],
            complex_imports
        ))
        
        for i, ci1 in enumerate(complex_imports):
            package1, import1 = ci1
            
            if import1 is None:
                continue
            
            for j, ci2 in enumerate(complex_imports):
                package2, import2 = ci2
                
                if package1 == package2 and i != j:
                    complex_imports[i][1] = self.concatenate_members(import1, import2)
                    complex_imports[j][1] = None
        
        complex_imports = list(
            filter(lambda item: bool(item[1]), complex_imports)
            )
        
        complex_imports = list(
            map(lambda item: f'from {item[0]} import {", ".join(item[1])}', complex_imports)
            )
        
        return self.concatenate_members(raw_imports, complex_imports)
    
    @staticmethod
    def concatenate_code_and_imports(
        source_code: str,
        imports: list[str],
        line_break: Optional[str] = '\n',
        postimports_stride: Optional[int] = 1,
        add_to_end: Optional[str] = ''
        ) -> str:
        out = f'{line_break.join(imports)}{line_break*postimports_stride}{source_code}'
        
        if add_to_end:
            out = f'{out}{line_break*postimports_stride}{add_to_end}{line_break*postimports_stride}'
        
        return out
    
    def concatenate(
        self,
        *,
        code: Optional[list[str]] = None,
        imports: Optional[list[list[str]]] = None,
        packages: Optional[list[list[str]]] = None,
        locals: Optional[list[list[str]]] = None,
        entities: Optional[list[list[object]]] = None
        ) -> tuple[Union[list[str], list[object]], ...]:
        out = list()
        
        if code is not None:
            out.append(self.concatenate_code(*code, sep=''))
        
        if imports is not None:
            out.append(self.concatenate_imports(*imports))
        
        if packages is not None:
            out.append(self.concatenate_members(*packages))
        
        if locals is not None:
            out.append(self.concatenate_members(*locals))
        
        if entities is not None:
            out.append(self.concatenate_members(*entities))
        
        return tuple(out)


class Source(object):
    def __init__(self, code: str, imports: list[str]):
        self._code = code
        self._imports = imports
    
    @property
    def code(self):
        return self._code
    @code.setter
    def code(self, value):
        raise AttributeError('Impossible to set new code')
    
    @property
    def imports(self):
        return self._imports
    @imports.setter
    def imports(self, new_imports: Union[str, list[str]]):
        if isinstance(new_imports, str):
            new_imports = [new_imports]
        self._imports += new_imports


class CodeInspector(CodeInspectorHelper):
    
    def __init__(
            self,
            func: Callable,
            known_packages: Optional[list[str]] = None,
            known_locals: Optional[list[str]] = None,
            *,
            import_to_runcode: Optional[Union[str, list[str]]] = None,
            import_to_locals: Optional[Union[str, list[str]]]  = None
        ):
        
        self.known_packages = known_packages if known_packages else []
        self._unknown_packages = []
        initially_known_locals = known_locals if known_locals else []
        self.__initially_known_locals = initially_known_locals.copy()
        
        self._import_to_runcode = self.check_manual_imports(import_to_runcode)
        self._import_to_locals = self.check_manual_imports(import_to_locals)
        
        source_code, imports, packages, locals, entities = self.inspect_function(func)
        
        imports = self.concatenate_imports(imports, self._import_to_runcode)
        
        source_code = astunparse.unparse(TypeHintRemover().visit(ast.parse(source_code)))
        
        self._runcode = Source(source_code, imports)
        self._helpcode = None
        
        self._unknown_entities = None
        
        self.unknown_packages = packages
        
        unknown_locals_map = [local not in self.__initially_known_locals for local in locals]
        unknown_locals, unknown_entities, _ = list(compress(locals, unknown_locals_map)),\
            list(compress(entities, unknown_locals_map)),\
            list(compress(entities, np.logical_not(unknown_locals_map)))
        
        
        self._known_entities = EntitiesContainer(self.__initially_known_locals, ['Server-Side Object' for _ in self.__initially_known_locals])
        
        if len(unknown_locals) != len(unknown_entities):
            
            raise ValueError(
                    f'Found local entities and their names are inconsistent:\n'
                    f'{len(unknown_locals)} names: {unknown_locals}\n'
                    f'{len(unknown_entities)} entities: {unknown_entities}'
                )
        
        if unknown_locals and unknown_entities:
            self.unknown_entities = unknown_locals, unknown_entities
    
    @property
    def unknown_entities(self):
        return self._unknown_entities
    @unknown_entities.setter
    def unknown_entities(self, entities_info: Union[tuple[str, object], tuple[list[str], list[object]]]):
        
        if self._unknown_entities is None:
            self._unknown_entities = EntitiesContainer(*entities_info)
        else:
            self._unknown_entities.add(*entities_info)
    
    @property
    def unknown_packages(self):
        return self._unknown_packages
    @unknown_packages.setter
    def unknown_packages(self, packages: list[str]):
        already_known = self.known_packages + self._unknown_packages
        self._unknown_packages += list(
                            filter(
                                lambda item: item not in already_known,
                                packages
                            )
                        )
    
    @property
    def runcode(self):
        return self._runcode
    @runcode.setter
    def runcode(self, value):
        raise AttributeError('Impossible to set a code to run directly')
    
    @property
    def helpcode(self):
        return self._helpcode
    @helpcode.setter
    def helpcode(self, value):
        raise AttributeError('Impossible to set a helping code directly')
    
    @property
    def known_entities(self):
        return self._known_entities
    @known_entities.setter
    def known_entities(self, entities_info: Union[tuple[str, object], tuple[list[str], list[object]]]):
        
        if self._known_entities is None:
            self._known_entities = EntitiesContainer(*entities_info)
        else:
            self._known_entities.add(*entities_info)
    
    def inspect_function(self, func: Callable) -> tuple[str, list[str], list[str], list[str]]:
            
        imports, packages, locals, entities = list(), list(), list(), list()
        
        for member_name, member in inspect.getclosurevars(func).globals.items():
            source_module = self._CodeInspectorHelper__protect_source_module(inspect.getmodule(member), member)
            
            if inspect.ismodule(member):
                self.import_module(member, member_name, imports, packages)
            else:
                self.import_member(member_name, member, source_module, imports, packages, locals, entities)

        try:
            source_code = inspect.getsource(func)
        except OSError:
            source_code = f'# code not found: {func}'
        
        return source_code, imports, packages, locals, entities
    
    def check_entities(self, locals: list[str], entities: list[object]) -> tuple[list[str], list[object]]:
        check_locals = locals.copy()
        check_entities = entities.copy()
        
        for local, entity in zip(locals, entities):
            
            if local in self.known_entities:
                check_locals.remove(local)
                check_entities.remove(entity)
        
        if len(check_locals) != len(check_entities):
            raise ValueError(f'Entities and its names are inconsistent:\n{check_locals}\n{check_entities}') 
        
        return check_locals, check_entities
    
    def inspect_unknown_entities(self):
        
        def concatenate_found_members(member: object, scode: str, imports: list[str], packages: list[str], locals: list[str], entities: list[object]):
            scode_, imports_, packages_, locals_, entities_ = self.inspect_function(member)
            
            return self.concatenate(
                        code=[scode, scode_],
                        imports=[imports, imports_],
                        packages=[packages, packages_],
                        locals=[locals, locals_],
                        entities=[entities, entities_]
                    )
        
        unknown_entities, unknown_entities_names = self.unknown_entities.entities.copy(), self.unknown_entities.names.copy()
        scode, imports, packages, locals, entities = '', list(), list(), list(), list()
        
        for entity_name, entity in zip(unknown_entities_names, unknown_entities):
            self.unknown_entities.remove(entity_name)
            self.known_entities.add(entity_name, entity)
            
            if inspect.isclass(entity):
                scode_ = inspect.getsource(entity)
                
                for mem in inspect.getmro(entity):
                    
                    if self.get_module_type(mem) in ['custom', '__main__']:
                        entities = self.concatenate_members(entities, mem)
                        locals = self.concatenate_members(locals, mem.__name__)
                
                if dataclasses.is_dataclass(entity):
                    imports = self.concatenate_imports(imports, ['from dataclasses import dataclass', 'from typing import Any'])
                    cs = ast.parse(scode_)
                    
                    for mem in cs.body[0].body:
                        
                        if isinstance(mem, ast.AnnAssign):
                            mem.annotation = ast.Name(
                                            id='Any',
                                            lineno=mem.annotation.lineno,
                                            col_offset=mem.annotation.col_offset,
                                            end_lineno=mem.annotation.end_lineno
                                        )
                    scode_ = astunparse.unparse(cs)
                
                if inspect.isabstract(entity):
                    imports = self.concatenate_imports(imports, ['from abc import ABC, ABCMeta'])
                
                scode = self.concatenate_code(scode_, scode, sep='')
                
                for _, member in inspect.getmembers(entity):
                
                    if inspect.isfunction(member):
                        _, imports, packages, locals, entities = concatenate_found_members(member, '', imports, packages, locals, entities)
                    
                    elif inspect.isdatadescriptor(member):
                        
                        for _, property in inspect.getmembers(member):
                            
                            if inspect.isfunction(property):
                                _, imports, packages, locals, entities = concatenate_found_members(property, '', imports, packages, locals, entities)
            
            elif inspect.isfunction(entity):
                scode, imports, packages, locals, entities = concatenate_found_members(entity, scode, imports, packages, locals, entities)
            
        imports = self.concatenate_imports(imports, self._import_to_locals)
        scode = astunparse.unparse(TypeHintRemover().visit(ast.parse(scode)))
        
        locals, entities = self.check_entities(locals, entities)
        if locals and entities:
            self._unknown_entities = EntitiesContainer(locals, entities)
        if self.unknown_entities:
            new_scode, new_imports, new_packages, new_locals, new_entities = self.inspect_unknown_entities()
            scode, imports, packages, locals, entities = self.concatenate(
                code=[scode, new_scode],
                imports=[imports, new_imports],
                packages = [packages, new_packages],
                locals=[locals, new_locals],
                entities=[entities, new_entities]
            )
        
        scode, imports = self.clean_locals_related(scode, imports)
        
        return scode, imports, packages, locals, entities
    
    def prepare_code(self):
        scode, imports, packages, _, _ = self.inspect_unknown_entities()
        if self.unknown_entities:
            raise ValueError(f'Some entities are still unknown: {self.unknown_entities.names}')
        
        self._helpcode = Source(scode, imports)
        self.unknown_packages = packages
        

In [23]:
from cross_runs_TF_planes import CrossRunsTFScorer
from typing import _GenericAlias
import typing
import ast
import astunparse


def bold(msg, **kwargs):
    print(ucc.ColoredText().color().style('b')(msg), **kwargs)

def func(a: np.ndarray) -> CrossRunsTFScorer:
    re.findall(r's', 'asdsasds')
    print(1)
    Linked('name')
    e = exp(2)
    tanh(e)
    bessel(10)
    cos(2)
    r.randint(1, 2)
    val = sp.sin(2)
    np.ones(100)
    dm.dict2str(dict(a=1, b=2))
    cs = CrossRunsTFScorer()
    text = f'text {val}'
    clean_styles(text)
    bold(text)

cinspector = CodeInspector(func, known_locals=[], known_packages=['typing'])


cinspector.prepare_code()


print(cinspector.concatenate_code_and_imports(cinspector.helpcode.code, cinspector.helpcode.imports))
# print(cinspector.known_entities.names, new_locals)

import _thread
import numpy as np
import re
from typing import Callable, Iterable, Any
from dataclasses import dataclass


@dataclass
class CrossRunsTFScorer():
    tf_scores: Any
    accuracy_cache: Any
    csp: Any

    def mean(self):
        return np.mean(self.tf_scores, axis=0)

    def std(self):
        return np.std(self.tf_scores, axis=0)

    def tf_windows_mean(self):
        return {freq: {time: np.mean(np.array(self.accuracy_cache[freq][time])) for time in self.accuracy_cache[freq]} for freq in self.accuracy_cache}

    def tf_windows_std(self):
        return {freq: {time: np.std(np.array(self.accuracy_cache[freq][time])) for time in self.accuracy_cache[freq]} for freq in self.accuracy_cache}

class Linked(object):

    def __init__(self, name='Linked', parent=None, meta=None, options=None, child_options_generator=None, children=None):
        self.name = name
        self.__siblings = None
        self.parent = parent
        if ((not isinstance(children, list)) and iss

In [37]:
arg1 = 1
arg2 = 2
arg3 = 3

os.system(f'python -m /path/to/file.py {arg1} {arg2} --somekeywordarg {arg3}')

[2, 3]
