In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
import numpy as np
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


##### CNN Feature Extractor
* Input of size 6x4x10 (cxhxw)
* -> zero pad of 1 pixel -> 6x6x12
* -> conv1 32x3x2 s 1x1 -> 32x4x10 (remove column 5)
* -> bn1
* -> relu1
* -> horizontal (left-right) zero pad of 1 pixel -> 32x4x12
* -> conv2 64x2x2 s 2x1 -> 64x2x10
* -> bn2
* -> relu2
* -> conv3 1x1x1 s 1x1 -> 1x2x10 (bottleneck)
* -> bn3
* -> relu3  

In [None]:

class CNNFeatureExtractor(nn.Module):
    def __init__(self, output_dim=128):
        super(CNNFeatureExtractor, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=6, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=5, kernel_size=(5,2), stride=(1,2))
        self.bn3 = nn.BatchNorm2d(5)
        self.relu3 = nn.ReLU()

    def forward(self, x: torch.Tensor):
        assert x.shape == (x.size(0), 6, 10, 4), f"Expected input shape (batch_size, 6, 10, 4), got {x.shape}"

        left = x[:,:,:5,:]
        right = x[:,:,5:,:]
        # Mirror right hand
        right = torch.flip(right, dims=[2])

        # Concatenate along the batch dimension
        x = torch.cat([left, right], dim=0)  # Shape: (2*batch_size, 6, 5, 4)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)

        # Split back into left and right
        batch_size = x.size(0) // 2
        left, right = x[:batch_size], x[batch_size:]

        # Mirror right hand back
        right = torch.flip(right, dims=[2])

        # Concatenate along the width dimension
        x = torch.cat([left, right], dim=2)  # Shape: (batch_size, 5, 2, 2)
        x = x.transpose(3,1).contiguous() # Shape: (batch_size, 2, 2, 5)
        x = x.view(batch_size, -1)  # Shape: (batch_size, 20)

        return x

class AirKeyboardModel(nn.Module):
    def __init__(self, lstm_hidden_size=256, lstm_layers=2):
        super(AirKeyboardModel, self).__init__()
        self.cnn = CNNFeatureExtractor()
        self.lstm = nn.LSTM(
            input_size=20,
            hidden_size=lstm_hidden_size,
            num_layers=lstm_layers,
            batch_first=True,
            dropout=0.2
        )

        self.output_matrix_rows = 6
        self.output_matrix_cols = 15
        output_features = self.output_matrix_rows * self.output_matrix_cols

        self.classifier = nn.Linear(lstm_hidden_size, output_features)
        self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x: torch.Tensor):
        # Input x shape: (batch_size, seq_len, C, W, H)
        batch_size, seq_len, c, w, h = x.shape

        # Pass through CNN
        cnn_in = x.view(batch_size * seq_len, c, w, h)
        cnn_out = self.cnn(cnn_in)
        
        # Pass through LSTM
        lstm_in = cnn_out.view(batch_size, seq_len, -1)
        lstm_out, _ = self.lstm(lstm_in) # Shape: (batch_size, seq_length, 256)
        raw_output = self.classifier(lstm_out) # Shape: (batch_size, seq_size, 90)
        normalized_output = self.log_softmax(raw_output)
        output = normalized_output.view(batch_size, seq_len, self.output_matrix_rows, self.output_matrix_cols)


        return output

class JointSequenceDataset(Dataset):
    def __init__(self, data: pd.DataFrame, valid_indices: list, seq_length=30):
        self.data = data
        self.seq_length = seq_length
        self.valid_indices = valid_indices # The pre-computed list of starting indices

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        # Use idx to look up the real starting index in the DataFrame
        # This line will now be safe because idx will never be larger than len(self.valid_indices) - 1
        start_idx = self.valid_indices[idx]
        
        seq_data = self.data.iloc[start_idx : start_idx + self.seq_length]
        
        # Ensure your data columns are named 'X' and 'y1' as per your original code
        X_seq_np = np.array(seq_data['X'].tolist())
        y_seq_np = np.array(seq_data['y1'].tolist())

        X_seq = torch.as_tensor(X_seq_np, dtype=torch.float32)
        y_seq = torch.as_tensor(y_seq_np, dtype=torch.float32)

        return X_seq, y_seq
    

### The loss function: Manhattan distance over keybaord layout
Layout
```plaintext
ESC    F1     F2     F3     F4     F5     F6     F7     F8     F9     F10    F11    F12    INS    DEL
`      1      2      3      4      5      6      7      8      9      0      -      =      BKSP   BKSP
TAB    Q      W      E      R      T      Y      U      I      O      P      [      ]      ENTR   ENTR
CPSL   A      S      D      F      G      H      J      K      L      ;      '      \      ENTR   ENTR
LSHFT  \      Z      X      C      V      B      N      M      ,      .      /      RSHFT  RSHFT  RSHFT
LCTRL  FN     META   LALT   SPC    SPC    SPC    SPC    SPC    RALT   RCTRL  LEFT   UP     DOWN   RIGHT
```


In [None]:


# He (Kaiming) Weight Initialization function
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LSTM):
        for name, param in m.named_parameters():
            if 'weight_ih' in name:
                nn.init.xavier_uniform_(param.data)
            elif "weight_hh" in name:
                nn.init.orthogonal_(param.data)
            elif 'bias' in name:
                param.data.fill_(0)
                
# HYPERPARAMETERS
SEQ_LENGTH = 120
BATCH_SIZE = 128
NUM_EPOCHS = 100
LEARNING_RATE = 0.001

df = pd.read_hdf("processed_data_17551910634500783_17578807586533651.h5")
df["key_present"] = df["pressed_label"].apply(lambda x: 0 if x == "NO_KEY" else 1)

split_point = int(0.8 * len(df))
train_df = df.iloc[:split_point]
val_df = df.iloc[split_point:]

# --- Apply downsampling logic ONLY to train_df ---
train_start_indices = range(len(train_df) - SEQ_LENGTH + 1)
minority_train_indices = []
majority_train_indices = []

logging.info("Categorizing training sequences...")
for i in train_start_indices:
    if train_df['key_present'].iloc[i : i + SEQ_LENGTH].to_numpy().sum() > 0:
        minority_train_indices.append(i)
    else:
        majority_train_indices.append(i)


# Downsample the majority indices from the training set
num_majority_to_keep = len(minority_train_indices) * 8
if len(majority_train_indices) < num_majority_to_keep:
    num_majority_to_keep = len(majority_train_indices)
majority_indices_downsampled = np.random.choice(
    majority_train_indices, 
    size=num_majority_to_keep, 
    replace=False
).tolist()

# The final list of indices for training is balanced
final_train_indices = minority_train_indices + majority_indices_downsampled
np.random.shuffle(final_train_indices)

train_dataset = JointSequenceDataset(df, final_train_indices, seq_length=SEQ_LENGTH)

# The validation indices are just all possible sequences in the imbalanced val_df
all_val_indices = list(range(len(val_df) - SEQ_LENGTH + 1))

# Create the two separate datasets
train_dataset = JointSequenceDataset(train_df, final_train_indices, seq_length=SEQ_LENGTH)
val_dataset = JointSequenceDataset(val_df, all_val_indices, seq_length=SEQ_LENGTH)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

len(df), len(train_dataset), len(val_dataset), len(majority_train_indices), len(minority_train_indices)



Categorizing training sequences...


(114132, 91186, 22708, 0, 91186)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from tqdm import tqdm


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")

# INITIALIZE MODEL, LOSS, AND OPTIMIZER
model = AirKeyboardModel(lstm_hidden_size=256, lstm_layers=2)
model.to(device)
model.apply(weights_init)

loss_function = nn.L1Loss() 
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=1)

wandb.init(
    project="air-keyboard-project",
    config={
        "learning_rate": LEARNING_RATE,
        "epochs": NUM_EPOCHS,
        "batch_size": BATCH_SIZE,
        "optimizer": "Adam",
        "loss_function": "L1Loss",
        "scheduler": "ExponentialLR"
    }
)
wandb.watch(model, log='all', log_freq=100) # Log every 100 batches


for epoch in range(NUM_EPOCHS):
    model.train()
    running_train_loss = 0.0
    
    for sequences, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [T]", leave=False):
        sequences = sequences.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(sequences)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        running_train_loss += loss.item()

    avg_train_loss = running_train_loss / len(train_loader)
    scheduler.step()
    lr = scheduler.get_last_lr()[0]

    model.eval()
    running_val_loss = 0.0
    with torch.no_grad():
        for sequences, labels in val_loader:
            sequences, labels = sequences.to(device), labels.to(device)
            outputs = model(sequences)
            loss = loss_function(outputs, labels)
            running_val_loss += loss.item()
    avg_val_loss = running_val_loss / len(val_loader)
    
    # Log everything to W&B in one step
    # The gradient/weight norms will be logged automatically by wandb.watch
    log_dict = {
        "Loss/Train": avg_train_loss,
        "Loss/Validation": avg_val_loss,
        "Learning Rate": lr,
        "epoch": epoch # Custom x-axis
    }
    
    # average gradient/weight ratio for the whole model
    total_grad_norm = 0.0
    total_weight_norm = 0.0
    for param in model.parameters():
        if param.grad is not None:
            total_grad_norm += param.grad.data.norm(2).item() ** 2
            total_weight_norm += param.data.norm(2).item() ** 2
    
    total_grad_norm = total_grad_norm ** 0.5
    total_weight_norm = total_weight_norm ** 0.5
    
    # weight gradient to weight ratio
    log_dict["Gradient-Weight Ratio"] = total_grad_norm / (total_weight_norm + 1e-8)
    
    wandb.log(log_dict)

    logging.info(f"Epoch [{epoch+1}/{NUM_EPOCHS}] | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {lr:.6f}")

logging.info("\nTraining finished.")
wandb.finish()

Using device: cuda


[34m[1mwandb[0m: Currently logged in as: [33msamanbzg[0m ([33msamanbzg-tu-darmstadt[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


                                                                  

Epoch [1/100] | Train Loss: 3.1037 | Val Loss: 3.1199 | LR: 0.001000


                                                                  

Epoch [2/100] | Train Loss: 2.9006 | Val Loss: 3.1103 | LR: 0.001000


                                                                  

Epoch [3/100] | Train Loss: 2.8984 | Val Loss: 3.1132 | LR: 0.001000


                                                                  

Epoch [4/100] | Train Loss: 2.8917 | Val Loss: 3.0884 | LR: 0.001000


                                                                  

Epoch [5/100] | Train Loss: 2.5538 | Val Loss: 2.8824 | LR: 0.001000


                                                                  

Epoch [6/100] | Train Loss: 2.1307 | Val Loss: 2.9307 | LR: 0.001000


                                                                  

Epoch [7/100] | Train Loss: 1.9615 | Val Loss: 3.0146 | LR: 0.001000


                                                                  

Epoch [8/100] | Train Loss: 1.8951 | Val Loss: 2.8997 | LR: 0.001000


                                                                  

Epoch [9/100] | Train Loss: 1.8290 | Val Loss: 2.9053 | LR: 0.001000


                                                                   

Epoch [10/100] | Train Loss: 1.7998 | Val Loss: 2.8505 | LR: 0.001000


                                                                   

Epoch [11/100] | Train Loss: 1.7900 | Val Loss: 2.8891 | LR: 0.001000


                                                                   

Epoch [12/100] | Train Loss: 1.7356 | Val Loss: 2.9305 | LR: 0.001000


                                                                   

Epoch [13/100] | Train Loss: 1.7456 | Val Loss: 2.9178 | LR: 0.001000


                                                                   

Epoch [14/100] | Train Loss: 1.7052 | Val Loss: 2.9220 | LR: 0.001000


                                                                   

Epoch [15/100] | Train Loss: 1.7149 | Val Loss: 2.9641 | LR: 0.001000


                                                                   

Epoch [16/100] | Train Loss: 1.6664 | Val Loss: 2.8623 | LR: 0.001000


                                                                   

Epoch [17/100] | Train Loss: 1.6506 | Val Loss: 2.8807 | LR: 0.001000


                                                                   