In [1]:
import os
from pathlib import Path
from omegaconf import OmegaConf
from hydra.utils import instantiate

from emg2qwerty import transforms, utils
from hydra import initialize, compose
from emg2qwerty.data import WindowedEMGDataset
from torch.utils.data import DataLoader, Dataset

from emg2qwerty.transforms import Compose, ToTensor, LogSpectrogram, TemporalAlignmentJitter, IdentityTransform, RandomBandRotation, SlidingCovariance, SpecAugment, ZNormalizeTime

import pytorch_lightning as pl
import logging
from emg2qwerty.lightning import WindowedEMGDataModule
import numpy as np

import torch
import pickle

In [None]:
""" Train the emg2qwerty model. 

Code taken from: 
Sivakumar, Viswanath, et al. 
"emg2qwerty: A large dataset with baselines for touch typing using surface electromyography." 
Advances in Neural Information Processing Systems 37 (2024): 91373-91389.

https://github.com/facebookresearch/emg2qwerty
"""

In [None]:
""" 
According to your setup, you need to change,

dataset:
  root: /mnt/dataDrive/qwertyData/data in base.yaml

and 
lm_path: /mnt/dataDrive/emgFullCorpora/toUpload/emg2qwerty/models/wikitext-103-6gram-charlm.bin in ctc_beam.yaml

in config files.
"""


"""
All codes are same except that we use SlidingCovariance instead of log spectrogams.
We also train for 250 epochs instead of 150 epochs since our model took longer to converge. 

The rest of the things are same as in https://github.com/facebookresearch/emg2qwerty.
This is a controlled experiment to show that covariance features rae better than spectral features. 

We train two different variations: one where covariance matrices are approximately diagonalized, and one where they are not.
"""

In [2]:
"""deocder = ctc_beam or ctc_greedy
approxDiag = True or False
"""

user = "user0"
decoder = "ctc_greedy" 
approxDiag = False

In [3]:
with initialize(config_path = "config", version_base = None):
    config = compose(
        config_name = "base",
        overrides=[
            f"user={user}",
            f"decoder={decoder}"
        ]
    )

In [4]:
with open("../DATA/emg2qwerty/frechetMean/" + user + "Mean.pkl", "rb") as f:
    userMean = pickle.load(f)
meanLeft = userMean["left"]
meanRight = userMean["right"]

In [5]:
eigenvaluesL, eigenvectorsL = np.linalg.eig(meanLeft)

eigenvaluesR, eigenvectorsR = np.linalg.eig(meanRight)

In [6]:
eigenvectorsLeft = torch.tensor(eigenvectorsL, dtype = torch.float32)
eigenvectorsRight = torch.tensor(eigenvectorsR, dtype = torch.float32)

In [7]:
def fullSessionPaths(dataset, root):
    sessions = [s["session"] for s in dataset]
    return [Path(root) / f"{session}.hdf5" for session in sessions]

In [8]:
print(config)

{'user': 'user0', 'dataset': {'train': [{'user': 43037958, 'session': '2020-12-17-1608244656-keystrokes-dca-study@1-e041d7d9-a53b-40f3-aabc-e9714072ca46'}, {'user': 43037958, 'session': '2020-12-17-1608255062-keystrokes-dca-study@1-e041d7d9-a53b-40f3-aabc-e9714072ca46'}, {'user': 43037958, 'session': '2020-12-17-1608257601-keystrokes-dca-study@1-e041d7d9-a53b-40f3-aabc-e9714072ca46'}, {'user': 43037958, 'session': '2020-12-17-1608268481-keystrokes-dca-study@1-e041d7d9-a53b-40f3-aabc-e9714072ca46'}, {'user': 43037958, 'session': '2020-12-18-1608304463-keystrokes-dca-study@1-e041d7d9-a53b-40f3-aabc-e9714072ca46'}, {'user': 43037958, 'session': '2020-12-18-1608314177-keystrokes-dca-study@1-e041d7d9-a53b-40f3-aabc-e9714072ca46'}, {'user': 43037958, 'session': '2020-12-18-1608311446-keystrokes-dca-study@1-e041d7d9-a53b-40f3-aabc-e9714072ca46'}, {'user': 43037958, 'session': '2020-12-17-1608220409-keystrokes-dca-study@1-e041d7d9-a53b-40f3-aabc-e9714072ca46'}, {'user': 43037958, 'session': '2

In [9]:
trainTransform = Compose([
    ToTensor(fields=['emg_left', 'emg_right'], stack_dim = 1),
    ZNormalizeTime(),
    RandomBandRotation(),
    TemporalAlignmentJitter(max_offset = 120, stack_dim = 1),
    SlidingCovariance(eigenvectorsL = eigenvectorsLeft, eigenvectorsR = eigenvectorsRight, approxDiag = approxDiag)
])

valTransform = Compose([
    ToTensor(fields=['emg_left', 'emg_right'], stack_dim = 1),
    ZNormalizeTime(),
    SlidingCovariance(eigenvectorsL = eigenvectorsLeft, eigenvectorsR = eigenvectorsRight, approxDiag = approxDiag)
])

testTransform = Compose([
    ToTensor(fields=['emg_left', 'emg_right'], stack_dim = 1),
    ZNormalizeTime(),
    SlidingCovariance(eigenvectorsL = eigenvectorsLeft, eigenvectorsR = eigenvectorsRight, approxDiag = approxDiag)
])

trainSessions = fullSessionPaths(config.dataset.train, config.dataset.root)
valSessions = fullSessionPaths(config.dataset.val, config.dataset.root)
testSessions = fullSessionPaths(config.dataset.test, config.dataset.root)

In [10]:
datamodule = instantiate(
    config.datamodule,
    batch_size = config.batch_size,
    num_workers = config.num_workers,
    train_sessions = trainSessions,
    val_sessions = valSessions,
    test_sessions = testSessions,
    train_transform = trainTransform,
    val_transform = valTransform,
    test_transform = testTransform,
    _convert_ = "object"
)


datamodule.setup()

In [11]:
from emg2qwerty.lightning import TDSConvCTCModule
from emg2qwerty.charset import charset
import torch.optim as optim
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

In [12]:
model = TDSConvCTCModule(
    in_features = 16 * 16,
    mlp_features = config.module.mlp_features,
    block_channels = config.module.block_channels,
    kernel_width = config.module.kernel_width,
    optimizer = config.optimizer,
    lr_scheduler = config.lr_scheduler,
    decoder = config.decoder,
    share_hand_weights = False
)

checkpointCB = ModelCheckpoint(
    monitor = 'val/loss',
    save_top_k = 1,
    mode = 'min',
    filename = 'best-model',
    save_last = True,
    dirpath = './checkpoints'
)

In [13]:
lrMonitor = LearningRateMonitor(logging_interval = 'epoch')

trainer = pl.Trainer(
    accelerator = 'gpu',
    devices = 1,
    max_epochs = config.trainer.max_epochs,
    callbacks = [checkpointCB, lrMonitor],
    default_root_dir = './outputs',
    log_every_n_steps = 1
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
""" Train the model. Skip this cell if you simpy want simply test with the given checkpoints. """

trainer.fit(model, datamodule = datamodule)

In [14]:
if approxDiag:
    where = "withApproxDiag"
else: 
    where = "withoutApproxDiag"

checkpoint = torch.load("../DATA/emg2qwerty/" + where + "/" + user + ".ckpt")
model.load_state_dict(checkpoint["state_dict"])

<All keys matched successfully>

In [15]:
trainer.validate(model, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Validation DataLoader 0:  15%|█▌        | 3/20 [00:00<00:02,  5.94it/s]

  return F.conv2d(input, weight, bias, self.stride,


Validation DataLoader 0: 100%|██████████| 20/20 [00:01<00:00, 15.66it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         val/CER            20.588546752929688
         val/DER            3.3767333030700684
         val/IER             5.186831951141357
         val/SER            12.024981498718262
        val/loss            1.1117563247680664
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val/loss': 1.1117563247680664,
  'val/CER': 20.588546752929688,
  'val/IER': 5.186831951141357,
  'val/DER': 3.3767333030700684,
  'val/SER': 12.024981498718262}]

In [16]:
trainer.test(model, datamodule = datamodule)         

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing DataLoader 0: 100%|██████████| 2/2 [00:01<00:00,  1.02it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test/CER            22.537473678588867
        test/DER             1.970021367073059
        test/IER             6.937901496887207
        test/SER            13.629549980163574
        test/loss           1.1046416759490967
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test/loss': 1.1046416759490967,
  'test/CER': 22.537473678588867,
  'test/IER': 6.937901496887207,
  'test/DER': 1.970021367073059,
  'test/SER': 13.629549980163574}]