In [1]:
import os
import sys
sys.path.insert(1, os.path.realpath(os.path.pardir))

import torch

from train import TrainConfig, run_train_model
from utils.augmentations import get_default_transform
from models import hvatnet
from utils import creating_dataset
import wandb

In [2]:
def count_parameters(model): 
    n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    n_total = sum(p.numel() for p in model.parameters())
    print(f"Total: {n_total/1e6:.2f}M, Trainable: {n_trainable/1e6:.2f}M")
    return n_total, n_trainable

In [3]:
train_config = TrainConfig(exp_name='test_2_run_fedya', p_augs=0.3, batch_size=64, eval_interval=150)
    
## Data preparation
transform = get_default_transform(train_config.p_augs)
data_paths = dict(datasets=[r"D:\Work\alvi_labs\code\data\processed\dataset_v1_big", 
                           r"D:\Work\alvi_labs\code\data\processed\dataset_v2_blocks"],
                    hand_type = ['left', 'right'], # [left, 'right']
                    human_type = ['health', 'amputant'], # [amputant, 'health']
                    test_dataset_list = ['fedya_tropin_standart_elbow_left'])
data_config = creating_dataset.DataConfig(**data_paths)
train_dataset, test_dataset = creating_dataset.get_datasets(data_config, transform=transform)

Getting val datasets
Number of moves: 72 | Dataset: fedya_tropin_standart_elbow_left
Reorder this dataset fedya_tropin_standart_elbow_left True
Getting train datasets
Number of moves: 72 | Dataset: fedya_tropin_standart_elbow_left
Reorder this dataset fedya_tropin_standart_elbow_left True
Number of trainining sessions: 1
Number of validation sessions: 1
Size of the input (8, 256) || Size of the output (20, 32)


In [4]:
## Init model
model_config = hvatnet.Config(n_electrodes=8, n_channels_out=20,
                            n_res_blocks=3, n_blocks_per_layer=3,
                            n_filters=128, kernel_size=3,
                            strides=(2, 2, 2), dilation=2, 
                            small_strides = (2, 2))
model = hvatnet.HVATNetv3(model_config)
count_parameters(model)

x = torch.zeros(4, 8, 256)
y = model(x)

print(y.size())

Number of parameters: 4210788
Total: 4.21M, Trainable: 4.21M
torch.Size([4, 20, 32])


In [5]:
device = 'cuda:0'
merged_config = {**train_config.__dict__, **model_config.__dict__}

wandb.init(project='hvatnet-hackathon', 
           entity='koval_alvi', 
           config=merged_config,
           name=train_config.exp_name)

run_train_model(model, (train_dataset, test_dataset), train_config, device)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkoval_alvi[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

Completed initialization of scheduler
******************************************************************************************************************************************************

overall_steps 150: 0.24392080307006836
val loss: 0.3187941908836365
saved model:  step_150_loss_0.3188.safetensors


******************************************************************************************************************************************************

overall_steps 300: 0.2432033270597458
val loss: 0.2892138659954071
saved model:  step_300_loss_0.2892.safetensors


******************************************************************************************************************************************************

overall_steps 450: 0.23250322043895721
val loss: 0.2859196364879608
saved model:  step_450_loss_0.2859.safetensors


******************************************************************************************************************************************************

over

KeyboardInterrupt: 