In [1]:
import warnings
from rdkit import RDLogger

# 屏蔽 RDKit 警告
RDLogger.DisableLog('rdApp.*')

# 或屏蔽所有 Python 警告
warnings.filterwarnings("ignore")


In [2]:
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import KFold
from sklearn.metrics import precision_recall_curve, auc, f1_score, recall_score, precision_score, matthews_corrcoef, accuracy_score
from kan import KAN  # Assuming the KAN model is defined in a module called kan
from tqdm import tqdm  # Import tqdm for progress bars
from sklearn.metrics import roc_curve




# 数据预处理
df = pd.read_csv('selected_features_Flam.csv')
labels = df['Flammability'].values
smiles_list = df['SMILES'].tolist()

# 函数：将SMILES转换为分子描述符和指纹
def smiles_to_features(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    # 提取描述符
    descriptors = [
        Descriptors.MolWt(mol),  # 分子量
        Descriptors.MolLogP(mol),  # LogP
        Descriptors.NumHDonors(mol),  # 氢键供体数量
        Descriptors.NumHAcceptors(mol)  # 氢键受体数量
    ]
    # 生成Morgan指纹
    fingerprint = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
    fingerprint_array = np.zeros((2048,))
    Chem.DataStructs.ConvertToNumpyArray(fingerprint, fingerprint_array)
    # 合并描述符和指纹
    features = np.concatenate([descriptors, fingerprint_array])
    return features

# 将SMILES转换为特征
features = []
for smiles in smiles_list:
    feature = smiles_to_features(smiles)
    if feature is not None:
        features.append(feature)

# 转换为numpy数组
features = np.array(features)
X = np.array(features)
y = labels

# 设置设备为 CUDA 或 CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 将数据转换为 PyTorch 张量
X_tensor = torch.tensor(X, dtype=torch.float32).to(device)
y_tensor = torch.tensor(y, dtype=torch.float32).to(device)

# 创建DataLoader函数
def get_dataloader(X, y, batch_size):
    dataset = TensorDataset(X, y)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

# 训练 KAN 模型
def train_kan(model, train_dataloader, optimizer, criterion, steps=20, patience=1):
    model.train()
    best_val_auc = -np.inf  # 最佳验证集AUC
    epochs_without_improvement = 0  # 连续多少个epoch验证集性能没有提升
    for epoch in range(steps):
        running_loss = 0.0
        for X_batch, y_batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{steps}", leave=False):
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs.squeeze(), y_batch)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{steps}], Loss: {running_loss/len(train_dataloader):.4f}")
        
        # 每个epoch结束后，进行验证集评估
        val_auc = evaluate_kan(model, val_dataloader)  # AUC评估
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

# 评估模型
def evaluate_kan(model, val_dataloader):
    model.eval()
    val_auc_scores = []
    with torch.no_grad():
        for X_batch, y_batch in val_dataloader:
            outputs = model(X_batch)
            probs = torch.sigmoid(outputs).squeeze().cpu().numpy()
            y_batch_cpu = y_batch.cpu().numpy()
            precision, recall, _ = precision_recall_curve(y_batch_cpu, probs)
            pr_auc = auc(recall, precision)
            val_auc_scores.append(pr_auc)
    return np.mean(val_auc_scores)


In [3]:


# 5折交叉验证
kf = KFold(n_splits=5, shuffle=True, random_state=42)
pr_auc_train_list, auc_train_list = [], []
f1_train_list, recall_train_list = [], []
precision_train_list, mcc_train_list = [], []
accuracy_train_list, pr_auc_val_list = [], []
auc_val_list, f1_val_list = [], []
recall_val_list, precision_val_list = [], []
mcc_val_list, accuracy_val_list = [], []

# 交叉验证训练与评估
for fold, (train_idx, val_idx) in enumerate(kf.split(X_tensor)):
    print(f"\nFold {fold + 1}")

    # 获取当前fold的训练集和验证集
    X_train_fold, X_val_fold = X_tensor[train_idx], X_tensor[val_idx]
    y_train_fold, y_val_fold = y_tensor[train_idx], y_tensor[val_idx]

    # 创建 DataLoader
    train_dataloader = get_dataloader(X_train_fold, y_train_fold, batch_size=128)  # 增加batch size
    val_dataloader = get_dataloader(X_val_fold, y_val_fold, batch_size=128)

    # 创建 KAN 模型
    model = KAN(width=[2048, 100, 1], grid=3, k=3, seed=42, device=device)

    # 使用 AdamW 优化器
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
    criterion = torch.nn.BCEWithLogitsLoss()

    # 训练 KAN 模型
    train_kan(model, train_dataloader, optimizer, criterion, steps=20, patience=5)

    # 评估模型（训练集和验证集）
    for data, target in [(X_train_fold, y_train_fold), (X_val_fold, y_val_fold)]:
        model.eval()
        with torch.no_grad():
            outputs = model(data)
            probs = torch.sigmoid(outputs).squeeze().cpu().numpy()
            target_cpu = target.cpu().numpy()
            
            # 计算 AUC 和其他指标
            precision, recall, _ = precision_recall_curve(target_cpu, probs)
            pr_auc = auc(recall, precision)
            pr_auc_train_list.append(pr_auc) if data is X_train_fold else pr_auc_val_list.append(pr_auc)
            
            fpr, tpr, _ = roc_curve(target_cpu, probs)
            auc_val = auc(fpr, tpr)
            auc_train_list.append(auc_val) if data is X_train_fold else auc_val_list.append(auc_val)
            
            f1 = f1_score(target_cpu, (probs > 0.5).astype(int))
            f1_train_list.append(f1) if data is X_train_fold else f1_val_list.append(f1)
            
            recall_val = recall_score(target_cpu, (probs > 0.5).astype(int))
            recall_train_list.append(recall_val) if data is X_train_fold else recall_val_list.append(recall_val)
            
            precision_val = precision_score(target_cpu, (probs > 0.5).astype(int))
            precision_train_list.append(precision_val) if data is X_train_fold else precision_val_list.append(precision_val)
            
            mcc_val = matthews_corrcoef(target_cpu, (probs > 0.5).astype(int))
            mcc_train_list.append(mcc_val) if data is X_train_fold else mcc_val_list.append(mcc_val)
            
            accuracy_val = accuracy_score(target_cpu, (probs > 0.5).astype(int))
            accuracy_train_list.append(accuracy_val) if data is X_train_fold else accuracy_val_list.append(accuracy_val)

# 计算五折平均的评估指标（训练集和验证集）
print(f"\nAverage PR-AUC (Train): {np.mean(pr_auc_train_list):.4f}")
print(f"Average AUC (Train): {np.mean(auc_train_list):.4f}")
print(f"Average F1 (Train): {np.mean(f1_train_list):.4f}")
print(f"Average Recall (Train): {np.mean(recall_train_list):.4f}")
print(f"Average Precision (Train): {np.mean(precision_train_list):.4f}")
print(f"Average MCC (Train): {np.mean(mcc_train_list):.4f}")
print(f"Average Accuracy (Train): {np.mean(accuracy_train_list):.4f}")

print(f"\nAverage PR-AUC (Validation): {np.mean(pr_auc_val_list):.4f}")
print(f"Average AUC (Validation): {np.mean(auc_val_list):.4f}")
print(f"Average F1 (Validation): {np.mean(f1_val_list):.4f}")
print(f"Average Recall (Validation): {np.mean(recall_val_list):.4f}")
print(f"Average Precision (Validation): {np.mean(precision_val_list):.4f}")
print(f"Average MCC (Validation): {np.mean(mcc_val_list):.4f}")
print(f"Average Accuracy (Validation): {np.mean(accuracy_val_list):.4f}")



Fold 1
checkpoint directory created: ./model
saving model version 0.0


                                                          

Epoch [1/20], Loss: 0.5172


                                                          

Epoch [2/20], Loss: 0.1651


                                                          

Epoch [3/20], Loss: 0.1943


                                                          

Epoch [4/20], Loss: 0.2363


                                                          

Epoch [5/20], Loss: 0.1850


                                                          

Epoch [6/20], Loss: 0.1937


                                                          

Epoch [7/20], Loss: 0.1636


                                                          

Epoch [8/20], Loss: 0.1659


                                                          

Epoch [9/20], Loss: 0.1462


                                                           

Epoch [10/20], Loss: 0.1400


                                                           

Epoch [11/20], Loss: 0.1239


                                                           

Epoch [12/20], Loss: 0.1244


                                                           

Epoch [13/20], Loss: 0.1158


                                                           

Epoch [14/20], Loss: 0.1058


                                                           

Epoch [15/20], Loss: 0.1012
Early stopping at epoch 15

Fold 2
checkpoint directory created: ./model
saving model version 0.0


                                                          

Epoch [1/20], Loss: 0.4843


                                                          

Epoch [2/20], Loss: 0.1803


                                                          

Epoch [3/20], Loss: 0.1822


                                                          

Epoch [4/20], Loss: 0.1658


                                                          

Epoch [5/20], Loss: 0.1748


                                                          

Epoch [6/20], Loss: 0.1647


                                                          

Epoch [7/20], Loss: 0.1759


                                                          

Epoch [8/20], Loss: 0.1643


                                                          

Epoch [9/20], Loss: 0.1442


                                                           

Epoch [10/20], Loss: 0.1423


                                                           

Epoch [11/20], Loss: 0.1255


                                                           

Epoch [12/20], Loss: 0.1081


                                                           

Epoch [13/20], Loss: 0.1060


                                                           

Epoch [14/20], Loss: 0.1055


                                                           

Epoch [15/20], Loss: 0.0964


                                                           

Epoch [16/20], Loss: 0.0987


                                                           

Epoch [17/20], Loss: 0.0900


Epoch 18/20:  80%|████████  | 4/5 [09:38<02:25, 145.81s/it]

Early stopping at epoch 18

Fold 3
checkpoint directory created: ./model
saving model version 0.0


                                                          

Epoch [1/20], Loss: 0.5178


                                                          

Epoch [2/20], Loss: 0.1593


                                                          

Epoch [3/20], Loss: 0.1498


                                                          

Epoch [4/20], Loss: 0.1394


                                                          

Epoch [5/20], Loss: 0.1392


                                                          

Epoch [6/20], Loss: 0.1629


                                                          

Epoch [7/20], Loss: 0.1464


                                                          

Epoch [8/20], Loss: 0.1357


                                                          

Epoch [9/20], Loss: 0.1338


                                                           

Epoch [10/20], Loss: 0.1179
Early stopping at epoch 10

Fold 4
checkpoint directory created: ./model
saving model version 0.0


                                                          

Epoch [1/20], Loss: 0.5307


                                                          

Epoch [2/20], Loss: 0.1700


                                                          

Epoch [3/20], Loss: 0.1865


                                                          

Epoch [4/20], Loss: 0.1762


                                                          

Epoch [5/20], Loss: 0.1812


                                                          

Epoch [6/20], Loss: 0.1788


                                                          

Epoch [7/20], Loss: 0.1656


                                                          

Epoch [8/20], Loss: 0.1756


                                                          

Epoch [9/20], Loss: 0.1595


                                                           

Epoch [10/20], Loss: 0.1498


                                                           

Epoch [11/20], Loss: 0.1290


                                                           

Epoch [12/20], Loss: 0.1260


                                                           

Epoch [13/20], Loss: 0.1180


                                                           

Epoch [14/20], Loss: 0.1049


                                                           

Epoch [15/20], Loss: 0.1140


                                                           

Epoch [16/20], Loss: 0.1057


                                                           

Epoch [17/20], Loss: 0.1110


                                                           

Epoch [18/20], Loss: 0.1167


                                                           

Epoch [19/20], Loss: 0.0934


                                                           

Epoch [20/20], Loss: 0.0939
Early stopping at epoch 20

Fold 5
checkpoint directory created: ./model
saving model version 0.0


                                                          

Epoch [1/20], Loss: 0.4982


                                                          

Epoch [2/20], Loss: 0.1568


                                                          

Epoch [3/20], Loss: 0.1889


                                                          

Epoch [4/20], Loss: 0.1801


                                                          

Epoch [5/20], Loss: 0.1685


                                                          

Epoch [6/20], Loss: 0.2347


                                                          

Epoch [7/20], Loss: 0.1708


                                                          

Epoch [8/20], Loss: 0.1493


                                                          

Epoch [9/20], Loss: 0.1418


                                                           

Epoch [10/20], Loss: 0.1422


                                                           

Epoch [11/20], Loss: 0.1435


                                                           

Epoch [12/20], Loss: 0.1304


                                                           

Epoch [13/20], Loss: 0.1247


                                                           

Epoch [14/20], Loss: 0.1067


                                                           

Epoch [15/20], Loss: 0.1087


                                                           

Epoch [16/20], Loss: 0.1138


                                                           

Epoch [17/20], Loss: 0.0961


                                                           

Epoch [18/20], Loss: 0.1029


                                                           

Epoch [19/20], Loss: 0.1179


                                                           

Epoch [20/20], Loss: 0.1010

Average PR-AUC (Train): 0.9930
Average AUC (Train): 0.8403
Average F1 (Train): 0.9862
Average Recall (Train): 1.0000
Average Precision (Train): 0.9728
Average MCC (Train): 0.1468
Average Accuracy (Train): 0.9728

Average PR-AUC (Validation): 0.9748
Average AUC (Validation): 0.6798
Average F1 (Validation): 0.9849
Average Recall (Validation): 0.9986
Average Precision (Validation): 0.9718
Average MCC (Validation): -0.0020
Average Accuracy (Validation): 0.9704


In [4]:
    # 保存模型
torch.save(model.state_dict(), f"kan_model_fold_{fold + 1}.pth")
print(f"Model for Fold {fold + 1} saved as kan_model_fold_{fold + 1}.pth")


Model for Fold 5 saved as kan_model_fold_5.pth


In [None]:
Fold 1
Train Set - Class 0: 17, Class 1: 549
Val Set   - Class 0: 3, Class 1: 139


Fold 2
Train Set - Class 0: 15, Class 1: 551
Val Set   - Class 0: 5, Class 1: 137


Fold 3
Train Set - Class 0: 14, Class 1: 552
Val Set   - Class 0: 6, Class 1: 136


Fold 4
Train Set - Class 0: 18, Class 1: 549
Val Set   - Class 0: 2, Class 1: 139


Fold 5
Train Set - Class 0: 16, Class 1: 551
Val Set   - Class 0: 4, Class 1: 137

In [None]:
Average PR-AUC (Train): 0.9930
Average AUC (Train): 0.8403
Average F1 (Train): 0.9862
Average Recall (Train): 1.0000
Average Precision (Train): 0.9728
Average MCC (Train): 0.1468
Average Accuracy (Train): 0.9728

Average PR-AUC (Validation): 0.9748
Average AUC (Validation): 0.6798
Average F1 (Validation): 0.9849
Average Recall (Validation): 0.9986
Average Precision (Validation): 0.9718
Average MCC (Validation): -0.0020
Average Accuracy (Validation): 0.9704

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, confusion_matrix, ConfusionMatrixDisplay

# 用于存储训练和验证集的评估数据
fpr_train_list, tpr_train_list, fpr_val_list, tpr_val_list = [], [], [], []
y_true_train_all, y_true_val_all, y_pred_train_all, y_pred_val_all = [], [], [], []

# 5折交叉验证
kf = KFold(n_splits=5, shuffle=True, random_state=42)
pr_auc_train_list, auc_train_list = [], []
f1_train_list, recall_train_list = [], []
precision_train_list, mcc_train_list = [], []
accuracy_train_list, pr_auc_val_list = [], []
auc_val_list, f1_val_list = [], []
recall_val_list, precision_val_list = [], []
mcc_val_list, accuracy_val_list = [], []
fold_distributions = []
# 交叉验证训练与评估
for fold, (train_idx, val_idx) in enumerate(kf.split(X_tensor)):
    print(f"\nFold {fold + 1}")

    # 获取当前fold的训练集和验证集
    X_train_fold, X_val_fold = X_tensor[train_idx], X_tensor[val_idx]
    y_train_fold, y_val_fold = y_tensor[train_idx], y_tensor[val_idx]

    y_train_labels = y[train_idx]
    y_val_labels = y[val_idx]
    
    # 统计类别0和1的数量
    train_class_counts = np.bincount(y_train_labels.astype(int), minlength=2)
    val_class_counts = np.bincount(y_val_labels.astype(int), minlength=2)
    
    # 记录分布信息
    fold_distributions.append({
        "Fold": fold + 1,
        "Train_Class0": train_class_counts[0],
        "Train_Class1": train_class_counts[1],
        "Val_Class0": val_class_counts[0],
        "Val_Class1": val_class_counts[1]
    })
    
    # 打印当前fold分布
    print(f"Train Set - Class 0: {train_class_counts[0]}, Class 1: {train_class_counts[1]}")
    print(f"Val Set   - Class 0: {val_class_counts[0]}, Class 1: {val_class_counts[1]}\n")