# üéØ Improved Attention-KAN - 4 Strategies Combined
## Better Accuracy + Better Recall

**Previous Results:**
- ‚ùå Accuracy: 45% (too low!)
- ‚ùå Precision: 18% (too many false alarms)
- ‚úÖ Recall: 81% (good but can be better)

**4 Improvements:**
1. ‚úÖ **Threshold Optimization**: Find best threshold (not just 0.5)
2. ‚úÖ **Focal Loss**: Better handling of class imbalance
3. ‚úÖ **pos_weight Tuning**: Optimize FN penalty via GWO (4th param)
4. ‚úÖ **Mini Ensemble (2 models)**: More stable predictions

**Target:**
- Recall ‚â• 85%
- Accuracy ‚â• 70%
- Precision ‚â• 40%
- F1 ‚â• 50%

---

In [None]:
# ============================================================================
# IMPORTS
# ============================================================================

import os
import glob
import warnings
import numpy as np
import pandas as pd
from scipy.io import arff
from io import StringIO

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from sklearn.preprocessing import MinMaxScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, 
    f1_score, roc_auc_score, fbeta_score, balanced_accuracy_score,
    confusion_matrix
)
from imblearn.over_sampling import SMOTE

import matplotlib.pyplot as plt
import seaborn as sns

warnings.filterwarnings('ignore')
sns.set_style('whitegrid')

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(RANDOM_SEED)

print("[INFO] All imports ready!")
print(f"[INFO] PyTorch: {torch.__version__}")
print(f"[INFO] Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

In [None]:
# ============================================================================
# ATTENTION-KAN (same as before)
# ============================================================================

class FeatureAttention(nn.Module):
    def __init__(self, in_features):
        super(FeatureAttention, self).__init__()
        hidden = max(in_features // 2, 8)
        self.attention = nn.Sequential(
            nn.Linear(in_features, hidden),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden, in_features),
            nn.Sigmoid()
        )
        self.bn = nn.BatchNorm1d(in_features)
    
    def forward(self, x):
        x_norm = self.bn(x)
        weights = self.attention(x_norm)
        return x * weights, weights


class KANLinear(nn.Module):
    def __init__(self, in_features, out_features, grid_size=5):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        
        self.grid = nn.Parameter(torch.linspace(-1, 1, grid_size).unsqueeze(0).unsqueeze(0).repeat(out_features, in_features, 1))
        self.coef = nn.Parameter(torch.randn(out_features, in_features, grid_size) * 0.1)
        self.base_weight = nn.Parameter(torch.randn(out_features, in_features) * 0.1)
    
    def forward(self, x):
        batch_size = x.shape[0]
        x_exp = x.unsqueeze(1).unsqueeze(-1)
        grid = self.grid.unsqueeze(0)
        basis = torch.exp(-torch.abs(x_exp - grid) ** 2 / 0.5)
        coef = self.coef.unsqueeze(0)
        spline_out = (basis * coef).sum(dim=-1).sum(dim=-1)
        base_out = torch.matmul(x, self.base_weight.t())
        return spline_out + base_out


class AttentionKAN(nn.Module):
    def __init__(self, input_dim, hidden_dim=64, grid_size=5):
        super(AttentionKAN, self).__init__()
        self.attention = FeatureAttention(input_dim)
        self.kan1 = KANLinear(input_dim, hidden_dim, grid_size)
        self.kan2 = KANLinear(hidden_dim, hidden_dim // 2, grid_size)
        self.output = nn.Linear(hidden_dim // 2, 1)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim // 2)
        self.dropout = nn.Dropout(0.3)
    
    def forward(self, x, return_attention=False):
        x, att_weights = self.attention(x)
        x = self.kan1(x)
        x = self.bn1(x)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.kan2(x)
        x = self.bn2(x)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.output(x)
        x = torch.sigmoid(x)
        if return_attention:
            return x, att_weights
        return x
    
    def get_feature_importance(self, X):
        self.eval()
        if not isinstance(X, torch.Tensor):
            X = torch.FloatTensor(X)
        device = next(self.parameters()).device
        X = X.to(device)
        with torch.no_grad():
            _, att_weights = self.attention(X)
            importance = att_weights.cpu().numpy().mean(axis=0)
        return importance

print("[INFO] Attention-KAN ready!")

In [None]:
# ============================================================================
# STRATEGY 2: FOCAL LOSS
# ============================================================================

class FocalLoss(nn.Module):
    """Focal Loss for class imbalance - focuses on hard examples"""
    
    def __init__(self, alpha=0.25, gamma=2.0, pos_weight=3.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha  # Weight for positive class
        self.gamma = gamma  # Focusing parameter (higher = more focus on hard examples)
        self.pos_weight = pos_weight  # Additional weight for positive class
    
    def forward(self, inputs, targets):
        # BCE loss
        bce = nn.functional.binary_cross_entropy(inputs, targets, reduction='none')
        
        # Focal term: (1 - pt)^gamma
        pt = torch.exp(-bce)  # pt = p if y=1, else 1-p
        focal = (1 - pt) ** self.gamma * bce
        
        # Alpha weighting
        alpha_weight = targets * self.alpha + (1 - targets) * (1 - self.alpha)
        focal = alpha_weight * focal
        
        # Extra weight for positive class (FN penalty)
        pos_mask = targets == 1
        focal[pos_mask] *= self.pos_weight
        
        return focal.mean()

print("[INFO] Focal Loss ready!")

In [None]:
# ============================================================================
# IMPROVED TRAINING (with loss selection)
# ============================================================================

def train_model(model, X_train, y_train, X_val, y_val, 
                lr=0.01, epochs=30, batch_size=32, 
                pos_weight=3.0, loss_type='focal'):
    """
    Train with selectable loss function
    
    loss_type: 'weighted_bce' or 'focal'
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    X_train_t = torch.FloatTensor(X_train).to(device)
    y_train_t = torch.FloatTensor(y_train).unsqueeze(1).to(device)
    X_val_t = torch.FloatTensor(X_val).to(device)
    y_val_t = torch.FloatTensor(y_val).unsqueeze(1).to(device)
    
    dataset = TensorDataset(X_train_t, y_train_t)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Select loss function
    if loss_type == 'weighted_bce':
        pos_weight_tensor = torch.tensor([pos_weight]).to(device)
        criterion = nn.BCELoss()
        print(f"  [LOSS] Weighted BCE (pos_weight={pos_weight:.2f})")
    else:  # focal
        criterion = FocalLoss(alpha=0.25, gamma=2.0, pos_weight=pos_weight)
        print(f"  [LOSS] Focal Loss (gamma=2.0, pos_weight={pos_weight:.2f})")
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    best_recall = 0
    patience = 10
    patience_counter = 0
    
    for epoch in range(epochs):
        model.train()
        for batch_X, batch_y in loader:
            optimizer.zero_grad()
            outputs = model(batch_X)
            
            # Apply pos_weight manually for weighted BCE
            if loss_type == 'weighted_bce':
                loss = criterion(outputs, batch_y)
                # Manual weighting
                weights = torch.ones_like(batch_y)
                weights[batch_y == 1] = pos_weight
                loss = (loss * weights).mean()
            else:
                loss = criterion(outputs, batch_y)
            
            loss.backward()
            optimizer.step()
        
        # Validation
        model.eval()
        with torch.no_grad():
            val_out = model(X_val_t)
            val_pred = (val_out > 0.5).float().cpu().numpy()
            val_recall = recall_score(y_val, val_pred, zero_division=0)
        
        if val_recall > best_recall:
            best_recall = val_recall
            patience_counter = 0
        else:
            patience_counter += 1
        
        if patience_counter >= patience:
            break
    
    print(f"  [TRAINING] Best val recall: {best_recall:.4f}")
    return model

print("[INFO] Training functions ready!")

In [None]:
# ============================================================================
# STRATEGY 1: THRESHOLD OPTIMIZATION
# ============================================================================

def find_optimal_threshold(model, X_val, y_val, target_recall=0.85):
    """
    Find optimal threshold on validation set
    
    Strategy: Find threshold that achieves target_recall with best F1
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    
    X_val_t = torch.FloatTensor(X_val).to(device)
    
    with torch.no_grad():
        y_prob = model(X_val_t).cpu().numpy().flatten()
    
    best_threshold = 0.5
    best_f1 = 0
    best_metrics = {}
    
    print(f"\n  [THRESHOLD] Finding optimal threshold (target recall ‚â•{target_recall})...")
    
    # Try different thresholds
    for threshold in np.arange(0.1, 0.7, 0.05):
        y_pred = (y_prob >= threshold).astype(int)
        
        recall = recall_score(y_val, y_pred, zero_division=0)
        precision = precision_score(y_val, y_pred, zero_division=0)
        f1 = f1_score(y_val, y_pred, zero_division=0)
        accuracy = accuracy_score(y_val, y_pred)
        
        # If recall meets target and F1 is better
        if recall >= target_recall and f1 > best_f1:
            best_threshold = threshold
            best_f1 = f1
            best_metrics = {
                'threshold': threshold,
                'recall': recall,
                'precision': precision,
                'f1': f1,
                'accuracy': accuracy
            }
    
    # If no threshold achieves target, use one with highest recall
    if best_f1 == 0:
        for threshold in np.arange(0.1, 0.7, 0.05):
            y_pred = (y_prob >= threshold).astype(int)
            recall = recall_score(y_val, y_pred, zero_division=0)
            if recall > best_metrics.get('recall', 0):
                best_threshold = threshold
                best_metrics = {
                    'threshold': threshold,
                    'recall': recall,
                    'precision': precision_score(y_val, y_pred, zero_division=0),
                    'f1': f1_score(y_val, y_pred, zero_division=0),
                    'accuracy': accuracy_score(y_val, y_pred)
                }
    
    print(f"  [THRESHOLD] Optimal: {best_threshold:.2f}")
    print(f"  [THRESHOLD] Val Recall: {best_metrics['recall']:.4f}, Precision: {best_metrics['precision']:.4f}, F1: {best_metrics['f1']:.4f}")
    
    return best_threshold

print("[INFO] Threshold optimization ready!")

In [None]:
# ============================================================================
# EVALUATION (with custom threshold)
# ============================================================================

def evaluate(model, X_test, y_test, threshold=0.5):
    """Evaluate with custom threshold"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    
    X_test_t = torch.FloatTensor(X_test).to(device)
    
    with torch.no_grad():
        y_prob = model(X_test_t).cpu().numpy().flatten()
        y_pred = (y_prob >= threshold).astype(int)
    
    cm = confusion_matrix(y_test, y_pred)
    print(f"\n  [CONFUSION MATRIX]")
    print(f"  TN: {cm[0,0]}, FP: {cm[0,1]}")
    print(f"  FN: {cm[1,0]}, TP: {cm[1,1]}")
    
    return {
        'Threshold': threshold,
        'Accuracy': accuracy_score(y_test, y_pred),
        'Precision': precision_score(y_test, y_pred, zero_division=0),
        'Recall': recall_score(y_test, y_pred, zero_division=0),
        'F1': f1_score(y_test, y_pred, zero_division=0),
        'F2': fbeta_score(y_test, y_pred, beta=2, zero_division=0),
        'AUC': roc_auc_score(y_test, y_prob) if len(np.unique(y_test)) > 1 else 0
    }

print("[INFO] Evaluation ready!")

In [None]:
# ============================================================================
# DATA LOADING
# ============================================================================

def load_arff(file_path):
    try:
        data, meta = arff.loadarff(file_path)
        df = pd.DataFrame(data)
        for col in df.columns:
            if df[col].dtype == object:
                try:
                    df[col] = df[col].str.decode('utf-8')
                except:
                    pass
        return df
    except:
        with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
            content = f.read()
        data_start = content.lower().find('@data')
        data_section = content[data_start + 5:].strip()
        return pd.read_csv(StringIO(data_section), header=None)


def prepare_data(df):
    X = df.iloc[:, :-1].values.astype(np.float32)
    y = df.iloc[:, -1].values
    
    if y.dtype == object or y.dtype.name.startswith('str'):
        le = LabelEncoder()
        y = le.fit_transform(y)
    else:
        y = y.astype(np.int32)
    
    if np.any(np.isnan(X)):
        col_median = np.nanmedian(X, axis=0)
        inds = np.where(np.isnan(X))
        X[inds] = np.take(col_median, inds[1])
    
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, stratify=y, random_state=RANDOM_SEED
    )
    
    scaler = MinMaxScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)
    
    try:
        smote = SMOTE(sampling_strategy=0.8, random_state=RANDOM_SEED)
        X_train, y_train = smote.fit_resample(X_train, y_train)
        print(f"[INFO] After SMOTE: {X_train.shape[0]} samples")
    except:
        print("[WARNING] SMOTE failed")
    
    return X_train, X_test, y_train, y_test

print("[INFO] Data loading ready!")

In [None]:
# ============================================================================
# STRATEGY 3+4: GWO WITH pos_weight + MINI ENSEMBLE
# ============================================================================

class ImprovedGWO:
    """GWO with 4 parameters: hidden_dim, grid_size, lr, pos_weight"""
    
    def __init__(self, bounds, fitness_func, n_wolves=6, n_iter=8):
        self.bounds = np.array(bounds)
        self.fitness_func = fitness_func
        self.n_wolves = n_wolves
        self.n_iter = n_iter
        self.dim = len(bounds)
        
        self.positions = np.random.uniform(
            self.bounds[:, 0], 
            self.bounds[:, 1],
            size=(n_wolves, self.dim)
        )
        
        self.alpha_pos = np.zeros(self.dim)
        self.alpha_score = float('-inf')
        self.beta_pos = np.zeros(self.dim)
        self.beta_score = float('-inf')
        self.delta_pos = np.zeros(self.dim)
        self.delta_score = float('-inf')
    
    def optimize(self):
        print(f"  [GWO] {self.n_wolves} wolves, {self.n_iter} iterations, {self.dim} parameters")
        
        for it in range(self.n_iter):
            for i in range(self.n_wolves):
                fitness = self.fitness_func(self.positions[i])
                
                if fitness > self.alpha_score:
                    self.delta_score = self.beta_score
                    self.delta_pos = self.beta_pos.copy()
                    self.beta_score = self.alpha_score
                    self.beta_pos = self.alpha_pos.copy()
                    self.alpha_score = fitness
                    self.alpha_pos = self.positions[i].copy()
                elif fitness > self.beta_score:
                    self.delta_score = self.beta_score
                    self.delta_pos = self.beta_pos.copy()
                    self.beta_score = fitness
                    self.beta_pos = self.positions[i].copy()
                elif fitness > self.delta_score:
                    self.delta_score = fitness
                    self.delta_pos = self.positions[i].copy()
            
            a = 2 - it * (2.0 / self.n_iter)
            
            for i in range(self.n_wolves):
                for j in range(self.dim):
                    r1, r2 = np.random.random(2)
                    A1 = 2 * a * r1 - a
                    C1 = 2 * r2
                    D_alpha = abs(C1 * self.alpha_pos[j] - self.positions[i, j])
                    X1 = self.alpha_pos[j] - A1 * D_alpha
                    
                    r1, r2 = np.random.random(2)
                    A2 = 2 * a * r1 - a
                    C2 = 2 * r2
                    D_beta = abs(C2 * self.beta_pos[j] - self.positions[i, j])
                    X2 = self.beta_pos[j] - A2 * D_beta
                    
                    r1, r2 = np.random.random(2)
                    A3 = 2 * a * r1 - a
                    C3 = 2 * r2
                    D_delta = abs(C3 * self.delta_pos[j] - self.positions[i, j])
                    X3 = self.delta_pos[j] - A3 * D_delta
                    
                    self.positions[i, j] = (X1 + X2 + X3) / 3.0
                    self.positions[i, j] = np.clip(
                        self.positions[i, j],
                        self.bounds[j, 0],
                        self.bounds[j, 1]
                    )
            
            print(f"  Iter {it+1}/{self.n_iter} | Best: {self.alpha_score:.4f}")
        
        return self.alpha_pos, self.alpha_score


def train_mini_ensemble(X_train, y_train, X_val, y_val, input_dim, 
                       hidden_dim, grid_size, lr, pos_weight, loss_type='focal'):
    """Train 2-model ensemble with different random seeds"""
    
    models = []
    seeds = [42, 123]
    
    print(f"\n  [ENSEMBLE] Training 2 models...")
    
    for i, seed in enumerate(seeds):
        print(f"\n  Model {i+1}/2 (seed={seed})")
        torch.manual_seed(seed)
        np.random.seed(seed)
        
        model = AttentionKAN(input_dim, hidden_dim, grid_size)
        model = train_model(model, X_train, y_train, X_val, y_val,
                          lr=lr, epochs=30, pos_weight=pos_weight, loss_type=loss_type)
        models.append(model)
    
    # Reset seed
    torch.manual_seed(RANDOM_SEED)
    np.random.seed(RANDOM_SEED)
    
    return models


def ensemble_predict(models, X_test, threshold=0.5):
    """Soft voting prediction"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    X_test_t = torch.FloatTensor(X_test).to(device)
    
    predictions = []
    
    for model in models:
        model.eval()
        with torch.no_grad():
            pred = model(X_test_t).cpu().numpy().flatten()
            predictions.append(pred)
    
    # Average probabilities
    avg_prob = np.mean(predictions, axis=0)
    y_pred = (avg_prob >= threshold).astype(int)
    
    return y_pred, avg_prob

print("[INFO] GWO and Ensemble ready!")

In [None]:
# ============================================================================
# VISUALIZATION
# ============================================================================

def plot_feature_importance(model, X_data, dataset_name, top_k=15):
    importance = model.get_feature_importance(X_data)
    feature_names = [f'F{i}' for i in range(len(importance))]
    
    sorted_idx = np.argsort(importance)[::-1][:top_k]
    top_importance = importance[sorted_idx]
    top_names = [feature_names[i] for i in sorted_idx]
    
    fig, ax = plt.subplots(figsize=(10, 6))
    colors = plt.cm.viridis(top_importance / top_importance.max())
    bars = ax.barh(range(len(top_importance)), top_importance, color=colors)
    ax.set_yticks(range(len(top_importance)))
    ax.set_yticklabels(top_names)
    ax.set_xlabel('Attention Weight', fontsize=12, fontweight='bold')
    ax.set_title(f'{dataset_name}: Top {top_k} Important Features', fontsize=14, fontweight='bold')
    ax.invert_yaxis()
    ax.grid(axis='x', alpha=0.3)
    
    for i, v in enumerate(top_importance):
        ax.text(v + 0.01, i, f'{v:.3f}', va='center', fontsize=9)
    
    plt.tight_layout()
    plt.savefig(f'{dataset_name}_improved_importance.png', dpi=300, bbox_inches='tight')
    plt.show()
    print(f"[INFO] Saved: {dataset_name}_improved_importance.png")

print("[INFO] Visualization ready!")

In [None]:
# ============================================================================
# MAIN EXECUTION - ALL 4 STRATEGIES COMBINED
# ============================================================================

def run_improved_experiment(dataset_dir='/content/drive/MyDrive/nasa-defect-gwo-kan/dataset'):
    
    target = ['PC1', 'CM1', 'KC1']
    files = glob.glob(os.path.join(dataset_dir, '*.arff'))
    files = [f for f in files if any(ds in os.path.basename(f).upper() for ds in target)]
    
    if not files:
        raise FileNotFoundError(f"Datasets not found")
    
    print(f"\n[INFO] Found {len(files)} datasets\n")
    results = []
    
    for file_path in files:
        dataset_name = os.path.basename(file_path).replace('.arff', '')
        
        print("="*70)
        print(f"DATASET: {dataset_name}")
        print("="*70)
        
        try:
            # Load
            print("[1/6] Loading data...")
            df = load_arff(file_path)
            X_train, X_test, y_train, y_test = prepare_data(df)
            input_dim = X_train.shape[1]
            
            # Val split
            X_train, X_val, y_train, y_val = train_test_split(
                X_train, y_train, test_size=0.2, stratify=y_train, random_state=RANDOM_SEED
            )
            
            print(f"[INFO] Features: {input_dim}, Train: {len(y_train)}, Val: {len(y_val)}, Test: {len(y_test)}")
            
            # GWO with 4 parameters
            print("\n[2/6] GWO optimization (4 params: hidden, grid, lr, pos_weight)...")
            
            def fitness(params):
                hidden_dim = int(params[0])
                grid_size = int(params[1])
                lr = params[2]
                pos_weight = params[3]
                
                try:
                    model = AttentionKAN(input_dim, hidden_dim, grid_size)
                    model = train_model(model, X_train, y_train, X_val, y_val,
                                      lr=lr, epochs=20, pos_weight=pos_weight, loss_type='focal')
                    
                    # Find threshold
                    threshold = find_optimal_threshold(model, X_val, y_val, target_recall=0.85)
                    
                    # Evaluate
                    metrics = evaluate(model, X_val, y_val, threshold=threshold)
                    
                    # Fitness: 50% Recall + 30% F1 + 20% Accuracy
                    score = 0.5 * metrics['Recall'] + 0.3 * metrics['F1'] + 0.2 * metrics['Accuracy']
                    return score
                except Exception as e:
                    print(f"  [WARNING] Fitness failed: {e}")
                    return 0.0
            
            bounds = [
                (32, 96),      # hidden_dim
                (3, 7),        # grid_size
                (0.005, 0.02), # learning_rate
                (2.0, 5.0)     # pos_weight (STRATEGY 3!)
            ]
            
            gwo = ImprovedGWO(bounds, fitness, n_wolves=6, n_iter=8)
            best_params, best_score = gwo.optimize()
            
            hidden_dim = int(best_params[0])
            grid_size = int(best_params[1])
            lr = best_params[2]
            pos_weight = best_params[3]
            
            print(f"\n  [GWO] Best: hidden={hidden_dim}, grid={grid_size}, lr={lr:.4f}, pos_weight={pos_weight:.2f}")
            print(f"  [GWO] Score: {best_score:.4f}")
            
            # Train mini ensemble (STRATEGY 4)
            print("\n[3/6] Training mini ensemble (2 models)...")
            models = train_mini_ensemble(X_train, y_train, X_val, y_val, input_dim,
                                        hidden_dim, grid_size, lr, pos_weight, loss_type='focal')
            
            # Find optimal threshold on ensemble (STRATEGY 1)
            print("\n[4/6] Finding optimal threshold for ensemble...")
            
            # Get ensemble validation predictions
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            X_val_t = torch.FloatTensor(X_val).to(device)
            
            val_preds = []
            for model in models:
                model.eval()
                with torch.no_grad():
                    pred = model(X_val_t).cpu().numpy().flatten()
                    val_preds.append(pred)
            
            val_avg_prob = np.mean(val_preds, axis=0)
            
            # Find best threshold
            best_threshold = 0.5
            best_f1 = 0
            target_recall = 0.85
            
            for threshold in np.arange(0.1, 0.7, 0.05):
                y_pred = (val_avg_prob >= threshold).astype(int)
                recall = recall_score(y_val, y_pred, zero_division=0)
                f1 = f1_score(y_val, y_pred, zero_division=0)
                
                if recall >= target_recall and f1 > best_f1:
                    best_threshold = threshold
                    best_f1 = f1
            
            print(f"  [THRESHOLD] Ensemble optimal: {best_threshold:.2f}")
            
            # Test ensemble
            print("\n[5/6] Testing ensemble...")
            y_pred, y_prob = ensemble_predict(models, X_test, threshold=best_threshold)
            
            cm = confusion_matrix(y_test, y_pred)
            print(f"\n  [CONFUSION MATRIX]")
            print(f"  TN: {cm[0,0]}, FP: {cm[0,1]}")
            print(f"  FN: {cm[1,0]}, TP: {cm[1,1]}")
            
            metrics = {
                'Accuracy': accuracy_score(y_test, y_pred),
                'Precision': precision_score(y_test, y_pred, zero_division=0),
                'Recall': recall_score(y_test, y_pred, zero_division=0),
                'F1': f1_score(y_test, y_pred, zero_division=0),
                'F2': fbeta_score(y_test, y_pred, beta=2, zero_division=0),
                'AUC': roc_auc_score(y_test, y_prob) if len(np.unique(y_test)) > 1 else 0
            }
            
            print(f"\n  [RESULTS]")
            for k, v in metrics.items():
                print(f"  {k}: {v:.4f}")
            
            # Visualize
            print("\n[6/6] Creating feature importance heatmap...")
            plot_feature_importance(models[0], X_test, dataset_name, top_k=15)
            
            results.append({
                'Dataset': dataset_name,
                'Features': input_dim,
                'Hidden_Dim': hidden_dim,
                'Grid_Size': grid_size,
                'Learning_Rate': lr,
                'Pos_Weight': pos_weight,
                'Threshold': best_threshold,
                **metrics
            })
            
        except Exception as e:
            print(f"\n  [ERROR] {e}")
            import traceback
            traceback.print_exc()
    
    # Summary
    results_df = pd.DataFrame(results)
    
    avg_row = {'Dataset': 'AVERAGE'}
    for col in ['Accuracy', 'Precision', 'Recall', 'F1', 'F2', 'AUC']:
        if col in results_df.columns:
            avg_row[col] = results_df[col].mean()
    
    results_df = pd.concat([results_df, pd.DataFrame([avg_row])], ignore_index=True)
    
    return results_df

print("[INFO] Main execution ready!")

In [None]:
# ============================================================================
# RUN!
# ============================================================================

print("\n" + "="*70)
print(" üöÄ IMPROVED ATTENTION-KAN - 4 STRATEGIES COMBINED")
print("="*70)
print("\nüìã IMPROVEMENTS:")
print("  1Ô∏è‚É£ Threshold Optimization (not fixed at 0.5)")
print("  2Ô∏è‚É£ Focal Loss (better than Weighted BCE)")
print("  3Ô∏è‚É£ pos_weight GWO tuning (4th parameter)")
print("  4Ô∏è‚É£ Mini Ensemble (2 models, soft voting)")
print("\nüéØ TARGET:")
print("  Recall ‚â• 85% | Accuracy ‚â• 70% | Precision ‚â• 40% | F1 ‚â• 50%")
print("\n" + "="*70 + "\n")

# Run
results = run_improved_experiment(
    dataset_dir='/content/drive/MyDrive/nasa-defect-gwo-kan/dataset'
)

# Display
print("\n" + "="*70)
print(" üìä FINAL RESULTS")
print("="*70)
print(results.to_string(index=False))

# Save
results.to_excel('improved_attention_kan_results.xlsx', index=False)
print("\n[INFO] Saved: improved_attention_kan_results.xlsx")

# Summary
print("\n" + "="*70)
print(" üéØ AVERAGE METRICS")
print("="*70)
avg = results[results['Dataset'] == 'AVERAGE'].iloc[0]
print(f"\n  Accuracy:  {avg['Accuracy']:.4f} {'‚úÖ' if avg['Accuracy'] >= 0.70 else '‚ùå'}")
print(f"  Precision: {avg['Precision']:.4f} {'‚úÖ' if avg['Precision'] >= 0.40 else '‚ùå'}")
print(f"  Recall:    {avg['Recall']:.4f} {'‚úÖ' if avg['Recall'] >= 0.85 else '‚ùå'} ‚≠ê")
print(f"  F1-Score:  {avg['F1']:.4f} {'‚úÖ' if avg['F1'] >= 0.50 else '‚ùå'}")
print(f"  F2-Score:  {avg['F2']:.4f}")
print(f"  AUC:       {avg['AUC']:.4f}")

print("\n" + "="*70)
print(" ‚úÖ COMPLETE!")
print("="*70)

In [None]:
# ============================================================================
# COMPARISON VISUALIZATION
# ============================================================================

fig, axes = plt.subplots(2, 3, figsize=(15, 8))
fig.suptitle('Improved Attention-KAN Results (4 Strategies)', fontsize=16, fontweight='bold')

metrics = ['Accuracy', 'Precision', 'Recall', 'F1', 'F2', 'AUC']
colors = ['#3498db', '#e74c3c', '#2ecc71', '#f39c12', '#9b59b6', '#1abc9c']

plot_data = results[results['Dataset'] != 'AVERAGE'].copy()

for idx, (metric, color) in enumerate(zip(metrics, colors)):
    ax = axes[idx // 3, idx % 3]
    
    if metric in plot_data.columns:
        bars = ax.barh(plot_data['Dataset'], plot_data[metric], color=color, alpha=0.7)
        ax.set_xlabel(metric, fontsize=11, fontweight='bold')
        ax.set_xlim(0, 1)
        ax.grid(axis='x', alpha=0.3)
        
        # Add target line
        if metric == 'Recall':
            ax.axvline(x=0.85, color='red', linestyle='--', linewidth=2, label='Target')
            ax.set_facecolor('#ffe6e6')
            ax.set_title('‚≠ê PRIMARY ‚≠ê', fontsize=10, color='red')
        elif metric == 'Accuracy':
            ax.axvline(x=0.70, color='blue', linestyle='--', linewidth=2, alpha=0.5)
        elif metric == 'Precision':
            ax.axvline(x=0.40, color='orange', linestyle='--', linewidth=2, alpha=0.5)
        elif metric == 'F1':
            ax.axvline(x=0.50, color='purple', linestyle='--', linewidth=2, alpha=0.5)

plt.tight_layout()
plt.savefig('improved_results_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("[INFO] Saved: improved_results_comparison.png")