In [None]:
%pip install optuna
%pip install sparsemax # https://pypi.org/project/sparsemax/

Collecting optuna
  Downloading optuna-4.2.1-py3-none-any.whl.metadata (17 kB)
Collecting alembic>=1.5.0 (from optuna)
  Downloading alembic-1.15.1-py3-none-any.whl.metadata (7.2 kB)
Collecting colorlog (from optuna)
  Downloading colorlog-6.9.0-py3-none-any.whl.metadata (10 kB)
Collecting Mako (from alembic>=1.5.0->optuna)
  Downloading Mako-1.3.9-py3-none-any.whl.metadata (2.9 kB)
Downloading optuna-4.2.1-py3-none-any.whl (383 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m383.6/383.6 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading alembic-1.15.1-py3-none-any.whl (231 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m231.8/231.8 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading colorlog-6.9.0-py3-none-any.whl (11 kB)
Downloading Mako-1.3.9-py3-none-any.whl (78 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.5/78.5 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: Ma

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sparsemax import Sparsemax
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.datasets import fetch_openml
from sklearn.impute import SimpleImputer
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from scipy.stats import spearmanr
import matplotlib.pyplot as plt
from tqdm import tqdm
import optuna
import json
import os
from urllib.request import urlretrieve
import zipfile

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sparsemax = Sparsemax(dim=-1)

class TabularDataset(Dataset):
    """Dataset for tabular data with continuous features"""
    def __init__(self, X, y=None):
        # Continuous features
        self.X = torch.tensor(X, dtype=torch.float32)

        # Target
        if y is not None:
            self.y = torch.tensor(y, dtype=torch.long)
        else:
            self.y = None

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        if self.y is not None:
            return self.X[idx], self.y[idx]
        else:
            return self.X[idx]

'''
The code below has been adapted from the original codebase.

For the implementation of the FT Transformer, please check out this repository: https://github.com/yandex-research/rtdl-revisiting-models

For the implementation of the Piecewise Linear Embedding, please check out: https://github.com/yandex-research/rtdl-num-embeddings
'''

class LinearEmbedding(nn.Module):
    """Linear embedding for continuous features"""
    def __init__(self, num_features, d_token):
        super().__init__()
        self.embeddings = nn.ModuleList([
            nn.Linear(1, d_token) for _ in range(num_features)
        ])

    def forward(self, x):
        # x shape: (batch_size, num_cont_features)
        batch_size = x.shape[0]
        num_features = x.shape[1]

        # Embed each continuous feature
        embedded = torch.zeros((batch_size, num_features, self.embeddings[0].out_features),
                              device=x.device)

        for i in range(num_features):
            embedded[:, i] = self.embeddings[i](x[:, i].unsqueeze(-1)).squeeze(-1)

        return embedded  # (batch_size, num_features, d_token)

class PiecewiseLinearEmbedding(nn.Module):
    """Piecewise linear embedding for continuous features"""
    def __init__(self, num_features, d_token, num_bins=20):
        super().__init__()
        self.num_features = num_features
        self.d_token = d_token
        self.num_bins = num_bins

        # Create embeddings for each feature
        self.embeddings = nn.ModuleList([
            nn.Linear(num_bins, d_token) for _ in range(num_features)
        ])

        # Create parameters for bin boundaries (learnable)
        self.bin_boundaries = nn.Parameter(torch.randn(num_features, num_bins-1))

    def forward(self, x):
        # x shape: (batch_size, num_features)
        batch_size = x.shape[0]

        # Output will contain embedded tokens for each feature
        embedded = torch.zeros((batch_size, self.num_features, self.d_token), device=x.device)

        for i in range(self.num_features):
            # Get feature values for current feature
            feature_values = x[:, i].unsqueeze(1)  # (batch_size, 1)

            # Get sorted boundaries for this feature
            boundaries = torch.sort(self.bin_boundaries[i]).values  # (num_bins-1)

            # Calculate bin activations using cumulative distribution
            # Start with all in the first bin
            bin_activations = torch.ones((batch_size, self.num_bins), device=x.device)

            # Update bin activations based on feature values and boundaries
            for j in range(self.num_bins-1):
                boundary = boundaries[j]
                # Calculate contribution to bins based on boundary comparison
                condition = feature_values > boundary
                # Move activations to next bin when condition is true
                bin_activations[:, j+1:] = torch.where(
                    condition.expand(-1, self.num_bins-j-1),
                    bin_activations[:, j:self.num_bins-1],
                    bin_activations[:, j+1:]
                )
                bin_activations[:, j] = torch.where(
                    condition.squeeze(1),
                    0.0,
                    bin_activations[:, j]
                )

            # Apply linear transformation to get embeddings
            feature_embedding = self.embeddings[i](bin_activations)  # (batch_size, d_token)
            embedded[:, i] = feature_embedding

        return embedded  # (batch_size, num_features, d_token)

# Custom attention module to capture attention weights
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # Ensure d_model is divisible by num_heads
        assert self.head_dim * num_heads == d_model, "d_model must be divisible by num_heads"

        # Linear projections
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

        # For storing attention weights
        self.attention_weights = None

    def forward(self, query, key, value, attn_mask=None):
        batch_size = query.shape[0]

        # Linear projections
        q = self.q_proj(query)  # (batch_size, seq_len, d_model)
        k = self.k_proj(key)    # (batch_size, seq_len, d_model)
        v = self.v_proj(value)  # (batch_size, seq_len, d_model)

        # Reshape for multi-head attention
        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # Calculate attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # Apply mask if provided
        if attn_mask is not None:
            scores = scores.masked_fill(attn_mask == 0, -1e9)

        # Apply softmax to get attention weights
        attention_weights = F.softmax(scores, dim=-1)
        self.attention_weights = attention_weights  # Store for later use

        # Apply attention weights to values
        out = torch.matmul(attention_weights, v)  # (batch_size, num_heads, seq_len, head_dim)

        # Reshape back
        out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

        # Final linear projection
        out = self.out_proj(out)

        return out

# Custom transformer layer to capture attention weights
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, nhead)

        # Feed-forward network
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        # Layer norm
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # Dropout
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src, src_mask=None):
        # Self-attention
        attn_output = self.self_attn(src, src, src, attn_mask=src_mask)
        src = src + self.dropout1(attn_output)
        src = self.norm1(src)

        # Feed-forward network
        ff_output = self.linear2(self.dropout(F.relu(self.linear1(src))))
        src = src + self.dropout2(ff_output)
        src = self.norm2(src)

        return src

class FTTransformer(nn.Module):
    def __init__(self, num_features, d_token=64, num_heads=8, num_layers=2,
                 d_ffn=128, dropout=0.1, embedding_type='linear', n_bins=20, num_classes=100):
        super().__init__()
        self.d_token = d_token
        self.num_features = num_features
        self.embedding_type = embedding_type
        self.num_classes = num_classes

        # CLS token parameter
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_token))

        # Feature tokenizer
        if embedding_type == 'linear':
            self.feature_tokenizer = LinearEmbedding(num_features, d_token)
        elif embedding_type == 'piecewise':
            self.feature_tokenizer = PiecewiseLinearEmbedding(num_features, d_token, num_bins=n_bins)
        else:
            raise ValueError(f"Unknown embedding type: {embedding_type}")

        # Feature positional embedding
        self.feature_pos_embedding = nn.Parameter(torch.randn(1, num_features, d_token))

        # Custom transformer layers
        self.transformer_layers = nn.ModuleList([
            TransformerEncoderLayer(d_model=d_token, nhead=num_heads,
                                   dim_feedforward=d_ffn, dropout=dropout)
            for _ in range(num_layers)
        ])

        # Output layer for multiclass classification
        self.output_layer = nn.Linear(d_token, num_classes)

    def forward(self, x):
        batch_size = x.shape[0]

        # Tokenize features
        tokens = self.feature_tokenizer(x)  # (batch_size, num_features, d_token)

        # Add positional embedding
        tokens = tokens + self.feature_pos_embedding

        # Add CLS token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        sequence = torch.cat([cls_tokens, tokens], dim=1)  # (batch_size, num_features+1, d_token)

        # Apply transformer layers
        for layer in self.transformer_layers:
            sequence = layer(sequence)

        # Use CLS token for prediction
        cls_output = sequence[:, 0]

        # Final prediction (logits)
        output = self.output_layer(cls_output)

        return output

    def get_cls_attention(self):
        """Return the attention weights from CLS token to feature tokens (average over all layers)"""
        # Average attention weights across all layers
        cls_attention = []

        for layer in self.transformer_layers:
            # Extract CLS token attention to features
            # layer_weights shape: (batch_size, num_heads, seq_len, seq_len)
            if layer.self_attn.attention_weights is not None:
                # Get attention from CLS (idx 0) to features (idx 1:)
                layer_weights = layer.self_attn.attention_weights
                cls_to_features = layer_weights[:, :, 0, 1:].mean(dim=1)  # Average over heads
                cls_attention.append(cls_to_features)
            else:
                raise ValueError("Attention weights not available. Run forward first.")

        # Average over layers
        avg_attention = torch.stack(cls_attention).mean(dim=0)
        return avg_attention

# Sparse attention variants
class sparseMultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # Ensure d_model is divisible by num_heads
        assert self.head_dim * num_heads == d_model, "d_model must be divisible by num_heads"

        # Linear projections
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

        # For storing attention weights
        self.attention_weights = None

    def forward(self, query, key, value, attn_mask=None):
        batch_size = query.shape[0]

        # Linear projections
        q = self.q_proj(query)  # (batch_size, seq_len, d_model)
        k = self.k_proj(key)    # (batch_size, seq_len, d_model)
        v = self.v_proj(value)  # (batch_size, seq_len, d_model)

        # Reshape for multi-head attention
        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # Calculate attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # Apply mask if provided
        if attn_mask is not None:
            scores = scores.masked_fill(attn_mask == 0, -1e9)

        # Apply sparsemax to get attention weights
        attention_weights = sparsemax(scores)
        self.attention_weights = attention_weights  # Store for later use

        # Apply attention weights to values
        out = torch.matmul(attention_weights, v)  # (batch_size, num_heads, seq_len, head_dim)

        # Reshape back
        out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

        # Final linear projection
        out = self.out_proj(out)

        return out

class sparseTransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1):
        super().__init__()
        self.self_attn = sparseMultiHeadAttention(d_model, nhead)

        # Feed-forward network
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        # Layer norm
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # Dropout
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src, src_mask=None):
        # Self-attention
        attn_output = self.self_attn(src, src, src, attn_mask=src_mask)
        src = src + self.dropout1(attn_output)
        src = self.norm1(src)

        # Feed-forward network
        ff_output = self.linear2(self.dropout(F.relu(self.linear1(src))))
        src = src + self.dropout2(ff_output)
        src = self.norm2(src)

        return src

class sparseFTTransformer(nn.Module):
    def __init__(self, num_features, d_token=64, num_heads=8, num_layers=2,
                 d_ffn=128, dropout=0.1, embedding_type='linear', n_bins=20, num_classes=100):
        super().__init__()
        self.d_token = d_token
        self.num_features = num_features
        self.embedding_type = embedding_type
        self.num_classes = num_classes

        # CLS token parameter
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_token))

        # Feature tokenizer
        if embedding_type == 'linear':
            self.feature_tokenizer = LinearEmbedding(num_features, d_token)
        elif embedding_type == 'piecewise':
            self.feature_tokenizer = PiecewiseLinearEmbedding(num_features, d_token, num_bins=n_bins)
        else:
            raise ValueError(f"Unknown embedding type: {embedding_type}")

        # Feature positional embedding
        self.feature_pos_embedding = nn.Parameter(torch.randn(1, num_features, d_token))

        # Custom transformer layers with sparse attention
        self.transformer_layers = nn.ModuleList([
            sparseTransformerEncoderLayer(d_model=d_token, nhead=num_heads,
                                   dim_feedforward=d_ffn, dropout=dropout)
            for _ in range(num_layers)
        ])

        # Output layer for multiclass classification
        self.output_layer = nn.Linear(d_token, num_classes)

    def forward(self, x):
        batch_size = x.shape[0]

        # Tokenize features
        tokens = self.feature_tokenizer(x)  # (batch_size, num_features, d_token)

        # Add positional embedding
        tokens = tokens + self.feature_pos_embedding

        # Add CLS token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        sequence = torch.cat([cls_tokens, tokens], dim=1)  # (batch_size, num_features+1, d_token)

        # Apply transformer layers
        for layer in self.transformer_layers:
            sequence = layer(sequence)

        # Use CLS token for prediction
        cls_output = sequence[:, 0]

        # Final prediction (logits)
        output = self.output_layer(cls_output)

        return output

    def get_cls_attention(self):
        """Return the attention weights from CLS token to feature tokens (average over all layers)"""
        # Average attention weights across all layers
        cls_attention = []

        for layer in self.transformer_layers:
            # Extract CLS token attention to features
            # layer_weights shape: (batch_size, num_heads, seq_len, seq_len)
            if layer.self_attn.attention_weights is not None:
                # Get attention from CLS (idx 0) to features (idx 1:)
                layer_weights = layer.self_attn.attention_weights
                cls_to_features = layer_weights[:, :, 0, 1:].mean(dim=1)  # Average over heads
                cls_attention.append(cls_to_features)
            else:
                raise ValueError("Attention weights not available. Run forward first.")

        # Average over layers
        avg_attention = torch.stack(cls_attention).mean(dim=0)
        return avg_attention

def calculate_pfi(model, X_val, y_val, num_permutations=5):
    """Calculate Permutation Feature Importance (PFI) for multiclass classification"""
    # Convert to PyTorch tensors
    X = torch.tensor(X_val, dtype=torch.float32).to(device)
    y = torch.tensor(y_val, dtype=torch.long).to(device)

    # Get baseline performance
    model.eval()
    with torch.no_grad():
        baseline_preds = model(X)
        baseline_loss = F.cross_entropy(baseline_preds, y).item()
        baseline_preds_class = torch.argmax(baseline_preds, dim=1)
        baseline_accuracy = (baseline_preds_class == y).float().mean().item()

    # Calculate importance for each feature
    importances = []

    for feat_idx in range(X.shape[1]):
        accuracies = []

        for _ in range(num_permutations):
            # Create a permuted copy of the data
            X_permuted = X.clone()

            # Permute the feature
            perm_idx = torch.randperm(X.shape[0])
            X_permuted[:, feat_idx] = X_permuted[perm_idx, feat_idx]

            # Calculate loss with permuted feature
            with torch.no_grad():
                perm_preds = model(X_permuted)
                perm_preds_class = torch.argmax(perm_preds, dim=1)
                perm_accuracy = (perm_preds_class == y).float().mean().item()

            # Feature importance is the decrease in accuracy
            accuracies.append(baseline_accuracy - perm_accuracy)

        # Average over permutations (higher = more important)
        importances.append(np.mean(accuracies))

    return np.array(importances)

def train_model(model, train_loader, val_loader, criterion, optimizer, device, epochs=100, early_stopping=16):
    """Train the model with early stopping"""
    model.to(device)
    best_val_loss = float('inf')
    early_stop_counter = 0
    best_state = None

    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0

        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            train_correct += (predicted == y_batch).sum().item()
            train_total += y_batch.size(0)

        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)

                outputs = model(X_batch)
                loss = criterion(outputs, y_batch)
                val_loss += loss.item()

                # Calculate accuracy
                _, predicted = torch.max(outputs.data, 1)
                val_correct += (predicted == y_batch).sum().item()
                val_total += y_batch.size(0)

        train_loss /= len(train_loader)
        train_accuracy = train_correct / train_total
        val_loss /= len(val_loader)
        val_accuracy = val_correct / val_total

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}")

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            early_stop_counter = 0
            # Save best model state dict
            best_state = model.state_dict()
        else:
            early_stop_counter += 1
            if early_stop_counter >= early_stopping:
                print(f"Early stopping at epoch {epoch+1}")
                break

    # Load best model
    if best_state is not None:
        model.load_state_dict(best_state)
    return model

def evaluate_model(model, X_test, y_test, device):
    """Evaluate model performance for multiclass classification"""
    model.to(device)
    model.eval()

    X = torch.tensor(X_test, dtype=torch.float32).to(device)
    y = torch.tensor(y_test, dtype=torch.long).to(device)

    with torch.no_grad():
        outputs = model(X)
        _, predicted = torch.max(outputs.data, 1)
        predicted = predicted.cpu().numpy()

    # Calculate metrics
    accuracy = accuracy_score(y_test, predicted)
    f1_macro = f1_score(y_test, predicted, average='macro')
    f1_weighted = f1_score(y_test, predicted, average='weighted')

    # Confusion matrix (might be large for many classes)
    # If there are many classes, we'll just save the confusion matrix without printing it
    cm = confusion_matrix(y_test, predicted)

    print(f"Test Accuracy: {accuracy:.4f}")
    print(f"Test Macro F1 Score: {f1_macro:.4f}")
    print(f"Test Weighted F1 Score: {f1_weighted:.4f}")
    print(f"Confusion Matrix shape: {cm.shape}")

    return {
        'accuracy': accuracy,
        'f1_macro': f1_macro,
        'f1_weighted': f1_weighted,
        'confusion_matrix': cm.tolist() if cm.shape[0] <= 10 else "Too large to include"
    }

def analyze_pfi_attention_correlation(model, X_val, y_val, feature_names, device):
    """Analyze correlation between PFI and attention scores for multiclass classification"""
    model.to(device)
    model.eval()

    # Get attention scores
    X = torch.tensor(X_val, dtype=torch.float32).to(device)
    with torch.no_grad():
        _ = model(X)  # Forward pass to compute attention
        attention_scores = model.get_cls_attention().cpu().numpy()

    # Average attention scores across samples
    avg_attention = attention_scores.mean(axis=0)

    # Calculate PFI
    pfi_scores = calculate_pfi(model, X_val, y_val)

    # Calculate Spearman rank correlation
    correlation, p_value = spearmanr(pfi_scores, avg_attention)

    print(f"Spearman Rank Correlation: {correlation:.4f} (p-value: {p_value:.4f})")

    # Create a visualization
    fig, ax = plt.subplots(figsize=(12, 8))

    # Create a scatter plot
    scatter = ax.scatter(pfi_scores, avg_attention, alpha=0.7)

    # Add feature labels (if not too many)
    if len(feature_names) <= 30:
        for i, name in enumerate(feature_names):
            ax.annotate(name, (pfi_scores[i], avg_attention[i]),
                       textcoords="offset points", xytext=(0,10), ha='center', fontsize=8)

    # Add best fit line
    z = np.polyfit(pfi_scores, avg_attention, 1)
    p = np.poly1d(z)
    ax.plot(np.sort(pfi_scores), p(np.sort(pfi_scores)), "r--", alpha=0.7)

    # Add correlation information
    ax.text(0.05, 0.95, f"Spearman ρ: {correlation:.4f}\np-value: {p_value:.4f}",
            transform=ax.transAxes, verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

    ax.set_xlabel('Permutation Feature Importance')
    ax.set_ylabel('CLS Token Attention Score')
    ax.set_title('PFI vs CLS Token Attention Correlation')

    plt.tight_layout()
    plt.savefig('pfi_attention_correlation_helena.png')
    plt.close()

    # Return results
    results = {
        'correlation': correlation,
        'p_value': p_value,
        'pfi_scores': pfi_scores.tolist(),
        'attention_scores': avg_attention.tolist(),
        'feature_names': feature_names
    }

    return results

def load_helena_dataset():
    """
    Load the Helena dataset (anonymized dataset from Guyon et al. 2019) from OpenML.

    Returns:
    - X_train, X_val, X_test, y_train, y_val, y_test, feature_names, num_classes
    """
    print("Fetching Helena dataset from OpenML...")
    data = fetch_openml(name="helena", version=1, as_frame=False)
    X = data.data
    y = data.target.astype(np.int32)  # Ensure target is of integer type
    feature_names = data.feature_names if hasattr(data, "feature_names") else [f"feature_{i+1}" for i in range(X.shape[1])]

    # Determine number of classes
    num_classes = len(np.unique(y))

    # Split the data into training, validation, and test sets
    X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
    X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp)

    # Impute missing values using median imputation
    imputer = SimpleImputer(strategy='median')
    X_train = imputer.fit_transform(X_train)
    X_val = imputer.transform(X_val)
    X_test = imputer.transform(X_test)

    # Standardize features
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_val = scaler.transform(X_val)
    X_test = scaler.transform(X_test)

    print(f"Dataset loaded - Train: {X_train.shape}, Val: {X_val.shape}, Test: {X_test.shape}")
    print(f"Number of classes: {num_classes}")

    return X_train, X_val, X_test, y_train, y_val, y_test, feature_names, num_classes



def tune_hyperparameters(X_train, y_train, X_val, y_val, num_classes, embedding_type='linear', n_trials=20, sparse=False):
    """Tune hyperparameters using Optuna for multiclass classification"""

    # Create datasets
    train_dataset = TabularDataset(X_train, y_train)
    val_dataset = TabularDataset(X_val, y_val)

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=8192)

    num_features = X_train.shape[1]

    def objective(trial):
        # Define hyperparameters to tune
        d_token = trial.suggest_int('d_token', 32, 128)
        num_heads = trial.suggest_int('num_heads', 2, 8)
        num_layers = trial.suggest_int('num_layers', 1, 3)
        d_ffn = trial.suggest_int('d_ffn', 64, 256)
        lr = trial.suggest_float('lr', 1e-4, 1e-2, log=True)
        dropout = trial.suggest_float('dropout', 0.0, 0.5)
        n_bins = trial.suggest_int('n_bins', 2, 50)

        # Ensure d_token is divisible by num_heads
        d_token = (d_token // num_heads) * num_heads

        # Create model with trial hyperparameters
        if not sparse:
            model = FTTransformer(
                num_features=num_features,
                d_token=d_token,
                num_heads=num_heads,
                num_layers=num_layers,
                d_ffn=d_ffn,
                dropout=dropout,
                embedding_type=embedding_type,
                n_bins=n_bins,
                num_classes=num_classes
            )
        else:
            model = sparseFTTransformer(
                num_features=num_features,
                d_token=d_token,
                num_heads=num_heads,
                num_layers=num_layers,
                d_ffn=d_ffn,
                dropout=dropout,
                embedding_type=embedding_type,
                n_bins=n_bins,
                num_classes=num_classes
            )

        # Define criterion and optimizer for multiclass classification
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

        # Train for a few epochs
        model.to(device)
        best_val_loss = float('inf')

        patience = 5
        patience_counter = 0
        num_epochs = 20

        # Short training loop for hyperparameter search
        for epoch in range(num_epochs):
            # Training
            model.train()
            for X_batch, y_batch in train_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)

                optimizer.zero_grad()
                outputs = model(X_batch)
                loss = criterion(outputs, y_batch)
                loss.backward()
                optimizer.step()

            # Validation
            model.eval()
            val_loss = 0
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                for X_batch, y_batch in val_loader:
                    X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                    outputs = model(X_batch)
                    loss = criterion(outputs, y_batch)
                    val_loss += loss.item()

                    # Calculate accuracy
                    _, predicted = torch.max(outputs.data, 1)
                    val_correct += (predicted == y_batch).sum().item()
                    val_total += y_batch.size(0)

            val_loss /= len(val_loader)
            val_accuracy = val_correct / val_total

            # Update best validation loss
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter > patience:
                    break

            trial.report(val_loss, epoch)

            if trial.should_prune():
                raise optuna.TrialPruned()

        return best_val_loss

    # Create Optuna study
    study = optuna.create_study(
        direction="minimize",
        pruner=optuna.pruners.MedianPruner( # https://optuna.readthedocs.io/en/stable/reference/generated/optuna.pruners.MedianPruner.html
            n_startup_trials=5,
            n_warmup_steps=10,
            interval_steps=2
        )
    )
    study.optimize(objective, n_trials=n_trials)

    # Print best parameters
    print("Best trial:")
    trial = study.best_trial
    print(f"  Value (validation loss): {trial.value:.4f}")
    print("  Params:")
    for key, value in trial.params.items():
        print(f"    {key}: {value}")

    # Return best parameters
    return trial.params

def train_with_best_params(X_train, y_train, X_val, y_val, X_test, y_test, best_params,
                         num_classes, embedding_type='linear', sparse=False):
    """Train a model with the best hyperparameters for multiclass classification"""
    # Create datasets
    train_dataset = TabularDataset(X_train, y_train)
    val_dataset = TabularDataset(X_val, y_val)

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=8192)

    # Ensure d_token is divisible by num_heads
    d_token = (best_params['d_token'] // best_params['num_heads']) * best_params['num_heads']

    # Create model with best hyperparameters
    if not sparse:
        model = FTTransformer(
            num_features=X_train.shape[1],
            d_token=d_token,
            num_heads=best_params['num_heads'],
            num_layers=best_params['num_layers'],
            d_ffn=best_params['d_ffn'],
            dropout=best_params['dropout'],
            embedding_type=embedding_type,
            n_bins=best_params['n_bins'],
            num_classes=num_classes
        )
    else:
        model = sparseFTTransformer(
            num_features=X_train.shape[1],
            d_token=d_token,
            num_heads=best_params['num_heads'],
            num_layers=best_params['num_layers'],
            d_ffn=best_params['d_ffn'],
            dropout=best_params['dropout'],
            embedding_type=embedding_type,
            n_bins=best_params['n_bins'],
            num_classes=num_classes
        )

    # Define criterion for multiclass classification
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=best_params['lr'], weight_decay=1e-5)

    # Train the model with early stopping
    model = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        device=device,
        epochs=100
    )

    # Evaluate on test set
    results = evaluate_model(model, X_test, y_test, device)

    return model, results

def visualize_all_models(results, feature_names):
    """Create a comprehensive visualization comparing all models"""

    # Create a figure with 2x2 subplots
    fig, axs = plt.subplots(2, 2, figsize=(20, 16))

    # Plot performance metrics
    models = list(results.keys())
    acc_values = [results[model]['accuracy'] for model in models]
    f1_macro_values = [results[model]['f1_macro'] for model in models]
    f1_weighted_values = [results[model]['f1_weighted'] for model in models]

    # Accuracy comparison
    axs[0, 0].bar(models, acc_values)
    axs[0, 0].set_title('Accuracy Comparison (Helena)')
    axs[0, 0].set_ylabel('Accuracy')
    axs[0, 0].tick_params(axis='x', rotation=45)

    # F1 score comparison
    x = np.arange(len(models))
    width = 0.35
    axs[0, 1].bar(x - width/2, f1_macro_values, width, label='F1 Macro')
    axs[0, 1].bar(x + width/2, f1_weighted_values, width, label='F1 Weighted')
    axs[0, 1].set_title('F1 Score Comparison (Helena)')
    axs[0, 1].set_xticks(x)
    axs[0, 1].set_xticklabels(models, rotation=45)
    axs[0, 1].set_ylabel('F1 Score')
    axs[0, 1].legend()

    # Correlation comparison
    correlations = [results[model]['correlation_analysis']['correlation'] for model in models]
    p_values = [results[model]['correlation_analysis']['p_value'] for model in models]

    axs[1, 0].bar(models, correlations)
    axs[1, 0].set_title('PFI-Attention Correlation Comparison (Helena)')
    axs[1, 0].set_ylabel('Spearman Correlation')
    axs[1, 0].tick_params(axis='x', rotation=45)

    # Feature importance summary (top 10 features)
    axs[1, 1].axis('off')  # Turn off the axis for the text summary

    summary_text = "Top 10 Important Features Summary:\n\n"

    for model in models:
        pfi_scores = np.array(results[model]['correlation_analysis']['pfi_scores'])
        attn_scores = np.array(results[model]['correlation_analysis']['attention_scores'])
        feature_names = results[model]['correlation_analysis']['feature_names']

        # Get top 10 features by PFI
        pfi_top_indices = np.argsort(-pfi_scores)[:10]
        pfi_top_features = [feature_names[i] for i in pfi_top_indices]

        # Get top 10 features by attention
        attn_top_indices = np.argsort(-attn_scores)[:10]
        attn_top_features = [feature_names[i] for i in attn_top_indices]

        summary_text += f"{model}:\n"
        summary_text += f"  Top PFI features: {', '.join(pfi_top_features)}\n"
        summary_text += f"  Top attention features: {', '.join(attn_top_features)}\n\n"

    axs[1, 1].text(0.05, 0.95, summary_text, transform=axs[1, 1].transAxes,
                 verticalalignment='top', fontsize=12)

    plt.tight_layout()
    plt.savefig('model_comparison_helena.png')
    plt.close()

    # Create additional visualization for feature importance comparison
    # For top 20 features only to avoid cluttering
    num_top_features = min(20, len(feature_names))
    fig, axs = plt.subplots(len(models), 1, figsize=(14, 5 * len(models)))

    if len(models) == 1:
        axs = [axs]  # Convert to list if there's only one model

    for i, model in enumerate(models):
        pfi_scores = np.array(results[model]['correlation_analysis']['pfi_scores'])
        attn_scores = np.array(results[model]['correlation_analysis']['attention_scores'])
        feature_names = results[model]['correlation_analysis']['feature_names']

        # Sort features by PFI for visualization (top 20)
        sorted_indices = np.argsort(-pfi_scores)[:num_top_features]
        sorted_features = [feature_names[j] for j in sorted_indices]
        sorted_pfi = [pfi_scores[j] for j in sorted_indices]
        sorted_attn = [attn_scores[j] for j in sorted_indices]

        x = np.arange(len(sorted_features))
        width = 0.35

        axs[i].bar(x - width/2, sorted_pfi, width, label='PFI')
        axs[i].bar(x + width/2, sorted_attn, width, label='Attention')

        axs[i].set_title(f'Top 20 Feature Importance (Helena): {model}')
        axs[i].set_ylabel('Importance Score')
        axs[i].set_xticks(x)
        axs[i].set_xticklabels(sorted_features, rotation=45, ha='right')
        axs[i].legend()

    plt.tight_layout()
    plt.savefig('feature_importance_comparison_helena.png')
    plt.close()

    print("\nVisualizations saved as 'model_comparison_helena.png' and 'feature_importance_comparison_helena.png'")

def save_model(model, filename):
    torch.save(model.state_dict(), filename)
    print(f"Model saved as {filename}")

def save_results(results, filename):
    # Convert numpy arrays to lists for json serialization
    for model in results:
        if 'correlation_analysis' in results[model]:
            if isinstance(results[model]['correlation_analysis']['pfi_scores'], np.ndarray):
                results[model]['correlation_analysis']['pfi_scores'] = results[model]['correlation_analysis']['pfi_scores'].tolist()
            if isinstance(results[model]['correlation_analysis']['attention_scores'], np.ndarray):
                results[model]['correlation_analysis']['attention_scores'] = results[model]['correlation_analysis']['attention_scores'].tolist()

    with open(filename, 'w') as f:
        json.dump(results, f, indent=2)
    print(f"Results saved as {filename}")

def main_with_tuning():
    """Main function to run the Helena dataset experiments"""
    # Set random seed for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)

    # Set device
    print(f"Using device: {device}")

    # Load Helena dataset
    X_train, X_val, X_test, y_train, y_val, y_test, feature_names, num_classes = load_helena_dataset()

    models = {}
    results = {}

    # Tune hyperparameters for Linear Embedding
    print("\n=== Tuning Hyperparameters for FT Transformer with Linear Embedding ===")
    linear_best_params = tune_hyperparameters(
        X_train=X_train,
        y_train=y_train,
        X_val=X_val,
        y_val=y_val,
        num_classes=num_classes,
        embedding_type='linear',
        n_trials=20
    )

    # Train with best parameters for Linear Embedding
    print("\n=== Training FT Transformer with Linear Embedding (Tuned) ===")
    ft_linear_tuned, linear_results = train_with_best_params(
        X_train=X_train,
        y_train=y_train,
        X_val=X_val,
        y_val=y_val,
        X_test=X_test,
        y_test=y_test,
        best_params=linear_best_params,
        num_classes=num_classes,
        embedding_type='linear'
    )

    save_model(ft_linear_tuned, 'ft_linear_tuned_helena.pth')

    models['ft_linear_tuned'] = ft_linear_tuned
    results['ft_linear_tuned'] = linear_results

    # Analyze PFI and attention correlation for tuned linear model
    print("\n=== Analyzing PFI vs Attention Correlation for Linear Embedding (Tuned) ===")
    linear_tuned_correlation = analyze_pfi_attention_correlation(
        model=ft_linear_tuned,
        X_val=X_val,
        y_val=y_val,
        feature_names=feature_names,
        device=device
    )
    results['ft_linear_tuned']['correlation_analysis'] = linear_tuned_correlation

    # Tune hyperparameters for Piecewise Linear Embedding
    print("\n=== Tuning Hyperparameters for FT Transformer with Piecewise Linear Embedding ===")
    piecewise_best_params = tune_hyperparameters(
        X_train=X_train,
        y_train=y_train,
        X_val=X_val,
        y_val=y_val,
        num_classes=num_classes,
        embedding_type='piecewise',
        n_trials=20
    )

    # Train with best parameters for Piecewise Linear Embedding
    print("\n=== Training FT Transformer with Piecewise Linear Embedding (Tuned) ===")
    ft_piecewise_tuned, piecewise_results = train_with_best_params(
        X_train=X_train,
        y_train=y_train,
        X_val=X_val,
        y_val=y_val,
        X_test=X_test,
        y_test=y_test,
        best_params=piecewise_best_params,
        num_classes=num_classes,
        embedding_type='piecewise'
    )

    save_model(ft_piecewise_tuned, 'ft_piecewise_tuned_helena.pth')

    models['ft_piecewise_tuned'] = ft_piecewise_tuned
    results['ft_piecewise_tuned'] = piecewise_results

    # Analyze PFI and attention correlation for tuned piecewise model
    print("\n=== Analyzing PFI vs Attention Correlation for Piecewise Embedding (Tuned) ===")
    piecewise_tuned_correlation = analyze_pfi_attention_correlation(
        model=ft_piecewise_tuned,
        X_val=X_val,
        y_val=y_val,
        feature_names=feature_names,
        device=device
    )
    results['ft_piecewise_tuned']['correlation_analysis'] = piecewise_tuned_correlation

    # Tune hyperparameters for Sparse Linear Embedding
    print("\n=== Tuning Hyperparameters for sparse FT Transformer with Linear Embedding ===")
    sparse_linear_best_params = tune_hyperparameters(
        X_train=X_train,
        y_train=y_train,
        X_val=X_val,
        y_val=y_val,
        num_classes=num_classes,
        embedding_type='linear',
        n_trials=20,
        sparse=True
    )

    # Train with best parameters for Sparse Linear Embedding
    print("\n=== Training sparse FT Transformer with Linear Embedding (Tuned) ===")
    sparse_ft_linear_tuned, sparse_linear_results = train_with_best_params(
        X_train=X_train,
        y_train=y_train,
        X_val=X_val,
        y_val=y_val,
        X_test=X_test,
        y_test=y_test,
        best_params=sparse_linear_best_params,
        num_classes=num_classes,
        embedding_type='linear',
        sparse=True
    )

    save_model(sparse_ft_linear_tuned, 'sparse_ft_linear_tuned_helena.pth')

    models['sparse_ft_linear_tuned'] = sparse_ft_linear_tuned
    results['sparse_ft_linear_tuned'] = sparse_linear_results

    # Analyze PFI and attention correlation for tuned sparse linear model
    print("\n=== Analyzing PFI vs Attention Correlation for Sparse Linear Embedding (Tuned) ===")
    sparse_linear_tuned_correlation = analyze_pfi_attention_correlation(
        model=sparse_ft_linear_tuned,
        X_val=X_val,
        y_val=y_val,
        feature_names=feature_names,
        device=device
    )
    results['sparse_ft_linear_tuned']['correlation_analysis'] = sparse_linear_tuned_correlation

    # Tune hyperparameters for Sparse Piecewise Embedding
    print("\n=== Tuning Hyperparameters for sparse FT Transformer with Piecewise Linear Embedding ===")
    sparse_piecewise_best_params = tune_hyperparameters(
        X_train=X_train,
        y_train=y_train,
        X_val=X_val,
        y_val=y_val,
        num_classes=num_classes,
        embedding_type='piecewise',
        n_trials=20,
        sparse=True
    )

    # Train with best parameters for Sparse Piecewise Embedding
    print("\n=== Training sparse FT Transformer with Piecewise Linear Embedding (Tuned) ===")
    sparse_ft_piecewise_tuned, sparse_piecewise_results = train_with_best_params(
        X_train=X_train,
        y_train=y_train,
        X_val=X_val,
        y_val=y_val,
        X_test=X_test,
        y_test=y_test,
        best_params=sparse_piecewise_best_params,
        num_classes=num_classes,
        embedding_type='piecewise',
        sparse=True
    )

    save_model(sparse_ft_piecewise_tuned, 'sparse_ft_piecewise_tuned_helena.pth')

    models['sparse_ft_piecewise_tuned'] = sparse_ft_piecewise_tuned
    results['sparse_ft_piecewise_tuned'] = sparse_piecewise_results

    # Analyze PFI and attention correlation for tuned sparse piecewise model
    print("\n=== Analyzing PFI vs Attention Correlation for Sparse Piecewise Embedding (Tuned) ===")
    sparse_piecewise_tuned_correlation = analyze_pfi_attention_correlation(
        model=sparse_ft_piecewise_tuned,
        X_val=X_val,
        y_val=y_val,
        feature_names=feature_names,
        device=device
    )
    results['sparse_ft_piecewise_tuned']['correlation_analysis'] = sparse_piecewise_tuned_correlation

    # Compare the results
    print("\n=== Comparison of Tuned Models ===")
    print(f"FT Transformer (Linear Tuned): Accuracy={results['ft_linear_tuned']['accuracy']:.4f}, F1 Macro={results['ft_linear_tuned']['f1_macro']:.4f}")
    print(f"FT Transformer (Piecewise Tuned): Accuracy={results['ft_piecewise_tuned']['accuracy']:.4f}, F1 Macro={results['ft_piecewise_tuned']['f1_macro']:.4f}")
    print(f"Sparse FT Transformer (Linear Tuned): Accuracy={results['sparse_ft_linear_tuned']['accuracy']:.4f}, F1 Macro={results['sparse_ft_linear_tuned']['f1_macro']:.4f}")
    print(f"Sparse FT Transformer (Piecewise Tuned): Accuracy={results['sparse_ft_piecewise_tuned']['accuracy']:.4f}, F1 Macro={results['sparse_ft_piecewise_tuned']['f1_macro']:.4f}")

    print("\n=== Comparison of PFI-Attention Correlations (Tuned Models) ===")
    print(f"FT Transformer (Linear Tuned): ρ={linear_tuned_correlation['correlation']:.4f}, p-value={linear_tuned_correlation['p_value']:.4f}")
    print(f"FT Transformer (Piecewise Tuned): ρ={piecewise_tuned_correlation['correlation']:.4f}, p-value={piecewise_tuned_correlation['p_value']:.4f}")
    print(f"Sparse FT Transformer (Linear Tuned): ρ={sparse_linear_tuned_correlation['correlation']:.4f}, p-value={sparse_linear_tuned_correlation['p_value']:.4f}")
    print(f"Sparse FT Transformer (Piecewise Tuned): ρ={sparse_piecewise_tuned_correlation['correlation']:.4f}, p-value={sparse_piecewise_tuned_correlation['p_value']:.4f}")

    # Create visualization comparing all models
    visualize_all_models(
        results=results,
        feature_names=feature_names
    )

    save_results(results, 'results_helena.json')

    return models, results

if __name__ == "__main__":
    main_with_tuning()

Using device: cuda
Fetching Helena dataset from OpenML...


[I 2025-03-05 11:32:15,346] A new study created in memory with name: no-name-5bafd14a-e45a-4b5e-baeb-ba3446770ba3


Dataset loaded - Train: (45637, 27), Val: (9779, 27), Test: (9780, 27)
Number of classes: 100

=== Tuning Hyperparameters for FT Transformer with Linear Embedding ===


[I 2025-03-05 11:32:53,419] Trial 0 finished with value: 2.8719139099121094 and parameters: {'d_token': 86, 'num_heads': 6, 'num_layers': 1, 'd_ffn': 212, 'lr': 0.0070009447743713685, 'dropout': 0.4187854322123039, 'n_bins': 45}. Best is trial 0 with value: 2.8719139099121094.
[I 2025-03-05 11:33:23,069] Trial 1 finished with value: 2.761746644973755 and parameters: {'d_token': 34, 'num_heads': 6, 'num_layers': 1, 'd_ffn': 240, 'lr': 0.0009297455475909623, 'dropout': 0.1888155081719124, 'n_bins': 43}. Best is trial 1 with value: 2.761746644973755.
[I 2025-03-05 11:33:52,523] Trial 2 finished with value: 2.977734923362732 and parameters: {'d_token': 36, 'num_heads': 2, 'num_layers': 1, 'd_ffn': 234, 'lr': 0.0005302899783373215, 'dropout': 0.43453178499032824, 'n_bins': 38}. Best is trial 1 with value: 2.761746644973755.
[I 2025-03-05 11:34:29,517] Trial 3 finished with value: 2.8524181842803955 and parameters: {'d_token': 97, 'num_heads': 2, 'num_layers': 2, 'd_ffn': 129, 'lr': 0.000142

Best trial:
  Value (validation loss): 2.5874
  Params:
    d_token: 107
    num_heads: 8
    num_layers: 2
    d_ffn: 148
    lr: 0.0011307753683893033
    dropout: 0.2404640502448573
    n_bins: 22

=== Training FT Transformer with Linear Embedding (Tuned) ===
Epoch 1/100, Train Loss: 3.5731, Train Acc: 0.1939, Val Loss: 3.0334, Val Acc: 0.2850
Epoch 2/100, Train Loss: 2.9930, Train Acc: 0.2902, Val Loss: 2.8764, Val Acc: 0.3082
Epoch 3/100, Train Loss: 2.8754, Train Acc: 0.3092, Val Loss: 2.7850, Val Acc: 0.3231
Epoch 4/100, Train Loss: 2.8094, Train Acc: 0.3206, Val Loss: 2.7592, Val Acc: 0.3319
Epoch 5/100, Train Loss: 2.7664, Train Acc: 0.3286, Val Loss: 2.7141, Val Acc: 0.3389
Epoch 6/100, Train Loss: 2.7316, Train Acc: 0.3352, Val Loss: 2.6960, Val Acc: 0.3472
Epoch 7/100, Train Loss: 2.7095, Train Acc: 0.3397, Val Loss: 2.6794, Val Acc: 0.3494
Epoch 8/100, Train Loss: 2.6890, Train Acc: 0.3445, Val Loss: 2.6657, Val Acc: 0.3536
Epoch 9/100, Train Loss: 2.6665, Train Acc: 0.348

[I 2025-03-05 11:47:11,502] A new study created in memory with name: no-name-d210ac2a-074b-4096-a308-78e34ff7a722



=== Tuning Hyperparameters for FT Transformer with Piecewise Linear Embedding ===


[I 2025-03-05 11:48:47,858] Trial 0 finished with value: 2.8392539024353027 and parameters: {'d_token': 52, 'num_heads': 4, 'num_layers': 2, 'd_ffn': 86, 'lr': 0.00045537765326472305, 'dropout': 0.20300343379917768, 'n_bins': 10}. Best is trial 0 with value: 2.8392539024353027.
[I 2025-03-05 11:51:09,484] Trial 1 finished with value: 2.6562435626983643 and parameters: {'d_token': 99, 'num_heads': 8, 'num_layers': 2, 'd_ffn': 252, 'lr': 0.0005349324830889021, 'dropout': 0.11359685896906307, 'n_bins': 16}. Best is trial 1 with value: 2.6562435626983643.
[I 2025-03-05 11:52:15,488] Trial 2 finished with value: 3.02019464969635 and parameters: {'d_token': 112, 'num_heads': 4, 'num_layers': 3, 'd_ffn': 86, 'lr': 0.00021016683024414445, 'dropout': 0.226254410044479, 'n_bins': 3}. Best is trial 1 with value: 2.6562435626983643.
[I 2025-03-05 11:57:14,206] Trial 3 finished with value: 2.7028138637542725 and parameters: {'d_token': 52, 'num_heads': 6, 'num_layers': 2, 'd_ffn': 191, 'lr': 0.0079

Best trial:
  Value (validation loss): 2.6251
  Params:
    d_token: 73
    num_heads: 7
    num_layers: 3
    d_ffn: 225
    lr: 0.0011903623409526525
    dropout: 0.3135770354416086
    n_bins: 32

=== Training FT Transformer with Piecewise Linear Embedding (Tuned) ===
Epoch 1/100, Train Loss: 3.8407, Train Acc: 0.1373, Val Loss: 3.2743, Val Acc: 0.2461
Epoch 2/100, Train Loss: 3.1646, Train Acc: 0.2607, Val Loss: 2.9696, Val Acc: 0.2940
Epoch 3/100, Train Loss: 2.9857, Train Acc: 0.2906, Val Loss: 2.8882, Val Acc: 0.3074
Epoch 4/100, Train Loss: 2.8992, Train Acc: 0.3031, Val Loss: 2.8021, Val Acc: 0.3204
Epoch 5/100, Train Loss: 2.8499, Train Acc: 0.3105, Val Loss: 2.7589, Val Acc: 0.3346
Epoch 6/100, Train Loss: 2.8049, Train Acc: 0.3203, Val Loss: 2.7326, Val Acc: 0.3385
Epoch 7/100, Train Loss: 2.7764, Train Acc: 0.3264, Val Loss: 2.7154, Val Acc: 0.3425
Epoch 8/100, Train Loss: 2.7493, Train Acc: 0.3325, Val Loss: 2.7072, Val Acc: 0.3470
Epoch 9/100, Train Loss: 2.7427, Train A

[I 2025-03-05 12:53:44,074] A new study created in memory with name: no-name-1df674ea-4436-452b-b0f7-4a56a5889f0f



=== Tuning Hyperparameters for sparse FT Transformer with Linear Embedding ===


[I 2025-03-05 12:55:22,090] Trial 0 finished with value: 2.626972198486328 and parameters: {'d_token': 97, 'num_heads': 7, 'num_layers': 2, 'd_ffn': 217, 'lr': 0.0033274412966155112, 'dropout': 0.008588973825872648, 'n_bins': 39}. Best is trial 0 with value: 2.626972198486328.
[I 2025-03-05 12:56:23,439] Trial 1 finished with value: 2.7228076457977295 and parameters: {'d_token': 80, 'num_heads': 4, 'num_layers': 2, 'd_ffn': 73, 'lr': 0.004401761424535928, 'dropout': 0.1614705086902602, 'n_bins': 48}. Best is trial 0 with value: 2.626972198486328.
[I 2025-03-05 12:56:58,829] Trial 2 finished with value: 2.867244839668274 and parameters: {'d_token': 125, 'num_heads': 2, 'num_layers': 1, 'd_ffn': 152, 'lr': 0.0026458101950654496, 'dropout': 0.006188480817093001, 'n_bins': 7}. Best is trial 0 with value: 2.626972198486328.
[I 2025-03-05 12:58:51,743] Trial 3 finished with value: 2.6743968725204468 and parameters: {'d_token': 116, 'num_heads': 5, 'num_layers': 3, 'd_ffn': 114, 'lr': 0.00011

Best trial:
  Value (validation loss): 2.6270
  Params:
    d_token: 97
    num_heads: 7
    num_layers: 2
    d_ffn: 217
    lr: 0.0033274412966155112
    dropout: 0.008588973825872648
    n_bins: 39

=== Training sparse FT Transformer with Linear Embedding (Tuned) ===
Epoch 1/100, Train Loss: 3.2306, Train Acc: 0.2502, Val Loss: 2.9353, Val Acc: 0.2936
Epoch 2/100, Train Loss: 2.8608, Train Acc: 0.3101, Val Loss: 2.8091, Val Acc: 0.3161
Epoch 3/100, Train Loss: 2.7828, Train Acc: 0.3254, Val Loss: 2.8022, Val Acc: 0.3227
Epoch 4/100, Train Loss: 2.7317, Train Acc: 0.3323, Val Loss: 2.7356, Val Acc: 0.3343
Epoch 5/100, Train Loss: 2.6943, Train Acc: 0.3408, Val Loss: 2.7337, Val Acc: 0.3331
Epoch 6/100, Train Loss: 2.6644, Train Acc: 0.3458, Val Loss: 2.6930, Val Acc: 0.3369
Epoch 7/100, Train Loss: 2.6241, Train Acc: 0.3534, Val Loss: 2.6969, Val Acc: 0.3467
Epoch 8/100, Train Loss: 2.6105, Train Acc: 0.3563, Val Loss: 2.6723, Val Acc: 0.3486
Epoch 9/100, Train Loss: 2.5907, Train Ac

[I 2025-03-05 13:21:20,199] A new study created in memory with name: no-name-261a8f5a-07a0-4957-bdc2-02ab6ced1bf5



=== Tuning Hyperparameters for sparse FT Transformer with Piecewise Linear Embedding ===


[I 2025-03-05 13:24:19,225] Trial 0 finished with value: 2.731738567352295 and parameters: {'d_token': 46, 'num_heads': 6, 'num_layers': 1, 'd_ffn': 184, 'lr': 0.001141667302185147, 'dropout': 0.15978389722580766, 'n_bins': 21}. Best is trial 0 with value: 2.731738567352295.
[I 2025-03-05 13:31:32,944] Trial 1 finished with value: 2.775459408760071 and parameters: {'d_token': 68, 'num_heads': 8, 'num_layers': 3, 'd_ffn': 107, 'lr': 0.00013472188613684137, 'dropout': 0.1815983911170434, 'n_bins': 49}. Best is trial 0 with value: 2.731738567352295.
[I 2025-03-05 13:36:55,302] Trial 2 finished with value: 2.7333595752716064 and parameters: {'d_token': 59, 'num_heads': 7, 'num_layers': 3, 'd_ffn': 71, 'lr': 0.0004986413713264766, 'dropout': 0.48763975891332956, 'n_bins': 34}. Best is trial 0 with value: 2.731738567352295.
[I 2025-03-05 13:39:19,771] Trial 3 finished with value: 2.8452706336975098 and parameters: {'d_token': 59, 'num_heads': 5, 'num_layers': 1, 'd_ffn': 193, 'lr': 0.0046287

Best trial:
  Value (validation loss): 2.6223
  Params:
    d_token: 102
    num_heads: 5
    num_layers: 3
    d_ffn: 255
    lr: 0.00074498713248751
    dropout: 0.2562146574077927
    n_bins: 37

=== Training sparse FT Transformer with Piecewise Linear Embedding (Tuned) ===
Epoch 1/100, Train Loss: 3.6953, Train Acc: 0.1750, Val Loss: 3.1780, Val Acc: 0.2652
Epoch 2/100, Train Loss: 3.1186, Train Acc: 0.2737, Val Loss: 2.9484, Val Acc: 0.2997
Epoch 3/100, Train Loss: 2.9588, Train Acc: 0.2973, Val Loss: 2.8645, Val Acc: 0.3095
Epoch 4/100, Train Loss: 2.8818, Train Acc: 0.3099, Val Loss: 2.7892, Val Acc: 0.3277
Epoch 5/100, Train Loss: 2.8283, Train Acc: 0.3207, Val Loss: 2.7633, Val Acc: 0.3288
Epoch 6/100, Train Loss: 2.7915, Train Acc: 0.3254, Val Loss: 2.7494, Val Acc: 0.3393
Epoch 7/100, Train Loss: 2.7546, Train Acc: 0.3333, Val Loss: 2.7124, Val Acc: 0.3445
Epoch 8/100, Train Loss: 2.7281, Train Acc: 0.3371, Val Loss: 2.7111, Val Acc: 0.3474
Epoch 9/100, Train Loss: 2.7134, T

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from google.colab import drive

def save_all_files_to_drive(folder_name='Jannis_Files'):
  """Saves all files in the current Colab environment to a new folder in Google Drive.

  Args:
    folder_name: The name of the folder to create in Google Drive. Defaults to 'Colab_Files'.
  """

  # Mount Google Drive
  drive.mount('/content/drive')

  # Create the folder in Google Drive
  folder_path = os.path.join('/content/drive/My Drive', folder_name)
  os.makedirs(folder_path, exist_ok=True)

  # Get a list of all files in the current directory
  files = os.listdir('.')

  # Copy each file to the Google Drive folder
  for file in files:
    source_path = os.path.join('.', file)
    destination_path = os.path.join(folder_path, file)
    os.system(f'cp "{source_path}" "{destination_path}"')  # Using os.system for file copying

  print(f"All files saved to Google Drive: /content/drive/My Drive/{folder_name}")

save_all_files_to_drive()

Mounted at /content/drive
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
All files saved to Google Drive: /content/drive/My Drive/Jannis_Files
