In [35]:
import tensorflow as tf
import numpy as np
import os
from tensorflow.keras.models import load_model
from sklearn.model_selection import train_test_split
import tensorflow as tf
import pandas as pd
from imblearn.over_sampling import SMOTE


In [36]:
# 按列标准化
def normalization(data):
    min_vals = np.min(data, axis=0)
    max_vals = np.max(data, axis=0)
    _range = max_vals - min_vals
    return (data - min_vals) / _range

# 加载数据
original_data = pd.read_csv('./datasets/AEEEM/JDT.csv')       

original_data.isnull().values.any()  # Gives false ie:No null value in dataset
original_data = original_data.fillna(value=False)  #将缺失值填充为False
original_Y = original_data['class']  # Defective   class   isDefective  defects  label
original_Y = pd.DataFrame(original_Y)    
original_data = normalization(original_data)    

original_X = pd.DataFrame(original_data.drop(['class'], axis=1))  

#分为训练集和测试集  
x_train, x_test, y_train, y_test = train_test_split(original_X, original_Y, test_size=.1, random_state=12)
print(x_train.shape, y_train.shape,x_test.shape, y_test.shape)
sm = SMOTE(random_state=12, sampling_strategy=1.0)  # 解决分类不平衡问题
x, y = sm.fit_resample(x_train, y_train)  
y_train = pd.DataFrame(y, columns=['class'])    #Defective  class  isDefective  defects
x_train = pd.DataFrame(x, columns=original_X.columns)


(897, 21) (897, 1) (100, 21) (100, 1)


In [37]:
#细分出验证集
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=.1, random_state=12)

x_train = x_train.values
x_val = x_val.values
x_test = x_test.values
y_train = y_train.values
y_val = y_val.values
y_test = y_test.values

# 将数据重塑为适合一维卷积的格式
x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], 1)
x_val = x_val.reshape(x_val.shape[0], x_val.shape[1], 1)
x_test = x_test.reshape(x_test.shape[0], x_test.shape[1], 1)

# print(x_train.shape, y_train.shape, x_val.shape, y_val.shape, x_test.shape, y_test.shape)


In [38]:
n_classes = 2

# 将标签转换为独热编码向量
y_train = tf.keras.utils.to_categorical(y_train, n_classes)
y_val = tf.keras.utils.to_categorical(y_val, n_classes)
y_test = tf.keras.utils.to_categorical(y_test, n_classes)

print(y_train.shape, y_val.shape, y_test.shape)

(1276, 2) (142, 2) (100, 2)


In [39]:
# 将输入数据转变为浮点型
x_train = x_train.astype(np.float32)
x_val = x_val.astype(np.float32)
x_test = x_test.astype(np.float32)

In [40]:
input_shape  = (x_train.shape[1], 1)

print(f'input_shape: {input_shape}')

input_shape: (21, 1)


In [41]:
import os
print(os.getcwd())
os.chdir('/home/user/xgf/Disstill_defect_interpretation/Disstill_defect_interpretation')

/home/user/xgf/Disstill_defect_interpretation/Disstill_defect_interpretation


In [42]:
from models.convnet import ConvNet
from models.tree import SoftDecisionTree

In [43]:
nn = ConvNet(input_shape, n_classes=n_classes)  

In [44]:
y_train_soft = nn.predict(x_train)
# y_train_soft.shape

In [45]:
# 首先将数据展平，变成二维数据
x_train_flat = x_train.reshape((x_train.shape[0], -1))
x_val_flat = x_val.reshape((x_val.shape[0], -1))
x_test_flat = x_test.reshape((x_test.shape[0], -1))

x_train_flat.shape, x_val_flat.shape, x_test_flat.shape

((1276, 21), (142, 21), (100, 21))

In [46]:

# 参数设置
max_depth = 4
n_features = x_train.shape[1]
n_classes = 2
penalty_strength = 1e+1
penalty_decay = 0.25
inv_temp = 0.01  # 逆温度参数
epochs = 40
ema_win_size = 100
 

# 解释器模型 g_model
g_model = SoftDecisionTree(max_depth=max_depth, n_features=n_features, n_classes=n_classes, 
                          penalty_strength=penalty_strength, penalty_decay=penalty_decay, 
                          inv_temp=inv_temp, ema_win_size=ema_win_size)


In [47]:
from joint import analyze, train, evaluate

In [48]:
epochs = 40 
batch_size = 16  

# 设置模型保存的路径
f_model_path = 'assets/JDT/f_model_joint'
g_model_path = 'assets/JDT/g_model_joint'

# 检查模型文件是否存在
f_model_exists = os.path.exists(f_model_path + ".index")
g_model_exists = os.path.exists(g_model_path + ".index")

data_test = (x_test, x_test_flat, y_test)
data_val = (x_val, x_val_flat, y_val)

if not f_model_exists or not g_model_exists:
    # 如果模型文件不存在，则训练模型
    f_model_joint, g_model_joint = train(nn, g_model, x_train, x_train_flat, y_train, data_val, epochs, batch_size=batch_size)
    # save model
    f_model_joint.save_weights(f_model_path)
    g_model_joint.save_weights(g_model_path)
    

else:
    # g_model
    g_model_joint = SoftDecisionTree(max_depth=max_depth, n_features=n_features, n_classes=n_classes, 
                          penalty_strength=penalty_strength, penalty_decay=penalty_decay, 
                          inv_temp=inv_temp, ema_win_size=ema_win_size)
    # f_model
    f_model_joint = ConvNet(input_shape, n_classes=n_classes)  

    # load model
    f_model_joint.load_weights(f_model_path)
    g_model_joint.load_weights(g_model_path)
    

Epoch 1/40
Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40
Epoch 9/40
Epoch 10/40
Epoch 11/40
Epoch 12/40
Epoch 13/40
Epoch 14/40
Epoch 15/40
Epoch 16/40
Epoch 17/40
Epoch 18/40
Epoch 19/40
Epoch 20/40
Epoch 21/40
Epoch 22/40
Epoch 23/40
Epoch 24/40
Epoch 25/40
Epoch 26/40
Epoch 27/40
Epoch 28/40
Epoch 29/40
Epoch 30/40
Epoch 31/40
Epoch 32/40
Epoch 33/40
Epoch 34/40
Epoch 35/40
Epoch 36/40
Epoch 37/40
Epoch 38/40
Epoch 39/40
Epoch 40/40


In [None]:
from joint import analyze

# 计算保真度
f_joint_acc, fidelity, g_joint_acc = analyze(f_model_joint, g_model_joint, x_test, x_test_flat, y_test)

print("Accuracy of f (in %): {:.2f}".format(f_joint_acc * 100))
print("Accuracy of g (in %): {:.2f}".format(g_joint_acc * 100))
print("Fidelity (in %): {:.2f}".format(fidelity * 100))


### 计算其他指标（MCC、AUC、F1-score）

In [50]:
y_test_ = np.argmax(y_test,axis=1)

In [None]:
from sklearn import metrics

y_pred = f_model_joint.predict(x_test)
y_pred_ = np.argmax(y_pred, axis=1)

accuracy_score=metrics.accuracy_score(y_test_,y_pred_)

print(f"accuracy_score={accuracy_score: .4f}")

fscore=metrics.f1_score(y_test_,y_pred_,average='macro')
print(f"f-score={fscore: .4f}")

auc = metrics.roc_auc_score(y_test_, y_pred_, average='macro')
print(f"auc={auc: .4f}")

mcc = metrics.matthews_corrcoef(y_test_, y_pred_)
print(f"MCC={mcc: .4f}")
