# Deep Knowledge Tracing using Transformer model

Dataset: Assistments 2017

# Data Layer

Import Dataset from Drive

In [None]:
import pandas as pd

from google.colab import drive
drive.mount('/content/drive')

assistments = pd.read_csv('/content/drive/MyDrive/DeepKT/assistments_2017.csv')

**Assistments 2017**

We will use mainly 2 columns from the dataframe: Skill and Correctness, the other two columns will be for aiding preprocessing.

In [None]:
assistments[['studentId', 'skill', 'correct', 'action_num']].head(15000)

Unnamed: 0,studentId,skill,correct,action_num
0,8,properties-of-geometric-figures,0,9950
1,8,properties-of-geometric-figures,1,9951
2,8,sum-of-interior-angles-more-than-3-sides,0,9952
3,8,sum-of-interior-angles-more-than-3-sides,0,9953
4,8,sum-of-interior-angles-more-than-3-sides,1,9954
...,...,...,...,...
14995,337,interpreting-numberline,1,269095
14996,337,interpreting-numberline,1,269096
14997,337,interpreting-numberline,0,269097
14998,337,inequality-solving,0,269098


# Preprocess

In [None]:
import pandas as pd
import numpy as np

from dataclasses import dataclass
from typing import Tuple, List, Dict, Any

@dataclass
class SequenceConfig:
    seq_length: int
    sliding_window_step: int = 1
    max_students: int = 100

class SequenceGenerator:
    def __init__(self, config: SequenceConfig, skill_to_id: Dict):
        self.config = config  # Configuring the parameters for preprocessing
        self.skill_to_id = {}  # Mapping skills to unique IDs
        self.num_skills = 0  # Will be set during data loading

    def load_and_process(self, file) -> Tuple[pd.DataFrame, int]:
        # Load and preprocess data from Dataset
        data = file

        self.num_skills = data['skill'].nunique()

        # Sort by student and action number
        data = data.sort_values(by=['studentId', 'action_num'])

        # Select limited number of students if specified
        selected_students = data['studentId'].unique()[:self.config.max_students]
        data = data[data['studentId'].isin(selected_students)]

        # Create skill mapping
        self.skill_to_id = self.skill_map(data)

        return data, self.num_skills

    def skill_map(self, data: pd.DataFrame) -> Dict[str, int]:
        skill_to_id = {}

        for skill in data['skill'].unique():
            skill_to_id[skill] = len(skill_to_id)

        return skill_to_id

    def encode_interaction(self, skill: int, correctness: int) -> List[int]:
        """
        Encode skill and correctness using one-hot encoding
        Returns a vector of length (num_skills * 2) where:
        - First num_skills positions represent correct responses for each skill
        - Last num_skills positions represent incorrect responses for each skill
        """
        # Create a zero vector of length num_skills * 2
        encoded = np.zeros(self.num_skills * 2, dtype=int)

        # Set the appropriate position to 1
        if correctness == 1:
            # Correct response for this skill
            encoded[skill] = 1
        else:
            # Incorrect response for this skill
            encoded[skill + self.num_skills] = 1

        return encoded.tolist()

    def generate_label(self, skill: int, correctness: int) -> List[int]:
        """Create one-hot encoded label vector"""
        # Create a zero vector of length num_skills
        label = np.zeros(self.num_skills, dtype=int)

        # Set the skill position to correctness value (0 or 1)
        label[skill] = correctness

        return label.tolist()

    def prepare_student_sequences(self, student_data: pd.DataFrame) -> Tuple[List[List[List[int]]], List[List[int]]]:
        """Prepare sequences for each student"""
        sequences = []
        labels = []

        if len(student_data) < self.config.seq_length + 1:  # +1 for the next interaction
            return sequences, labels

        for i in range(0, len(student_data) - self.config.seq_length, self.config.sliding_window_step):
            if i + self.config.seq_length >= len(student_data):
                break

            # Get window of interactions
            window = student_data.iloc[i:i + self.config.seq_length]

            # Get the next interaction after the sequence
            next_interaction = student_data.iloc[i + self.config.seq_length]
            next_skill_id = self.skill_to_id[next_interaction['skill']]
            next_correctness = next_interaction['correct']

            # Encode the sequence
            encoded_sequence = [
                self.encode_interaction(
                    self.skill_to_id[row['skill']],
                    row['correct']
                ) for _, row in window.iterrows()
            ]

            # Generate label for the next interaction
            label = self.generate_label(next_skill_id, next_correctness)

            sequences.append(encoded_sequence)
            labels.append(label)

        return sequences, labels

    def prepare_sequences(self, df: pd.DataFrame) -> Tuple[List[List[List[int]]], List[List[int]]]:
        """Prepare sequences for all students"""
        all_sequences = []
        all_labels = []

        for student_id in df['studentId'].unique():
            student_data = df[df['studentId'] == student_id]

            student_seq, student_lab = self.prepare_student_sequences(student_data)

            all_sequences.extend(student_seq)
            all_labels.extend(student_lab)

        return all_sequences, all_labels


config = SequenceConfig(seq_length=10, sliding_window_step=5, max_students=200)
generator = SequenceGenerator(config, {})
df, num_skills = generator.load_and_process(assistments)
seq, lab = generator.prepare_sequences(df)


# **Save and Load Functions**

In [None]:
import pickle
import json
import os

def save_preprocessed_data(sequences, labels, skill_to_id, config, save_dir='/content/drive/MyDrive/DeepKT/preprocessed_data'):
    """Save preprocessed data to Google Drive"""
    # Mount Google Drive if not already mounted
    if not os.path.exists('/content/drive'):
        drive.mount('/content/drive')

    os.makedirs(save_dir, exist_ok=True)

    # Save sequences and labels
    np.save(os.path.join(save_dir, 'sequences_10_2.npy'), np.array(sequences))
    np.save(os.path.join(save_dir, 'labels_10_2.npy'), np.array(labels))

    # Save skill mapping and configuration
    metadata = {
        'skill_to_id': skill_to_id,
        'config': {
            'seq_length': config.seq_length,
            'sliding_window_step': config.sliding_window_step,
            'num_students': config.max_students
        },
        'dataset_stats': {
            'num_sequences': len(sequences),
            'sequence_length': len(sequences[0]) if sequences else 0,
            'num_skills': len(skill_to_id)
        }
    }

    with open(os.path.join(save_dir, 'metadata_10_2.json'), 'w') as f:
        json.dump(metadata, f, indent=2)

    print(f"Data saved successfully to {save_dir}")
    print("Files saved:")
    print(f"- sequences_10_2.npy: {os.path.getsize(os.path.join(save_dir, 'sequences_10_2.npy'))/1024/1024:.2f} MB")
    print(f"- labels_10_2.npy: {os.path.getsize(os.path.join(save_dir, 'labels_10_2.npy'))/1024/1024:.2f} MB")
    print(f"- metadata_10_2.json: {os.path.getsize(os.path.join(save_dir, 'metadata_10_2.json'))/1024:.2f} KB")

# # Save Preprocessed Data:
save_preprocessed_data(seq, lab, gen.skill_to_id, gen.config)

# **In case already preprocessed, load initial packages and start here**

In [None]:
import pickle
import json
import os
from google.colab import drive
import numpy as np

def load_preprocessed_data(load_dir='/content/drive/MyDrive/DeepKT/preprocessed_data'):
    """Load preprocessed data from Google Drive"""
    if not os.path.exists('/content/drive'):
        drive.mount('/content/drive')

    # Load sequences and labels
    sequences = np.load(os.path.join(load_dir, 'sequences_10_2.npy'))
    labels = np.load(os.path.join(load_dir, 'labels_10_2.npy'))

    # Load metadata
    with open(os.path.join(load_dir, 'metadata_10_2.json'), 'r') as f:
        metadata = json.load(f)

    print("Data loaded successfully")
    print(f"Loaded {metadata['dataset_stats']['num_sequences']} sequences")
    print(f"Sequence length: {metadata['dataset_stats']['sequence_length']}")
    print(f"Number of skills: {metadata['dataset_stats']['num_skills']}")

    return sequences, labels, metadata

# Load in preprocessed data
sequences, labels, metadata = load_preprocessed_data()

print(sequences[:50])
print(labels[:50])

# Data Transformation

In [None]:
import tensorflow as tf
import numpy as np
from sklearn.model_selection import train_test_split

def prepare_data(sequences, labels, batch_size = 64, train_ratio = 0.7, val_ratio = 0.15):
  sequences = sequences.astype(np.int32)
  labels = labels.astype(np.float32)

  train_sequences, temp_sequences, train_labels, temp_labels = train_test_split(sequences, labels, train_size=train_ratio, random_state=42)

  val_ratio_adjusted = val_ratio / (1 - train_ratio)

  val_sequences, test_sequences, val_labels, test_labels = train_test_split(temp_sequences, temp_labels, train_size=val_ratio_adjusted, random_state=42)

  def create_dataset(sequences, labels, batch_size, training=False):
    dataset = tf.data.Dataset.from_tensor_slices((sequences, labels))

    if training:
      dataset = dataset.shuffle(len(sequences)) # Shuffle tensors

    dataset = dataset.batch(batch_size)

    if training:
      dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) # Prefetch for optimum training

    return dataset

  train_dataset = create_dataset(train_sequences, train_labels, batch_size)
  val_dataset = create_dataset(val_sequences, val_labels, batch_size)
  test_dataset = create_dataset(test_sequences, test_labels, batch_size)

  return train_dataset, val_dataset, test_dataset

train_dataset, val_dataset, test_dataset = prepare_data(sequences, labels)

def inspect_dataset(dataset, name="Dataset"):
    """Helper function to inspect the prepared datasets"""
    for sequences, labels in dataset.take(1):
        print(f"\n{name} inspection:")
        print(f"Sequences shape: {sequences.shape}")
        print(f"Labels shape: {labels.shape}")
        print(f"Sequences dtype: {sequences.dtype}")
        print(f"Labels dtype: {labels.dtype}")
        print("\nSample sequence (first in batch):")
        print("Encoded interactions:", sequences[0])
        print("Correctness labels:", labels[0])

inspect_dataset(train_dataset, "Training")
inspect_dataset(val_dataset, "Validation")
inspect_dataset(test_dataset, "Test")

# Transformer Implementation

In [None]:
import tensorflow as tf
import time
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, MultiHeadAttention, LayerNormalization, Embedding, Concatenate, Input

# ================================================
# Transformer Block with Causal Masking
# ================================================
class TransformerBlock(tf.keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super(TransformerBlock, self).__init__()
        self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = Sequential([
            Dense(ff_dim, activation="relu"),
            Dense(embed_dim)
        ])
        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)

    def call(self, inputs, training, mask):
        # Pass the mask to the attention layer.
        attn_output = self.att(inputs, inputs, inputs, attention_mask=mask)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

# ================================================
# Transformer Model for Separate Skill and Correctness Encoding
# ================================================
class TransformerModel(tf.keras.Model):
    def __init__(self, num_skills, seq_len, embed_dim, num_heads, ff_dim, num_blocks, mlp_units, dropout_rate=0.1):
        """
        Args:
          num_skills: Number of unique skills in the dataset.
          seq_len: Length of the sequences.
          embed_dim: Embedding dimension.
          num_heads: Number of attention heads.
          ff_dim: Feed-forward network hidden layer size.
          num_blocks: Number of transformer blocks.
          mlp_units: List with the number of units for each MLP Dense layer.
          dropout_rate: Dropout rate.
        """
        super(TransformerModel, self).__init__()
        self.seq_len = seq_len

        # Separate embeddings for skills and correctness
        self.skill_embedding = Embedding(num_skills, embed_dim // 2)
        self.correctness_embedding = Embedding(2, embed_dim // 2)  # 2 possible values: 0 or 1

        # Positional encoding as an Embedding layer
        self.pos_encoding = Embedding(seq_len, embed_dim)

        self.transformer_blocks = [
            TransformerBlock(embed_dim, num_heads, ff_dim, dropout_rate)
            for _ in range(num_blocks)
        ]
        self.mlp_layers = [Dense(mlp_dim, activation="relu") for mlp_dim in mlp_units]
        self.dropout = Dropout(dropout_rate)
        self.final_layer = Dense(1)  # No activation here; we apply sigmoid later.

    def causal_attention_mask(self, batch_size, seq_len):
        """
        Create a boolean causal mask of shape (batch_size, seq_len, seq_len)
        where True indicates allowed (i.e. non-masked) positions.
        """
        # Create a lower-triangular matrix of ones.
        lower_triangle = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
        # Cast to boolean: True means allowed.
        mask = tf.cast(lower_triangle, dtype=tf.bool)
        # Tile for each sample in the batch.
        mask = tf.tile(tf.expand_dims(mask, 0), [batch_size, 1, 1])
        return mask

    def call(self, inputs, training):
        """
        Args:
          inputs: Tensor of shape (batch_size, seq_len, 2) where [:,:,0] contains skill IDs
                 and [:,:,1] contains correctness values.
          training: Boolean, for dropout behavior.
        Returns:
          predictions: Tensor of shape (batch_size, seq_len) with values in [0,1].
        """
        batch_size = tf.shape(inputs)[0]

        # Split the input into skill IDs and correctness values
        skill_ids = tf.cast(inputs[:, :, 0], tf.int32)     # (batch_size, seq_len)
        correctness = tf.cast(inputs[:, :, 1], tf.int32)   # (batch_size, seq_len)

        # Create positional indices
        positions = tf.range(start=0, limit=self.seq_len, delta=1)
        positions = tf.expand_dims(positions, 0)           # (1, seq_len)
        positions = tf.tile(positions, [batch_size, 1])    # (batch_size, seq_len)

        # Get embeddings for skills and correctness
        skill_emb = self.skill_embedding(skill_ids)           # (batch_size, seq_len, embed_dim//2)
        correctness_emb = self.correctness_embedding(correctness)  # (batch_size, seq_len, embed_dim//2)

        # Concatenate the skill and correctness embeddings
        x = tf.concat([skill_emb, correctness_emb], axis=-1)  # (batch_size, seq_len, embed_dim)

        # Add positional encoding
        pos_enc = self.pos_encoding(positions)
        x = x + pos_enc

        # Create the causal mask
        mask = self.causal_attention_mask(batch_size, self.seq_len)  # (batch_size, seq_len, seq_len)

        # Pass through transformer blocks
        for block in self.transformer_blocks:
            x = block(x, training=training, mask=mask)

        # Pass through the MLP layers
        for layer in self.mlp_layers:
            x = layer(x)
            x = self.dropout(x, training=training)

        x = self.final_layer(x)  # (batch_size, seq_len, 1)
        x = tf.sigmoid(x)        # Map outputs to [0, 1]
        return tf.squeeze(x, -1) # (batch_size, seq_len)

# Model Training

In [None]:
# ================================================
# Training, Evaluation, and Testing
# ================================================
@tf.function
def train_step(model, optimizer, batch_sequences, batch_labels):
    """
    Performs one training step.
    For knowledge tracing, we predict the outcome at time t+1.

    Args:
        model: The transformer model.
        optimizer: Optimizer instance.
        batch_sequences: Tensor of shape (batch_size, seq_len, 2) with skill IDs and correctness.
        batch_labels: Tensor of shape (batch_size, seq_len) with correctness values.
    """
    with tf.GradientTape() as tape:
        predictions = model(batch_sequences, training=True)
        # Shift predictions and labels so that prediction at time t is compared with label at time t+1.
        pred = predictions[:, :-1]
        target = tf.cast(batch_labels[:, 1:], tf.float32)
        loss = tf.keras.losses.binary_crossentropy(target, pred)
        loss = tf.reduce_mean(loss)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

def evaluate_model(model, dataset):
    """
    Evaluate the model on the given dataset.
    Returns:
      auc: ROC-AUC score.
      accuracy: Binary accuracy.
    """
    all_preds = []
    all_labels = []
    for batch_sequences, batch_labels in dataset:
        predictions = model(batch_sequences, training=False)
        pred = predictions[:, :-1]
        target = batch_labels[:, 1:]
        all_preds.append(pred)
        all_labels.append(target)
    all_preds = tf.concat(all_preds, axis=0)
    all_labels = tf.concat(all_labels, axis=0)
    # Flatten the tensors.
    all_preds_np = all_preds.numpy().flatten()
    all_labels_np = all_labels.numpy().flatten()
    try:
        auc = roc_auc_score(all_labels_np, all_preds_np)
    except Exception as e:
        print("Error computing AUC:", e)
        auc = 0.0
    y_pred_bin = (all_preds_np > 0.5).astype(int)
    accuracy = accuracy_score(all_labels_np, y_pred_bin)
    return auc, accuracy

def train_model(model, train_dataset, val_dataset, test_dataset, epochs=50, patience=10, learning_rate=0.001):
    """
    Train the Transformer model while evaluating on training, validation, and test sets.
    Early stopping is applied based on validation AUC.
    """
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    best_test_auc = 0.0
    patience_counter = 0
    best_weights = None

    for epoch in range(epochs):
        start_time = time.time()
        train_loss_metric = tf.keras.metrics.Mean()
        print(f"\nEpoch {epoch+1}/{epochs}")

        # Training loop
        for batch_idx, (batch_sequences, batch_labels) in enumerate(train_dataset):
            loss = train_step(model, optimizer, batch_sequences, batch_labels)
            train_loss_metric.update_state(loss)
            if (batch_idx + 1) % 50 == 0:
                print(f"  Batch {batch_idx+1} - Loss: {loss:.4f}")

        epoch_loss = train_loss_metric.result().numpy()
        print(f"Epoch {epoch+1} - Average Training Loss: {epoch_loss:.4f}")

        # Evaluate on test set only for faster iteration
        test_auc, test_accuracy = evaluate_model(model, test_dataset)
        print(f"  Test Metrics - AUC: {test_auc:.4f} | Accuracy: {test_accuracy:.4f}")

        # Early stopping based on test AUC
        if test_auc > best_test_auc:
            best_test_auc = test_auc
            best_weights = model.get_weights()
            patience_counter = 0
            print("  New best model found!")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("  Early stopping triggered!")
                break

        epoch_time = time.time() - start_time
        print(f"Epoch {epoch+1} took {epoch_time:.2f}s")

    if best_weights is not None:
        model.set_weights(best_weights)
        print(f"\nTraining completed. Best Test AUC: {best_test_auc:.4f}")

    final_test_auc, final_test_accuracy = evaluate_model(model, test_dataset)
    print(f"\nFinal Test Metrics - AUC: {final_test_auc:.4f} | Accuracy: {final_test_accuracy:.4f}")
    return model

# ================================================
# Main Training Call
# ================================================

model = TransformerModel(
    num_skills=num_skills,  # Set this to the number of unique skills in your dataset
    seq_len=10,             # Set this to your sequence length
    embed_dim=64,
    num_heads=4,
    ff_dim=64,
    num_blocks=4,
    mlp_units=[128, 64],
    dropout_rate=0.1
)

# Train the model with your prepared datasets
trained_model = train_model(
    model=model,
    train_dataset=train_dataset,  # Your prepared TensorFlow datasets
    val_dataset=val_dataset,      # Your prepared validation dataset
    test_dataset=test_dataset,    # Your prepared test dataset
    epochs=50,
    patience=10,
    learning_rate=0.001
)

# Save Model

In [None]:
save_dir = "/content/drive/MyDrive/DeepKT/saved_models"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

model_save_path = os.path.join(save_dir, "seq60_trained_model.weights.h5")
trained_model.save_weights(model_save_path)
print(f"Model weights saved to {model_save_path}")

# Load Model

In [None]:
# ====================================================
# Load the model for evaluation & visualization
# ====================================================
# Recreate the model architecture with the same configuration.

model_save_path = "/content/drive/MyDrive/DeepKT/saved_models/seq60_trained_model.weights.h5"

loaded_model = TransformerModel(
    num_items=2 * len(metadata['skill_to_id']),
    seq_len=metadata['config']['seq_length'],
    embed_dim=64,
    num_heads=4,
    ff_dim=64,
    num_blocks=4,
    mlp_units=[128, 64],
    dropout_rate=0.1
)

# Build the model by passing a dummy input.
dummy_input = tf.zeros((1, metadata['config']['seq_length']))
_ = loaded_model(dummy_input, training=False)

# Load the saved weights.
loaded_model.load_weights(model_save_path)
print(f"Model loaded from {model_save_path}")


# Visualization

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import roc_curve, auc

def plot_roc_curve(model, dataset):
    """
    Generate and display the ROC curve for the given dataset.

    This function iterates over the dataset, collects the model predictions
    (after shifting to align with t+1) and true labels, computes the false
    positive and true positive rates, and then plots the ROC curve along with
    the computed AUC.
    """
    all_preds = []
    all_labels = []
    for batch_sequences, batch_labels in dataset:
        predictions = model(batch_sequences, training=False)
        # Align predictions with labels by removing the final timestep.
        pred = predictions[:, :-1]
        target = batch_labels[:, 1:]
        all_preds.append(pred)
        all_labels.append(target)

    # Concatenate predictions and labels from all batches.
    all_preds = tf.concat(all_preds, axis=0)
    all_labels = tf.concat(all_labels, axis=0)
    y_scores = all_preds.numpy().flatten()
    y_true = all_labels.numpy().flatten()

    # Compute ROC curve and AUC.
    fpr, tpr, _ = roc_curve(y_true, y_scores)
    roc_auc = auc(fpr, tpr)

    # Plot the ROC curve.
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, label=f"ROC curve (AUC = {roc_auc:.4f})", lw=2)
    plt.plot([0, 1], [0, 1], linestyle="--", color="gray", lw=2)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Receiver Operating Characteristic (ROC)")
    plt.legend(loc="lower right")
    plt.grid(alpha=0.3)
    plt.show()


import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf  # assuming tensorflow is used as in your original code

def visualize_sample_heatmap_gradient_balanced(model, dataset, bin_size=5):
    """
    Visualize a single sample's performance over time as a smoothed heatmap
    with a gradient of blue shades. Instead of specifying a sample index, this
    function selects the sample with ground truth values closest to a 50-50
    split of correct (1) and incorrect (0) responses.

    The function:
      1. Retrieves one batch from the dataset.
      2. Iterates through the batch to choose the sample whose ground truth is most balanced.
      3. Computes model predictions (assumed shape: (batch_size, seq_len)).
      4. Aggregates predictions and ground truth values into bins (of length bin_size).
      5. Constructs a heatmap with two rows:
         - Row 0: Average predicted probability per bin.
         - Row 1: Average ground truth per bin.

    Args:
      model: The trained model that outputs predictions with shape (batch_size, seq_len).
      dataset: A tf.data.Dataset yielding (batch_sequences, batch_labels).
      bin_size: Number of time steps to aggregate into each bin.
    """
    # Retrieve one batch from the dataset.
    for batch_sequences, batch_labels in dataset.take(1):
        break

    # Convert tensors to NumPy arrays.
    batch_sequences_np = batch_sequences.numpy()
    batch_labels_np = batch_labels.numpy()

    # Loop to choose the balanced test vector:
    balanced_index = None
    min_diff = float('inf')
    for i in range(batch_labels_np.shape[0]):
        # Compute the average correctness for sample i.
        avg_correct = np.mean(batch_labels_np[i])
        # Compute the difference from the balanced ratio of 0.5.
        diff = abs(avg_correct - 0.5)
        if diff < min_diff:
            min_diff = diff
            balanced_index = i

    # Use the balanced sample.
    sample_index = balanced_index

    # Select the sample sequence and corresponding ground truth.
    sample_sequence = batch_sequences_np[sample_index:sample_index+1]
    sample_labels = batch_labels_np[sample_index]  # shape: (seq_len,)

    # Obtain model predictions. Expected shape: (1, seq_len)
    predictions = model(sample_sequence, training=False).numpy()
    pred_probs = predictions[0]  # shape: (seq_len,)

    seq_len = len(pred_probs)
    num_bins = int(np.ceil(seq_len / bin_size))
    avg_pred = []
    avg_true = []
    for i in range(num_bins):
        start = i * bin_size
        end = min((i + 1) * bin_size, seq_len)
        avg_pred.append(np.mean(pred_probs[start:end]))
        avg_true.append(np.mean(sample_labels[start:end]))

    # Stack the aggregated values into a 2D array:
    # Row 0: Average predicted probabilities, Row 1: Average ground truth.
    heatmap_data = np.vstack([avg_pred, avg_true])

    # Plot the heatmap with a lower vertical height.
    fig, ax = plt.subplots(figsize=(10, 2))
    cax = ax.imshow(heatmap_data, aspect='auto', cmap='Blues', origin='upper', interpolation='nearest')
    fig.colorbar(cax, ax=ax, label='Value')

    # Set y-axis labels.
    ax.set_yticks([0, 1])
    ax.set_yticklabels(['Avg Predicted', 'Avg Ground Truth'])
    # Set x-axis ticks and labels for each bin.
    ax.set_xticks(np.arange(num_bins))
    x_labels = [f"{i*bin_size+1}-{min((i+1)*bin_size, seq_len)}" for i in range(num_bins)]
    ax.set_xticklabels(x_labels, rotation=45)
    ax.set_xlabel("Time Bins")
    ax.set_title("Smoothed Heatmap (Balanced Test Vector)")
    plt.tight_layout()
    plt.show()

# Example usage:
visualize_sample_heatmap_gradient_balanced(loaded_model, test_dataset, bin_size=5)

# plot_roc_curve(loaded_model, test_dataset)

