In [None]:
import os
import gc
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import (
    roc_auc_score, precision_recall_curve, roc_curve, auc,
    confusion_matrix, cohen_kappa_score, matthews_corrcoef, f1_score
)
from sklearn.preprocessing import StandardScaler
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import AdaBoostClassifier, GradientBoostingClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from rdkit import Chem
from rdkit.Chem import Descriptors, AllChem
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
import psutil

import warnings
from rdkit import RDLogger
# 关闭 RDKit 的所有日志（包括警告）
RDLogger.DisableLog('rdApp.*')  # 禁用所有 RDKit 日志
# 导入svg高清图库
import seaborn as sns
import matplotlib.pyplot as plt

# 设置 Matplotlib 支持中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']  # 使用黑体
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

# 设置 Matplotlib 后端为 SVG
%config InlineBackend.figure_format = 'svg'

# 设置 DPI 以提高图像清晰度
plt.rcParams['figure.dpi'] = 300
warnings.filterwarnings("ignore")

# ---------------------- Memory Protection ----------------------
def memory_safe(func):
    def wrapper(*args, **kwargs):
        mem = psutil.virtual_memory()
        if mem.percent > 80:
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            print(f"⚠️ Memory warning: Usage {mem.percent}%, performed garbage collection")
        return func(*args, **kwargs)
    return wrapper

# ---------------------- SMILES Feature Extraction ----------------------
class SMILESFeatureExtractor:
    def __init__(self, fp_size=1024, desc_list=None):
        self.fp_size = fp_size
        self.desc_list = desc_list or [
            'MolWt', 'NumHAcceptors', 'NumHDonors', 
            'MolLogP', 'TPSA', 'NumRotatableBonds'
        ]
    
    @memory_safe
    def smiles_to_features(self, smiles):
        """Convert SMILES to numerical features"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if not mol:
                return np.nan * np.ones(len(self.desc_list) + self.fp_size)
            
            # Calculate descriptors
            desc_values = [getattr(Descriptors, desc)(mol) for desc in self.desc_list]
            
            # Calculate fingerprints
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=self.fp_size)
            fp_values = np.array(fp, dtype=np.float32)
            
            return np.concatenate([desc_values, fp_values])
        except:
            return np.nan * np.ones(len(self.desc_list) + self.fp_size)

# ---------------------- Data Preparation ----------------------
def prepare_features(X_smiles):
    """Convert SMILES pairs to numerical features"""
    fe = SMILESFeatureExtractor()
    features = []
    
    for drug1, drug2 in tqdm(X_smiles, desc="Extracting features"):
        feat1 = fe.smiles_to_features(drug1)
        feat2 = fe.smiles_to_features(drug2)
        features.append(np.concatenate([feat1, feat2]))
    
    X_num = np.stack(features)
    
    # Handle NaN values
    X_num = np.nan_to_num(X_num)
    return X_num

# ---------------------- Deep Learning Components ----------------------
class DrugInteractionDataset(Dataset):
    def __init__(self, drug1_smiles, drug2_smiles, labels, tokenizer, max_length=128):
        self.drug1_smiles = drug1_smiles
        self.drug2_smiles = drug2_smiles
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        encoding1 = self.tokenizer(
            str(self.drug1_smiles[idx]), 
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        encoding2 = self.tokenizer(
            str(self.drug2_smiles[idx]),
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {
            'drug1_input_ids': encoding1['input_ids'].flatten(),
            'drug1_attention_mask': encoding1['attention_mask'].flatten(),
            'drug2_input_ids': encoding2['input_ids'].flatten(),
            'drug2_attention_mask': encoding2['attention_mask'].flatten(),
            'label': torch.tensor(self.labels[idx], dtype=torch.long)
        }

class CoAttentionModel(nn.Module):
    def __init__(self, bert_model_name="DeepChem/ChemBERTa-77M-MLM", hidden_size=384):
        super().__init__()
        self.bert = AutoModel.from_pretrained(bert_model_name)
        self.co_attention = nn.MultiheadAttention(hidden_size, num_heads=8)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size*4, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, 2)
        )
    
    def forward(self, drug1_input_ids, drug1_attention_mask, drug2_input_ids, drug2_attention_mask):
        drug1 = self.bert(drug1_input_ids, attention_mask=drug1_attention_mask).last_hidden_state[:, 0, :]
        drug2 = self.bert(drug2_input_ids, attention_mask=drug2_attention_mask).last_hidden_state[:, 0, :]
        
        # Co-attention
        attn1, _ = self.co_attention(drug1.unsqueeze(1), drug2.unsqueeze(1), drug2.unsqueeze(1))
        attn2, _ = self.co_attention(drug2.unsqueeze(1), drug1.unsqueeze(1), drug1.unsqueeze(1))
        
        combined = torch.cat([drug1, drug2, attn1.squeeze(1), attn2.squeeze(1)], dim=1)
        return self.classifier(combined)

# ---------------------- Training and Evaluation ----------------------
@memory_safe
def train_model(model, train_loader, val_loader, optimizer, criterion, device, epochs=10):
    best_val_auc = 0
    history = {'train_loss': [], 'val_loss': [], 'val_auc': []}
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        
        for batch in progress_bar:
            optimizer.zero_grad()
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'label'}
            outputs = model(**inputs)
            loss = criterion(outputs, batch['label'].to(device))
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            progress_bar.set_postfix({'loss': train_loss/(progress_bar.n+1)})
        
        # Validation
        val_loss, val_metrics = evaluate_model(model, val_loader, criterion, device)
        print(f"\nValidation - Loss: {val_loss:.4f}, AUC: {val_metrics['AUC']:.4f}, F1: {val_metrics['F1']:.4f}")
        
        # Save best model
        if val_metrics['AUC'] > best_val_auc:
            best_val_auc = val_metrics['AUC']
            torch.save(model.state_dict(), "best_model.pth")
            print("✅ Saved best model")
        
        history['train_loss'].append(train_loss/len(train_loader))
        history['val_loss'].append(val_loss)
        history['val_auc'].append(val_metrics['AUC'])
    
    return model, history

@memory_safe
def evaluate_model(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for batch in data_loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'label'}
            outputs = model(**inputs)
            loss = criterion(outputs, batch['label'].to(device))
            total_loss += loss.item()
            
            probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
            all_labels.extend(batch['label'].cpu().numpy())
            all_probs.extend(probs)
    
    # Calculate metrics
    auc_score = roc_auc_score(all_labels, all_probs)
    y_pred = (np.array(all_probs) > 0.5).astype(int)
    tn, fp, fn, tp = confusion_matrix(all_labels, y_pred).ravel()
    
    metrics = {
        "AUC": auc_score,
        "Sensitivity": tp / (tp + fn),
        "Specificity": tn / (tn + fp),
        "Kappa": cohen_kappa_score(all_labels, y_pred),
        "MCC": matthews_corrcoef(all_labels, y_pred),
        "F1": f1_score(all_labels, y_pred),
    }
    
    return total_loss / len(data_loader), metrics

# ---------------------- Main Execution ----------------------
def main():
    # Load data
    file_path = "/kaggle/working/1/dat.txt"
    df = pd.read_csv(file_path, sep='\t', header=None)
    X_smiles = df.iloc[:, :2].values
    Y = df.iloc[:, -1].values
    
    # Prepare numerical features for traditional models
    X_num = prepare_features(X_smiles)
    scaler = StandardScaler()
    X_num = scaler.fit_transform(X_num)
    
    # Define classifiers
    classifiers = {
        "ChemCoBERT": None,
        "Decision Tree": DecisionTreeClassifier(random_state=42),
        "AdaBoost": AdaBoostClassifier(n_estimators=100, random_state=42),
        "GBDT": GradientBoostingClassifier(n_estimators=100, random_state=42),
        "K-NN": KNeighborsClassifier(n_neighbors=5),
        "Naive Bayes": GaussianNB(),
    }
    
    # Initialize metrics storage
    all_metrics = {name: [] for name in classifiers}
    pr_curves = {name: {"precision": [], "recall": [], "auc": []} for name in classifiers}
    roc_curves = {name: {"fpr": [], "tpr": [], "auc": []} for name in classifiers}
    
    # 3-fold CV
    kf = KFold(n_splits=3, shuffle=True, random_state=42)
    for fold, (train_idx, val_idx) in enumerate(kf.split(X_num)):
        print(f"\nFold {fold + 1}")
        
        # Split data
        X_train_num, X_val_num = X_num[train_idx], X_num[val_idx]
        X_train_smiles, X_val_smiles = X_smiles[train_idx], X_smiles[val_idx]
        y_train, y_val = Y[train_idx], Y[val_idx]
        
        for name, clf in classifiers.items():
            print(f"\nTraining {name}...")
            
            if name == "ChemCoBERT":
                # Deep learning model
                tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MLM")
                train_dataset = DrugInteractionDataset(
                    X_train_smiles[:, 0], X_train_smiles[:, 1], y_train, tokenizer
                )
                val_dataset = DrugInteractionDataset(
                    X_val_smiles[:, 0], X_val_smiles[:, 1], y_val, tokenizer
                )
                
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                model = CoAttentionModel().to(device)
                optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
                criterion = nn.CrossEntropyLoss()
                
                # Train
                model, _ = train_model(
                    model, 
                    DataLoader(train_dataset, batch_size=64, shuffle=True),
                    DataLoader(val_dataset, batch_size=64),
                    optimizer, criterion, device, epochs=10
                )
                
                # Evaluate
                model.load_state_dict(torch.load("best_model.pth"))
                _, val_metrics = evaluate_model(
                    model, 
                    DataLoader(val_dataset, batch_size=64),
                    criterion, device
                )
                
                # Get predictions
                model.eval()
                all_probs = []
                with torch.no_grad():
                    for batch in DataLoader(val_dataset, batch_size=128):
                        inputs = {k: v.to(device) for k, v in batch.items() if k != 'label'}
                        outputs = model(**inputs)
                        probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
                        all_probs.extend(probs)
                
                y_pred_prob = np.array(all_probs)
                y_pred = (y_pred_prob > 0.5).astype(int)
                
                # Save model
                model_path = f"best_models/fold_{fold+1}_best.pth"
                torch.save(model.state_dict(), model_path)
                print(f"Saved model to {model_path}")
                
            else:
                # Traditional ML models
                clf.fit(X_train_num, y_train)
                if hasattr(clf, "predict_proba"):
                    y_pred_prob = clf.predict_proba(X_val_num)[:, 1]
                else:
                    y_pred_prob = clf.decision_function(X_val_num)
                    y_pred_prob = (y_pred_prob - y_pred_prob.min()) / (y_pred_prob.max() - y_pred_prob.min())
                y_pred = (y_pred_prob > 0.5).astype(int)
                
                # Calculate metrics
                val_metrics = {
                    "AUC": roc_auc_score(y_val, y_pred_prob),
                    "Sensitivity": confusion_matrix(y_val, y_pred)[1,1] / (confusion_matrix(y_val, y_pred)[1,1] + confusion_matrix(y_val, y_pred)[1,0]),
                    "Specificity": confusion_matrix(y_val, y_pred)[0,0] / (confusion_matrix(y_val, y_pred)[0,0] + confusion_matrix(y_val, y_pred)[0,1]),
                    "Kappa": cohen_kappa_score(y_val, y_pred),
                    "MCC": matthews_corrcoef(y_val, y_pred),
                    "F1": f1_score(y_val, y_pred),
                }
            
            # Store metrics
            all_metrics[name].append(val_metrics)
            
            # Store curves
            precision, recall, _ = precision_recall_curve(y_val, y_pred_prob)
            pr_auc = auc(recall, precision)
            pr_curves[name]["precision"].append(precision)
            pr_curves[name]["recall"].append(recall)
            pr_curves[name]["auc"].append(pr_auc)
            
            fpr, tpr, _ = roc_curve(y_val, y_pred_prob)
            roc_auc = auc(fpr, tpr)
            roc_curves[name]["fpr"].append(fpr)
            roc_curves[name]["tpr"].append(tpr)
            roc_curves[name]["auc"].append(roc_auc)
            
            # Print fold results
            print(f"{name} Fold {fold+1} Results:")
            for metric, value in val_metrics.items():
                print(f"{metric}: {value:.4f}")
    
    # Print final metrics
    print("\nFinal Metrics:")
    for name in classifiers:
        print(f"\n{name}:")
        for metric in all_metrics[name][0].keys():
            values = [m[metric] for m in all_metrics[name]]
            print(f"{metric}: {np.mean(values):.4f} ± {np.std(values):.4f}")
    
    # Plot curves
    plt.figure(figsize=(15, 6))
    
    # PR Curve
    plt.subplot(122)
    for name in classifiers:
        precision_interp = []
        for precision, recall in zip(pr_curves[name]["precision"], pr_curves[name]["recall"]):
            recall_interp = np.linspace(0, 1, 100)
            precision_interp.append(np.interp(recall_interp, recall[::-1], precision[::-1]))
        mean_precision = np.mean(precision_interp, axis=0)
        
        # Set ChemCoBERT to red, others to default colors
        if name == "ChemCoBERT":
            plt.plot(recall_interp, mean_precision, '#E41A1C', linewidth=3, label=f"{name} (AUC={np.mean(pr_curves[name]['auc']):.4f})")
        elif name == "GBDT":
            plt.plot(recall_interp, mean_precision, '#FF7F00', linewidth=2.5, label=f"{name} (AUC={np.mean(pr_curves[name]['auc']):.4f})")
        elif name == "Decision Tree":
            plt.plot(recall_interp, mean_precision, '#A65628', linewidth=2.5, label=f"{name} (AUC={np.mean(pr_curves[name]['auc']):.4f})")
        elif name == "Naive Bayes":
            plt.plot(recall_interp, mean_precision, '#377EB8', linewidth=2.5, label=f"{name} (AUC={np.mean(pr_curves[name]['auc']):.4f})")
        elif name == "K-NN":
            plt.plot(recall_interp, mean_precision, '#984EA3', linewidth=2.5, label=f"{name} (AUC={np.mean(pr_curves[name]['auc']):.4f})")
        elif name == "AdaBoost":
            plt.plot(recall_interp, mean_precision, '#4DAF4A', linewidth=2.5, label=f"{name} (AUC={np.mean(pr_curves[name]['auc']):.4f})")
        else:
            plt.plot(recall_interp, mean_precision, linewidth=2.5, label=f"{name} (AUC={np.mean(pr_curves[name]['auc']):.4f})")
    plt.xlabel("Recall", fontweight='bold')
    plt.ylabel("Precision", fontweight='bold')
    plt.title("Precision Recall Curve",fontweight='bold')
    plt.legend()
    
    # Receiver Operating Characteristic Curve
    plt.subplot(121)
    for name in classifiers:
        tpr_interp = []
        for fpr, tpr in zip(roc_curves[name]["fpr"], roc_curves[name]["tpr"]):
            fpr_interp = np.linspace(0, 1, 100)
            tpr_interp.append(np.interp(fpr_interp, fpr, tpr))
        mean_tpr = np.mean(tpr_interp, axis=0)
        
        # Set ChemCoBERT to red, others to default colors
        if name == "ChemCoBERT":
            plt.plot(fpr_interp, mean_tpr, '#E41A1C', linewidth=2.5, 
                    label=f"{name} (AUC={np.mean(roc_curves[name]['auc']):.4f})")
        elif name == "GBDT":
            plt.plot(fpr_interp, mean_tpr, '#FF7F00', linewidth=2.5, 
                    label=f"{name} (AUC={np.mean(roc_curves[name]['auc']):.4f})")
        elif name == "Decision Tree":
            plt.plot(fpr_interp, mean_tpr, '#A65628', linewidth=2.5, 
                    label=f"{name} (AUC={np.mean(roc_curves[name]['auc']):.4f})")
        elif name == "Naive Bayes":
            plt.plot(fpr_interp, mean_tpr, '#377EB8', linewidth=2.5, 
                    label=f"{name} (AUC={np.mean(roc_curves[name]['auc']):.4f})")
        elif name == "K-NN":
            plt.plot(fpr_interp, mean_tpr, '#984EA3', linewidth=2.5, 
                    label=f"{name} (AUC={np.mean(roc_curves[name]['auc']):.4f})")
        elif name == "AdaBoost":
            plt.plot(fpr_interp, mean_tpr, '#4DAF4A', linewidth=2.5, 
                    label=f"{name} (AUC={np.mean(roc_curves[name]['auc']):.4f})")
        else:
            plt.plot(fpr_interp, mean_tpr, linewidth=2.5, 
                    label=f"{name} (AUC={np.mean(roc_curves[name]['auc']):.4f})")
    
   
    plt.xlabel("False Positive Rate", fontweight='bold')
    plt.ylabel("True Positive Rate", fontweight='bold')
    plt.title("Receiver Operating Characteristic Curve", fontweight='bold')
    plt.legend(loc="lower right")
    
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    main()


## 5

In [None]:
import os
import gc
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import (
    roc_auc_score, precision_recall_curve, roc_curve, auc,
    confusion_matrix, cohen_kappa_score, matthews_corrcoef, f1_score
)
from sklearn.preprocessing import StandardScaler
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import AdaBoostClassifier, GradientBoostingClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from rdkit import Chem
from rdkit.Chem import Descriptors, AllChem
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
import psutil

import warnings
from rdkit import RDLogger
# 关闭 RDKit 的所有日志（包括警告）
RDLogger.DisableLog('rdApp.*')  # 禁用所有 RDKit 日志
# 导入svg高清图库
import seaborn as sns
import matplotlib.pyplot as plt

# 设置 Matplotlib 支持中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']  # 使用黑体
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

# 设置 Matplotlib 后端为 SVG
%config InlineBackend.figure_format = 'svg'

# 设置 DPI 以提高图像清晰度
plt.rcParams['figure.dpi'] = 300
warnings.filterwarnings("ignore")

# ---------------------- Memory Protection ----------------------
def memory_safe(func):
    def wrapper(*args, **kwargs):
        mem = psutil.virtual_memory()
        if mem.percent > 80:
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            print(f"⚠️ Memory warning: Usage {mem.percent}%, performed garbage collection")
        return func(*args, **kwargs)
    return wrapper

# ---------------------- SMILES Feature Extraction ----------------------
class SMILESFeatureExtractor:
    def __init__(self, fp_size=1024, desc_list=None):
        self.fp_size = fp_size
        self.desc_list = desc_list or [
            'MolWt', 'NumHAcceptors', 'NumHDonors', 
            'MolLogP', 'TPSA', 'NumRotatableBonds'
        ]
    
    @memory_safe
    def smiles_to_features(self, smiles):
        """Convert SMILES to numerical features"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if not mol:
                return np.nan * np.ones(len(self.desc_list) + self.fp_size)
            
            # Calculate descriptors
            desc_values = [getattr(Descriptors, desc)(mol) for desc in self.desc_list]
            
            # Calculate fingerprints
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=self.fp_size)
            fp_values = np.array(fp, dtype=np.float32)
            
            return np.concatenate([desc_values, fp_values])
        except:
            return np.nan * np.ones(len(self.desc_list) + self.fp_size)

# ---------------------- Data Preparation ----------------------
def prepare_features(X_smiles):
    """Convert SMILES pairs to numerical features"""
    fe = SMILESFeatureExtractor()
    features = []
    
    for drug1, drug2 in tqdm(X_smiles, desc="Extracting features"):
        feat1 = fe.smiles_to_features(drug1)
        feat2 = fe.smiles_to_features(drug2)
        features.append(np.concatenate([feat1, feat2]))
    
    X_num = np.stack(features)
    
    # Handle NaN values
    X_num = np.nan_to_num(X_num)
    return X_num

# ---------------------- Deep Learning Components ----------------------
class DrugInteractionDataset(Dataset):
    def __init__(self, drug1_smiles, drug2_smiles, labels, tokenizer, max_length=128):
        self.drug1_smiles = drug1_smiles
        self.drug2_smiles = drug2_smiles
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        encoding1 = self.tokenizer(
            str(self.drug1_smiles[idx]), 
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        encoding2 = self.tokenizer(
            str(self.drug2_smiles[idx]),
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {
            'drug1_input_ids': encoding1['input_ids'].flatten(),
            'drug1_attention_mask': encoding1['attention_mask'].flatten(),
            'drug2_input_ids': encoding2['input_ids'].flatten(),
            'drug2_attention_mask': encoding2['attention_mask'].flatten(),
            'label': torch.tensor(self.labels[idx], dtype=torch.long)
        }

class CoAttentionModel(nn.Module):
    def __init__(self, bert_model_name="DeepChem/ChemBERTa-77M-MLM", hidden_size=384):
        super().__init__()
        self.bert = AutoModel.from_pretrained(bert_model_name)
        self.co_attention = nn.MultiheadAttention(hidden_size, num_heads=8)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size*4, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, 2)
        )
    
    def forward(self, drug1_input_ids, drug1_attention_mask, drug2_input_ids, drug2_attention_mask):
        drug1 = self.bert(drug1_input_ids, attention_mask=drug1_attention_mask).last_hidden_state[:, 0, :]
        drug2 = self.bert(drug2_input_ids, attention_mask=drug2_attention_mask).last_hidden_state[:, 0, :]
        
        # Co-attention
        attn1, _ = self.co_attention(drug1.unsqueeze(1), drug2.unsqueeze(1), drug2.unsqueeze(1))
        attn2, _ = self.co_attention(drug2.unsqueeze(1), drug1.unsqueeze(1), drug1.unsqueeze(1))
        
        combined = torch.cat([drug1, drug2, attn1.squeeze(1), attn2.squeeze(1)], dim=1)
        return self.classifier(combined)

# ---------------------- Training and Evaluation ----------------------
@memory_safe
def train_model(model, train_loader, val_loader, optimizer, criterion, device, epochs=10):
    best_val_auc = 0
    history = {'train_loss': [], 'val_loss': [], 'val_auc': []}
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        
        for batch in progress_bar:
            optimizer.zero_grad()
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'label'}
            outputs = model(**inputs)
            loss = criterion(outputs, batch['label'].to(device))
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            progress_bar.set_postfix({'loss': train_loss/(progress_bar.n+1)})
        
        # Validation
        val_loss, val_metrics = evaluate_model(model, val_loader, criterion, device)
        print(f"\nValidation - Loss: {val_loss:.4f}, AUC: {val_metrics['AUC']:.4f}, F1: {val_metrics['F1']:.4f}")
        
        # Save best model
        if val_metrics['AUC'] > best_val_auc:
            best_val_auc = val_metrics['AUC']
            torch.save(model.state_dict(), "best_model.pth")
            print("✅ Saved best model")
        
        history['train_loss'].append(train_loss/len(train_loader))
        history['val_loss'].append(val_loss)
        history['val_auc'].append(val_metrics['AUC'])
    
    return model, history

@memory_safe
def evaluate_model(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for batch in data_loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'label'}
            outputs = model(**inputs)
            loss = criterion(outputs, batch['label'].to(device))
            total_loss += loss.item()
            
            probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
            all_labels.extend(batch['label'].cpu().numpy())
            all_probs.extend(probs)
    
    # Calculate metrics
    auc_score = roc_auc_score(all_labels, all_probs)
    y_pred = (np.array(all_probs) > 0.5).astype(int)
    tn, fp, fn, tp = confusion_matrix(all_labels, y_pred).ravel()
    
    metrics = {
        "AUC": auc_score,
        "Sensitivity": tp / (tp + fn),
        "Specificity": tn / (tn + fp),
        "Kappa": cohen_kappa_score(all_labels, y_pred),
        "MCC": matthews_corrcoef(all_labels, y_pred),
        "F1": f1_score(all_labels, y_pred),
    }
    
    return total_loss / len(data_loader), metrics

# ---------------------- Main Execution ----------------------
def main():
    # Load data
    file_path = "/kaggle/working/1/dat.txt"
    df = pd.read_csv(file_path, sep='\t', header=None)
    X_smiles = df.iloc[:, :2].values
    Y = df.iloc[:, -1].values
    
    # Prepare numerical features for traditional models
    X_num = prepare_features(X_smiles)
    scaler = StandardScaler()
    X_num = scaler.fit_transform(X_num)
    
    # Define classifiers
    classifiers = {
        "ChemCoBERT": None,
        "Decision Tree": DecisionTreeClassifier(random_state=42),
        "AdaBoost": AdaBoostClassifier(n_estimators=100, random_state=42),
        "GBDT": GradientBoostingClassifier(n_estimators=100, random_state=42),
        "K-NN": KNeighborsClassifier(n_neighbors=5),
        "Naive Bayes": GaussianNB(),
    }
    
    # Initialize metrics storage
    all_metrics = {name: [] for name in classifiers}
    pr_curves = {name: {"precision": [], "recall": [], "auc": []} for name in classifiers}
    roc_curves = {name: {"fpr": [], "tpr": [], "auc": []} for name in classifiers}
    
    # 5-fold CV
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    for fold, (train_idx, val_idx) in enumerate(kf.split(X_num)):
        print(f"\nFold {fold + 1}")
        
        # Split data
        X_train_num, X_val_num = X_num[train_idx], X_num[val_idx]
        X_train_smiles, X_val_smiles = X_smiles[train_idx], X_smiles[val_idx]
        y_train, y_val = Y[train_idx], Y[val_idx]
        
        for name, clf in classifiers.items():
            print(f"\nTraining {name}...")
            
            if name == "ChemCoBERT":
                # Deep learning model
                tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MLM")
                train_dataset = DrugInteractionDataset(
                    X_train_smiles[:, 0], X_train_smiles[:, 1], y_train, tokenizer
                )
                val_dataset = DrugInteractionDataset(
                    X_val_smiles[:, 0], X_val_smiles[:, 1], y_val, tokenizer
                )
                
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                model = CoAttentionModel().to(device)
                optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
                criterion = nn.CrossEntropyLoss()
                
                # Train
                model, _ = train_model(
                    model, 
                    DataLoader(train_dataset, batch_size=64, shuffle=True),
                    DataLoader(val_dataset, batch_size=64),
                    optimizer, criterion, device, epochs=10
                )
                
                # Evaluate
                model.load_state_dict(torch.load("best_model.pth"))
                _, val_metrics = evaluate_model(
                    model, 
                    DataLoader(val_dataset, batch_size=64),
                    criterion, device
                )
                
                # Get predictions
                model.eval()
                all_probs = []
                with torch.no_grad():
                    for batch in DataLoader(val_dataset, batch_size=128):
                        inputs = {k: v.to(device) for k, v in batch.items() if k != 'label'}
                        outputs = model(**inputs)
                        probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
                        all_probs.extend(probs)
                
                y_pred_prob = np.array(all_probs)
                y_pred = (y_pred_prob > 0.5).astype(int)
                
                # Save model
                model_path = f"best_models/fold_{fold+1}_best.pth"
                torch.save(model.state_dict(), model_path)
                print(f"Saved model to {model_path}")
                
            else:
                # Traditional ML models
                clf.fit(X_train_num, y_train)
                if hasattr(clf, "predict_proba"):
                    y_pred_prob = clf.predict_proba(X_val_num)[:, 1]
                else:
                    y_pred_prob = clf.decision_function(X_val_num)
                    y_pred_prob = (y_pred_prob - y_pred_prob.min()) / (y_pred_prob.max() - y_pred_prob.min())
                y_pred = (y_pred_prob > 0.5).astype(int)
                
                # Calculate metrics
                val_metrics = {
                    "AUC": roc_auc_score(y_val, y_pred_prob),
                    "Sensitivity": confusion_matrix(y_val, y_pred)[1,1] / (confusion_matrix(y_val, y_pred)[1,1] + confusion_matrix(y_val, y_pred)[1,0]),
                    "Specificity": confusion_matrix(y_val, y_pred)[0,0] / (confusion_matrix(y_val, y_pred)[0,0] + confusion_matrix(y_val, y_pred)[0,1]),
                    "Kappa": cohen_kappa_score(y_val, y_pred),
                    "MCC": matthews_corrcoef(y_val, y_pred),
                    "F1": f1_score(y_val, y_pred),
                }
            
            # Store metrics
            all_metrics[name].append(val_metrics)
            
            # Store curves
            precision, recall, _ = precision_recall_curve(y_val, y_pred_prob)
            pr_auc = auc(recall, precision)
            pr_curves[name]["precision"].append(precision)
            pr_curves[name]["recall"].append(recall)
            pr_curves[name]["auc"].append(pr_auc)
            
            fpr, tpr, _ = roc_curve(y_val, y_pred_prob)
            roc_auc = auc(fpr, tpr)
            roc_curves[name]["fpr"].append(fpr)
            roc_curves[name]["tpr"].append(tpr)
            roc_curves[name]["auc"].append(roc_auc)
            
            # Print fold results
            print(f"{name} Fold {fold+1} Results:")
            for metric, value in val_metrics.items():
                print(f"{metric}: {value:.4f}")
    
    # Print final metrics
    print("\nFinal Metrics:")
    for name in classifiers:
        print(f"\n{name}:")
        for metric in all_metrics[name][0].keys():
            values = [m[metric] for m in all_metrics[name]]
            print(f"{metric}: {np.mean(values):.4f} ± {np.std(values):.4f}")
    
    # Plot curves
    plt.figure(figsize=(15, 6))
    
    # PR Curve
    plt.subplot(122)
    for name in classifiers:
        precision_interp = []
        for precision, recall in zip(pr_curves[name]["precision"], pr_curves[name]["recall"]):
            recall_interp = np.linspace(0, 1, 100)
            precision_interp.append(np.interp(recall_interp, recall[::-1], precision[::-1]))
        mean_precision = np.mean(precision_interp, axis=0)
        
        # Set ChemCoBERT to red, others to default colors
        if name == "ChemCoBERT":
            plt.plot(recall_interp, mean_precision, '#E41A1C', linewidth=3, label=f"{name} (AUC={np.mean(pr_curves[name]['auc']):.4f})")
        elif name == "GBDT":
            plt.plot(recall_interp, mean_precision, '#FF7F00', linewidth=2.5, label=f"{name} (AUC={np.mean(pr_curves[name]['auc']):.4f})")
        elif name == "Decision Tree":
            plt.plot(recall_interp, mean_precision, '#A65628', linewidth=2.5, label=f"{name} (AUC={np.mean(pr_curves[name]['auc']):.4f})")
        elif name == "Naive Bayes":
            plt.plot(recall_interp, mean_precision, '#377EB8', linewidth=2.5, label=f"{name} (AUC={np.mean(pr_curves[name]['auc']):.4f})")
        elif name == "K-NN":
            plt.plot(recall_interp, mean_precision, '#984EA3', linewidth=2.5, label=f"{name} (AUC={np.mean(pr_curves[name]['auc']):.4f})")
        elif name == "AdaBoost":
            plt.plot(recall_interp, mean_precision, '#4DAF4A', linewidth=2.5, label=f"{name} (AUC={np.mean(pr_curves[name]['auc']):.4f})")
        else:
            plt.plot(recall_interp, mean_precision, linewidth=2.5, label=f"{name} (AUC={np.mean(pr_curves[name]['auc']):.4f})")
    plt.xlabel("Recall", fontweight='bold')
    plt.ylabel("Precision", fontweight='bold')
    plt.title("Precision Recall Curve",fontweight='bold')
    plt.legend()
    
    # Receiver Operating Characteristic Curve
    plt.subplot(121)
    for name in classifiers:
        tpr_interp = []
        for fpr, tpr in zip(roc_curves[name]["fpr"], roc_curves[name]["tpr"]):
            fpr_interp = np.linspace(0, 1, 100)
            tpr_interp.append(np.interp(fpr_interp, fpr, tpr))
        mean_tpr = np.mean(tpr_interp, axis=0)
        
        # Set ChemCoBERT to red, others to default colors
        if name == "ChemCoBERT":
            plt.plot(fpr_interp, mean_tpr, '#E41A1C', linewidth=2.5, 
                    label=f"{name} (AUC={np.mean(roc_curves[name]['auc']):.4f})")
        elif name == "GBDT":
            plt.plot(fpr_interp, mean_tpr, '#FF7F00', linewidth=2.5, 
                    label=f"{name} (AUC={np.mean(roc_curves[name]['auc']):.4f})")
        elif name == "Decision Tree":
            plt.plot(fpr_interp, mean_tpr, '#A65628', linewidth=2.5, 
                    label=f"{name} (AUC={np.mean(roc_curves[name]['auc']):.4f})")
        elif name == "Naive Bayes":
            plt.plot(fpr_interp, mean_tpr, '#377EB8', linewidth=2.5, 
                    label=f"{name} (AUC={np.mean(roc_curves[name]['auc']):.4f})")
        elif name == "K-NN":
            plt.plot(fpr_interp, mean_tpr, '#984EA3', linewidth=2.5, 
                    label=f"{name} (AUC={np.mean(roc_curves[name]['auc']):.4f})")
        elif name == "AdaBoost":
            plt.plot(fpr_interp, mean_tpr, '#4DAF4A', linewidth=2.5, 
                    label=f"{name} (AUC={np.mean(roc_curves[name]['auc']):.4f})")
        else:
            plt.plot(fpr_interp, mean_tpr, linewidth=2.5, 
                    label=f"{name} (AUC={np.mean(roc_curves[name]['auc']):.4f})")
    
   
    plt.xlabel("False Positive Rate", fontweight='bold')
    plt.ylabel("True Positive Rate", fontweight='bold')
    plt.title("Receiver Operating Characteristic Curve", fontweight='bold')
    plt.legend(loc="lower right")
    
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    main()


## 10

In [None]:
import os
import gc
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import (
    roc_auc_score, precision_recall_curve, roc_curve, auc,
    confusion_matrix, cohen_kappa_score, matthews_corrcoef, f1_score
)
from sklearn.preprocessing import StandardScaler
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import AdaBoostClassifier, GradientBoostingClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from rdkit import Chem
from rdkit.Chem import Descriptors, AllChem
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
import psutil

import warnings
from rdkit import RDLogger
# 关闭 RDKit 的所有日志（包括警告）
RDLogger.DisableLog('rdApp.*')  # 禁用所有 RDKit 日志
# 导入svg高清图库
import seaborn as sns
import matplotlib.pyplot as plt

# 设置 Matplotlib 支持中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']  # 使用黑体
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

# 设置 Matplotlib 后端为 SVG
%config InlineBackend.figure_format = 'svg'

# 设置 DPI 以提高图像清晰度
plt.rcParams['figure.dpi'] = 300
warnings.filterwarnings("ignore")

# ---------------------- Memory Protection ----------------------
def memory_safe(func):
    def wrapper(*args, **kwargs):
        mem = psutil.virtual_memory()
        if mem.percent > 80:
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            print(f"⚠️ Memory warning: Usage {mem.percent}%, performed garbage collection")
        return func(*args, **kwargs)
    return wrapper

# ---------------------- SMILES Feature Extraction ----------------------
class SMILESFeatureExtractor:
    def __init__(self, fp_size=1024, desc_list=None):
        self.fp_size = fp_size
        self.desc_list = desc_list or [
            'MolWt', 'NumHAcceptors', 'NumHDonors', 
            'MolLogP', 'TPSA', 'NumRotatableBonds'
        ]
    
    @memory_safe
    def smiles_to_features(self, smiles):
        """Convert SMILES to numerical features"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if not mol:
                return np.nan * np.ones(len(self.desc_list) + self.fp_size)
            
            # Calculate descriptors
            desc_values = [getattr(Descriptors, desc)(mol) for desc in self.desc_list]
            
            # Calculate fingerprints
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=self.fp_size)
            fp_values = np.array(fp, dtype=np.float32)
            
            return np.concatenate([desc_values, fp_values])
        except:
            return np.nan * np.ones(len(self.desc_list) + self.fp_size)

# ---------------------- Data Preparation ----------------------
def prepare_features(X_smiles):
    """Convert SMILES pairs to numerical features"""
    fe = SMILESFeatureExtractor()
    features = []
    
    for drug1, drug2 in tqdm(X_smiles, desc="Extracting features"):
        feat1 = fe.smiles_to_features(drug1)
        feat2 = fe.smiles_to_features(drug2)
        features.append(np.concatenate([feat1, feat2]))
    
    X_num = np.stack(features)
    
    # Handle NaN values
    X_num = np.nan_to_num(X_num)
    return X_num

# ---------------------- Deep Learning Components ----------------------
class DrugInteractionDataset(Dataset):
    def __init__(self, drug1_smiles, drug2_smiles, labels, tokenizer, max_length=128):
        self.drug1_smiles = drug1_smiles
        self.drug2_smiles = drug2_smiles
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        encoding1 = self.tokenizer(
            str(self.drug1_smiles[idx]), 
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        encoding2 = self.tokenizer(
            str(self.drug2_smiles[idx]),
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {
            'drug1_input_ids': encoding1['input_ids'].flatten(),
            'drug1_attention_mask': encoding1['attention_mask'].flatten(),
            'drug2_input_ids': encoding2['input_ids'].flatten(),
            'drug2_attention_mask': encoding2['attention_mask'].flatten(),
            'label': torch.tensor(self.labels[idx], dtype=torch.long)
        }

class CoAttentionModel(nn.Module):
    def __init__(self, bert_model_name="DeepChem/ChemBERTa-77M-MLM", hidden_size=384):
        super().__init__()
        self.bert = AutoModel.from_pretrained(bert_model_name)
        self.co_attention = nn.MultiheadAttention(hidden_size, num_heads=8)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size*4, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, 2)
        )
    
    def forward(self, drug1_input_ids, drug1_attention_mask, drug2_input_ids, drug2_attention_mask):
        drug1 = self.bert(drug1_input_ids, attention_mask=drug1_attention_mask).last_hidden_state[:, 0, :]
        drug2 = self.bert(drug2_input_ids, attention_mask=drug2_attention_mask).last_hidden_state[:, 0, :]
        
        # Co-attention
        attn1, _ = self.co_attention(drug1.unsqueeze(1), drug2.unsqueeze(1), drug2.unsqueeze(1))
        attn2, _ = self.co_attention(drug2.unsqueeze(1), drug1.unsqueeze(1), drug1.unsqueeze(1))
        
        combined = torch.cat([drug1, drug2, attn1.squeeze(1), attn2.squeeze(1)], dim=1)
        return self.classifier(combined)

# ---------------------- Training and Evaluation ----------------------
@memory_safe
def train_model(model, train_loader, val_loader, optimizer, criterion, device, epochs=10):
    best_val_auc = 0
    history = {'train_loss': [], 'val_loss': [], 'val_auc': []}
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        
        for batch in progress_bar:
            optimizer.zero_grad()
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'label'}
            outputs = model(**inputs)
            loss = criterion(outputs, batch['label'].to(device))
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            progress_bar.set_postfix({'loss': train_loss/(progress_bar.n+1)})
        
        # Validation
        val_loss, val_metrics = evaluate_model(model, val_loader, criterion, device)
        print(f"\nValidation - Loss: {val_loss:.4f}, AUC: {val_metrics['AUC']:.4f}, F1: {val_metrics['F1']:.4f}")
        
        # Save best model
        if val_metrics['AUC'] > best_val_auc:
            best_val_auc = val_metrics['AUC']
            torch.save(model.state_dict(), "best_model.pth")
            print("✅ Saved best model")
        
        history['train_loss'].append(train_loss/len(train_loader))
        history['val_loss'].append(val_loss)
        history['val_auc'].append(val_metrics['AUC'])
    
    return model, history

@memory_safe
def evaluate_model(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for batch in data_loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'label'}
            outputs = model(**inputs)
            loss = criterion(outputs, batch['label'].to(device))
            total_loss += loss.item()
            
            probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
            all_labels.extend(batch['label'].cpu().numpy())
            all_probs.extend(probs)
    
    # Calculate metrics
    auc_score = roc_auc_score(all_labels, all_probs)
    y_pred = (np.array(all_probs) > 0.5).astype(int)
    tn, fp, fn, tp = confusion_matrix(all_labels, y_pred).ravel()
    
    metrics = {
        "AUC": auc_score,
        "Sensitivity": tp / (tp + fn),
        "Specificity": tn / (tn + fp),
        "Kappa": cohen_kappa_score(all_labels, y_pred),
        "MCC": matthews_corrcoef(all_labels, y_pred),
        "F1": f1_score(all_labels, y_pred),
    }
    
    return total_loss / len(data_loader), metrics

# ---------------------- Main Execution ----------------------
def main():
    # Load data
    file_path = "/kaggle/working/1/dat.txt"
    df = pd.read_csv(file_path, sep='\t', header=None)
    X_smiles = df.iloc[:, :2].values
    Y = df.iloc[:, -1].values
    
    # Prepare numerical features for traditional models
    X_num = prepare_features(X_smiles)
    scaler = StandardScaler()
    X_num = scaler.fit_transform(X_num)
    
    # Define classifiers
    classifiers = {
        "ChemCoBERT": None,
        "Decision Tree": DecisionTreeClassifier(random_state=42),
        "AdaBoost": AdaBoostClassifier(n_estimators=100, random_state=42),
        "GBDT": GradientBoostingClassifier(n_estimators=100, random_state=42),
        "K-NN": KNeighborsClassifier(n_neighbors=5),
        "Naive Bayes": GaussianNB(),
    }
    
    # Initialize metrics storage
    all_metrics = {name: [] for name in classifiers}
    pr_curves = {name: {"precision": [], "recall": [], "auc": []} for name in classifiers}
    roc_curves = {name: {"fpr": [], "tpr": [], "auc": []} for name in classifiers}
    
    # 10-fold CV
    kf = KFold(n_splits=10, shuffle=True, random_state=42)
    for fold, (train_idx, val_idx) in enumerate(kf.split(X_num)):
        print(f"\nFold {fold + 1}")
        
        # Split data
        X_train_num, X_val_num = X_num[train_idx], X_num[val_idx]
        X_train_smiles, X_val_smiles = X_smiles[train_idx], X_smiles[val_idx]
        y_train, y_val = Y[train_idx], Y[val_idx]
        
        for name, clf in classifiers.items():
            print(f"\nTraining {name}...")
            
            if name == "ChemCoBERT":
                # Deep learning model
                tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MLM")
                train_dataset = DrugInteractionDataset(
                    X_train_smiles[:, 0], X_train_smiles[:, 1], y_train, tokenizer
                )
                val_dataset = DrugInteractionDataset(
                    X_val_smiles[:, 0], X_val_smiles[:, 1], y_val, tokenizer
                )
                
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                model = CoAttentionModel().to(device)
                optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
                criterion = nn.CrossEntropyLoss()
                
                # Train
                model, _ = train_model(
                    model, 
                    DataLoader(train_dataset, batch_size=64, shuffle=True),
                    DataLoader(val_dataset, batch_size=64),
                    optimizer, criterion, device, epochs=10
                )
                
                # Evaluate
                model.load_state_dict(torch.load("best_model.pth"))
                _, val_metrics = evaluate_model(
                    model, 
                    DataLoader(val_dataset, batch_size=64),
                    criterion, device
                )
                
                # Get predictions
                model.eval()
                all_probs = []
                with torch.no_grad():
                    for batch in DataLoader(val_dataset, batch_size=128):
                        inputs = {k: v.to(device) for k, v in batch.items() if k != 'label'}
                        outputs = model(**inputs)
                        probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
                        all_probs.extend(probs)
                
                y_pred_prob = np.array(all_probs)
                y_pred = (y_pred_prob > 0.5).astype(int)
                
                # Save model
                model_path = f"best_models/fold_{fold+1}_best.pth"
                torch.save(model.state_dict(), model_path)
                print(f"Saved model to {model_path}")
                
            else:
                # Traditional ML models
                clf.fit(X_train_num, y_train)
                if hasattr(clf, "predict_proba"):
                    y_pred_prob = clf.predict_proba(X_val_num)[:, 1]
                else:
                    y_pred_prob = clf.decision_function(X_val_num)
                    y_pred_prob = (y_pred_prob - y_pred_prob.min()) / (y_pred_prob.max() - y_pred_prob.min())
                y_pred = (y_pred_prob > 0.5).astype(int)
                
                # Calculate metrics
                val_metrics = {
                    "AUC": roc_auc_score(y_val, y_pred_prob),
                    "Sensitivity": confusion_matrix(y_val, y_pred)[1,1] / (confusion_matrix(y_val, y_pred)[1,1] + confusion_matrix(y_val, y_pred)[1,0]),
                    "Specificity": confusion_matrix(y_val, y_pred)[0,0] / (confusion_matrix(y_val, y_pred)[0,0] + confusion_matrix(y_val, y_pred)[0,1]),
                    "Kappa": cohen_kappa_score(y_val, y_pred),
                    "MCC": matthews_corrcoef(y_val, y_pred),
                    "F1": f1_score(y_val, y_pred),
                }
            
            # Store metrics
            all_metrics[name].append(val_metrics)
            
            # Store curves
            precision, recall, _ = precision_recall_curve(y_val, y_pred_prob)
            pr_auc = auc(recall, precision)
            pr_curves[name]["precision"].append(precision)
            pr_curves[name]["recall"].append(recall)
            pr_curves[name]["auc"].append(pr_auc)
            
            fpr, tpr, _ = roc_curve(y_val, y_pred_prob)
            roc_auc = auc(fpr, tpr)
            roc_curves[name]["fpr"].append(fpr)
            roc_curves[name]["tpr"].append(tpr)
            roc_curves[name]["auc"].append(roc_auc)
            
            # Print fold results
            print(f"{name} Fold {fold+1} Results:")
            for metric, value in val_metrics.items():
                print(f"{metric}: {value:.4f}")
    
    # Print final metrics
    print("\nFinal Metrics:")
    for name in classifiers:
        print(f"\n{name}:")
        for metric in all_metrics[name][0].keys():
            values = [m[metric] for m in all_metrics[name]]
            print(f"{metric}: {np.mean(values):.4f} ± {np.std(values):.4f}")
    
    # Plot curves
    plt.figure(figsize=(15, 6))
    
    # PR Curve
    plt.subplot(122)
    for name in classifiers:
        precision_interp = []
        for precision, recall in zip(pr_curves[name]["precision"], pr_curves[name]["recall"]):
            recall_interp = np.linspace(0, 1, 100)
            precision_interp.append(np.interp(recall_interp, recall[::-1], precision[::-1]))
        mean_precision = np.mean(precision_interp, axis=0)
        
        # Set ChemCoBERT to red, others to default colors
        if name == "ChemCoBERT":
            plt.plot(recall_interp, mean_precision, '#E41A1C', linewidth=3, label=f"{name} (AUC={np.mean(pr_curves[name]['auc']):.4f})")
        elif name == "GBDT":
            plt.plot(recall_interp, mean_precision, '#FF7F00', linewidth=2.5, label=f"{name} (AUC={np.mean(pr_curves[name]['auc']):.4f})")
        elif name == "Decision Tree":
            plt.plot(recall_interp, mean_precision, '#A65628', linewidth=2.5, label=f"{name} (AUC={np.mean(pr_curves[name]['auc']):.4f})")
        elif name == "Naive Bayes":
            plt.plot(recall_interp, mean_precision, '#377EB8', linewidth=2.5, label=f"{name} (AUC={np.mean(pr_curves[name]['auc']):.4f})")
        elif name == "K-NN":
            plt.plot(recall_interp, mean_precision, '#984EA3', linewidth=2.5, label=f"{name} (AUC={np.mean(pr_curves[name]['auc']):.4f})")
        elif name == "AdaBoost":
            plt.plot(recall_interp, mean_precision, '#4DAF4A', linewidth=2.5, label=f"{name} (AUC={np.mean(pr_curves[name]['auc']):.4f})")
        else:
            plt.plot(recall_interp, mean_precision, linewidth=2.5, label=f"{name} (AUC={np.mean(pr_curves[name]['auc']):.4f})")
    plt.xlabel("Recall", fontweight='bold')
    plt.ylabel("Precision", fontweight='bold')
    plt.title("Precision Recall Curve",fontweight='bold')
    plt.legend()
    
    # Receiver Operating Characteristic Curve
    plt.subplot(121)
    for name in classifiers:
        tpr_interp = []
        for fpr, tpr in zip(roc_curves[name]["fpr"], roc_curves[name]["tpr"]):
            fpr_interp = np.linspace(0, 1, 100)
            tpr_interp.append(np.interp(fpr_interp, fpr, tpr))
        mean_tpr = np.mean(tpr_interp, axis=0)
        
        # Set ChemCoBERT to red, others to default colors
        if name == "ChemCoBERT":
            plt.plot(fpr_interp, mean_tpr, '#E41A1C', linewidth=2.5, 
                    label=f"{name} (AUC={np.mean(roc_curves[name]['auc']):.4f})")
        elif name == "GBDT":
            plt.plot(fpr_interp, mean_tpr, '#FF7F00', linewidth=2.5, 
                    label=f"{name} (AUC={np.mean(roc_curves[name]['auc']):.4f})")
        elif name == "Decision Tree":
            plt.plot(fpr_interp, mean_tpr, '#A65628', linewidth=2.5, 
                    label=f"{name} (AUC={np.mean(roc_curves[name]['auc']):.4f})")
        elif name == "Naive Bayes":
            plt.plot(fpr_interp, mean_tpr, '#377EB8', linewidth=2.5, 
                    label=f"{name} (AUC={np.mean(roc_curves[name]['auc']):.4f})")
        elif name == "K-NN":
            plt.plot(fpr_interp, mean_tpr, '#984EA3', linewidth=2.5, 
                    label=f"{name} (AUC={np.mean(roc_curves[name]['auc']):.4f})")
        elif name == "AdaBoost":
            plt.plot(fpr_interp, mean_tpr, '#4DAF4A', linewidth=2.5, 
                    label=f"{name} (AUC={np.mean(roc_curves[name]['auc']):.4f})")
        else:
            plt.plot(fpr_interp, mean_tpr, linewidth=2.5, 
                    label=f"{name} (AUC={np.mean(roc_curves[name]['auc']):.4f})")
    
   
    plt.xlabel("False Positive Rate", fontweight='bold')
    plt.ylabel("True Positive Rate", fontweight='bold')
    plt.title("Receiver Operating Characteristic Curve", fontweight='bold')
    plt.legend(loc="lower right")
    
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    main()
