In [1]:
#pip install pykan -i https://pypi.tuna.tsinghua.edu.cn/simple


In [1]:
import warnings
from rdkit import RDLogger

# 屏蔽 RDKit 警告
RDLogger.DisableLog('rdApp.*')

# 或屏蔽所有 Python 警告
warnings.filterwarnings("ignore")


In [3]:
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import KFold
from sklearn.metrics import precision_recall_curve, auc, f1_score, recall_score, precision_score, matthews_corrcoef, accuracy_score
from kan import KAN  # Assuming the KAN model is defined in a module called kan
from tqdm import tqdm  # Import tqdm for progress bars
from sklearn.metrics import roc_curve




# 数据预处理
df = pd.read_csv('imputed_selected_features_Toxcity.csv')
labels = df['Toxicity'].values
smiles_list = df['SMILES'].tolist()

# 函数：将SMILES转换为分子描述符和指纹
def smiles_to_features(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    # 提取描述符
    descriptors = [
        Descriptors.MolWt(mol),  # 分子量
        Descriptors.MolLogP(mol),  # LogP
        Descriptors.NumHDonors(mol),  # 氢键供体数量
        Descriptors.NumHAcceptors(mol)  # 氢键受体数量
    ]
    # 生成Morgan指纹
    fingerprint = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
    fingerprint_array = np.zeros((2048,))
    Chem.DataStructs.ConvertToNumpyArray(fingerprint, fingerprint_array)
    # 合并描述符和指纹
    features = np.concatenate([descriptors, fingerprint_array])
    return features

# 将SMILES转换为特征
features = []
for smiles in smiles_list:
    feature = smiles_to_features(smiles)
    if feature is not None:
        features.append(feature)

# 转换为numpy数组
features = np.array(features)
X = np.array(features)
y = labels

# 设置设备为 CUDA 或 CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 将数据转换为 PyTorch 张量
X_tensor = torch.tensor(X, dtype=torch.float32).to(device)
y_tensor = torch.tensor(y, dtype=torch.float32).to(device)

# 创建DataLoader函数
def get_dataloader(X, y, batch_size):
    dataset = TensorDataset(X, y)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

# 训练 KAN 模型
def train_kan(model, train_dataloader, optimizer, criterion, steps=20, patience=1):
    model.train()
    best_val_auc = -np.inf  # 最佳验证集AUC
    epochs_without_improvement = 0  # 连续多少个epoch验证集性能没有提升
    for epoch in range(steps):
        running_loss = 0.0
        for X_batch, y_batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{steps}", leave=False):
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs.squeeze(), y_batch)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{steps}], Loss: {running_loss/len(train_dataloader):.4f}")
        
        # 每个epoch结束后，进行验证集评估
        val_auc = evaluate_kan(model, val_dataloader)  # AUC评估
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

# 评估模型
def evaluate_kan(model, val_dataloader):
    model.eval()
    val_auc_scores = []
    with torch.no_grad():
        for X_batch, y_batch in val_dataloader:
            outputs = model(X_batch)
            probs = torch.sigmoid(outputs).squeeze().cpu().numpy()
            y_batch_cpu = y_batch.cpu().numpy()
            precision, recall, _ = precision_recall_curve(y_batch_cpu, probs)
            pr_auc = auc(recall, precision)
            val_auc_scores.append(pr_auc)
    return np.mean(val_auc_scores)


# 5折交叉验证
kf = KFold(n_splits=5, shuffle=True, random_state=42)
pr_auc_train_list, auc_train_list = [], []
f1_train_list, recall_train_list = [], []
precision_train_list, mcc_train_list = [], []
accuracy_train_list, pr_auc_val_list = [], []
auc_val_list, f1_val_list = [], []
recall_val_list, precision_val_list = [], []
mcc_val_list, accuracy_val_list = [], []

# 交叉验证训练与评估
for fold, (train_idx, val_idx) in enumerate(kf.split(X_tensor)):
    print(f"\nFold {fold + 1}")

    # 获取当前fold的训练集和验证集
    X_train_fold, X_val_fold = X_tensor[train_idx], X_tensor[val_idx]
    y_train_fold, y_val_fold = y_tensor[train_idx], y_tensor[val_idx]

    # 创建 DataLoader
    train_dataloader = get_dataloader(X_train_fold, y_train_fold, batch_size=128)  # 增加batch size
    val_dataloader = get_dataloader(X_val_fold, y_val_fold, batch_size=128)

    # 创建 KAN 模型
    model = KAN(width=[2048, 100, 1], grid=3, k=3, seed=42, device=device)

    # 使用 AdamW 优化器
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
    criterion = torch.nn.BCEWithLogitsLoss()

    # 训练 KAN 模型
    train_kan(model, train_dataloader, optimizer, criterion, steps=20, patience=5)

    # 评估模型（训练集和验证集）
    for data, target in [(X_train_fold, y_train_fold), (X_val_fold, y_val_fold)]:
        model.eval()
        with torch.no_grad():
            outputs = model(data)
            probs = torch.sigmoid(outputs).squeeze().cpu().numpy()
            target_cpu = target.cpu().numpy()
            
            # 计算 AUC 和其他指标
            precision, recall, _ = precision_recall_curve(target_cpu, probs)
            pr_auc = auc(recall, precision)
            pr_auc_train_list.append(pr_auc) if data is X_train_fold else pr_auc_val_list.append(pr_auc)
            
            fpr, tpr, _ = roc_curve(target_cpu, probs)
            auc_val = auc(fpr, tpr)
            auc_train_list.append(auc_val) if data is X_train_fold else auc_val_list.append(auc_val)
            
            f1 = f1_score(target_cpu, (probs > 0.5).astype(int))
            f1_train_list.append(f1) if data is X_train_fold else f1_val_list.append(f1)
            
            recall_val = recall_score(target_cpu, (probs > 0.5).astype(int))
            recall_train_list.append(recall_val) if data is X_train_fold else recall_val_list.append(recall_val)
            
            precision_val = precision_score(target_cpu, (probs > 0.5).astype(int))
            precision_train_list.append(precision_val) if data is X_train_fold else precision_val_list.append(precision_val)
            
            mcc_val = matthews_corrcoef(target_cpu, (probs > 0.5).astype(int))
            mcc_train_list.append(mcc_val) if data is X_train_fold else mcc_val_list.append(mcc_val)
            
            accuracy_val = accuracy_score(target_cpu, (probs > 0.5).astype(int))
            accuracy_train_list.append(accuracy_val) if data is X_train_fold else accuracy_val_list.append(accuracy_val)

# 计算五折平均的评估指标（训练集和验证集）
print(f"\nAverage PR-AUC (Train): {np.mean(pr_auc_train_list):.4f}")
print(f"Average AUC (Train): {np.mean(auc_train_list):.4f}")
print(f"Average F1 (Train): {np.mean(f1_train_list):.4f}")
print(f"Average Recall (Train): {np.mean(recall_train_list):.4f}")
print(f"Average Precision (Train): {np.mean(precision_train_list):.4f}")
print(f"Average MCC (Train): {np.mean(mcc_train_list):.4f}")
print(f"Average Accuracy (Train): {np.mean(accuracy_train_list):.4f}")

print(f"\nAverage PR-AUC (Validation): {np.mean(pr_auc_val_list):.4f}")
print(f"Average AUC (Validation): {np.mean(auc_val_list):.4f}")
print(f"Average F1 (Validation): {np.mean(f1_val_list):.4f}")
print(f"Average Recall (Validation): {np.mean(recall_val_list):.4f}")
print(f"Average Precision (Validation): {np.mean(precision_val_list):.4f}")
print(f"Average MCC (Validation): {np.mean(mcc_val_list):.4f}")
print(f"Average Accuracy (Validation): {np.mean(accuracy_val_list):.4f}")



Fold 1
checkpoint directory created: ./model
saving model version 0.0


                                                         

Epoch [1/20], Loss: 0.6102


                                                          

Epoch [2/20], Loss: 0.4692


                                                         

Epoch [3/20], Loss: 0.4185


                                                         

Epoch [4/20], Loss: 0.3719


                                                         

Epoch [5/20], Loss: 0.3613


                                                         

Epoch [6/20], Loss: 0.3534


                                                          

Epoch [7/20], Loss: 0.3375


                                                         

Epoch [8/20], Loss: 0.3318


                                                          

Epoch [9/20], Loss: 0.3219


                                                          

Epoch [10/20], Loss: 0.3516


                                                          

Epoch [11/20], Loss: 0.3155


                                                           

Epoch [12/20], Loss: 0.3070


                                                          

Epoch [13/20], Loss: 0.3082


                                                          

Epoch [14/20], Loss: 0.2983


                                                          

Epoch [15/20], Loss: 0.2833


                                                          

Epoch [16/20], Loss: 0.2777


                                                          

Epoch [17/20], Loss: 0.2758


                                                          

Epoch [18/20], Loss: 0.2665


                                                          

Epoch [19/20], Loss: 0.2701
Early stopping at epoch 19

Fold 2
checkpoint directory created: ./model
saving model version 0.0


                                                         

Epoch [1/20], Loss: 0.5807


                                                         

Epoch [2/20], Loss: 0.4435


                                                         

Epoch [3/20], Loss: 0.4072


                                                         

Epoch [4/20], Loss: 0.3529


                                                         

Epoch [5/20], Loss: 0.3521


                                                         

Epoch [6/20], Loss: 0.3231


                                                         

Epoch [7/20], Loss: 0.3231


                                                         

Epoch [8/20], Loss: 0.3002


                                                          

Epoch [9/20], Loss: 0.2985


                                                          

Epoch [10/20], Loss: 0.2939


                                                          

Epoch [11/20], Loss: 0.2862


                                                           

Epoch [12/20], Loss: 0.2836


                                                           

Epoch [13/20], Loss: 0.2711


                                                           

Epoch [14/20], Loss: 0.2655


                                                           

Epoch [15/20], Loss: 0.2481


                                                           

Epoch [16/20], Loss: 0.2516


                                                           

Epoch [17/20], Loss: 0.2496


                                                           

Epoch [18/20], Loss: 0.2445
Early stopping at epoch 18

Fold 3
checkpoint directory created: ./model
saving model version 0.0


                                                         

Epoch [1/20], Loss: 0.5822


                                                         

Epoch [2/20], Loss: 0.4589


                                                         

Epoch [3/20], Loss: 0.4193


                                                         

Epoch [4/20], Loss: 0.3731


                                                         

Epoch [5/20], Loss: 0.3542


                                                         

Epoch [6/20], Loss: 0.3396


                                                          

Epoch [7/20], Loss: 0.2995


                                                          

Epoch [8/20], Loss: 0.2973


                                                          

Epoch [9/20], Loss: 0.3140


                                                          

Epoch [10/20], Loss: 0.2918


                                                           

Epoch [11/20], Loss: 0.2933


                                                          

Epoch [12/20], Loss: 0.2793


                                                           

Epoch [13/20], Loss: 0.2845


                                                          

Epoch [14/20], Loss: 0.2679


                                                           

Epoch [15/20], Loss: 0.2749


                                                          

Epoch [16/20], Loss: 0.2579


                                                          

Epoch [17/20], Loss: 0.2534


                                                          

Epoch [18/20], Loss: 0.2614


                                                          

Epoch [19/20], Loss: 0.2472
Early stopping at epoch 19

Fold 4
checkpoint directory created: ./model
saving model version 0.0


                                                         

Epoch [1/20], Loss: 0.6078


                                                         

Epoch [2/20], Loss: 0.3965


                                                          

Epoch [3/20], Loss: 0.4008


                                                          

Epoch [4/20], Loss: 0.3518


                                                         

Epoch [5/20], Loss: 0.3331


                                                          

Epoch [6/20], Loss: 0.3154


                                                         

Epoch [7/20], Loss: 0.3167


                                                         

Epoch [8/20], Loss: 0.3028


                                                          

Epoch [9/20], Loss: 0.3145


                                                          

Epoch [10/20], Loss: 0.2924


                                                           

Epoch [11/20], Loss: 0.2932
Early stopping at epoch 11

Fold 5
checkpoint directory created: ./model
saving model version 0.0


                                                         

Epoch [1/20], Loss: 0.5864


                                                         

Epoch [2/20], Loss: 0.5043


                                                         

Epoch [3/20], Loss: 0.4545


                                                         

Epoch [4/20], Loss: 0.3762


                                                         

Epoch [5/20], Loss: 0.3864


                                                         

Epoch [6/20], Loss: 0.3806


                                                         

Epoch [7/20], Loss: 0.3460


                                                         

Epoch [8/20], Loss: 0.3450


                                                         

Epoch [9/20], Loss: 0.3338
Early stopping at epoch 9

Average PR-AUC (Train): 0.9843
Average AUC (Train): 0.8997
Average F1 (Train): 0.9409
Average Recall (Train): 0.9956
Average Precision (Train): 0.8920
Average MCC (Train): 0.2785
Average Accuracy (Train): 0.8902

Average PR-AUC (Validation): 0.9420
Average AUC (Validation): 0.7008
Average F1 (Validation): 0.9306
Average Recall (Validation): 0.9888
Average Precision (Validation): 0.8799
Average MCC (Validation): 0.0307
Average Accuracy (Validation): 0.8715


In [None]:
Average PR-AUC (Train): 0.9843
Average AUC (Train): 0.8997
Average F1 (Train): 0.9409
Average Recall (Train): 0.9956
Average Precision (Train): 0.8920
Average MCC (Train): 0.2785
Average Accuracy (Train): 0.8902

Average PR-AUC (Validation): 0.9420
Average AUC (Validation): 0.7008
Average F1 (Validation): 0.9306
Average Recall (Validation): 0.9888
Average Precision (Validation): 0.8799
Average MCC (Validation): 0.0307
Average Accuracy (Validation): 0.8715