In [2]:
file_path = './gmdh_model_expression.txt'

In [5]:
import sympy as sp
import re
from collections import defaultdict

# 1. 读取文件内容
with open('./gmdh_model_expression.txt', 'r') as f:
    content = f.read()

# 2. 解析所有表达式并按层排序
expressions = {}
pattern = r'# 层 (\d+), 模型 (\d+)\nexpr_(\d+)\s*=\s*(.*?)(?=\n#|\n$|$)'
for match in re.finditer(pattern, content, re.DOTALL):
    layer = int(match.group(1))
    model = int(match.group(2))
    expr_id = int(match.group(3))
    expr_text = match.group(4).strip()
    expressions[expr_id] = expr_text

# 3. 提取所有基因名称
gene_pattern = r"'([A-Za-z0-9_]+)'"
genes = set()
for expr in expressions.values():
    genes.update(re.findall(gene_pattern, expr))

print(f"找到 {len(expressions)} 个表达式")
print(f"找到 {len(genes)} 个基因")

# 4. 创建符号变量
gene_symbols = {gene: sp.Symbol(gene) for gene in genes}
expr_symbols = {}

# 5. 处理表达式（从底层向上）
def process_expression(expr_str):
    """解析表达式字符串并转换为sympy表达式对象"""
    # 处理常数项和线性项
    result = 0
    
    # 提取常数项
    constant_pattern = r'^([-+]?\d+\.\d+)'
    constant_match = re.match(constant_pattern, expr_str)
    if constant_match:
        constant = float(constant_match.group(1))
        result = sp.Float(constant)
    
    # 提取线性项和交互项
    terms = []
    
    # 处理基因项: 数字*('基因名')
    gene_pattern = r'([-+]?\d+\.\d+)\*\(\'([A-Za-z0-9_]+)\'\)'
    for match in re.finditer(gene_pattern, expr_str):
        coef = sp.Float(match.group(1))
        gene = match.group(2)
        term = coef * gene_symbols[gene]
        terms.append(term)
    
    # 处理交互项: 数字*('基因1'*'基因2')
    interaction_pattern = r'([-+]?\d+\.\d+)\*\(\'([A-Za-z0-9_]+)\'\*\'([A-Za-z0-9_]+)\'\)'
    for match in re.finditer(interaction_pattern, expr_str):
        coef = sp.Float(match.group(1))
        gene1 = match.group(2)
        gene2 = match.group(3)
        term = coef * gene_symbols[gene1] * gene_symbols[gene2]
        terms.append(term)
    
    # 处理子表达式项: 数字*(expr_X)
    subexpr_pattern = r'([-+]?\d+\.\d+)\*\(expr_(\d+)\)'
    for match in re.finditer(subexpr_pattern, expr_str):
        coef = sp.Float(match.group(1))
        expr_id = int(match.group(2))
        if expr_id in expr_symbols and expr_symbols[expr_id] is not None:
            term = coef * expr_symbols[expr_id]
            terms.append(term)
    
    # 处理子表达式交互项: 数字*(expr_X*expr_Y)
    interact_expr_pattern = r'([-+]?\d+\.\d+)\*\(expr_(\d+)\*expr_(\d+)\)'
    for match in re.finditer(interact_expr_pattern, expr_str):
        coef = sp.Float(match.group(1))
        expr_id1 = int(match.group(2))
        expr_id2 = int(match.group(3))
        if expr_id1 in expr_symbols and expr_id2 in expr_symbols and \
           expr_symbols[expr_id1] is not None and expr_symbols[expr_id2] is not None:
            term = coef * expr_symbols[expr_id1] * expr_symbols[expr_id2]
            terms.append(term)
    
    # 处理子表达式与基因交互项: 数字*(expr_X*'基因')
    gene_expr_pattern = r'([-+]?\d+\.\d+)\*\(expr_(\d+)\*\'([A-Za-z0-9_]+)\'\)'
    for match in re.finditer(gene_expr_pattern, expr_str):
        coef = sp.Float(match.group(1))
        expr_id = int(match.group(2))
        gene = match.group(3)
        if expr_id in expr_symbols and expr_symbols[expr_id] is not None:
            term = coef * expr_symbols[expr_id] * gene_symbols[gene]
            terms.append(term)
    
    # 合并所有项
    for term in terms:
        result += term
    
    return result

# 6. 自底向上计算表达式
for expr_id in sorted(expressions.keys(), reverse=True):
    print(f"处理表达式 {expr_id}")
    expr_symbols[expr_id] = process_expression(expressions[expr_id])
    # 每处理完一个表达式尝试简化，以节省内存
    if expr_symbols[expr_id] is not None:
        expr_symbols[expr_id] = sp.expand(expr_symbols[expr_id])

# 7. 获取最终结果并展开
try:
    final_result = expr_symbols[0]  # 假设最终结果是expr_0
    expanded_result = sp.expand(final_result)
    print("成功展开最终表达式")
except Exception as e:
    print(f"展开最终表达式失败: {e}")
    # 尝试找到非空的最低层表达式
    for expr_id in sorted(expressions.keys()):
        if expr_symbols[expr_id] is not None:
            expanded_result = sp.expand(expr_symbols[expr_id])
            print(f"使用表达式 {expr_id} 作为替代")
            break

# 8. 提取原始权重 - 直接从第0层表达式中获取
original_weights = defaultdict(list)
layer0_pattern = r'# 层 0, 模型 \d+\nexpr_\d+\s*=\s*(.*?)(?=\n#|\n$|$)'
for match in re.finditer(layer0_pattern, content, re.DOTALL):
    expr = match.group(1).strip()
    
    # 提取线性权重
    linear_pattern = r'([-+]?\d+\.\d+)\*\(\'([A-Za-z0-9_]+)\'\)'
    for w_match in re.finditer(linear_pattern, expr):
        weight = float(w_match.group(1))
        gene = w_match.group(2)
        original_weights[gene].append(weight)
    
    # 提取交互项权重
    interaction_pattern = r'([-+]?\d+\.\d+)\*\(\'([A-Za-z0-9_]+)\'\*\'([A-Za-z0-9_]+)\'\)'
    for w_match in re.finditer(interaction_pattern, expr):
        weight = float(w_match.group(1))
        gene1 = w_match.group(2)
        gene2 = w_match.group(3)
        interaction = f"{gene1}*{gene2}"
        original_weights[interaction].append(weight)

# 计算平均权重
avg_weights = {feature: sum(weights)/len(weights) for feature, weights in original_weights.items()}

# 9. 输出结果
print("\n原始特征权重(从第0层模型提取):")
sorted_features = sorted(avg_weights.items(), key=lambda x: abs(x[1]), reverse=True)

# 分类显示
print("\n线性项权重:")
for feature, weight in [(f,w) for f,w in sorted_features if '*' not in f]:
    print(f"{feature}: {weight}")

print("\n交互项权重(前20个):")
interaction_weights = [(f,w) for f,w in sorted_features if '*' in f]
for feature, weight in interaction_weights[:20]:
    print(f"{feature}: {weight}")

print(f"\n共有 {len([f for f,w in sorted_features if '*' not in f])} 个基因的线性权重")
print(f"共有 {len([f for f,w in sorted_features if '*' in f])} 个交互项权重")

找到 77 个表达式
找到 64 个基因
处理表达式 76
处理表达式 75
处理表达式 74
处理表达式 73
处理表达式 72
处理表达式 71
处理表达式 70
处理表达式 69
处理表达式 68
处理表达式 67
处理表达式 66
处理表达式 65
处理表达式 64
处理表达式 63
处理表达式 62
处理表达式 61
处理表达式 60
处理表达式 59
处理表达式 58
处理表达式 57
处理表达式 56
处理表达式 55
处理表达式 54
处理表达式 53
处理表达式 52
处理表达式 51
处理表达式 50
处理表达式 49
处理表达式 48
处理表达式 47
处理表达式 46
处理表达式 45
处理表达式 44
处理表达式 43


MemoryError: 