In [1]:
import sys
import os 
sys.path.insert(1, os.path.realpath(os.path.pardir))

import numpy as np
from pathlib import Path

import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_model_summary import summary
from natsort import natsorted

from utils import data_utils
from utils import common, losses, hand_visualize
from models import HVATNet_v3, HVATNet_v3_FineTune
import audiomentations as A
import time

%load_ext autoreload
%autoreload 2

## Data

In [2]:
common.make_it_reproducible()


class TrainConfig:
    WANDB_NOTES = "HVATNet v3 FT: large bs + augs + wd"

    # datasets = ["../../data/processed/dataset_v1_big"]

    hand_type = ["left"]  # ['left', 'right']
    human_type = ["amputant"]  # ['health', 'amputant']

    use_preproc_data = True  # use preproc data (faster preparation
    use_angles = True  # use angeles as target.

    original_fps = 250  # TODO describtion
    delay_ms = 0  # Shift vr vs EMG parameter. Do not work with preproc data. Fix it!!
    start_crop_ms = 0  # bad values in the beginning of recordign in ms to delete.
    window_size = 256
    down_sample_target = 8  # None

    max_epochs = 3000
    samples_per_epoch = 1000 * 256
    train_bs = 512
    # train_bs = 256
    val_bs = 512
    # val_bs = 256
    device = [0]  # [0]
    optimizer_params = dict(lr=1e-4, wd=1e-6)


config = TrainConfig()

Fixed all random things with randow seed 42


In [34]:
rootdir = Path("C:/Users/vlvdi/Desktop/EMG/EMG_TRAINING/Nastya/GeneralTraining/Train")
files = list(rootdir.glob("*"))
train_paths = files

In [35]:
rootdir = Path(
    "C:/Users/vlvdi/Desktop/EMG/EMG_TRAINING/Nastya/GeneralTraining/Validation"
)


files = list(rootdir.glob("*"))


val_paths = files

In [None]:
# # augmentations
# transform = A.Compose([
# A.AddGaussianNoise(min_amplitude=0.01, max_amplitude=0.1, p=0.3),
# data_utils.SpatialRotation(min_angle=1, max_angle=10, p=0.5)
# ])
transform = None

# Init train and val dataset and dataloaders
train_datasets = []
for train_folder in train_paths:
    train_dataset = data_utils.create_dataset(
        data_folder=train_folder,
        original_fps=config.original_fps,
        delay_ms=config.delay_ms,
        start_crop_ms=config.start_crop_ms,
        window_size=config.window_size,
        down_sample_target=config.down_sample_target,
        use_preproc_data=config.use_preproc_data,
        use_angles=config.use_angles,
        random_sampling=True,
        samples_per_epoch=config.samples_per_epoch,  # // len(train_paths),
        transform=transform,
    )

    if len(train_dataset) == 0:
        print("WWWWW: Problem with dataset")
        break
    train_datasets.append(train_dataset)

val_datasets = []
for val_folder in val_paths:
    val_dataset = data_utils.create_dataset(
        data_folder=val_folder,
        original_fps=config.original_fps,
        delay_ms=config.delay_ms,
        start_crop_ms=config.start_crop_ms,
        window_size=config.window_size,
        down_sample_target=config.down_sample_target,
        use_preproc_data=config.use_preproc_data,
        use_angles=config.use_angles,
        random_sampling=False,
        samples_per_epoch=None,
        transform=None,
    )

    val_datasets.append(val_dataset)

train_dataset = torch.utils.data.ConcatDataset(train_datasets)
val_dataset = torch.utils.data.ConcatDataset(val_datasets)
print(f"Size of the train dataset {len(train_dataset)}")
for kek in val_datasets:
    print(f"Size of the val dataset {len(kek)}")

print("-")
print(
    f"Size of the train dataset {len(train_dataset)} || Size of the val dataset {len(val_dataset)}"
)
print(
    f"Size of the input {train_dataset[0][0].shape} || Size of the output {train_dataset[0][1].shape}"
)

# Init model

In [26]:
class LitHVATNet_v3(pl.LightningModule):
    def __init__(self, model, lr, wd):
        """
        Wrapper of model with loss function calculatino and initing optimizer.
        """
        super().__init__()
        self.model = model
        self.lr = lr
        self.wd = wd
        self.mae_loss = nn.L1Loss()
        self.validation_step_outputs = []

    def forward(self, x):
        x = self.model(x)
        return x

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(), lr=self.lr, weight_decay=self.wd
        )
        return optimizer

    def step(self, train_batch):
        x, y = train_batch
        full_size_pred = self.model(x)
        mae_loss = self.mae_loss(full_size_pred, y)
        #         mse_loss = torch.sqrt(self.mse_loss(full_size_pred, y))

        cosine_sim = torch.mean(
            F.cosine_similarity(full_size_pred, y, dim=-1, eps=1e-8)
        )

        loss_dict = {
            "total_loss": mae_loss,
            "angle_degree": mae_loss * 180 / 3.14,
            "cosine_sim": cosine_sim,
        }
        return loss_dict

    def training_step(self, train_batch, batch_idx):

        loss_dict = self.step(train_batch)

        for k, v in loss_dict.items():
            self.log("train_" + str(k), v, on_step=True)
        print(".", end="")
        return loss_dict["total_loss"]

    def validation_step(self, val_batch, batch_idx):
        if trainer.global_step == 0:
            wandb.define_metric("val_angle_degree", summary="min")

        loss_dict = self.step(val_batch)

        for k, v in loss_dict.items():
            self.log("val_" + str(k), v, on_step=False, on_epoch=True)

        self.validation_step_outputs.append(loss_dict["angle_degree"])

        return loss_dict["angle_degree"]

    def on_validation_epoch_end(self):
        val_current_loss = torch.stack(self.validation_step_outputs).mean()
        self.val_current_loss = val_current_loss

        self.validation_step_outputs.clear()
        print(
            f"current step {self.current_epoch} val_current_loss {self.val_current_loss}"
        )

In [27]:
class MovementsWandb(Callback):
    def __init__(self, pred_fps=25):
        self.pred_fps = pred_fps
        self.best_val_loss = 10000000

    def on_validation_epoch_end(self, trainer, pl_module):

        ### do not calculate epoch in sanity check.
        if trainer.state.stage == "sanity_check":
            return

        # check better or not
        if pl_module.val_current_loss < self.best_val_loss:

            print("new best val score", pl_module.val_current_loss)

            self.best_val_loss = pl_module.val_current_loss

            # VLAD: commented this


#             hand_visualize.visualize_val_moves(model = pl_module.model,
#                                                val_exps_data = trainer.val_dataloaders.dataset.datasets[0].exps_data,
#                                                epoch = pl_module.current_epoch,
#                                                device = pl_module.device,
#                                                pred_fps = self.pred_fps)

## Init model

In [None]:
# PRETRAIN_PATH = Path('../weights/hvatnet_v3_angles_full_data.pt')


hvatnet_v3_params = dict(
    n_electrodes=8,
    n_channels_out=20,
    n_res_blocks=3,
    n_blocks_per_layer=3,
    n_filters=128,
    kernel_size=3,
    strides=(2, 2, 2),
    dilation=2,
    use_angles=config.use_angles,
)

model = HVATNet_v3_FineTune.HVATNetv3(**hvatnet_v3_params)
# model.load_state_dict(torch.load(PRETRAIN_PATH, map_location='cpu'))

# model = torch.compile()
model_pl = LitHVATNet_v3(model, **config.optimizer_params)
# model_pl = torch.compile(model_pl)

x = torch.zeros([2, 8, 256])
y = model(x)
print(summary(model.to("cpu"), x, show_input=True))
print("Input shape: ", x.shape)
print("Output shape: ", y.shape)

### Start to train model.

In [29]:
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=config.train_bs, shuffle=True, num_workers=3
)


val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=config.train_bs, shuffle=False, num_workers=3
)

In [30]:
os.environ["WANDB_NOTEBOOK_NAME"] = (
    r"D:\study\emg\machine-learning\notebooks\important_HVATNet_v3_FT.ipynb"
)

In [31]:
# wandb.init(
#     # set the wandb project where this run will be logged
#     project="myo_prost")

In [None]:
wandb_logger = WandbLogger(
    entity="vlad-aksiotis",
    project="myo_prost",
    log_model=True,
    save_code=True,
    notes=config.WANDB_NOTES,
    dir="lightning_logs",
    tags=["data_v1"],
)


checkpoint_callback = ModelCheckpoint(
    monitor="val_angle_degree",
    save_top_k=5,
    save_last=True,
    filename="{epoch:02d}_{val_angle_degree:.3f}",
    verbose=True,
    mode="min",
)


trainer = pl.Trainer(
    max_epochs=config.max_epochs,
    #  accelerator= 'cpu', #'cuda',
    accelerator="cpu",
    #  devices= 8 ,#config.device,
    #  strategy=pl.strategies.Strategy(),
    logger=wandb_logger,
    callbacks=[
        MovementsWandb(pred_fps=config.original_fps // config.down_sample_target),
        checkpoint_callback,
    ],
)


trainer.fit(model_pl, train_dataloader, val_dataloader)