In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from sklearn.preprocessing import StandardScaler
from scipy.stats import pearsonr, spearmanr
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import timedelta
import warnings
warnings.filterwarnings('ignore')

# === TFT COMPONENTS ===
class GatedLinearUnit(tf.keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.linear = tf.keras.layers.Dense(units)
        self.gate = tf.keras.layers.Dense(units, activation='sigmoid')
        
    def call(self, inputs):
        return self.linear(inputs) * self.gate(inputs)

class VariableSelectionNetwork(tf.keras.layers.Layer):
    def __init__(self, num_features, hidden_units, dropout_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.num_features = num_features
        self.hidden_units = hidden_units

        # Layer components
        self.flattened_grn = GatedResidualNetwork(hidden_units, dropout_rate)
        self.single_variable_grns = [
            GatedResidualNetwork(hidden_units, dropout_rate)
            for _ in range(num_features)
        ]
        self.selection_dense = tf.keras.layers.Dense(num_features)
        self.softmax = tf.keras.layers.Softmax(axis=-1)

    def call(self, inputs):
        # inputs shape: (batch, time, features)
        batch_size = tf.shape(inputs)[0]
        time_steps = tf.shape(inputs)[1]

        # Flatten to shape: (batch*time, features)
        flattened = tf.reshape(inputs, (batch_size * time_steps, self.num_features))

        # Get feature selection weights
        selection_weights = self.flattened_grn(flattened)
        selection_weights = self.selection_dense(selection_weights)
        selection_weights = self.softmax(selection_weights)  # shape: (batch*time, num_features)

        # Process each variable individually
        processed_vars = []
        for i in range(self.num_features):
            var_input = flattened[:, i:i+1]  # shape: (batch*time, 1)
            processed = self.single_variable_grns[i](var_input)  # shape: (batch*time, hidden_units)
            processed_vars.append(processed)

        # Concatenate and split per feature
        processed_vars = tf.concat(processed_vars, axis=-1)  # shape: (batch*time, hidden_units * num_features)
        split_vars = tf.split(processed_vars, self.num_features, axis=-1)

        # Weight each variable output by its selection weight
        weighted_vars = [
            v * tf.expand_dims(selection_weights[:, i], axis=-1)
            for i, v in enumerate(split_vars)
        ]
        selected = tf.add_n(weighted_vars)  # shape: (batch*time, hidden_units)

        # Reshape back to (batch, time, hidden_units)
        selected = tf.reshape(selected, (batch_size, time_steps, self.hidden_units))

        return selected, selection_weights

class GatedResidualNetwork(tf.keras.layers.Layer):
    def __init__(self, hidden_units, dropout_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.hidden_units = hidden_units
        self.dropout_rate = dropout_rate
        
        self.dense1 = tf.keras.layers.Dense(hidden_units, activation='relu')
        self.dense2 = tf.keras.layers.Dense(hidden_units)
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        self.glu = GatedLinearUnit(hidden_units)
        self.layer_norm = tf.keras.layers.LayerNormalization()
        
    def call(self, inputs, training=None):
        # Skip connection
        skip = inputs
        
        # First dense layer
        x = self.dense1(inputs)
        x = self.dropout(x, training=training)
        
        # Second dense layer
        x = self.dense2(x)
        x = self.dropout(x, training=training)
        
        # GLU activation
        x = self.glu(x)
        
        # Residual connection (if dimensions match)
        if skip.shape[-1] == x.shape[-1]:
            x = x + skip
            
        # Layer normalization
        x = self.layer_norm(x)
        
        return x

class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, num_heads, key_dim, dropout_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.attention = tf.keras.layers.MultiHeadAttention(
            num_heads=num_heads,
            key_dim=key_dim,
            dropout=dropout_rate
        )
        self.layer_norm = tf.keras.layers.LayerNormalization()
        
    def call(self, inputs, training=None):
        attention_output = self.attention(inputs, inputs, training=training)
        return self.layer_norm(inputs + attention_output)

# === ENHANCED TFT MODEL WITH FEATURE IMPORTANCE EXTRACTION ===
class TFTModel(tf.keras.Model):
    def __init__(self, sequence_length, num_features, feature_names, hidden_units=128, num_heads=4, dropout_rate=0.2):
        super().__init__()
        self.sequence_length = sequence_length
        self.num_features = num_features
        self.feature_names = feature_names
        
        # Store VSN for feature importance extraction
        self.vsn = VariableSelectionNetwork(num_features, hidden_units, dropout_rate)
        
        # LSTM layers
        self.lstm1 = tf.keras.layers.LSTM(
            hidden_units, 
            return_sequences=True, 
            dropout=dropout_rate,
            recurrent_dropout=dropout_rate/2
        )
        
        self.lstm2 = tf.keras.layers.LSTM(
            hidden_units//2, 
            return_sequences=True,
            dropout=dropout_rate,
            recurrent_dropout=dropout_rate/2
        )
        
        # Attention
        self.attention = MultiHeadAttention(num_heads, hidden_units//num_heads, dropout_rate)
        
        # Temporal processing
        self.temporal_gln = GatedLinearUnit(hidden_units)
        
        # Output layers
        self.dense1 = tf.keras.layers.Dense(32, activation='relu')
        self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
        self.dense2 = tf.keras.layers.Dense(16, activation='relu')
        self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
        self.output_layer = tf.keras.layers.Dense(1, activation='linear')
        
    def call(self, inputs, training=None):
        # Variable Selection Network
        selected_features, self.feature_weights = self.vsn(inputs)
        
        # LSTM processing
        lstm1_out = self.lstm1(selected_features, training=training)
        lstm2_out = self.lstm2(lstm1_out, training=training)
        
        # Attention
        attention_out = self.attention(lstm2_out, training=training)
        
        # Temporal processing
        temporal_features = self.temporal_gln(attention_out)
        
        # Take last time step
        last_step = temporal_features[:, -1, :]
        
        # Output layers
        x = self.dense1(last_step)
        x = self.dropout1(x, training=training)
        x = self.dense2(x)
        x = self.dropout2(x, training=training)
        output = self.output_layer(x)
        
        return output
    
    def get_feature_importance(self, X_sample):
        """Extract feature importance from Variable Selection Network"""
        # Get feature weights for a sample batch
        _ = self(X_sample, training=False)  # Forward pass to compute weights
        
        # Average weights across batch and time
        importance = tf.reduce_mean(self.feature_weights, axis=0).numpy()
        
        # Create feature importance dictionary
        feature_importance = {
            name: float(imp) for name, imp in zip(self.feature_names, importance)
        }
        
        return feature_importance

def build_tft_model(sequence_length, num_features, feature_names, hidden_units=128, num_heads=4, dropout_rate=0.2):
    """Build TFT model with feature importance capability"""
    model = TFTModel(
        sequence_length=sequence_length,
        num_features=num_features, 
        feature_names=feature_names,
        hidden_units=hidden_units,
        num_heads=num_heads,
        dropout_rate=dropout_rate
    )
    
    # Build the model with a dummy input
    dummy_input = tf.zeros((1, sequence_length, num_features))
    model(dummy_input)
    
    return model

# === VISUALIZATION FUNCTIONS ===
def plot_learning_curves(history, window_name):
    """Plot training and validation loss curves"""
    plt.figure(figsize=(12, 4))
    
    # Loss plot
    plt.subplot(1, 2, 1)
    plt.plot(history.history['loss'], label='Training Loss', alpha=0.8)
    plt.plot(history.history['val_loss'], label='Validation Loss', alpha=0.8)
    plt.title(f'{window_name} - Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # MAE plot
    plt.subplot(1, 2, 2)
    plt.plot(history.history['mae'], label='Training MAE', alpha=0.8)
    plt.plot(history.history['val_mae'], label='Validation MAE', alpha=0.8)
    plt.title(f'{window_name} - Training MAE')
    plt.xlabel('Epoch')
    plt.ylabel('MAE')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

def plot_feature_importance(importance_dict, window_name, top_n=15):
    """Plot feature importance as horizontal bar chart"""
    # Sort features by importance
    sorted_features = sorted(importance_dict.items(), key=lambda x: x[1], reverse=True)
    
    # Take top N features
    top_features = sorted_features[:top_n]
    
    # Create plot
    plt.figure(figsize=(10, 8))
    features, importances = zip(*top_features)
    
    plt.barh(range(len(features)), importances, alpha=0.8)
    plt.yticks(range(len(features)), features)
    plt.xlabel('Feature Importance (VSN Weight)')
    plt.title(f'{window_name} - Top {top_n} Feature Importance')
    plt.gca().invert_yaxis()  # Most important at top
    
    # Add value labels
    for i, v in enumerate(importances):
        plt.text(v + 0.001, i, f'{v:.3f}', va='center', fontsize=9)
    
    plt.tight_layout()
    plt.show()

def print_feature_importance_ranking(all_importance):
    """Print top features ranked by average importance across all windows"""
    # Kombiner alle feature importance dicts til DataFrame
    importance_df = pd.DataFrame(all_importance).T  # shape: (features, windows)

    # Beregn gennemsnit og sorter
    importance_df['mean_importance'] = importance_df.mean(axis=1)
    sorted_df = importance_df.sort_values('mean_importance', ascending=False)

    # Udskriv top features
    print("\n🔢 Feature Importance Ranking (avg. across windows):")
    for i, (feat, row) in enumerate(sorted_df.iterrows(), start=1):
        print(f"{i:>2}. {feat:<25} → {row['mean_importance']:.4f}")


# === SEQUENCE CREATION (SAME AS ORIGINAL) ===
def create_sequences(data, features, target, sequence_length):
    X, y, dates = [], [], []
    data_values = data[features].values
    target_values = data[target].values
    date_values = data.index.values
    
    for i in range(sequence_length, len(data)):
        X.append(data_values[i-sequence_length:i])
        y.append(target_values[i])
        dates.append(date_values[i])
    
    return np.array(X), np.array(y), np.array(dates)

# === METRICS (SAME AS ORIGINAL) ===
def evaluate_metrics(y_true, y_pred):
    r2 = r2_score(y_true, y_pred)
    mae = mean_absolute_error(y_true, y_pred)
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    corr, _ = pearsonr(y_true, y_pred)
    hit_rate = np.mean(np.sign(y_true) == np.sign(y_pred))
    return {"r2": r2, "mae": mae, "rmse": rmse, "corr": corr, "hit_rate": hit_rate}

# === SETTINGS ===
TARGET = "future_return_7d"
SEQUENCE_LENGTH = 30
EPOCHS = 150 
BATCH_SIZE = 32           
LEARNING_RATE = 0.0001
PURGE_DAYS = 5
DROPOUT_RATE = 0.1 
HIDDEN_UNITS = 128        
NUM_HEADS = 2              


top_features = [
    'open', 'high', 'low', 'close', 'volume', 'rsi', 'ema_short', 'ema_long',
    'volatility_atr', 'bb_width', 'obv', 'volume_norm', 'macd', 'macd_signal',
    'macd_hist', 'return_1d', 'return_3d', 'return_7d', 'adx', 'hma_14',
    'vwap', 'cmf', 'sentiment_news_z', 'sentiment_twitter_z'
]

# === WINDOWS ===
WINDOWS = [
    {"name": "W1", "train_end": "2021-09-30", "test_start": "2021-10-10", "test_end": "2021-12-31"},
    {"name": "W2", "train_end": "2021-11-30", "test_start": "2021-12-10", "test_end": "2022-02-28"},
]

# === METRICS (SAME AS ORIGINAL) ===
def evaluate_metrics(y_true, y_pred):
    r2 = r2_score(y_true, y_pred)
    mae = mean_absolute_error(y_true, y_pred)
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    corr, _ = pearsonr(y_true, y_pred)
    hit_rate = np.mean(np.sign(y_true) == np.sign(y_pred))
    return {"r2": r2, "mae": mae, "rmse": rmse, "corr": corr, "hit_rate": hit_rate}

# === FEATURE IMPORTANCE RANKING FUNCTION ===
def get_feature_importance_ranking_df(all_importance):
    """
    Return a DataFrame with average importance across all windows.
    Rows = features, columns = windows + mean_importance.
    """
    df = pd.DataFrame(all_importance).T  # shape: (features, windows)
    df['mean_importance'] = df.mean(axis=1)
    return df.sort_values('mean_importance', ascending=False)




# === MAIN USAGE EXAMPLE ===
if __name__ == "__main__":

    # Load your data
    df = pd.read_csv("/kaggle/input/final-features/BTCUSDT_features_final_sentiment.csv", parse_dates=["timestamp"])
    df.set_index("timestamp", inplace=True)

    df.dropna(subset=top_features + [TARGET], inplace=True)
    
    results = []
    all_feature_importance = {}
    window_names = []
    
    for w in WINDOWS:
        print(f"\n🚀 Enhanced TFT Training - {w['name']}")
        
        # Data splitting logic
        train_end = pd.to_datetime(w["train_end"])
        test_start = pd.to_datetime(w["test_start"])
        test_end = pd.to_datetime(w["test_end"])
        purge_start = test_start - timedelta(days=PURGE_DAYS)
        
        train = df[df.index <= purge_start]
        test = df[(df.index >= test_start) & (df.index <= test_end)]
        
        # Scaling
        scaler = StandardScaler()
        train_scaled = train.copy()
        test_scaled = test.copy()
        train_scaled[top_features] = scaler.fit_transform(train[top_features])
        test_scaled[top_features] = scaler.transform(test[top_features])
        
        # Create sequences
        X_train, y_train, _ = create_sequences(train_scaled, top_features, TARGET, SEQUENCE_LENGTH)
        X_test, y_test, test_dates = create_sequences(test_scaled, top_features, TARGET, SEQUENCE_LENGTH)
        
        if len(X_train) < 50 or len(X_test) < 10:
            print("⚠️ Not enough data, skipping...")
            continue
        
        # Scale target
        target_scaler = StandardScaler()
        y_train_scaled = target_scaler.fit_transform(y_train.reshape(-1, 1)).flatten()
        y_test_scaled = target_scaler.transform(y_test.reshape(-1, 1)).flatten()
        
        # Build TFT model
        tf.keras.backend.clear_session()
        model = build_tft_model(
            sequence_length=SEQUENCE_LENGTH,
            num_features=len(top_features),
            feature_names=top_features,
            hidden_units=HIDDEN_UNITS,
            num_heads=NUM_HEADS,
            dropout_rate=DROPOUT_RATE
        )

        
        # Compile
        model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
            loss='mse',
            metrics=['mae']
        )
        
        # Callbacks
        callbacks = [
            tf.keras.callbacks.EarlyStopping(patience=15, restore_best_weights=True),
            tf.keras.callbacks.ReduceLROnPlateau(patience=8, factor=0.5),
            tf.keras.callbacks.ModelCheckpoint(f"tft_model_{w['name']}.h5", save_best_only=True)
        ]
        
        # Training
        print("Training model...")
        history = model.fit(
            X_train, y_train_scaled,
            validation_split=0.1,
            epochs=EPOCHS,
            batch_size=BATCH_SIZE,
            verbose=1,
            callbacks=callbacks,
            shuffle=True
        )
        
        # Plot learning curves
        plot_learning_curves(history, w['name'])
        
        # Predictions
        preds_scaled = model.predict(X_test).flatten()
        preds = target_scaler.inverse_transform(preds_scaled.reshape(-1, 1)).flatten()
        
        # Evaluate
        metrics = evaluate_metrics(y_test, preds)
        metrics["ic"], _ = spearmanr(y_test, preds)
        
        # Extract feature importance
        print("Extracting feature importance...")
        sample_batch = X_test[:min(100, len(X_test))]  # Use sample for efficiency
        feature_importance = model.get_feature_importance(sample_batch)
        
        # Store results
        all_feature_importance[w['name']] = feature_importance
        window_names.append(w['name'])
        
        # Plot individual feature importance
        plot_feature_importance(feature_importance, w['name'])
        
        print(f"📊 Enhanced TFT Results - {w['name']}:")
        print(f"   R²: {metrics['r2']:.4f}")
        print(f"   Hit Rate: {metrics['hit_rate']:.4f}")
        print(f"   Correlation: {metrics['corr']:.4f}")
        print(f"   IC (Spearman): {metrics['ic']:.4f}")
        
        results.append({"window": w["name"], "model": "Enhanced TFT", **metrics})
    
   # === Feature Importance Summary Across Windows ===
    if len(all_feature_importance) > 0:
        print("\n Overall Feature Importance (Average Across All Windows):")
        print_feature_importance_ranking(all_feature_importance)
    
       # === Overall Feature Importance Summary Across Windows ===
        ranking_df = get_feature_importance_ranking_df(all_feature_importance)
        
        print("\n Overall Feature Importance (Average Across All Windows):")
        print(ranking_df.head(20))  # viser top 10 samlet
        
        print("\n Top 5 Features Across All Windows:")
        for feat, row in ranking_df.head(10).iterrows():
            print(f"   {feat}: {row['mean_importance']:.4f}")



    
    # Results summary
    results_df = pd.DataFrame(results)
    print("\n Enhanced TFT Summary:")
    print(results_df.round(4))
    
    # Feature importance summary
    print("\n🔍 Feature Importance Summary:")
    for window, importance in all_feature_importance.items():
        print(f"\n{window} - Top 5 Features:")
        sorted_features = sorted(importance.items(), key=lambda x: x[1], reverse=True)
        for feat, imp in sorted_features[:5]:
            print(f"   {feat}: {imp:.4f}")