In [1]:
import os
from pathlib import Path
from omegaconf import OmegaConf
from hydra.utils import instantiate

from emg2qwerty import transforms, utils
from hydra import initialize, compose
from emg2qwerty.data import WindowedEMGDataset
from torch.utils.data import DataLoader, Dataset

from emg2qwerty.transforms import Compose, ToTensor, LogSpectrogram, TemporalAlignmentJitter, IdentityTransform, RandomBandRotation, SlidingCovariance, SpecAugment, ZNormalizeTime

import pytorch_lightning as pl
import logging
from emg2qwerty.lightning import WindowedEMGDataModule
import numpy as np

import torch
import pickle

In [None]:
""" Train the emg2qwerty model. 

Code taken from: 
Sivakumar, Viswanath, et al. 
"emg2qwerty: A large dataset with baselines for touch typing using surface electromyography." 
Advances in Neural Information Processing Systems 37 (2024): 91373-91389.

https://github.com/facebookresearch/emg2qwerty
"""

In [None]:
""" 
According to your setup, you need to change,

dataset:
  root: /mnt/dataDrive/qwertyData/data in base.yaml

and 
lm_path: /mnt/dataDrive/emgFullCorpora/toUpload/emg2qwerty/models/wikitext-103-6gram-charlm.bin in ctc_beam.yaml

in config files.
"""


"""
All codes are same except that we use SlidingCovariance instead of log spectrogams.
We also train for 250 epochs instead of 150 epochs since our model took longer to converge. 

The rest of the things are same as in https://github.com/facebookresearch/emg2qwerty.
This is a controlled experiment to show that covariance features rae better than spectral features. 

We train two different variations: one where covariance matrices are approximately diagonalized, and one where they are not.
"""

In [208]:
"""deocder = ctc_beam or ctc_greedy
approxDiag = True or False
"""

user = "user2"
decoder = "ctc_beam" 
approxDiag = True

In [209]:
with initialize(config_path = "config", version_base = None):
    config = compose(
        config_name = "base",
        overrides=[
            f"user={user}",
            f"decoder={decoder}"
        ]
    )

In [210]:
with open("../DATA/emg2qwerty/frechetMean/" + user + "Mean.pkl", "rb") as f:
    userMean = pickle.load(f)
meanLeft = userMean["left"]
meanRight = userMean["right"]

In [211]:
eigenvaluesL, eigenvectorsL = np.linalg.eig(meanLeft)

eigenvaluesR, eigenvectorsR = np.linalg.eig(meanRight)

In [212]:
eigenvectorsLeft = torch.tensor(eigenvectorsL, dtype = torch.float32)
eigenvectorsRight = torch.tensor(eigenvectorsR, dtype = torch.float32)

In [213]:
def fullSessionPaths(dataset, root):
    sessions = [s["session"] for s in dataset]
    return [Path(root) / f"{session}.hdf5" for session in sessions]

In [214]:
print(config)

{'user': 'user2', 'dataset': {'train': [{'user': 71786456, 'session': '2020-12-17-1608243953-keystrokes-dca-study@1-d09c3c47-708c-4b37-9e12-a0218ae9e5c4'}, {'user': 71786456, 'session': '2020-12-17-1608263300-keystrokes-dca-study@1-d09c3c47-708c-4b37-9e12-a0218ae9e5c4'}, {'user': 71786456, 'session': '2020-12-18-1608332019-keystrokes-dca-study@1-d09c3c47-708c-4b37-9e12-a0218ae9e5c4'}, {'user': 71786456, 'session': '2020-12-18-1608335464-keystrokes-dca-study@1-d09c3c47-708c-4b37-9e12-a0218ae9e5c4'}, {'user': 71786456, 'session': '2020-12-15-1608082012-keystrokes-dca-study@1-d09c3c47-708c-4b37-9e12-a0218ae9e5c4'}, {'user': 71786456, 'session': '2020-12-15-1608079307-keystrokes-dca-study@1-d09c3c47-708c-4b37-9e12-a0218ae9e5c4'}, {'user': 71786456, 'session': '2020-12-16-1608169994-keystrokes-dca-study@1-d09c3c47-708c-4b37-9e12-a0218ae9e5c4'}, {'user': 71786456, 'session': '2020-12-16-1608164146-keystrokes-dca-study@1-d09c3c47-708c-4b37-9e12-a0218ae9e5c4'}, {'user': 71786456, 'session': '2

In [215]:
trainTransform = Compose([
    ToTensor(fields=['emg_left', 'emg_right'], stack_dim = 1),
    ZNormalizeTime(),
    RandomBandRotation(),
    TemporalAlignmentJitter(max_offset = 120, stack_dim = 1),
    SlidingCovariance(eigenvectorsL = eigenvectorsLeft, eigenvectorsR = eigenvectorsRight, approxDiag = approxDiag)
])

valTransform = Compose([
    ToTensor(fields=['emg_left', 'emg_right'], stack_dim = 1),
    ZNormalizeTime(),
    SlidingCovariance(eigenvectorsL = eigenvectorsLeft, eigenvectorsR = eigenvectorsRight, approxDiag = approxDiag)
])

testTransform = Compose([
    ToTensor(fields=['emg_left', 'emg_right'], stack_dim = 1),
    ZNormalizeTime(),
    SlidingCovariance(eigenvectorsL = eigenvectorsLeft, eigenvectorsR = eigenvectorsRight, approxDiag = approxDiag)
])

trainSessions = fullSessionPaths(config.dataset.train, config.dataset.root)
valSessions = fullSessionPaths(config.dataset.val, config.dataset.root)
testSessions = fullSessionPaths(config.dataset.test, config.dataset.root)

In [216]:
datamodule = instantiate(
    config.datamodule,
    batch_size = config.batch_size,
    num_workers = config.num_workers,
    train_sessions = trainSessions,
    val_sessions = valSessions,
    test_sessions = testSessions,
    train_transform = trainTransform,
    val_transform = valTransform,
    test_transform = testTransform,
    _convert_ = "object"
)


datamodule.setup()

In [217]:
from emg2qwerty.lightning import TDSConvCTCModule
from emg2qwerty.charset import charset
import torch.optim as optim
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

In [218]:
model = TDSConvCTCModule(
    in_features = 16 * 16,
    mlp_features = config.module.mlp_features,
    block_channels = config.module.block_channels,
    kernel_width = config.module.kernel_width,
    optimizer = config.optimizer,
    lr_scheduler = config.lr_scheduler,
    decoder = config.decoder,
    share_hand_weights = False
)

checkpointCB = ModelCheckpoint(
    monitor = 'val/loss',
    save_top_k = 1,
    mode = 'min',
    filename = 'best-model',
    save_last = True,
    dirpath = './checkpoints'
)

In [219]:
lrMonitor = LearningRateMonitor(logging_interval = 'epoch')

trainer = pl.Trainer(
    accelerator = 'gpu',
    devices = 1,
    max_epochs = config.trainer.max_epochs,
    callbacks = [checkpointCB, lrMonitor],
    default_root_dir = './outputs',
    log_every_n_steps = 1
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
""" Train the model. Skip this cell if you simpy want simply test with the given checkpoints. """

trainer.fit(model, datamodule = datamodule)

In [220]:
if approxDiag:
    where = "withApproxDiag"
else: 
    where = "withoutApproxDiag"

checkpoint = torch.load("../DATA/emg2qwerty/" + where + "/" + user + ".ckpt")
model.load_state_dict(checkpoint["state_dict"])

<All keys matched successfully>

In [221]:
trainer.validate(model, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Validation DataLoader 0: 100%|██████████| 19/19 [29:03<00:00, 91.75s/it]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         val/CER             7.057211399078369
         val/DER            2.0098206996917725
         val/IER            2.7635035514831543
         val/SER            2.2838871479034424
        val/loss            0.5897980332374573
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val/loss': 0.5897980332374573,
  'val/CER': 7.057211399078369,
  'val/IER': 2.7635035514831543,
  'val/DER': 2.0098206996917725,
  'val/SER': 2.2838871479034424}]

In [222]:
trainer.test(model, datamodule = datamodule)         

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing DataLoader 0: 100%|██████████| 2/2 [35:02<00:00, 1051.49s/it]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test/CER             5.546200752258301
        test/DER            1.9469585418701172
        test/IER             1.715428352355957
        test/SER            1.8838139772415161
        test/loss           0.46622198820114136
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test/loss': 0.46622198820114136,
  'test/CER': 5.546200752258301,
  'test/IER': 1.715428352355957,
  'test/DER': 1.9469585418701172,
  'test/SER': 1.8838139772415161}]