In [None]:

import os
import math
import gc
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import confusion_matrix, roc_auc_score
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.base import BaseEstimator, ClassifierMixin


def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(42)
torch.cuda.empty_cache()
gc.collect()

device = "cuda" if torch.cuda.is_available() else "cpu"


class TorchTransformer(nn.Module):
    def __init__(self, in_dim, hidden_dim=32, num_heads=4, dropout=0.3):
        super().__init__()
        self.input_proj = nn.Linear(in_dim, hidden_dim)
        self.ln1 = nn.LayerNorm(hidden_dim)
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim, num_heads=num_heads,
            dropout=dropout, batch_first=True
        )
        self.ln2 = nn.LayerNorm(hidden_dim)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, hidden_dim)
        )
        self.ln3 = nn.LayerNorm(hidden_dim)
        self.classifier = nn.Linear(hidden_dim, 2)
        
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
        
    def forward(self, x):
        x = self.ln1(self.input_proj(x))
        x = x.unsqueeze(1)
        attn_out, _ = self.attention(x, x, x)
        x = self.ln2(x + attn_out)
        x = self.ln3(x + self.ffn(x))
        x = x.squeeze(1)
        return self.classifier(x)


class TorchTransformerClassifier(BaseEstimator, ClassifierMixin):
    
    def __init__(self, in_dim=None, hidden_dim=32, num_heads=4, dropout=0.3, 
                 epochs=200, lr=5e-4, batch_size=32, patience=30,
                 device=None, verbose=True, random_state=42):
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.epochs = epochs
        self.lr = lr
        self.batch_size = batch_size
        self.patience = patience
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.verbose = verbose
        self.random_state = random_state
        self.model = None
        self.classes_ = np.array([0, 1])
        
    def fit(self, X, y):
        set_seed(self.random_state)
        
        if self.in_dim is None:
            self.in_dim = X.shape[1]
        
        X_tr, X_val, y_tr, y_val = train_test_split(
            X, y, test_size=0.15, random_state=self.random_state, stratify=y
        )
        
        train_dataset = TensorDataset(
            torch.tensor(X_tr, dtype=torch.float32),
            torch.tensor(y_tr, dtype=torch.long)
        )
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        
        X_val_t = torch.tensor(X_val, dtype=torch.float32).to(self.device)
        y_val_t = torch.tensor(y_val, dtype=torch.long).to(self.device)
        
        self.model = TorchTransformer(
            self.in_dim, 
            hidden_dim=self.hidden_dim,
            num_heads=self.num_heads,
            dropout=self.dropout
        ).to(self.device)
        
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=0.01)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=10, verbose=False
        )
        
        best_val_loss = float('inf')
        patience_counter = 0
        best_state = None
        
        if self.verbose:
            print(f" Training Transformer...")
        
        for epoch in range(self.epochs):
            self.model.train()
            train_loss = 0
            for batch_X, batch_y in train_loader:
                batch_X = batch_X.to(self.device)
                batch_y = batch_y.to(self.device)
                
                optimizer.zero_grad()
                pred = self.model(batch_X)
                loss = criterion(pred, batch_y)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                optimizer.step()
                train_loss += loss.item() * batch_X.size(0)
            
            train_loss /= len(train_dataset)
            
            self.model.eval()
            with torch.no_grad():
                val_pred = self.model(X_val_t)
                val_loss = criterion(val_pred, y_val_t).item()
            
            scheduler.step(val_loss)
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                best_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()}
            else:
                patience_counter += 1
                if patience_counter >= self.patience:
                    if self.verbose:
                        print(f"  Early stopping at epoch {epoch+1}, best val_loss: {best_val_loss:.4f}")
                    break
            
            if self.verbose and (epoch + 1) % 50 == 0:
                print(f"  Epoch {epoch+1}, Train: {train_loss:.4f}, Val: {val_loss:.4f}")
        
        if best_state is not None:
            self.model.load_state_dict(best_state)
            self.model.to(self.device)
        
        if self.verbose:
            print(f" Transformer done! Best val_loss: {best_val_loss:.4f}\n")
        
        return self
    
    def predict_proba(self, X):
        self.model.eval()
        dataset = TensorDataset(torch.tensor(X, dtype=torch.float32))
        loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
        
        all_proba = []
        with torch.no_grad():
            for (batch_X,) in loader:
                batch_X = batch_X.to(self.device)
                logits = self.model(batch_X)
                proba = torch.softmax(logits, dim=1).cpu().numpy()
                all_proba.append(proba)
        
        return np.vstack(all_proba)
    
    def predict(self, X):
        return np.argmax(self.predict_proba(X), axis=1)


if __name__ == "__main__":
    os.chdir('G:\lhy\work')
    
    X_new = pd.read_csv(r'G:\\lhy\work\\train_feature\\feature\\ESMC300+PAAC+APAAC.csv', header=None)
    y_new = pd.read_csv(r'data/label_train.csv', header=None)
    X_new1 = pd.read_csv(r"G:\lhy\work\test_feature\ESMC300+PAAC+APAAC.csv", header=None)
    y_new1 = pd.read_csv(r'data/label_test.csv', header=None)
    
    X_train = np.array(X_new)
    y_train = np.array(y_new).ravel()
    X_test = np.array(X_new1)
    y_test = np.array(y_new1).ravel()
    
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)
    
    print(f" {X_train.shape}")
    print(f" {X_test.shape}")
    
    input_dim = X_train.shape[1]
    
    print("=" * 60)
    print("=" * 60)
    
    transformer_classifier = TorchTransformerClassifier(
        in_dim=input_dim,
        hidden_dim=32,
        num_heads=4,
        dropout=0.3,
        epochs=200,
        lr=5e-4,
        batch_size=32,
        patience=30,
        verbose=True,
        random_state=42
    )
    
    transformer_classifier.fit(X_train, y_train)
    
    y_pred = transformer_classifier.predict(X_test)
    y_proba = transformer_classifier.predict_proba(X_test)
    
    TN, FP, FN, TP = confusion_matrix(y_test, y_pred).ravel()
    accuracy = (TP + TN) / (TP + TN + FP + FN)
    sensitivity = TP / (TP + FN)
    specificity = TN / (TN + FP)
    mcc = (TP * TN - FP * FN) / math.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN) + 1e-10)
    auc = roc_auc_score(y_test, y_proba[:, 1])
    
    print("\n" + "=" * 60)
    print("=" * 60)
    print(f"Accuracy:    {accuracy:.4f}")
    print(f"Sensitivity: {sensitivity:.4f}")
    print(f"Specificity: {specificity:.4f}")
    print(f"MCC:         {mcc:.4f}")
    print(f"AUC:         {auc:.4f}")
    print("=" * 60)
