In [1]:
import torch
from RPNbuilder import *
from OP import *
import re
from RPN import *
from OrganAbstractClass import *

In [2]:
producer=RPN_Producer()
producer.run()
rpns=producer.tree

In [5]:
parser=RPN_Parser(rpns[0])
parser.get_tree_structure()
parser.parse_tree()

In [10]:
class RPN_Compiler:
    def __init__(self, year_list, device=torch.device("cuda")):

        self.year_list = year_list
        self.general_pset = general_pset.pset
        self.device = device
        self.__init_data(self.year_list)

    def __init_data(self, year_list):
        self.data_reader = MmapReader()
        self.day_list = self.data_reader.get_daylist(year_list)
        self.D_O, self.D_H, self.D_L, self.D_C, self.D_V = list(self.data_reader.get_Day_data(year_list))
        self.industry = [self.data_reader.get_Barra(year_list)[:, :, 10:].to(self.device)]

    def extract_op(self, expression):
        op_list = []
        # 初始化一个列表，用于存储左括号、右括号和逗号的位置
        positions = []
        # 初始化括号计数器
        bracket_count = 0
        # 遍历表达式字符串
        for i, char in enumerate(expression):
            if char == '(':
                # 如果是左括号，计数器加1
                bracket_count += 1
                if bracket_count == 1:
                    # 记录第一个左括号的位置
                    positions.append(i)
            elif char == ')':
                # 如果是右括号，计数器减1
                bracket_count -= 1
                if bracket_count == 0:
                    # 如果括号计数器为0，说明找到了匹配的右括号，记录位置
                    positions.append(i)
            elif char == ',' and bracket_count == 1:
                # 如果是逗号且括号计数器为1，记录逗号位置
                positions.append(i)

        # 如果没有记录任何位置，说明这段表达式中没有算子
        if not positions:
            return op_list

        # 提取第一个左括号之前的所有字符，即算子名称
        op_name = expression[:positions[0]]
        # 将算子名称添加到结果列表中
        op_list.append(op_name)

        # 如果有多个位置记录，说明有嵌套的算子
        if len(positions) > 1:
            # 遍历所有记录的位置，提取子表达式并递归调用
            for start, end in zip(positions[:-1], positions[1:]):
                # 提取子表达式
                sub_expression = expression[start + 1:end]
                # 递归调用函数，传入空列表
                sub_list = self.extract_op(sub_expression)
                # 将递归调用的结果扩展到主列表中
                op_list.extend(sub_list)
        return op_list

    def add_op_class(self, op):
        interface = op_info[op.strip()]['classification']['interface']
        return f"OP_{interface['属'][:-1]}2{interface['目'][:-1]}.{op}"

    def replace_primities(self, rpn):
        used_op = self.extract_op(rpn)
        used_op = [i.strip() for i in used_op]
        used_op = list(dict.fromkeys(used_op))
        for op in used_op:
            rpn = rpn.replace(op, self.add_op_class(op))
        return rpn

    def replace_D_tensor(self, rpn):
        count = 0
        pattern = r"D_tensor"

        def replacer(match):
            nonlocal count  # 使用 nonlocal 关键字访问外部的 count 变量
            current_count = count  # 保存当前计数
            count += 1  # 计数器递增
            return f"D_tensor{current_count}"  # 返回替换后的字符串 D_i

        result = re.sub(pattern, replacer, rpn)
        return result

    def compile_module1(self, rpn, D_tensor: [torch.Tensor]):
        rpn = self.replace_D_tensor(rpn)
        rpn = self.replace_primities(rpn)

        for i in range(len(D_tensor)):
            locals()[f'D_tensor{i}'] = D_tensor[i].to(self.device)

        return eval(rpn)

    def compile_module2(self, rpn, D_tensor: [torch.Tensor]):

        rpn = self.replace_D_tensor(rpn)
        rpn = self.replace_primities(rpn)

        for i in range(len(D_tensor)):
            locals()[f'D_tensor_all{i}'] = D_tensor[i]

        template = torch.full((len(self.day_list), len(self.data_reader.DailyDataReader.StockCodes)), float('nan'))
        for i, day in tqdm(enumerate(self.day_list)):
            M_O, M_H, M_L, M_C, M_V = self.data_reader.get_Minute_data_daily(day)
            M_O, M_H, M_L, M_C, M_V = [i.to(self.device) for i in [M_O, M_H, M_L, M_C, M_V]]
            for j in range(len(D_tensor)):
                locals()[f'D_tensor{j}'] = locals()[f'D_tensor_all{j}'][i].to(self.device)
            template[i] = eval(rpn)

        return template

    def adjust_memorizer(self, deap_primitive, string_memorizer):
        expr = f"{deap_primitive.name}({', '.join(string_memorizer[:deap_primitive.arity])})"
        string_memorizer = string_memorizer[deap_primitive.arity:]
        string_memorizer.insert(0, expr)
        return string_memorizer

    def compile(self, rpn):
        name = general_pset.input1 + general_pset.input2 + general_pset.input3 + general_pset.input4 + general_pset.input5
        deap_code = gp.PrimitiveTree.from_string(rpn, self.general_pset)
        deap_code.reverse()
        D_tensor_memorizer = []
        string_memorizer = []
        for code in deap_code:
            if isinstance(code, gp.Terminal):

                if code.name.startswith('ARG') and int(code.name[3:]) >= 5:
                    D_tensor_memorizer.insert(0, getattr(self, name[int(code.name[3:])]))
                    string_memorizer.insert(0, 'D_tensor')

                elif code.name.startswith('ARG') and int(code.name[3:]) < 5:
                    string_memorizer.insert(0, name[int(code.name[3:])])

                else:
                    string_memorizer.insert(0, code.name)

            if isinstance(code, gp.Primitive):
                if code.name.startswith('D'):
                    string_memorizer = self.adjust_memorizer(code, string_memorizer)
                    flag = any(item in string_memorizer[0] for item in name[:5])

                    if flag == 0:
                        count = string_memorizer[0].count("D_tensor")
                        result = self.compile_module1(string_memorizer[0], D_tensor_memorizer[:count])
                        D_tensor_memorizer = D_tensor_memorizer[count:]
                        D_tensor_memorizer.insert(0, result)
                        string_memorizer[0] = 'D_tensor'

                    elif flag == 1:
                        count = string_memorizer[0].count("D_tensor")
                        result = self.compile_module2(string_memorizer[0], D_tensor_memorizer[:count])
                        D_tensor_memorizer = D_tensor_memorizer[count:]
                        D_tensor_memorizer.insert(0, result)
                        string_memorizer[0] = 'D_tensor'

                elif code.name.startswith('M'):
                    string_memorizer = self.adjust_memorizer(code, string_memorizer)

        return D_tensor_memorizer[0]

In [11]:
compiler=RPN_Compiler(range(2016,2017))
factor=compiler.compile(rpns[0])

KeyError: 'D_cs_industry_neutra'