In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
import numpy as np
import logging

torch.manual_seed(42)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class CNNFeatureExtractor(nn.Module):
    def __init__(self, output_dim=128):
        super(CNNFeatureExtractor, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=6, out_channels=32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # Output size: 32x5x2

        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # Output size: 64x2x1

        self.flatten = nn.Flatten()
        
        self.fc = nn.Linear(64 * 2 * 1, output_dim)
        self.relu3 = nn.ReLU()

    def forward(self, x):
        x = self.pool1(self.relu1(self.bn1(self.conv1(x))))
        x = self.pool2(self.relu2(self.bn2(self.conv2(x))))
        x = self.flatten(x)
        x = self.relu3(self.fc(x))
        # print(f"CNN output: {x}")
        return x
    
class AirKeyboardModel(nn.Module):
    def __init__(self, cnn_output_dim=128, lstm_hidden_size=256, lstm_layers=2):
        super(AirKeyboardModel, self).__init__()
        self.cnn = CNNFeatureExtractor(output_dim=cnn_output_dim)
        self.lstm = nn.LSTM(
            input_size=cnn_output_dim,
            hidden_size=lstm_hidden_size,
            num_layers=lstm_layers,
            batch_first=True
        )

        self.output_matrix_rows = 6
        self.output_matrix_cols = 15
        output_features = self.output_matrix_rows * self.output_matrix_cols

        self.classifier = nn.Linear(lstm_hidden_size, output_features)
        self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x: torch.Tensor):
        # Input x shape: (batch_size, seq_len, C, H, W)
        batch_size, seq_len, c, h, w = x.shape
        
        # Pass through CNN
        cnn_in = x.view(batch_size * seq_len, c, h, w)
        cnn_out = self.cnn(cnn_in)
        
        # Pass through LSTM
        lstm_in = cnn_out.view(batch_size, seq_len, -1)
        lstm_out, _ = self.lstm(lstm_in) # Shape: (batch_size, 15, 256)
        
        # This maps the 256 LSTM features to a flat vector of 90 features for each time step.
        raw_output = self.classifier(lstm_out) # Shape: (batch_size, seq_size, 90)
        # print(f"Raw output before reshaping: {raw_output}")
        # We create the 6x15 matrix for each of the 15 time steps.
        reshaped_output = raw_output.view(batch_size, seq_len, -1)

        # print(f"Reshaped output: {reshaped_output}")

        normalized_output = self.log_softmax(reshaped_output)
        output = normalized_output.view(batch_size, seq_len, self.output_matrix_rows, self.output_matrix_cols)

        # print(f"Final output: {output}")

        return output


class JointSequenceDataset(Dataset):
    def __init__(self, data: pd.DataFrame, seq_length=30):
        self.data = data
        self.seq_length = seq_length
        self._logger = logger.getChild(self.__class__.__name__)

    def __len__(self):
        # The number of possible starting points for a sequence of fixed length
        return len(self.data) - self.seq_length + 1

    def _get_sequence_offset(self, sequence: pd.Series):
        """Returns the minimum offset to make the sequence monotonic increasing."""
        diffs = sequence.diff().fillna(1)
        offsets = diffs[diffs <= 0].index
        if len(offsets) == 0:
            return 0
        return offsets[0] - sequence.index[0] + 1

    def __getitem__(self, idx):
        seq_data = self.data.iloc[idx : idx + self.seq_length]
        offset = self._get_sequence_offset(seq_data['frame_index'])
        
        self._logger.debug(f"Fetching sequence at {idx} with fixed length {self.seq_length}")
        if offset > 0:
            self._logger.debug(f"Applying offset of {offset} to ensure monotonic increase")
            # Recalculate the final sequence with the offset
            final_idx_start = idx + offset
            final_idx_end = idx + self.seq_length + offset
            seq_data = self.data.iloc[final_idx_start:final_idx_end]

        X_seq = torch.tensor(seq_data['X'].to_list(), dtype=torch.float32)
        y_seq = torch.tensor(seq_data['y'].to_list(), dtype=torch.float32)

        return X_seq, y_seq

### The loss function: Manhattan distance over keybaord layout
Layout
```plaintext
ESC    F1     F2     F3     F4     F5     F6     F7     F8     F9     F10    F11    F12    INS    DEL
`      1      2      3      4      5      6      7      8      9      0      -      =      BKSP   BKSP
TAB    Q      W      E      R      T      Y      U      I      O      P      [      ]      ENTR   ENTR
CPSL   A      S      D      F      G      H      J      K      L      ;      '      \      ENTR   ENTR
LSHFT  \      Z      X      C      V      B      N      M      ,      .      /      RSHFT  RSHFT  RSHFT
LCTRL  FN     META   LALT   SPC    SPC    SPC    SPC    SPC    RALT   RCTRL  LEFT   UP     DOWN   RIGHT
```


In [12]:


# He (Kaiming) Weight Initialization function
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LSTM):
        for name, param in m.named_parameters():
            if 'weight_ih' in name:
                nn.init.xavier_uniform_(param.data)
            elif 'weight_hh' in name:
                nn.init.orthogonal_(param.data)
            elif 'bias' in name:
                param.data.fill_(0)
                
# HYPERPARAMETERS
SEQ_LENGTH = 15
BATCH_SIZE = 1
NUM_EPOCHS = 1
LEARNING_RATE = 0.001

# CREATE DUMMY DATA (replace with your actual data loading)
df = pd.read_hdf("processed_data_5055492_8350647.h5")

# SETUP DATASET AND DATALOADERS
full_dataset = JointSequenceDataset(df, seq_length=SEQ_LENGTH)

# Split data into training and validation sets (e.g., 80/20 split)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)


# INITIALIZE MODEL, LOSS, AND OPTIMIZER
model = AirKeyboardModel(lstm_hidden_size=64, lstm_layers=2)
# model.apply(weights_init)

loss_function = nn.KLDivLoss(reduction='batchmean') 
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

problematic_outputs = None
problematic_labels = None
# TRAINING AND VALIDATION LOOP
for epoch in range(NUM_EPOCHS):
    # Training Phase
    model.train()  # Set model to training mode
    running_train_loss = 0.0
    for sequences, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(sequences)
        loss = loss_function(outputs, labels)
        loss.backward()  # Backpropagation

        optimizer.step()
        if loss.isnan():
            print("NaN loss encountered during training.")
            problematic_outputs = outputs
            problematic_labels = labels
            break
        running_train_loss += loss.item()

    avg_train_loss = running_train_loss / len(train_loader)

    # # Validation Phase
    # model.eval()  # Set model to evaluation mode
    # running_val_loss = 0.0
    # with torch.no_grad():  # Disable gradient calculation
    #     for sequences, labels in val_loader:
    #         outputs = model(sequences)
    #         loss = loss_function(outputs, labels)
    #         running_val_loss += loss.item()

    # avg_val_loss = running_val_loss / len(val_loader)

    # print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

print("\nTraining finished.")


NaN loss encountered during training.

Training finished.


In [None]:
class CompositeLoss(nn.Module):

kldiv = nn.KLDivLoss(reduction='batchmean')
torch.set_printoptions(threshold=10_000)
loss = kldiv(problematic_outputs, problematic_labels)

problematic_outputs, problematic_labels, loss

(tensor([[[[-4.4821, -4.5004, -4.4909, -4.4917, -4.5017, -4.5015, -4.5099,
            -4.5009, -4.5091, -4.5027, -4.5024, -4.5021, -4.5029, -4.5033,
            -4.4951],
           [-4.4983, -4.5012, -4.5101, -4.4930, -4.5011, -4.5135, -4.4926,
            -4.4866, -4.4894, -4.4999, -4.4938, -4.4979, -4.4948, -4.4945,
            -4.5040],
           [-4.4952, -4.4958, -4.5015, -4.4950, -4.5019, -4.4964, -4.5130,
            -4.4966, -4.4876, -4.5088, -4.5012, -4.5079, -4.4890, -4.4934,
            -4.5082],
           [-4.5009, -4.5085, -4.4974, -4.5024, -4.4980, -4.5086, -4.4981,
            -4.5134, -4.5129, -4.5112, -4.5250, -4.5019, -4.5026, -4.4993,
            -4.4975],
           [-4.5131, -4.4907, -4.5081, -4.5065, -4.4892, -4.4925, -4.5018,
            -4.4892, -4.5116, -4.5021, -4.4969, -4.5022, -4.5038, -4.4952,
            -4.4993],
           [-4.4957, -4.5027, -4.4985, -4.4840, -4.5045, -4.4783, -4.4812,
            -4.5006, -4.4913, -4.5009, -4.4962, -4.5077, -4.5124,