In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
import pandas as pd
from matplotlib.colors import LinearSegmentedColormap

def enhanced_attention_matrix(attention_weights, seq_length, title="Enhanced Attention Matrix", 
                              layer_idx=0, head_idx=0, sample_idx=0,
                              feature_names=None, timestamp_data=None, 
                              predicted_class=None, true_class=None,
                              cmap='viridis'):
    """
    Creates an enhanced visualization of the attention matrix with better labeling and context.
    
    Args:
        attention_weights: List of attention weight tensors from the transformer model
        seq_length: Length of the input sequence
        title: Base title for the plot
        layer_idx: Index of the transformer layer to visualize
        head_idx: Index of the attention head to visualize
        sample_idx: Index of the sample in the batch to visualize
        feature_names: List of names for the input features (for context)
        timestamp_data: Time data for the sequence (if available)
        predicted_class: Model's prediction for this sample
        true_class: Ground truth label for this sample
        cmap: Colormap for the heatmap
    
    Returns:
        Path to the saved visualization
    """
    # Create a figure with two subplots - main heatmap and a zoomed-in view
    fig = plt.figure(figsize=(16, 12))
    
    # Define a custom colormap that better highlights the attention patterns
    custom_cmap = sns.color_palette("YlGnBu", as_cmap=True)
    
    # Create a grid for the subplots
    gs = fig.add_gridspec(nrows=2, ncols=2, height_ratios=[3, 1], width_ratios=[3, 1])
    
    # Main heatmap - upper left
    ax_main = fig.add_subplot(gs[0, 0])
    
    # Extract attention matrix - handle different return formats
    if isinstance(attention_weights[layer_idx], torch.Tensor):
        attn_shape = attention_weights[layer_idx].shape
        
        if len(attn_shape) == 4:  # [batch_size, nhead, seq_len, seq_len]
            attn_matrix = attention_weights[layer_idx][sample_idx, head_idx].cpu().numpy()
        elif len(attn_shape) == 3:  # [batch_size, seq_len, seq_len]
            attn_matrix = attention_weights[layer_idx][sample_idx].cpu().numpy()
        else:
            attn_matrix = np.ones((seq_length, seq_length)) / seq_length
    else:
        try:
            attn_matrix = attention_weights[layer_idx][0][sample_idx, head_idx].cpu().numpy()
        except (TypeError, IndexError):
            attn_matrix = np.ones((seq_length, seq_length)) / seq_length
    
    # Create tick labels based on position or timestamp if available
    if timestamp_data is not None:
        x_labels = [f"t{i}: {timestamp_data[i]:.2f}s" for i in range(seq_length)]
    else:
        x_labels = [f"t{i}" for i in range(seq_length)]
    
    # Plot the main heatmap
    sns.heatmap(attn_matrix, annot=False, cmap=custom_cmap, ax=ax_main, 
                cbar=True, cbar_kws={'label': 'Attention Weight'})
    
    # Add labels and title with model prediction info
    prediction_info = ""
    if predicted_class and true_class:
        match_status = "✓" if predicted_class == true_class else "✗"
        prediction_info = f"\nPredicted: {predicted_class} | Actual: {true_class} {match_status}"
    
    ax_main.set_title(f"{title} - Layer {layer_idx+1}, Head {head_idx+1}{prediction_info}", 
                     fontsize=14, fontweight='bold')
    ax_main.set_xlabel("Target Position in Sequence", fontsize=12)
    ax_main.set_ylabel("Source Position in Sequence", fontsize=12)
    
    # Improve tick labels - use nice formatting and rotation
    ax_main.set_xticks(np.arange(len(x_labels))+0.5)
    ax_main.set_yticks(np.arange(len(x_labels))+0.5)
    ax_main.set_xticklabels(x_labels, rotation=45, ha='right', fontsize=10)
    ax_main.set_yticklabels(x_labels, rotation=0, fontsize=10)
    
    # Add a grid to make it easier to follow
    ax_main.grid(False)
    
    # Highlight the diagonal if appropriate
    for i in range(seq_length):
        ax_main.add_patch(plt.Rectangle((i, i), 1, 1, fill=False, edgecolor='black', lw=1))
    
    # Plot row-wise attention sums - right panel
    ax_row_sum = fig.add_subplot(gs[0, 1])
    row_sums = attn_matrix.sum(axis=1)
    
    # Make a horizontal barplot
    row_bars = ax_row_sum.barh(np.arange(len(row_sums)), row_sums, 
                               color=plt.cm.get_cmap(cmap)(row_sums/row_sums.max()))
    ax_row_sum.set_yticks(np.arange(len(x_labels)))
    ax_row_sum.set_yticklabels([])  # No need to duplicate the labels
    ax_row_sum.set_xlim(0, max(row_sums)*1.1)
    ax_row_sum.set_title("Row Attention Sum", fontsize=12)
    ax_row_sum.set_xlabel("Sum of Weights", fontsize=10)
    
    # Plot column-wise attention sums - bottom panel
    ax_col_sum = fig.add_subplot(gs[1, 0])
    col_sums = attn_matrix.sum(axis=0)
    
    # Make a vertical barplot
    col_bars = ax_col_sum.bar(np.arange(len(col_sums)), col_sums,
                              color=plt.cm.get_cmap(cmap)(col_sums/col_sums.max()))
    ax_col_sum.set_xticks(np.arange(len(x_labels)))
    ax_col_sum.set_xticklabels([])  # No need to duplicate the labels
    ax_col_sum.set_ylim(0, max(col_sums)*1.1)
    ax_col_sum.set_title("Column Attention Sum", fontsize=12)
    ax_col_sum.set_ylabel("Sum of Weights", fontsize=10)
    
    # Add model info and matrix statistics in bottom right
    ax_info = fig.add_subplot(gs[1, 1])
    ax_info.axis('off')  # No axis for text box
    
    # Calculate some stats about the attention matrix
    stats_text = (
        f"Matrix Statistics:\n"
        f"  Min: {attn_matrix.min():.3f}\n"
        f"  Max: {attn_matrix.max():.3f}\n"
        f"  Mean: {attn_matrix.mean():.3f}\n"
        f"  Std: {attn_matrix.std():.3f}\n\n"
        f"Key Position Focus:\n"
        f"  Strongest source: t{np.argmax(row_sums)}\n"
        f"  Strongest target: t{np.argmax(col_sums)}"
    )
    
    # Add the stats text
    ax_info.text(0, 0.95, stats_text, fontsize=10, 
                verticalalignment='top', family='monospace')
    
    # Add a caption explaining what attention means
    caption = (
        "The attention matrix shows how each position in the sequence\n"
        "attends to other positions. Brighter colors indicate stronger\n"
        "attention weights. The row sums show which source positions\n"
        "are most influential, while column sums show which target\n"
        "positions receive the most attention."
    )
    ax_info.text(0, 0.3, caption, fontsize=9, style='italic',
                verticalalignment='top')
    
    plt.tight_layout()
    plt.subplots_adjust(wspace=0.15, hspace=0.2)
    
    # Save the figure
    save_path = f'enhanced_attn_layer{layer_idx+1}_head{head_idx+1}.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    return save_path

def visualize_all_attention_heads_enhanced(attention_weights, seq_length, num_layers, num_heads, 
                                          sample_idx=0, predicted_class=None, true_class=None):
    """
    Creates an enhanced visualization of attention matrices for all heads and layers.
    
    Args:
        attention_weights: List of attention weight tensors
        seq_length: Length of input sequence
        num_layers: Number of transformer layers
        num_heads: Number of attention heads per layer
        sample_idx: Index of the sample in the batch
        predicted_class: Model's prediction for this sample
        true_class: Ground truth label for this sample
        
    Returns:
        Path to the saved visualization
    """
    # Use a larger figure size to accommodate all heads and layers
    fig_width = max(12, num_heads * 2.5)
    fig_height = max(8, num_layers * 2.5)
    
    fig, axes = plt.subplots(num_layers, num_heads, 
                            figsize=(fig_width, fig_height),
                            squeeze=False)
    
    # Use a consistent colormap
    cmap = "YlGnBu"  # This colormap works well for attention visualization
    
    # Add a super title with prediction information
    if predicted_class and true_class:
        match_status = "Correct ✓" if predicted_class == true_class else "Incorrect ✗"
        plt.suptitle(f"Attention Patterns Across All Layers and Heads\n" + 
                    f"Prediction: {predicted_class} | Actual: {true_class} ({match_status})",
                    fontsize=16, fontweight='bold', y=0.98)
    else:
        plt.suptitle("Attention Patterns Across All Layers and Heads", 
                    fontsize=16, fontweight='bold', y=0.98)

    # Create attention matrix for each layer and head
    max_value = 0  # Track the maximum attention value for consistent color scale
    matrices = []
    
    # First pass to collect all matrices and find global max for consistent coloring
    for layer in range(num_layers):
        layer_matrices = []
        for head in range(num_heads):
            if isinstance(attention_weights[layer], torch.Tensor):
                attn_shape = attention_weights[layer].shape
                
                if len(attn_shape) == 4:  # [batch_size, nhead, seq_len, seq_len]
                    attn_matrix = attention_weights[layer][sample_idx, head].cpu().numpy()
                elif len(attn_shape) == 3:  # [batch_size, seq_len, seq_len]
                    attn_matrix = attention_weights[layer][sample_idx].cpu().numpy()
                else:
                    attn_matrix = np.ones((seq_length, seq_length)) / seq_length
            else:
                try:
                    attn_matrix = attention_weights[layer][0][sample_idx, head].cpu().numpy()
                except (TypeError, IndexError):
                    attn_matrix = np.ones((seq_length, seq_length)) / seq_length
            
            layer_matrices.append(attn_matrix)
            max_value = max(max_value, attn_matrix.max())
        
        matrices.append(layer_matrices)
    
    # Second pass to plot with consistent coloring
    for layer in range(num_layers):
        for head in range(num_heads):
            ax = axes[layer, head]
            attn_matrix = matrices[layer][head]
            
            # Create heatmap with consistent vmin/vmax for color scaling
            sns.heatmap(attn_matrix, annot=False, cmap=cmap, ax=ax, 
                       cbar=False, vmin=0, vmax=max_value)
            
            # Configure the subplot
            ax.set_title(f"L{layer+1}H{head+1}", fontsize=10)
            
            # Only show ticks for edge subplots
            if layer == num_layers - 1:
                ax.set_xticks([0, seq_length-1])
                ax.set_xticklabels(['t0', f't{seq_length-1}'], fontsize=8)
            else:
                ax.set_xticks([])
            
            if head == 0:
                ax.set_yticks([0, seq_length-1])
                ax.set_yticklabels(['t0', f't{seq_length-1}'], fontsize=8)
            else:
                ax.set_yticks([])
            
            # Add row/column average magnitude indicators
            # These small bars help identify which positions are attended to most
            row_sum = attn_matrix.sum(axis=1)
            col_sum = attn_matrix.sum(axis=0)
            
            # Highlight the diagonal for reference
            for i in range(seq_length):
                ax.add_patch(plt.Rectangle((i, i), 1, 1, fill=False, edgecolor='black', lw=0.5))
            
            # Mark the most attended position with a star
            max_pos = np.unravel_index(np.argmax(attn_matrix), attn_matrix.shape)
            ax.add_patch(plt.Circle((max_pos[1] + 0.5, max_pos[0] + 0.5), 0.4, 
                                   facecolor='none', edgecolor='red', linewidth=1))
    
    # Add a colorbar for the entire figure
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(0, max_value))
    sm.set_array([])
    cbar = fig.colorbar(sm, cax=cbar_ax)
    cbar.set_label('Attention Weight')
    
    # Add a text explanation of attention patterns
    fig.text(0.5, 0.02, 
            "Each cell shows how much attention is given from a source token (y-axis) to a target token (x-axis).\n" +
            "Red circles mark the strongest attention connection in each head.",
            ha='center', fontsize=10, style='italic')
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.92, right=0.9, wspace=0.1, hspace=0.2)
    
    # Save the figure
    save_path = 'enhanced_all_attention_heads.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    return save_path

def create_attention_pattern_summary(attention_weights, seq_length, num_layers, num_heads, 
                                    sample_idx=0, predicted_class=None, true_class=None):
    """
    Creates a summary visualization of attention patterns focused on how they evolve
    through layers and what positions they emphasize.
    
    Args:
        attention_weights: List of attention weight tensors
        seq_length: Length of input sequence
        num_layers: Number of transformer layers
        num_heads: Number of attention heads per layer
        sample_idx: Index of the sample in the batch
        predicted_class: Model's prediction for this sample
        true_class: Ground truth label for this sample
        
    Returns:
        Path to the saved visualization
    """
    fig = plt.figure(figsize=(14, 10))
    gs = fig.add_gridspec(nrows=3, ncols=2, height_ratios=[1, 2, 1])
    
    # Create the first panel - mean attention across layers
    ax_layers = fig.add_subplot(gs[0, :])
    
    # Extract and aggregate attention data across layers
    layer_avg_attns = []
    position_focus = np.zeros((num_layers, seq_length))
    
    for layer in range(num_layers):
        layer_attn = None
        
        if isinstance(attention_weights[layer], torch.Tensor):
            attn_shape = attention_weights[layer].shape
            
            if len(attn_shape) == 4:  # [batch_size, nhead, seq_len, seq_len]
                # Average across all heads
                layer_attn = attention_weights[layer][sample_idx].mean(dim=0).cpu().numpy()
            elif len(attn_shape) == 3:  # [batch_size, seq_len, seq_len]
                layer_attn = attention_weights[layer][sample_idx].cpu().numpy()
        else:
            try:
                # Average across all heads
                head_attns = [attention_weights[layer][0][sample_idx, h].cpu().numpy() 
                            for h in range(num_heads)]
                layer_attn = np.mean(head_attns, axis=0)
            except (TypeError, IndexError):
                layer_attn = np.ones((seq_length, seq_length)) / seq_length
        
        layer_avg_attns.append(layer_attn)
        
        # Track which positions each layer focuses on (column sums)
        position_focus[layer] = layer_attn.sum(axis=0)
    
    # Plot how attention evolves through layers
    layer_labels = [f"Layer {i+1}" for i in range(num_layers)]
    sns.heatmap(position_focus, cmap="YlGnBu", 
               xticklabels=[f"t{i}" for i in range(seq_length)],
               yticklabels=layer_labels, ax=ax_layers)
    
    ax_layers.set_title("Position Focus Across Layers", fontsize=14)
    ax_layers.set_xlabel("Sequence Position", fontsize=12)
    ax_layers.set_ylabel("Layer", fontsize=12)
    
    # Second panel - Layer-wise attention maps
    ax_maps = fig.add_subplot(gs[1, :])
    
    # Create a grid of small attention maps
    inner_gs = gs[1, :].subgridspec(1, num_layers, wspace=0.1)
    axes_maps = [fig.add_subplot(inner_gs[0, i]) for i in range(num_layers)]
    
    for i, (ax, layer_attn) in enumerate(zip(axes_maps, layer_avg_attns)):
        sns.heatmap(layer_attn, cmap="YlGnBu", ax=ax, cbar=False)
        ax.set_title(f"Layer {i+1}", fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
    
    # Third panel - Head-wise attention pattern
    ax_heads = fig.add_subplot(gs[2, 0])
    ax_prediction = fig.add_subplot(gs[2, 1])
    
    # Create a summary of head patterns from the final layer
    head_position_focus = np.zeros((num_heads, seq_length))
    
    layer_idx = num_layers - 1  # Use the final layer
    for head in range(num_heads):
        if isinstance(attention_weights[layer_idx], torch.Tensor):
            attn_shape = attention_weights[layer_idx].shape
            
            if len(attn_shape) == 4:  # [batch_size, nhead, seq_len, seq_len]
                head_attn = attention_weights[layer_idx][sample_idx, head].cpu().numpy()
            elif len(attn_shape) == 3:  # [batch_size, seq_len, seq_len]
                head_attn = attention_weights[layer_idx][sample_idx].cpu().numpy()
            else:
                head_attn = np.ones((seq_length, seq_length)) / seq_length
        else:
            try:
                head_attn = attention_weights[layer_idx][0][sample_idx, head].cpu().numpy()
            except (TypeError, IndexError):
                head_attn = np.ones((seq_length, seq_length)) / seq_length
                
        # Track which positions each head focuses on
        head_position_focus[head] = head_attn.sum(axis=0)
    
    # Plot head-wise position focus
    head_labels = [f"Head {i+1}" for i in range(num_heads)]
    sns.heatmap(head_position_focus, cmap="YlGnBu", 
               xticklabels=[f"t{i}" for i in range(seq_length)],
               yticklabels=head_labels, ax=ax_heads)
    
    ax_heads.set_title(f"Position Focus Across Heads (Layer {num_layers})", fontsize=12)
    ax_heads.set_xlabel("Sequence Position", fontsize=10)
    ax_heads.set_ylabel("Attention Head", fontsize=10)
    
    # Final panel - Add prediction information and interpretation
    ax_prediction.axis('off')
    
    # Create a text summary of the attention patterns
    # Find which positions were most attended to overall
    final_layer_focus = position_focus[-1]
    most_attended_pos = np.argmax(final_layer_focus)
    least_attended_pos = np.argmin(final_layer_focus)
    
    # Create the summary text
    if predicted_class and true_class:
        prediction_result = "correct" if predicted_class == true_class else "incorrect"
        prediction_text = (
            f"PREDICTION SUMMARY\n\n"
            f"Model prediction: {predicted_class}\n"
            f"Actual class: {true_class}\n"
            f"Result: {prediction_result.upper()}\n\n"
            f"ATTENTION ANALYSIS\n\n"
            f"• Most attended position: t{most_attended_pos}\n"
            f"• Least attended position: t{least_attended_pos}\n"
            f"• {num_heads} attention heads in {num_layers} layers\n\n"
            f"The model's attention is primarily focused on\n"
            f"the {'beginning' if most_attended_pos < seq_length/3 else 'middle' if most_attended_pos < 2*seq_length/3 else 'end'} "
            f"of the sequence, suggesting this region\n"
            f"contains the most discriminative features for\n"
            f"classifying this mudra gesture."
        )
    else:
        prediction_text = (
            f"ATTENTION ANALYSIS\n\n"
            f"• Most attended position: t{most_attended_pos}\n"
            f"• Least attended position: t{least_attended_pos}\n"
            f"• {num_heads} attention heads in {num_layers} layers\n\n"
            f"The model's attention is primarily focused on\n"
            f"the {'beginning' if most_attended_pos < seq_length/3 else 'middle' if most_attended_pos < 2*seq_length/3 else 'end'} "
            f"of the sequence."
        )
    
    ax_prediction.text(0.05, 0.95, prediction_text, fontsize=11,
                     verticalalignment='top', family='monospace')
    
    # Add a title for the entire figure
    if predicted_class and true_class:
        match_status = "Correct ✓" if predicted_class == true_class else "Incorrect ✗"
        fig.suptitle(f"Comprehensive Attention Analysis for {predicted_class} ({match_status})", 
                   fontsize=16, fontweight='bold')
    else:
        fig.suptitle("Comprehensive Attention Analysis", fontsize=16, fontweight='bold')
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.92, hspace=0.3)
    
    # Save the figure
    save_path = 'attention_pattern_summary.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    return save_path

'\n# After model prediction\nfor features, labels in val_loader:\n    batch_size = features.size(0)\n    predicted_classes, probabilities, attention_weights = predict_with_attention(\n        model, features, label_encoder, device\n    )\n    \n    # Get sample information\n    sample_idx = 0  # First sample in batch\n    true_class = label_encoder.inverse_transform([labels[sample_idx].item()])[0]\n    predicted_class = predicted_classes[sample_idx]\n    \n    # Visualize with enhanced functions\n    enhanced_attn_file = enhanced_attention_matrix(\n        attention_weights, \n        seq_length,\n        title="Mudra Attention Pattern",\n        layer_idx=num_encoder_layers-1,  # Last layer\n        head_idx=0,  # First head\n        sample_idx=sample_idx,\n        predicted_class=predicted_class,\n        true_class=true_class\n    )\n    \n    # Create all-heads visualization\n    all_heads_file = visualize_all_attention_heads_enhanced(\n        attention_weights,\n        seq_lengt

In [35]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import math
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
import joblib

# Definição do dispositivo para treinamento
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Usando dispositivo: {device}")

# 1. Preparo dos dados

class MudraDataset(Dataset):
    def __init__(self, features, labels=None, seq_length=10, stride=1):
        """
        Dataset para dados de tracking de mãos.
        
        Args:
            features: Numpy array com as features (posição, rotação, curvatura)
            labels: Numpy array com os rótulos (tipos de mudra) se disponível
            seq_length: Tamanho da sequência para cada amostra
            stride: Passo para criar sequências
        """
        self.features = features
        self.labels = labels
        self.seq_length = seq_length
        self.stride = stride
        self.indices = self._create_indices()
        
    def _create_indices(self):
        """Cria índices para sequências válidas com o comprimento desejado"""
        indices = []
        for i in range(0, len(self.features) - self.seq_length + 1, self.stride):
            indices.append(i)
        return indices
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        start_idx = self.indices[idx]
        end_idx = start_idx + self.seq_length
        
        # Obter a sequência de features
        seq_features = self.features[start_idx:end_idx]
        
        # Se temos rótulos, retornar o rótulo correspondente (geralmente o último da sequência)
        if self.labels is not None:
            seq_label = self.labels[end_idx - 1]  # Último rótulo da sequência
            return torch.FloatTensor(seq_features), torch.LongTensor([seq_label])
        
        return torch.FloatTensor(seq_features)

def load_and_preprocess_data(file_path, seq_length=30, stride=5, test_size=0.2, random_state=42):
    """
    Carrega e pré-processa os dados de tracking de mãos.
    
    Args:
        file_path: Caminho para o arquivo CSV
        seq_length: Tamanho da sequência para cada amostra
        stride: Passo para criar sequências
        test_size: Proporção do conjunto de teste
        random_state: Semente para reprodutibilidade
        
    Returns:
        train_loader, val_loader, input_dim, num_classes, label_encoder
    """
    # Carregar os dados
    df = pd.read_csv(file_path)
    df = df.fillna(method='ffill')  # Forward fill para lidar com NaN
    # Separar features e rótulos
    features = df.drop(['Time', 'Mudra'], axis=1).values  # Remover colunas de tempo e rótulo
    labels = df['Mudra'].values
    
    # Normalizar as features
    scaler = StandardScaler()
    features = scaler.fit_transform(features)
    
    # Codificar os rótulos
    label_encoder = LabelEncoder()
    encoded_labels = label_encoder.fit_transform(labels)

    # Save the label encoder for later use
    joblib.dump(label_encoder, 'mudra_label_encoder.joblib')
    print(f"Label encoder saved with {len(label_encoder.classes_)} classes: {label_encoder.classes_}")
    
    # Dividir em treino e validação
    X_train, X_val, y_train, y_val = train_test_split(
        features, encoded_labels, test_size=test_size, random_state=random_state
    )
    
    # Criar datasets
    train_dataset = MudraDataset(X_train, y_train, seq_length=seq_length, stride=stride)
    val_dataset = MudraDataset(X_val, y_val, seq_length=seq_length, stride=stride)
    
    # Criar dataloaders
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    
    # Dimensão de entrada (número de features)
    input_dim = features.shape[1]
    
    # Número de classes
    num_classes = len(label_encoder.classes_)
    
    return train_loader, val_loader, input_dim, num_classes, label_encoder

# 2. Arquitetura do Transformer modificado para capturar matrizes de atenção

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=5000):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        return x + self.pe[:, :x.size(1), :]

class CustomTransformerEncoderLayer(nn.TransformerEncoderLayer):
    """Custom TransformerEncoderLayer that returns attention weights"""
    
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        # self attention
        src2, attn_weights = self._sa_block(src, src_mask, src_key_padding_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        
        # feedforward network
        src2 = self._ff_block(src)
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        
        return src, attn_weights
    
    def _sa_block(self, x, attn_mask, key_padding_mask):
        x2, attn_weights = self.self_attn(
            x, x, x,
            attn_mask=attn_mask,
            key_padding_mask=key_padding_mask,
            need_weights=True
        )
        return x2, attn_weights

class CustomMultiheadAttention(nn.MultiheadAttention):
    """Modified MultiheadAttention that returns attention weights"""
    
    def forward(self, query, key, value, key_padding_mask=None,
                need_weights=True, attn_mask=None):
        return super().forward(query, key, value, key_padding_mask=key_padding_mask,
                               need_weights=need_weights, attn_mask=attn_mask)

class CustomTransformerEncoder(nn.Module):
    """Custom TransformerEncoder that returns attention weights"""
    
    def __init__(self, encoder_layer, num_layers):
        super(CustomTransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([encoder_layer for _ in range(num_layers)])
        self.num_layers = num_layers
    
    def forward(self, src, mask=None, src_key_padding_mask=None):
        output = src
        attention_weights = []
        
        for layer in self.layers:
            output, attn_weights = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
            attention_weights.append(attn_weights)
        
        return output, attention_weights

class MudraTransformer(nn.Module):
    def __init__(self, input_dim, d_model, nhead, num_encoder_layers, dim_feedforward, num_classes, dropout=0.1):
        super(MudraTransformer, self).__init__()
        
        # Embedding da entrada
        self.embedding = nn.Linear(input_dim, d_model)
        
        # Codificação posicional
        self.positional_encoding = PositionalEncoding(d_model)
        
        # Camadas do Transformer customizadas para retornar matrizes de atenção
        encoder_layer = CustomTransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=False  # Importante para manter a compatibilidade
        )
        self.transformer_encoder = CustomTransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        
        # Camada de classificação (usamos apenas o último elemento da sequência)
        self.classifier = nn.Linear(d_model, num_classes)
        
        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)
        self.nhead = nhead
        self.num_encoder_layers = num_encoder_layers
        
    def forward(self, x, src_mask=None, return_attention=False):
        # x: [batch_size, seq_len, input_dim]
        
        # Aplicar embedding
        x = self.embedding(x) * math.sqrt(self.d_model)
        
        # Aplicar codificação posicional
        x = self.positional_encoding(x)
        
        # Aplicar dropout
        x = self.dropout(x)
        
        # Transformar formato para o esperado pelo Transformer: [seq_len, batch_size, d_model]
        x = x.permute(1, 0, 2)
        
        # Aplicar Transformer e capturar pesos de atenção
        output, attention_weights = self.transformer_encoder(x, src_mask)
        
        # Transformar de volta: [batch_size, seq_len, d_model]
        output = output.permute(1, 0, 2)
        
        # Usar apenas o último elemento da sequência para classificação
        output_last = output[:, -1, :]
        
        # Camada de classificação
        logits = self.classifier(output_last)
        
        if return_attention:
            return logits, attention_weights
        return logits

# 3. Treinamento e avaliação

def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for features, labels in train_loader:
        features, labels = features.to(device), labels.squeeze().to(device)
        
        # Check for NaN values
        if torch.isnan(features).any():
            print("Warning: NaN values found in features, skipping batch")
            continue
            
        # Forward pass
        optimizer.zero_grad()
        outputs = model(features)
        loss = criterion(outputs, labels)
        
        # Skip batch if loss is NaN
        if torch.isnan(loss).any():
            print("Warning: NaN loss detected, skipping batch")
            continue
        
        # Backward pass e otimização
        loss.backward()
        
        # Add gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        # Estatísticas
        running_loss += loss.item() * features.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    # Avoid division by zero
    if total == 0:
        return float('nan'), float('nan')
        
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for features, labels in val_loader:
            features, labels = features.to(device), labels.squeeze().to(device)
            
            # Forward pass
            outputs = model(features)
            loss = criterion(outputs, labels)
            
            # Estatísticas
            running_loss += loss.item() * features.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device, patience=10):
    best_val_acc = 0.0
    patience_counter = 0
    
    for epoch in range(num_epochs):
        # Treinamento
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # Validação
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
        
        # Early stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            # Salvar o melhor modelo
            torch.save(model.state_dict(), 'best_mudra_transformer.pth')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f'Early stopping após {epoch+1} épocas!')
                break
    
    # Carregar o melhor modelo
    model.load_state_dict(torch.load('best_mudra_transformer.pth'))
    return model

# 4. Função para previsão e visualização de atenção

def predict_with_attention(model, features, label_encoder, device):
    """
    Faz previsões com o modelo treinado e retorna os pesos de atenção.
    
    Args:
        model: Modelo treinado
        features: Tensor de features [batch_size, seq_len, input_dim]
        label_encoder: LabelEncoder usado para codificar as classes
        device: Dispositivo para execução
        
    Returns:
        Classe prevista, probabilidades e pesos de atenção
    """
    model.eval()
    features = features.to(device)
    
    with torch.no_grad():
        outputs, attention_weights = model(features, return_attention=True)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        _, predicted = torch.max(outputs, 1)
    
    # Converter índices previstos para classes originais
    predicted_classes = label_encoder.inverse_transform(predicted.cpu().numpy())
    
    return predicted_classes, probabilities, attention_weights

def plot_attention_matrix(attention_weights, seq_length, title="Matriz de Atenção", layer_idx=0, head_idx=0, sample_idx=0):
    """
    Plota a matriz de atenção para uma amostra específica.
    
    Args:
        attention_weights: Lista de tensores de atenção de cada camada
        seq_length: Comprimento da sequência
        title: Título do gráfico
        layer_idx: Índice da camada do transformer a ser visualizada
        head_idx: Índice da cabeça de atenção a ser visualizada
        sample_idx: Índice da amostra no batch a ser visualizada
    """
    plt.figure(figsize=(10, 8))
    
    # FIX: The attention matrix format is different than expected
    # Extract the attention matrix correctly based on how it's returned from the transformer
    # Attention weights is a list of matrices with shape [batch_size, nhead, seq_len, seq_len]
    # or [seq_len, seq_len] depending on the transformer implementation
    
    # Check the shape of attention_weights to handle appropriately
    if isinstance(attention_weights[layer_idx], torch.Tensor):
        # If it's a tensor, we need to extract the right dimensions
        attn_shape = attention_weights[layer_idx].shape
        
        if len(attn_shape) == 4:  # [batch_size, nhead, seq_len, seq_len]
            attn_matrix = attention_weights[layer_idx][sample_idx, head_idx].cpu().numpy()
        elif len(attn_shape) == 3:  # [batch_size, seq_len, seq_len] (no separate heads)
            attn_matrix = attention_weights[layer_idx][sample_idx].cpu().numpy()
        else:  # Other formats
            print(f"Unexpected attention weight shape: {attn_shape}")
            # Try to adapt to the shape we have
            if attn_shape[0] == seq_length and attn_shape[1] == seq_length:
                # It's already a matrix of the right shape
                attn_matrix = attention_weights[layer_idx].cpu().numpy()
            else:
                # Create a fallback matrix
                attn_matrix = np.ones((seq_length, seq_length)) / seq_length
                print("Using fallback attention matrix due to unexpected shape")
    else:
        # If it's not a tensor (e.g., a tuple or other structure)
        # Extract the attention part, assuming first element contains the attention weights
        # This is common in some transformer implementations
        try:
            # Try to get the first element if it's a tuple or list
            attn_matrix = attention_weights[layer_idx][0][sample_idx, head_idx].cpu().numpy()
        except (TypeError, IndexError):
            print(f"Unable to extract attention matrix from type {type(attention_weights[layer_idx])}")
            # Create a fallback matrix
            attn_matrix = np.ones((seq_length, seq_length)) / seq_length
            print("Using fallback attention matrix due to extraction error")
    
    # Criar um mapa de calor
    sns.heatmap(attn_matrix, annot=False, cmap='viridis')
    
    # Configurar o gráfico
    plt.title(f"{title} - Camada {layer_idx+1}, Cabeça {head_idx+1}")
    plt.xlabel("Posição de Sequência (Destino)")
    plt.ylabel("Posição de Sequência (Origem)")
    plt.tight_layout()
    
    # Salvar o gráfico
    plt.savefig(f'attention_matrix_layer{layer_idx+1}_head{head_idx+1}.png')
    plt.close()
    
    return f'attention_matrix_layer{layer_idx+1}_head{head_idx+1}.png'

def visualize_all_attention_heads(attention_weights, seq_length, num_layers, num_heads, sample_idx=0):
    """
    Plota matrizes de atenção para todas as cabeças em todas as camadas.
    
    Args:
        attention_weights: Lista de tensores de atenção de cada camada
        seq_length: Comprimento da sequência
        num_layers: Número de camadas do transformer
        num_heads: Número de cabeças de atenção
        sample_idx: Índice da amostra no batch a ser visualizada
    """
    fig, axes = plt.subplots(num_layers, num_heads, figsize=(num_heads*3, num_layers*3))
    
    for layer in range(num_layers):
        for head in range(num_heads):
            if num_layers > 1:
                ax = axes[layer, head] if num_heads > 1 else axes[layer]
            else:
                ax = axes[head] if num_heads > 1 else axes
            
            # FIX: Extract attention matrix correctly as in the plot_attention_matrix function
            if isinstance(attention_weights[layer], torch.Tensor):
                attn_shape = attention_weights[layer].shape
                
                if len(attn_shape) == 4:  # [batch_size, nhead, seq_len, seq_len]
                    attn_matrix = attention_weights[layer][sample_idx, head].cpu().numpy()
                elif len(attn_shape) == 3:  # [batch_size, seq_len, seq_len] (no separate heads)
                    # If no separate heads but trying to show multiple heads
                    # Just show the same matrix for all heads
                    attn_matrix = attention_weights[layer][sample_idx].cpu().numpy()
                else:
                    # Create a fallback matrix
                    attn_matrix = np.ones((seq_length, seq_length)) / seq_length
            else:
                try:
                    # Try to get the first element if it's a tuple or list
                    attn_matrix = attention_weights[layer][0][sample_idx, head].cpu().numpy()
                except (TypeError, IndexError):
                    # Create a fallback matrix
                    attn_matrix = np.ones((seq_length, seq_length)) / seq_length
            
            # Criar mapa de calor
            sns.heatmap(attn_matrix, annot=False, cmap='viridis', ax=ax, cbar=False)
            
            # Configurar o subplot
            ax.set_title(f"L{layer+1}H{head+1}")
            ax.set_xticks([])
            ax.set_yticks([])
    
    plt.tight_layout()
    plt.savefig('all_attention_heads.png')
    plt.close()
    
    return 'all_attention_heads.png'

# 5. Função principal

def main(file_path='combined_one_hand_data_with_classification.csv'):
    # Hiperparâmetros
    seq_length = 10        # Tamanho da sequência
    stride = 5             # Passo para criação de sequências
    d_model = 128          # Dimensão do modelo
    nhead = 8              # Número de cabeças de atenção
    num_encoder_layers = 4 # Número de camadas do encoder
    dim_feedforward = 512  # Dimensão da camada feed-forward
    dropout = 0.2          # Taxa de dropout
    learning_rate = 0.001  # Taxa de aprendizado
    num_epochs = 50        # Número de épocas
    patience = 10          # Paciência para early stopping
    
    # Carregar e pré-processar os dados
    train_loader, val_loader, input_dim, num_classes, label_encoder = load_and_preprocess_data(
        file_path, seq_length=seq_length, stride=stride
    )
    
    print(f"Dimensão de entrada: {input_dim}")
    print(f"Número de classes: {num_classes}")
    print(f"Classes: {label_encoder.classes_}")
    
    # Inicializar o modelo
    model = MudraTransformer(
        input_dim=input_dim,
        d_model=d_model,
        nhead=nhead,
        num_encoder_layers=num_encoder_layers,
        dim_feedforward=dim_feedforward,
        num_classes=num_classes,
        dropout=dropout
    ).to(device)
    
    # Mostrar resumo do modelo (número de parâmetros)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total de parâmetros do modelo: {total_params}")
    
    # Definir função de perda e otimizador
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Treinar o modelo
    print("Iniciando treinamento...")
    model = train_model(
        model, train_loader, val_loader, criterion, optimizer, 
        num_epochs, device, patience=patience
    )
    print("Treinamento concluído!")
    
    # Avaliar no conjunto de validação
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    print(f"Acurácia final de validação: {val_acc:.4f}")
    
    # Exemplo de como fazer previsões e visualizar matrizes de atenção
    for features, labels in val_loader:
        batch_size = features.size(0)
        predicted_classes, probabilities, attention_weights = predict_with_attention(
            model, features, label_encoder, device
        )
        
        # Get sample information
        sample_idx = 0  # First sample in batch
        true_class = label_encoder.inverse_transform([labels[sample_idx].item()])[0]
        predicted_class = predicted_classes[sample_idx]
        
        # Visualize with enhanced functions
        enhanced_attn_file = enhanced_attention_matrix(
            attention_weights, 
            seq_length,
            title="Mudra Attention Pattern",
            layer_idx=num_encoder_layers-1,  # Last layer
            head_idx=0,  # First head
            sample_idx=sample_idx,
            predicted_class=predicted_class,
            true_class=true_class
        )
        
        # Create all-heads visualization
        all_heads_file = visualize_all_attention_heads_enhanced(
            attention_weights,
            seq_length,
            num_encoder_layers,
            nhead,
            sample_idx=sample_idx,
            predicted_class=predicted_class,
            true_class=true_class
        )
        
        # Create a comprehensive summary
        summary_file = create_attention_pattern_summary(
            attention_weights,
            seq_length,
            num_encoder_layers,
            nhead,
            sample_idx=sample_idx,
            predicted_class=predicted_class,
            true_class=true_class
        )
        
        print(f"Enhanced attention visualization saved to: {enhanced_attn_file}")
        print(f"All heads visualization saved to: {all_heads_file}")
        print(f"Attention pattern summary saved to: {summary_file}")
        
        break  # Just process one batch for example

if __name__ == "__main__":
    # Substitua pelo caminho real do seu arquivo CSV
    main(file_path='combined_one_hand_data_with_classification.csv')

Usando dispositivo: cuda
Label encoder saved with 14 classes: ['Abhaya' 'Bhumisparsa' 'Dharmachakra' 'Dhyana' 'Jnana' 'Karana' 'Ksepana'
 'Namaskara' 'No Mudra/No Movement/Transition' 'Tarjani'
 'Unknown/Transition' 'Uttarabodhi' 'Varada' 'Vitarka']
Dimensão de entrada: 24
Número de classes: 14
Classes: ['Abhaya' 'Bhumisparsa' 'Dharmachakra' 'Dhyana' 'Jnana' 'Karana' 'Ksepana'
 'Namaskara' 'No Mudra/No Movement/Transition' 'Tarjani'
 'Unknown/Transition' 'Uttarabodhi' 'Varada' 'Vitarka']


  df = df.fillna(method='ffill')  # Forward fill para lidar com NaN


Total de parâmetros do modelo: 203278
Iniciando treinamento...
Epoch 1/50:
Train Loss: 0.9298, Train Acc: 0.7139
Val Loss: 0.6855, Val Acc: 0.8002
Epoch 2/50:
Train Loss: 0.6845, Train Acc: 0.7892
Val Loss: 0.5667, Val Acc: 0.8259
Epoch 3/50:
Train Loss: 0.6029, Train Acc: 0.8147
Val Loss: 0.5113, Val Acc: 0.8495
Epoch 4/50:
Train Loss: 0.5692, Train Acc: 0.8268
Val Loss: 0.4713, Val Acc: 0.8617
Epoch 5/50:
Train Loss: 0.5441, Train Acc: 0.8282
Val Loss: 0.5099, Val Acc: 0.8591
Epoch 6/50:
Train Loss: 0.5335, Train Acc: 0.8370
Val Loss: 0.4608, Val Acc: 0.8499
Epoch 7/50:
Train Loss: 0.5113, Train Acc: 0.8411
Val Loss: 0.4561, Val Acc: 0.8652
Epoch 8/50:
Train Loss: 0.4985, Train Acc: 0.8433
Val Loss: 0.4132, Val Acc: 0.8717
Epoch 9/50:
Train Loss: 0.4811, Train Acc: 0.8445
Val Loss: 0.4363, Val Acc: 0.8569
Epoch 10/50:
Train Loss: 0.4874, Train Acc: 0.8490
Val Loss: 0.4726, Val Acc: 0.8674
Epoch 11/50:
Train Loss: 0.4800, Train Acc: 0.8458
Val Loss: 0.4477, Val Acc: 0.8722
Epoch 12/50

  model.load_state_dict(torch.load('best_mudra_transformer.pth'))


Acurácia final de validação: 0.8975


  color=plt.cm.get_cmap(cmap)(row_sums/row_sums.max()))
  color=plt.cm.get_cmap(cmap)(col_sums/col_sums.max()))
  plt.tight_layout()


Enhanced attention visualization saved to: enhanced_attn_layer4_head1.png
All heads visualization saved to: enhanced_all_attention_heads.png
Attention pattern summary saved to: attention_pattern_summary.png


In [30]:
import joblib
import pandas as pd
import torch
from sklearn.preprocessing import StandardScaler


df = pd.read_csv('combined_one_hand_data_with_classification.csv')
df = df.fillna(method='ffill')  # Forward fill para lidar com NaN   


# Load the trained model and label encoder
model_path = 'best_mudra_transformer.pth'
label_encoder_path = 'mudra_label_encoder.joblib'

# Load the label encoder
label_encoder = joblib.load(label_encoder_path)

# Set parameters to match the model training
seq_length = 10
input_dim = len(df.columns) - 2  # All columns except Time and Mudra
d_model = 128
nhead = 8
num_encoder_layers = 4
dim_feedforward = 512
num_classes = len(label_encoder.classes_)

# Initialize the model with the same architecture
model = MudraTransformer(
    input_dim=input_dim,
    d_model=d_model,
    nhead=nhead,
    num_encoder_layers=num_encoder_layers,
    dim_feedforward=dim_feedforward,
    num_classes=num_classes,
    dropout=0.2
).to(device)

# Load the trained weights
model.load_state_dict(torch.load(model_path))
model.eval()

# Sample a sequence from the dataframe
sample_idx = 3000 # Choose an arbitrary index
sample_features = df.drop(['Time', 'Mudra'], axis=1).values[sample_idx:sample_idx+seq_length]
sample_label = df['Mudra'].values[sample_idx+seq_length-1]

# Normalize the features (similar to training)
scaler = StandardScaler()
# Using multiple samples to get a better fit for the scaler
scaler.fit(df.drop(['Time', 'Mudra'], axis=1).values)
normalized_features = scaler.transform(sample_features)

# Convert to tensor and add batch dimension
input_tensor = torch.FloatTensor(normalized_features).unsqueeze(0).to(device)  # [1, seq_length, input_dim]

# Make prediction
with torch.no_grad():
    outputs = model(input_tensor)
    probabilities = torch.nn.functional.softmax(outputs, dim=1)
    _, predicted = torch.max(outputs, 1)

# Get predicted class
predicted_class = label_encoder.inverse_transform(predicted.cpu().numpy())[0]

# Print results
print(f"Sample sequence from index {sample_idx} to {sample_idx+seq_length-1}")
print(f"Actual label: {sample_label}")
print(f"Predicted label: {predicted_class}")
print("\nClass probabilities:")
for i, prob in enumerate(probabilities[0].cpu().numpy()):
    class_name = label_encoder.inverse_transform([i])[0]
    print(f"{class_name}: {prob:.4f}")

Sample sequence from index 3000 to 3009
Actual label: Dharmachakra
Predicted label: Dharmachakra

Class probabilities:
Abhaya: 0.0000
Bhumisparsa: 0.0004
Dharmachakra: 0.9554
Dhyana: 0.0001
Jnana: 0.0042
Karana: 0.0001
Ksepana: 0.0002
Namaskara: 0.0001
No Mudra/No Movement/Transition: 0.0378
Tarjani: 0.0000
Unknown/Transition: 0.0004
Uttarabodhi: 0.0001
Varada: 0.0001
Vitarka: 0.0011


  df = df.fillna(method='ffill')  # Forward fill para lidar com NaN
  model.load_state_dict(torch.load(model_path))
