# GIK Character Prediction Model

Train a model to predict keyboard characters from IMU sensor data.


In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import sys
import yaml
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

# Add project root to path
PROJECT_ROOT = os.path.dirname(os.path.abspath('__file__'))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from pretraining import preprocess_multiple_sources, load_preprocessed_dataset, export_dataset_to_csv, get_class_weights
from src.Constants.char_to_key import INDEX_TO_CHAR, CHAR_TO_INDEX, NUM_CLASSES
from src.pre_processing.reduce_dim import reduce_dim
from src.visualisation.visualisation import (
    compute_confusion_matrix_40x40,
    plot_confusion_matrix_40x40,
    plot_anchor_with_closest_neighbours,
    plot_virtual_keyboard_heatmap,
    show_predictions as viz_show_predictions,
    show_predictions_coordinate as viz_show_predictions_coordinate,
)
from ml.models.gik_model import create_model_auto_input_dim, GIKTrainer, decode_predictions
from ml.models.loss_functions.custom_losses import FocalLoss, CoordinateLoss, CoordinateLossClassification

torch.manual_seed(42)
np.random.seed(42)

DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"PyTorch {torch.__version__} | Device: {DEVICE}")

## Configuration

In [None]:
from src.Constants.char_to_key import (
    KEY_COORDS,
    SPECIAL_COORDS,
    SPACE_ANCHORS,
    ALL_CHARS,
    FULL_COORDS,
)

print(f"Loaded keyboard coordinate constants for {len(ALL_CHARS)} classes")

In [None]:
# Load experiment config from project-root YAML
CONFIG_PATH = os.path.join(PROJECT_ROOT, "train_config.yaml")
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
    config_data = yaml.safe_load(f)

# Data paths (paired keyboard + IMU files)
DATA_DIR = config_data["data"]["data_dir"]
KEYBOARD_FILES = config_data["data"]["keyboard_files"]
LEFT_FILES = config_data["data"].get("left_files")
RIGHT_FILES = config_data["data"].get("right_files")
PROCESSED_DATA_PATH = os.path.join(DATA_DIR, "processed_dataset.pt")

# Build CONFIG from shared + mode-specific sections
EXPERIMENT = config_data["experiment"]
MODE = EXPERIMENT["mode"]
MODE_CONFIG = config_data["modes"][MODE]

CONFIG = {
    "max_seq_length": EXPERIMENT["max_seq_length"],
    "reduce_dim": EXPERIMENT["use_dim_reduction"],
    "enable_class_weights": EXPERIMENT["use_class_weights"],
    "run_preprocess": EXPERIMENT["run_preprocess"],
    "export_dataset_csv": EXPERIMENT["export_dataset_csv"],
    **config_data["model"],
    **config_data["train"],
    **MODE_CONFIG,
}

# Resolve object references encoded as strings in YAML
KEY_MAPPING_REGISTRY = {
    "FULL_COORDS": FULL_COORDS,
    "CHAR_TO_INDEX": CHAR_TO_INDEX,
}
LOSS_REGISTRY = {
    "CoordinateLossClassification": CoordinateLossClassification,
    "CoordinateLoss": CoordinateLoss,
    "FocalLoss": FocalLoss,
}
OUTPUT_LOGITS_REGISTRY = {
    "NUM_CLASSES": NUM_CLASSES,
}

CONFIG["key_mapping_dict"] = KEY_MAPPING_REGISTRY[CONFIG["key_mapping_dict"]]
CONFIG["loss"] = LOSS_REGISTRY[CONFIG["loss"]]
if isinstance(CONFIG["output_logits"], str):
    CONFIG["output_logits"] = OUTPUT_LOGITS_REGISTRY[CONFIG["output_logits"]]

print(f"Loaded config: {CONFIG_PATH}")
print(f"Mode: {MODE}")
print(f"Data dir: {DATA_DIR}")
print(f"Keyboard files: {KEYBOARD_FILES}")
print(f"Left IMU files: {LEFT_FILES}")
print(f"Right IMU files: {RIGHT_FILES}")
print(f"Model: {CONFIG['model_type']}")
print(f"Seq length: {CONFIG['max_seq_length']}")
print(f"Loss: {CONFIG['loss'].__name__}")
print(f"Run preprocess: {CONFIG['run_preprocess']}")
print(f"Export CSV: {CONFIG['export_dataset_csv']}")

## Preprocess Data

In [None]:
if CONFIG["run_preprocess"]:
    # Preprocess and combine multiple data sources
    metadata = preprocess_multiple_sources(
        data_dir=DATA_DIR,
        keyboard_files=KEYBOARD_FILES,
        left_files=LEFT_FILES,
        right_files=RIGHT_FILES,
        output_path=PROCESSED_DATA_PATH,
        max_seq_length=CONFIG['max_seq_length'],
        normalize=True,
        apply_filtering=True
    )
else:
    # Assume preprocessing already done and .pt file exists
    preprocessed = torch.load(PROCESSED_DATA_PATH, weights_only=False)
    metadata = preprocessed["metadata"]
    print(f"Using existing preprocessed dataset: {PROCESSED_DATA_PATH}")

In [None]:
print(f"\nTotal Samples: {metadata['num_samples']} | Feat dim: {metadata['feat_dim']} | Sources: {metadata['num_sources']}")

In [None]:
if CONFIG["run_preprocess"] and CONFIG["export_dataset_csv"]:
    # Export to CSV for inspection (optional)
    export_dataset_to_csv(PROCESSED_DATA_PATH, DATA_DIR, include_features=True)

## Dimensionality Reduction

In [None]:
if CONFIG["reduce_dim"]:
    DIM_RED_OUTPUT = os.path.join(DATA_DIR, "dim_red_output.pt")

    # Ideally pass the metadata from preprocessing instead of manually setting it here but this avoids having to rerun the preprocessing step
    HAS_LEFT = False
    HAS_RIGHT = True

    # dims = reduce_dim(
    #     data_source=PROCESSED_DATA_PATH,
    #     method="active-imu",
    #     has_left=HAS_LEFT,
    #     has_right=HAS_RIGHT, 
    #     normalize=True,
    #     output_path=DIM_RED_OUTPUT)
    
    dims = reduce_dim(
        data_source=PROCESSED_DATA_PATH,
        method="pca",
        dims_ratio=0.4,
        has_left=HAS_LEFT,
        has_right=HAS_RIGHT, 
        normalize=True,
        output_path=DIM_RED_OUTPUT,
        root_dir=PROJECT_ROOT)

    print(f"Feature dimension reduced from {dims['dim_bef']} to {dims['dim_aft']}")

## Balance DataSet

In [None]:
if CONFIG["enable_class_weights"]:
    class_weights = get_class_weights(DIM_RED_OUTPUT if CONFIG["reduce_dim"] else PROCESSED_DATA_PATH)
    # class_weights.to(DEVICE)
    if CONFIG["loss"] == FocalLoss:
        CONFIG["loss_params"]["alpha"] = class_weights
    elif CONFIG["loss"] == CoordinateLossClassification:
        CONFIG["loss_params"]["class_weights"] = class_weights

## Load Dataset & Create Model

In [None]:
dataset = load_preprocessed_dataset(DIM_RED_OUTPUT if CONFIG["reduce_dim"] else PROCESSED_DATA_PATH, 
                                    is_one_hot_labels=CONFIG["is_one_hot"],
                                    char_to_index=CONFIG["key_mapping_dict"],
                                    return_class_id=CONFIG["return_class_id"])
print(f"Dataset: {len(dataset)} samples | Input dim: {dataset.input_dim}")

# Create model
model = create_model_auto_input_dim(
    dataset,
    model_type=CONFIG['model_type'],
    hidden_dim_inner_model=CONFIG['hidden_dim_inner_model'],
    hidden_dim_classification_head=CONFIG['hidden_dim_classification_head'],
    no_layers_classification_head=CONFIG['num_layers'],
    dropout_inner_layers=CONFIG['dropout'],
    inner_model_kwargs=CONFIG['inner_model_prams'],
    output_logits = CONFIG['output_logits'],
)

# Print model architecture
print("\n" + "=" * 60)
print("Model architecture")
print("=" * 60)
print(model)
print("=" * 60)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

## Train Model

In [None]:
trainer = GIKTrainer(
    model=model,
    dataset=dataset,
    batch_size=CONFIG['batch_size'],
    learning_rate=CONFIG['learning_rate'],
    device=DEVICE,
    loss=CONFIG.get('loss'),
    loss_kwargs=CONFIG.get('loss_params'),
    regression=CONFIG.get('regression'),
)

history = trainer.train(
    epochs=CONFIG['epochs'],
    early_stopping_patience=CONFIG['early_stopping']
)

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].set_title('Loss')


axes[1].plot(history['train_acc'], label='Train')
axes[1].plot(history['val_acc'], label='Val')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].legend()
axes[1].set_title('Accuracy')

plt.tight_layout()
plt.show()

## Evaluate Model

In [None]:
# Evaluate on validation set
val_loss, val_acc = trainer.validate()
print(f"Validation Loss: {val_loss:.4f}")  
# print("Validation Accuracy: {val_acc:.2%}")
print()

# Evaluate on test set
test_loss, test_acc = trainer.evaluate_test()
print(f"Test Loss: {test_loss:.4f}") 
# print(" Test Accuracy: {test_acc:.2%}")

In [None]:
# Show validation predictions
if CONFIG["regression"]:
    viz_show_predictions_coordinate(trainer.val_dataset, model, DEVICE, 'Validation')
else:
    viz_show_predictions(trainer.val_dataset, model, DEVICE, 'Validation')

In [None]:
# Show test predictions
if CONFIG["regression"]:
    viz_show_predictions_coordinate(trainer.test_dataset, model, DEVICE, 'Test')
else:
    viz_show_predictions(trainer.test_dataset, model, DEVICE, 'Test')

## Keyboard Heatmap

## Test Set Visualisation

In [None]:
coord_dict = FULL_COORDS if CONFIG["regression"] else None
cm_orig = compute_confusion_matrix_40x40(trainer.test_dataset, model, DEVICE, coord_dict=coord_dict)

plot_virtual_keyboard_heatmap(cm_orig, 'd', 'Test')

# Optional matrix view
# plot_confusion_matrix_40x40(cm_orig, 'Test')

# neighbours_a, cm_a = plot_anchor_with_closest_neighbours(cm_orig, 'i', 'Test', k_neighbours=5)
# neighbours_g, cm_g = plot_anchor_with_closest_neighbours(cm_orig, 'l', 'Test', k_neighbours=5)
# neighbours_l, cm_l = plot_anchor_with_closest_neighbours(cm_orig, 'l', 'Test', k_neighbours=13)

In [None]:
plot_virtual_keyboard_heatmap(cm_orig, 's', 'Test')

In [None]:
cm_orig_val = compute_confusion_matrix_40x40(
    trainer.val_dataset,
    model,
    DEVICE,
    coord_dict=FULL_COORDS if CONFIG["regression"] else None,
)

plot_virtual_keyboard_heatmap(cm_orig_val, 'c', 'Validation')

# Optional matrix view
# plot_confusion_matrix_40x40(cm_orig_val, 'Validation')

# neighbours_a_val, cm_a_val = plot_anchor_with_closest_neighbours(cm_orig_val, 'd', 'Validation', k_neighbours=12)
# neighbours_g_val, cm_g_val = plot_anchor_with_closest_neighbours(cm_orig_val, 'g', 'Validation', k_neighbours=20)
# neighbours_l_val, cm_l_val = plot_anchor_with_closest_neighbours(cm_orig_val, 'l', 'Validation', k_neighbours=13)

## Save Model

In [None]:
# MODEL_PATH = os.path.join(DATA_DIR, "gik_model.pt")
# torch.save({
#     'model_state_dict': model.state_dict(),
#     'config': CONFIG,
#     'input_dim': dataset.input_dim,
#     'metadata': metadata
# }, MODEL_PATH)
# print(f"Model saved to {MODEL_PATH}")