# üéØ Simple Attention-KAN for NASA Defect Prediction
## Clean, Focused, Explainable

**Strategy:** Simple and effective instead of complex

**Key Features:**
- ‚úÖ **Attention-KAN**: Built-in feature importance (XAI)
- ‚úÖ **Weighted BCE Loss**: Simple but effective (FN cost = 3x)
- ‚úÖ **Fast GWO**: Only 3 hyperparameters
- ‚úÖ **Heatmaps**: Visualize what the model learns

**Datasets:** PC1, CM1, KC1

---

In [None]:
# ============================================================================
# IMPORTS
# ============================================================================

import os
import glob
import warnings
import numpy as np
import pandas as pd
from scipy.io import arff
from io import StringIO

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from sklearn.preprocessing import MinMaxScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, 
    f1_score, roc_auc_score, fbeta_score, balanced_accuracy_score,
    confusion_matrix
)
from imblearn.over_sampling import SMOTE

import matplotlib.pyplot as plt
import seaborn as sns

warnings.filterwarnings('ignore')
sns.set_style('whitegrid')

# Random seeds
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(RANDOM_SEED)

print("[INFO] Dependencies loaded!")
print(f"[INFO] PyTorch: {torch.__version__}")
print(f"[INFO] Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

In [None]:
# ============================================================================
# ATTENTION-KAN ARCHITECTURE
# ============================================================================

class FeatureAttention(nn.Module):
    """Learn which features are important"""
    
    def __init__(self, in_features):
        super(FeatureAttention, self).__init__()
        hidden = max(in_features // 2, 8)
        
        self.attention = nn.Sequential(
            nn.Linear(in_features, hidden),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden, in_features),
            nn.Sigmoid()  # Attention weights [0,1]
        )
        self.bn = nn.BatchNorm1d(in_features)
    
    def forward(self, x):
        x_norm = self.bn(x)
        weights = self.attention(x_norm)  # Learn importance
        return x * weights, weights  # Weighted features, attention weights


class KANLinear(nn.Module):
    """KAN layer with spline functions"""
    
    def __init__(self, in_features, out_features, grid_size=5):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        
        # Learnable parameters
        self.grid = nn.Parameter(torch.linspace(-1, 1, grid_size).unsqueeze(0).unsqueeze(0).repeat(out_features, in_features, 1))
        self.coef = nn.Parameter(torch.randn(out_features, in_features, grid_size) * 0.1)
        self.base_weight = nn.Parameter(torch.randn(out_features, in_features) * 0.1)
    
    def forward(self, x):
        batch_size = x.shape[0]
        x_exp = x.unsqueeze(1).unsqueeze(-1)
        grid = self.grid.unsqueeze(0)
        
        # RBF basis
        basis = torch.exp(-torch.abs(x_exp - grid) ** 2 / 0.5)
        
        # Spline output
        coef = self.coef.unsqueeze(0)
        spline_out = (basis * coef).sum(dim=-1).sum(dim=-1)
        
        # Base linear
        base_out = torch.matmul(x, self.base_weight.t())
        
        return spline_out + base_out


class AttentionKAN(nn.Module):
    """Simple KAN with Attention for XAI"""
    
    def __init__(self, input_dim, hidden_dim=64, grid_size=5):
        super(AttentionKAN, self).__init__()
        
        # Feature Attention (XAI)
        self.attention = FeatureAttention(input_dim)
        
        # KAN layers
        self.kan1 = KANLinear(input_dim, hidden_dim, grid_size)
        self.kan2 = KANLinear(hidden_dim, hidden_dim // 2, grid_size)
        
        # Output
        self.output = nn.Linear(hidden_dim // 2, 1)
        
        # Batch norm & dropout
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim // 2)
        self.dropout = nn.Dropout(0.3)
    
    def forward(self, x, return_attention=False):
        # Attention
        x, att_weights = self.attention(x)
        
        # KAN layers
        x = self.kan1(x)
        x = self.bn1(x)
        x = torch.relu(x)
        x = self.dropout(x)
        
        x = self.kan2(x)
        x = self.bn2(x)
        x = torch.relu(x)
        x = self.dropout(x)
        
        # Output
        x = self.output(x)
        x = torch.sigmoid(x)
        
        if return_attention:
            return x, att_weights
        return x
    
    def get_feature_importance(self, X):
        """Get global feature importance"""
        self.eval()
        if not isinstance(X, torch.Tensor):
            X = torch.FloatTensor(X)
        
        device = next(self.parameters()).device
        X = X.to(device)
        
        with torch.no_grad():
            _, att_weights = self.attention(X)
            importance = att_weights.cpu().numpy().mean(axis=0)
        
        return importance

print("[INFO] Attention-KAN architecture ready!")

In [None]:
# ============================================================================
# XAI VISUALIZATION
# ============================================================================

def plot_feature_importance(model, X_data, dataset_name, top_k=15):
    """Plot feature importance heatmap"""
    
    importance = model.get_feature_importance(X_data)
    feature_names = [f'F{i}' for i in range(len(importance))]
    
    # Sort by importance
    sorted_idx = np.argsort(importance)[::-1][:top_k]
    top_importance = importance[sorted_idx]
    top_names = [feature_names[i] for i in sorted_idx]
    
    # Plot
    fig, ax = plt.subplots(figsize=(10, 6))
    
    colors = plt.cm.viridis(top_importance / top_importance.max())
    bars = ax.barh(range(len(top_importance)), top_importance, color=colors)
    
    ax.set_yticks(range(len(top_importance)))
    ax.set_yticklabels(top_names)
    ax.set_xlabel('Attention Weight', fontsize=12, fontweight='bold')
    ax.set_title(f'{dataset_name}: Top {top_k} Important Features', fontsize=14, fontweight='bold')
    ax.invert_yaxis()
    ax.grid(axis='x', alpha=0.3)
    
    # Add values
    for i, v in enumerate(top_importance):
        ax.text(v + 0.01, i, f'{v:.3f}', va='center', fontsize=9)
    
    plt.tight_layout()
    plt.savefig(f'{dataset_name}_importance.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"[INFO] Saved: {dataset_name}_importance.png")

print("[INFO] XAI visualization ready!")

In [None]:
# ============================================================================
# DATA LOADING
# ============================================================================

def load_arff(file_path):
    """Load ARFF file"""
    try:
        data, meta = arff.loadarff(file_path)
        df = pd.DataFrame(data)
        for col in df.columns:
            if df[col].dtype == object:
                try:
                    df[col] = df[col].str.decode('utf-8')
                except:
                    pass
        return df
    except:
        with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
            content = f.read()
        data_start = content.lower().find('@data')
        data_section = content[data_start + 5:].strip()
        return pd.read_csv(StringIO(data_section), header=None)


def prepare_data(df):
    """Prepare data for training"""
    X = df.iloc[:, :-1].values.astype(np.float32)
    y = df.iloc[:, -1].values
    
    # Encode labels
    if y.dtype == object or y.dtype.name.startswith('str'):
        le = LabelEncoder()
        y = le.fit_transform(y)
    else:
        y = y.astype(np.int32)
    
    # Handle NaN
    if np.any(np.isnan(X)):
        col_median = np.nanmedian(X, axis=0)
        inds = np.where(np.isnan(X))
        X[inds] = np.take(col_median, inds[1])
    
    # Split
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, stratify=y, random_state=RANDOM_SEED
    )
    
    # Normalize
    scaler = MinMaxScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)
    
    # SMOTE (simple)
    try:
        smote = SMOTE(sampling_strategy=0.8, random_state=RANDOM_SEED)
        X_train, y_train = smote.fit_resample(X_train, y_train)
        print(f"[INFO] After SMOTE: {X_train.shape[0]} samples, {np.bincount(y_train)}")
    except:
        print("[WARNING] SMOTE failed, using original data")
    
    return X_train, X_test, y_train, y_test

print("[INFO] Data loading ready!")

In [None]:
# ============================================================================
# TRAINING (SIMPLE WEIGHTED BCE)
# ============================================================================

def train_model(model, X_train, y_train, X_val, y_val, 
                lr=0.01, epochs=30, batch_size=32, fn_weight=3.0):
    """Train with Weighted BCE Loss"""
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    X_train_t = torch.FloatTensor(X_train).to(device)
    y_train_t = torch.FloatTensor(y_train).unsqueeze(1).to(device)
    X_val_t = torch.FloatTensor(X_val).to(device)
    y_val_t = torch.FloatTensor(y_val).unsqueeze(1).to(device)
    
    dataset = TensorDataset(X_train_t, y_train_t)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Weighted BCE Loss (FN cost = fn_weight)
    pos_weight = torch.tensor([fn_weight]).to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    print(f"  [TRAINING] Weighted BCE Loss (FN weight={fn_weight})")
    
    best_recall = 0
    patience = 10
    patience_counter = 0
    
    for epoch in range(epochs):
        model.train()
        for batch_X, batch_y in loader:
            optimizer.zero_grad()
            # Need logits for BCEWithLogitsLoss
            logits = model.output(model.dropout(model.bn2(torch.relu(model.kan2(model.dropout(model.bn1(torch.relu(model.kan1(model.attention(batch_X)[0])))))))))
            loss = criterion(logits, batch_y)
            loss.backward()
            optimizer.step()
        
        # Validation
        model.eval()
        with torch.no_grad():
            val_out = model(X_val_t)
            val_pred = (val_out > 0.5).float().cpu().numpy()
            val_recall = recall_score(y_val, val_pred, zero_division=0)
        
        if val_recall > best_recall:
            best_recall = val_recall
            patience_counter = 0
        else:
            patience_counter += 1
        
        if patience_counter >= patience:
            break
    
    print(f"  [TRAINING] Best val recall: {best_recall:.4f}")
    return model


def evaluate(model, X_test, y_test):
    """Evaluate model"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    
    X_test_t = torch.FloatTensor(X_test).to(device)
    
    with torch.no_grad():
        y_prob = model(X_test_t).cpu().numpy()
        y_pred = (y_prob > 0.5).astype(int).flatten()
    
    cm = confusion_matrix(y_test, y_pred)
    print(f"\n  [CONFUSION MATRIX]")
    print(f"  TN: {cm[0,0]}, FP: {cm[0,1]}")
    print(f"  FN: {cm[1,0]}, TP: {cm[1,1]}")
    
    return {
        'Accuracy': accuracy_score(y_test, y_pred),
        'Precision': precision_score(y_test, y_pred, zero_division=0),
        'Recall': recall_score(y_test, y_pred, zero_division=0),
        'F1': f1_score(y_test, y_pred, zero_division=0),
        'F2': fbeta_score(y_test, y_pred, beta=2, zero_division=0),
        'AUC': roc_auc_score(y_test, y_prob) if len(np.unique(y_test)) > 1 else 0
    }

print("[INFO] Training functions ready!")

In [None]:
# ============================================================================
# SIMPLE GWO (3 PARAMETERS ONLY)
# ============================================================================

class SimpleGWO:
    """Fast GWO for 3 hyperparameters"""
    
    def __init__(self, bounds, fitness_func, n_wolves=5, n_iter=6):
        self.bounds = np.array(bounds)
        self.fitness_func = fitness_func
        self.n_wolves = n_wolves
        self.n_iter = n_iter
        self.dim = len(bounds)
        
        # Initialize
        self.positions = np.random.uniform(
            self.bounds[:, 0], 
            self.bounds[:, 1],
            size=(n_wolves, self.dim)
        )
        
        self.alpha_pos = np.zeros(self.dim)
        self.alpha_score = float('-inf')
        self.beta_pos = np.zeros(self.dim)
        self.beta_score = float('-inf')
        self.delta_pos = np.zeros(self.dim)
        self.delta_score = float('-inf')
    
    def optimize(self):
        print(f"  [GWO] {self.n_wolves} wolves, {self.n_iter} iterations")
        
        for it in range(self.n_iter):
            # Evaluate fitness
            for i in range(self.n_wolves):
                fitness = self.fitness_func(self.positions[i])
                
                if fitness > self.alpha_score:
                    self.delta_score = self.beta_score
                    self.delta_pos = self.beta_pos.copy()
                    self.beta_score = self.alpha_score
                    self.beta_pos = self.alpha_pos.copy()
                    self.alpha_score = fitness
                    self.alpha_pos = self.positions[i].copy()
                elif fitness > self.beta_score:
                    self.delta_score = self.beta_score
                    self.delta_pos = self.beta_pos.copy()
                    self.beta_score = fitness
                    self.beta_pos = self.positions[i].copy()
                elif fitness > self.delta_score:
                    self.delta_score = fitness
                    self.delta_pos = self.positions[i].copy()
            
            # Update positions
            a = 2 - it * (2.0 / self.n_iter)
            
            for i in range(self.n_wolves):
                for j in range(self.dim):
                    r1, r2 = np.random.random(2)
                    A1 = 2 * a * r1 - a
                    C1 = 2 * r2
                    D_alpha = abs(C1 * self.alpha_pos[j] - self.positions[i, j])
                    X1 = self.alpha_pos[j] - A1 * D_alpha
                    
                    r1, r2 = np.random.random(2)
                    A2 = 2 * a * r1 - a
                    C2 = 2 * r2
                    D_beta = abs(C2 * self.beta_pos[j] - self.positions[i, j])
                    X2 = self.beta_pos[j] - A2 * D_beta
                    
                    r1, r2 = np.random.random(2)
                    A3 = 2 * a * r1 - a
                    C3 = 2 * r2
                    D_delta = abs(C3 * self.delta_pos[j] - self.positions[i, j])
                    X3 = self.delta_pos[j] - A3 * D_delta
                    
                    self.positions[i, j] = (X1 + X2 + X3) / 3.0
                    self.positions[i, j] = np.clip(
                        self.positions[i, j],
                        self.bounds[j, 0],
                        self.bounds[j, 1]
                    )
            
            print(f"  Iter {it+1}/{self.n_iter} | Best: {self.alpha_score:.4f}")
        
        return self.alpha_pos, self.alpha_score

print("[INFO] Simple GWO ready!")

In [None]:
# ============================================================================
# MAIN EXECUTION - 3 DATASETS
# ============================================================================

def run_experiment(dataset_dir='/content/drive/MyDrive/nasa-defect-gwo-kan/dataset'):
    """Run on PC1, CM1, KC1"""
    
    # Find datasets
    target = ['PC1', 'CM1', 'KC1']
    files = glob.glob(os.path.join(dataset_dir, '*.arff'))
    files = [f for f in files if any(ds in os.path.basename(f).upper() for ds in target)]
    
    if not files:
        raise FileNotFoundError(f"Datasets not found in {dataset_dir}")
    
    print(f"\n[INFO] Found {len(files)} datasets\n")
    
    results = []
    
    for file_path in files:
        dataset_name = os.path.basename(file_path).replace('.arff', '')
        
        print("="*70)
        print(f"DATASET: {dataset_name}")
        print("="*70)
        
        try:
            # Load data
            print("[1/5] Loading data...")
            df = load_arff(file_path)
            X_train, X_test, y_train, y_test = prepare_data(df)
            
            input_dim = X_train.shape[1]
            print(f"[INFO] Features: {input_dim}, Train: {len(y_train)}, Test: {len(y_test)}")
            
            # Validation split
            X_train, X_val, y_train, y_val = train_test_split(
                X_train, y_train, test_size=0.2, stratify=y_train, random_state=RANDOM_SEED
            )
            
            # GWO optimization (3 params only: hidden_dim, grid_size, lr)
            print("\n[2/5] GWO optimization...")
            
            def fitness(params):
                hidden_dim = int(params[0])
                grid_size = int(params[1])
                lr = params[2]
                
                try:
                    model = AttentionKAN(input_dim, hidden_dim, grid_size)
                    model = train_model(model, X_train, y_train, X_val, y_val, 
                                      lr=lr, epochs=20, fn_weight=3.0)
                    metrics = evaluate(model, X_val, y_val)
                    
                    # Fitness: 60% Recall + 30% F1 + 10% Acc
                    score = 0.6 * metrics['Recall'] + 0.3 * metrics['F1'] + 0.1 * metrics['Accuracy']
                    return score
                except:
                    return 0.0
            
            bounds = [
                (32, 96),      # hidden_dim
                (3, 7),        # grid_size
                (0.005, 0.02)  # learning_rate
            ]
            
            gwo = SimpleGWO(bounds, fitness, n_wolves=5, n_iter=6)
            best_params, best_score = gwo.optimize()
            
            hidden_dim = int(best_params[0])
            grid_size = int(best_params[1])
            lr = best_params[2]
            
            print(f"\n  [GWO] Best params: hidden={hidden_dim}, grid={grid_size}, lr={lr:.4f}")
            print(f"  [GWO] Best score: {best_score:.4f}")
            
            # Train final model
            print("\n[3/5] Training final model...")
            model = AttentionKAN(input_dim, hidden_dim, grid_size)
            model = train_model(model, X_train, y_train, X_val, y_val, 
                              lr=lr, epochs=30, fn_weight=3.0)
            
            # Evaluate
            print("\n[4/5] Evaluating...")
            metrics = evaluate(model, X_test, y_test)
            
            print(f"\n  [RESULTS]")
            for k, v in metrics.items():
                print(f"  {k}: {v:.4f}")
            
            # Visualize feature importance
            print("\n[5/5] Creating heatmap...")
            plot_feature_importance(model, X_test, dataset_name, top_k=15)
            
            # Save results
            results.append({
                'Dataset': dataset_name,
                'Features': input_dim,
                'Hidden_Dim': hidden_dim,
                'Grid_Size': grid_size,
                'Learning_Rate': lr,
                **metrics
            })
            
        except Exception as e:
            print(f"\n  [ERROR] {e}")
            import traceback
            traceback.print_exc()
    
    # Summary
    results_df = pd.DataFrame(results)
    
    # Average
    avg_row = {'Dataset': 'AVERAGE'}
    for col in ['Accuracy', 'Precision', 'Recall', 'F1', 'F2', 'AUC']:
        if col in results_df.columns:
            avg_row[col] = results_df[col].mean()
    
    results_df = pd.concat([results_df, pd.DataFrame([avg_row])], ignore_index=True)
    
    return results_df

print("[INFO] Main execution ready!")

In [None]:
# ============================================================================
# RUN!
# ============================================================================

print("\n" + "="*70)
print(" üöÄ SIMPLE ATTENTION-KAN - NASA DEFECT PREDICTION")
print("="*70)
print("\nüìã APPROACH:")
print("  ‚úÖ Attention-KAN (built-in XAI)")
print("  ‚úÖ Weighted BCE Loss (FN cost=3x)")
print("  ‚úÖ Fast GWO (3 hyperparameters)")
print("  ‚úÖ Feature importance heatmaps")
print("\n" + "="*70 + "\n")

# Run experiment
results = run_experiment(
    dataset_dir='/content/drive/MyDrive/nasa-defect-gwo-kan/dataset'
)

# Display results
print("\n" + "="*70)
print(" üìä FINAL RESULTS")
print("="*70)
print(results.to_string(index=False))

# Save
results.to_excel('simple_attention_kan_results.xlsx', index=False)
print("\n[INFO] Results saved: simple_attention_kan_results.xlsx")

# Summary
print("\n" + "="*70)
print(" üéØ AVERAGE METRICS")
print("="*70)
avg = results[results['Dataset'] == 'AVERAGE'].iloc[0]
print(f"\n  Accuracy:  {avg['Accuracy']:.4f}")
print(f"  Precision: {avg['Precision']:.4f}")
print(f"  Recall:    {avg['Recall']:.4f} ‚≠ê")
print(f"  F1-Score:  {avg['F1']:.4f}")
print(f"  F2-Score:  {avg['F2']:.4f}")
print(f"  AUC:       {avg['AUC']:.4f}")

print("\n" + "="*70)
print(" ‚úÖ COMPLETE!")
print("="*70)

In [None]:
# ============================================================================
# VISUALIZATION - COMPARISON
# ============================================================================

fig, axes = plt.subplots(2, 3, figsize=(15, 8))
fig.suptitle('Simple Attention-KAN Results', fontsize=16, fontweight='bold')

metrics = ['Accuracy', 'Precision', 'Recall', 'F1', 'F2', 'AUC']
colors = ['#3498db', '#e74c3c', '#2ecc71', '#f39c12', '#9b59b6', '#1abc9c']

plot_data = results[results['Dataset'] != 'AVERAGE'].copy()

for idx, (metric, color) in enumerate(zip(metrics, colors)):
    ax = axes[idx // 3, idx % 3]
    
    if metric in plot_data.columns:
        ax.barh(plot_data['Dataset'], plot_data[metric], color=color, alpha=0.7)
        ax.set_xlabel(metric, fontsize=11, fontweight='bold')
        ax.set_xlim(0, 1)
        ax.grid(axis='x', alpha=0.3)
        
        if metric == 'Recall':
            ax.set_facecolor('#ffe6e6')
            ax.set_title('‚≠ê PRIMARY ‚≠ê', fontsize=10, color='red')

plt.tight_layout()
plt.savefig('simple_results_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("[INFO] Saved: simple_results_comparison.png")