In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pandas as pd
import re
from tqdm import tqdm
import random
import math
import copy

  from pandas.core import (


In [2]:
from Modules import read_data, collate_fn
from Modules import train_epoch, eval_epoch, train_model
from Modules import RuleSet, RuleBasedTPP, EventDataset
# from Modules import RuleNode, RuleMCTS
from Modules import optimize

In [3]:
file_path = "stroke_dataset_all.csv"
target_name = "Middle_to_Sever"
data, var_name_dict = read_data(file_path, target_varibles=target_name, outliers=0.0)
print(f"The data have {len(data)} samples.")

The data have 1351 samples.


In [4]:
max_order = 2  # 最高阶规则
num_candidates = 10  # 每次优化的规则数量
n_calls = 20  # 优化的最大迭代次数
device = torch.device("cpu" if torch.cuda.is_available() else "cpu")

best_rules, best_loss = optimize(
    data, var_name_dict, target_name, max_order=max_order, 
    num_candidates=num_candidates, n_calls=n_calls, device=device
)

print("Best Rules:", best_rules)
print("Best Loss:", best_loss)

Optimizing rule selection...


100%|██████████| 48/48 [1:19:08<00:00, 98.92s/it] 


Number of selected first order rules: 6
Number of candidate rules: 1608
Iteration No: 1 started. Evaluating function at random point.
Score computation time: 82.1205563545227 seconds
Iteration No: 1 ended. Evaluation done at random point.
Time taken: 82.1246
Function value obtained: 15.9117
Current minimum: 15.9117
Iteration No: 2 started. Evaluating function at random point.
Score computation time: 108.81337189674377 seconds
Iteration No: 2 ended. Evaluation done at random point.
Time taken: 108.8163
Function value obtained: 15.8951
Current minimum: 15.8951
Iteration No: 3 started. Evaluating function at random point.
Score computation time: 111.99849081039429 seconds
Iteration No: 3 ended. Evaluation done at random point.
Time taken: 112.0015
Function value obtained: 15.8647
Current minimum: 15.8647
Iteration No: 4 started. Evaluating function at random point.
Score computation time: 111.05566763877869 seconds
Iteration No: 4 ended. Evaluation done at random point.
Time taken: 111.05

In [4]:
rule_set = RuleSet(data, var_name_dict)
rules = ['Arterial Blood Pressure mean Low and Arterial Blood Pressure systolic High', 'Creatinine High equal PTT High', 'Temperature Fahrenheit High equal Chloride High', 'HCO3 Low equal Temperature Fahrenheit High', 'Arterial Blood Pressure systolic High before Creatinine High', 'Temperature Fahrenheit High', 'Potassium Low equal Creatinine High', 'Temperature Fahrenheit High and Hemoglobin High', 'Glucose High before Respiratory Rate Low']
for rule in rules:
    rule_set.add_rule(rule)
device = "cpu"
model = RuleBasedTPP(rule_set.var_name_dict, rule_set.rule_name_dict, rule_set.rule_var_ids, device=device)
model.to(device)
loss, output = train_model(model, data, rule_set.rule_event_data, target_name, device, num_epochs=100, lr=0.01, patience=5, if_print=True)

Epoch 0, Loss: 96.41525452490212
Eval NLL: 90.6661458173801, Eval MAE: 1.5251959198295422, Eval RMSE: 1.7702715123601511
Epoch 1, Loss: 83.40246126243358
Eval NLL: 79.32269318649244, Eval MAE: 1.7709893157561323, Eval RMSE: 2.0248115472495556
Epoch 2, Loss: 73.57882778379664
Eval NLL: 70.7193800778407, Eval MAE: 1.8040234064923881, Eval RMSE: 2.073168364068239
Epoch 3, Loss: 66.01683661076754
Eval NLL: 63.858525172810715, Eval MAE: 1.9292914615565986, Eval RMSE: 2.201806623667367
Epoch 4, Loss: 59.904766269304204
Eval NLL: 58.32974046842642, Eval MAE: 2.097292041921526, Eval RMSE: 2.3783999480224622
Epoch 5, Loss: 54.964303723639944
Eval NLL: 53.73080065549522, Eval MAE: 2.132470824071842, Eval RMSE: 2.4161078030332854
Epoch 6, Loss: 50.806143437601925
Eval NLL: 50.03549485892827, Eval MAE: 2.1041820616739697, Eval RMSE: 2.420421057109243
Epoch 7, Loss: 47.4661981953514
Eval NLL: 46.930027969209, Eval MAE: 2.3875887053039233, Eval RMSE: 2.683139619379685
Epoch 8, Loss: 44.6700766785277

Eval NLL: 7.4159391550547, Eval MAE: 6.2274567986250915, Eval RMSE: 6.469516489894602
Eval NLL: 7.409330251390839, Eval MAE: 6.2264036909705025, Eval RMSE: 6.480850927819507
Eval NLL: 7.4119570848767875, Eval MAE: 5.100745918413086, Eval RMSE: 5.40912512955235
Eval NLL: 7.415242275966593, Eval MAE: 5.083468017404172, Eval RMSE: 5.393620904039412

In [5]:
print(torch.exp(model.rule_weights))
print(model.rule_name_dict)
print("beta:", model.beta)
print("meas_weights:", model.meas_weights)
print("mu:", model.mu)

tensor([2.6146, 1.4559, 2.2141, 1.4785, 1.1452, 1.8543, 1.0374, 2.4063, 1.8094],
       grad_fn=<ExpBackward0>)
{'Arterial Blood Pressure mean Low and Arterial Blood Pressure systolic High': 0, 'Creatinine High equal PTT High': 1, 'Temperature Fahrenheit High equal Chloride High': 2, 'HCO3 Low equal Temperature Fahrenheit High': 3, 'Arterial Blood Pressure systolic High before Creatinine High': 4, 'Temperature Fahrenheit High': 5, 'Potassium Low equal Creatinine High': 6, 'Temperature Fahrenheit High and Hemoglobin High': 7, 'Glucose High before Respiratory Rate Low': 8}
beta: Parameter containing:
tensor(5.3095, requires_grad=True)
meas_weights: Parameter containing:
tensor([ 4.4662e-01,  8.7705e-01,  8.5516e-01,  7.1428e-01,  7.3018e-02,
         5.3840e-01,  4.2042e-01, -3.4226e-01, -2.1234e+00,  1.6787e-01,
         3.2238e-01,  9.7340e-02,  6.4284e-01,  1.7217e-01,  3.0516e-01,
         3.1567e-01, -2.5439e+00,  3.1711e-01, -6.2819e+00,  8.4135e-01,
        -8.7619e+00,  4.3233e-0

In [6]:
count = 0
for rule_name in model.rule_name_dict:
    weight = round(torch.exp(model.rule_weights[model.rule_name_dict[rule_name]]).item(), 4)
    #print(f"{rule_name} -> {target_name}, weight = {weight}")
    count += 1
    print(f"|{count} | {rule_name} -> {target_name} | {weight}| |")

|1 | Arterial Blood Pressure mean Low and Arterial Blood Pressure systolic High -> Middle_to_Sever | 2.6146| |
|2 | Creatinine High equal PTT High -> Middle_to_Sever | 1.4559| |
|3 | Temperature Fahrenheit High equal Chloride High -> Middle_to_Sever | 2.2141| |
|4 | HCO3 Low equal Temperature Fahrenheit High -> Middle_to_Sever | 1.4785| |
|5 | Arterial Blood Pressure systolic High before Creatinine High -> Middle_to_Sever | 1.1452| |
|6 | Temperature Fahrenheit High -> Middle_to_Sever | 1.8543| |
|7 | Potassium Low equal Creatinine High -> Middle_to_Sever | 1.0374| |
|8 | Temperature Fahrenheit High and Hemoglobin High -> Middle_to_Sever | 2.4063| |
|9 | Glucose High before Respiratory Rate Low -> Middle_to_Sever | 1.8094| |


In [7]:
for line in output:
    print(line)

[96.41525452490212, 90.6661458173801, 1.5251959198295422, 1.7702715123601511]
[83.40246126243358, 79.32269318649244, 1.7709893157561323, 2.0248115472495556]
[73.57882778379664, 70.7193800778407, 1.8040234064923881, 2.073168364068239]
[66.01683661076754, 63.858525172810715, 1.9292914615565986, 2.201806623667367]
[59.904766269304204, 58.32974046842642, 2.097292041921526, 2.3783999480224622]
[54.964303723639944, 53.73080065549522, 2.132470824071842, 2.4161078030332854]
[50.806143437601925, 50.03549485892827, 2.1041820616739697, 2.420421057109243]
[47.4661981953514, 46.930027969209, 2.3875887053039233, 2.683139619379685]
[44.67007667852774, 44.348089028988376, 2.5649730339921253, 2.860573357541622]
[42.288973713583324, 42.236383310543204, 2.684030575725867, 2.9841840790990206]
[40.36004819483667, 40.356786745060866, 2.838696124148984, 3.1488053614472995]
[38.659332329255534, 38.76421365702723, 2.828310370005364, 3.1255490324593986]
[37.19663076610478, 37.398925633008204, 2.8529339214972476

Best Rules: ['WBC High and Glucoseserum High', 'AST High and Anion gap Low', 'AST High and Total Bilirubin High', 'Lactic Acid High and BUN High', 'Arterial CO2 Pressure High before Potassiumserum High', 'Lactic Acid Low and Arterial CO2 Pressure Low', 'C Reactive Protein CRP High before Arterial O2 Saturation Low', 'C Reactive Protein CRP High before BUN High', 'PTT Low equal PH Arterial Low', 'Potassiumserum Low before Anion gap Low']
Best Loss: 90.15578431586043

Best Rules: ['BUN High', 'Anion gap Low', 'Hemoglobin Low', 'Prothrombin time High', 'Heart Rate High', 'Anion gap High', 'Albumin Low', 'WBC High', 'Albumin High', 'Potassiumserum High']
Best Loss: 90.14582462717163

'Arterial CO2 Pressure Low', 'Heart Rate Low', 'O2 saturation pulseoxymetry Low', 'Sodium Low', 'Creatinine High', 'Respiratory Rate Low', 'Creatinine High', 'Heart Rate Low', 'HCO3 Low', 'Arterial Blood Pressure mean Low'

Best Rules: ['Heart Rate High', 'Respiratory Rate Low', 'Arterial Blood Pressure diastolic Low', 'Heart Rate Low', 'Anion gap High', 'Arterial Blood Pressure systolic Low', 'Temperature Fahrenheit High', 'O2 saturation pulseoxymetry Low', 'Heart Rate High', 'Anion gap High']
Best Loss: 58.30078836220586