In [1]:
#!/usr/bin/env python
# coding: utf-8

import sys
import random
import pickle
from pathlib import Path

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm, trange

from conv_nets.wesad_tries.utils import (
    set_seed,
    val,
    save_data,
    train_epoch,
)
from conv_nets.wesad_tries.models import get_model
from ds.wesad.datasets import subjects_data
from ds.wesad.datasets_users import SubjectDataset

In [2]:
seed = 1337

random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)

<torch._C.Generator at 0x7fe8457ed3d0>

In [20]:
subjects_was_skipped_while_training = True
data_key = "rr_intervals"
numeric_derivative = True
agg_type = "max"

epoch_count = 50
signal_len = 30
optim_lr = 1e-6
each_user_rate_history = True
ds_step_size = 5
train_batch_size = 8
test_batch_size = 1
device = "cuda"
Model = get_model(signal_len, agg_type)
Optimizer = optim.ASGD
net_with_optim_name = f"{Model.__name__}_{Optimizer.__name__}_lr_{optim_lr}"

In [11]:
base_path = Path("/home/dmo/Documents/human_func_state/human_func_state")
wesad = base_path.joinpath("models_dumps", "wesad")

In [5]:
subject_id = 2

In [6]:
def get_ds_and_dl(subject_id):
    ds_subj_train = SubjectDataset(
        subjects_data.get(subject_id),
        ds_type="train",
        window_size=signal_len,
        step=ds_step_size,
        key=data_key,
        numeric_derivative=numeric_derivative,
    )
    ds_subj_test = SubjectDataset(
        subjects_data.get(subject_id),
        ds_type="test",
        window_size=signal_len,
        step=ds_step_size,
        key=data_key,
        numeric_derivative=numeric_derivative,
    )
    return {
        "ds": {
            "train": ds_subj_train,
            "test": ds_subj_test,
        },
        "dl": {
            "train": DataLoader(
                ds_subj_train,
                batch_size=train_batch_size,
                shuffle=True,
                num_workers=1,
                pin_memory=False,
                drop_last=True,
            ),
            "test": DataLoader(
                ds_subj_test,
                batch_size=test_batch_size,
                shuffle=True,
                num_workers=1,
                pin_memory=False,
                drop_last=True,
            ),
        },
    }

In [21]:
def get_out_and_common_base_paths(
    skip_user_on_train=False,
) -> tuple[Path, Path]:
    global net_with_optim_name
    return (
        wesad.joinpath(
            *[
                "subjects_related",
                *(["derivative"] if numeric_derivative else []),
                *(["skip_users"] if skip_user_on_train else []),
                net_with_optim_name,
            ]
        ),
        wesad.joinpath(
            *[
                "steps",
                *(["derivative"] if numeric_derivative else []),
                *(["skip_users"] if skip_user_on_train else []),
                net_with_optim_name,
            ]
        ),
    )


out_base_path, common_base_net_home = get_out_and_common_base_paths(
    subjects_was_skipped_while_training
)

In [22]:
list(common_base_net_home.iterdir())

[PosixPath('/home/dmo/Documents/human_func_state/human_func_state/models_dumps/wesad/steps/derivative/skip_users/NetUpDownCoder3MP_30_ASGD_lr_1e-06/subj_05_NetUpDownCoder3MP_30_ASGD_lr_1e-06'),
 PosixPath('/home/dmo/Documents/human_func_state/human_func_state/models_dumps/wesad/steps/derivative/skip_users/NetUpDownCoder3MP_30_ASGD_lr_1e-06/subj_06_NetUpDownCoder3MP_30_ASGD_lr_1e-06'),
 PosixPath('/home/dmo/Documents/human_func_state/human_func_state/models_dumps/wesad/steps/derivative/skip_users/NetUpDownCoder3MP_30_ASGD_lr_1e-06/subj_15_NetUpDownCoder3MP_30_ASGD_lr_1e-06'),
 PosixPath('/home/dmo/Documents/human_func_state/human_func_state/models_dumps/wesad/steps/derivative/skip_users/NetUpDownCoder3MP_30_ASGD_lr_1e-06/subj_14_NetUpDownCoder3MP_30_ASGD_lr_1e-06'),
 PosixPath('/home/dmo/Documents/human_func_state/human_func_state/models_dumps/wesad/steps/derivative/skip_users/NetUpDownCoder3MP_30_ASGD_lr_1e-06/subj_08_NetUpDownCoder3MP_30_ASGD_lr_1e-06'),
 PosixPath('/home/dmo/Document

In [23]:
def get_net_state_dict(subject_id: int | None = None):
    """
    Get net state dict. If `subject_id` specified then returns network
    state with skipped subject while training

    :param subject_id: skipped subject while training
    :return: Common trained
    """
    global net_with_optim_name
    net_home = common_base_net_home / net_with_optim_name
    if subject_id is not None:
        net_home = common_base_net_home.joinpath(
            f"subj_{str(subject_id).zfill(2)}_{net_with_optim_name}"
        )
    net_state_file = next(net_home.glob("*_best.pkl"))
    with open(net_state_file, "rb") as f:
        return pickle.load(f).get("net_state_dict")

In [44]:
stored = get_net_state_dict(2)

In [45]:
new_d = net.state_dict()

In [49]:
net.seq

Sequential(
  (0): ConvX(
    (conv): Conv1d(1, 4, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
    (bn): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (1): ConvX(
    (conv): Conv1d(4, 8, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
    (bn): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): ConvX(
    (conv): Conv1d(8, 16, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
    (bn): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (4): ConvX(
    (conv): Conv1d(16, 32, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
    (bn): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (5): MaxPool1d(kern

In [48]:
new_d.keys()

odict_keys(['seq.0.conv.weight', 'seq.0.bn.weight', 'seq.0.bn.bias', 'seq.0.bn.running_mean', 'seq.0.bn.running_var', 'seq.0.bn.num_batches_tracked', 'seq.1.conv.weight', 'seq.1.bn.weight', 'seq.1.bn.bias', 'seq.1.bn.running_mean', 'seq.1.bn.running_var', 'seq.1.bn.num_batches_tracked', 'seq.3.conv.weight', 'seq.3.bn.weight', 'seq.3.bn.bias', 'seq.3.bn.running_mean', 'seq.3.bn.running_var', 'seq.3.bn.num_batches_tracked', 'seq.4.conv.weight', 'seq.4.bn.weight', 'seq.4.bn.bias', 'seq.4.bn.running_mean', 'seq.4.bn.running_var', 'seq.4.bn.num_batches_tracked', 'seq.6.conv.weight', 'seq.6.bn.weight', 'seq.6.bn.bias', 'seq.6.bn.running_mean', 'seq.6.bn.running_var', 'seq.6.bn.num_batches_tracked', 'seq.7.conv.weight', 'seq.7.bn.weight', 'seq.7.bn.bias', 'seq.7.bn.running_mean', 'seq.7.bn.running_var', 'seq.7.bn.num_batches_tracked', 'seq.9.conv.weight', 'seq.9.bn.weight', 'seq.9.bn.bias', 'seq.9.bn.running_mean', 'seq.9.bn.running_var', 'seq.9.bn.num_batches_tracked', 'seq.10.conv.weight', 

In [46]:
for s, n in zip(stored.values(), new_d.values()):
    print(s.shape, n.shape)

torch.Size([4, 2, 3]) torch.Size([4, 1, 3])
torch.Size([4]) torch.Size([4])
torch.Size([4]) torch.Size([4])
torch.Size([4]) torch.Size([4])
torch.Size([4]) torch.Size([4])
torch.Size([]) torch.Size([])
torch.Size([8, 4, 3]) torch.Size([8, 4, 3])
torch.Size([8]) torch.Size([8])
torch.Size([8]) torch.Size([8])
torch.Size([8]) torch.Size([8])
torch.Size([8]) torch.Size([8])
torch.Size([]) torch.Size([])
torch.Size([16, 8, 3]) torch.Size([16, 8, 3])
torch.Size([16]) torch.Size([16])
torch.Size([16]) torch.Size([16])
torch.Size([16]) torch.Size([16])
torch.Size([16]) torch.Size([16])
torch.Size([]) torch.Size([])
torch.Size([32, 16, 3]) torch.Size([32, 16, 3])
torch.Size([32]) torch.Size([32])
torch.Size([32]) torch.Size([32])
torch.Size([32]) torch.Size([32])
torch.Size([32]) torch.Size([32])
torch.Size([]) torch.Size([])
torch.Size([16, 32, 3]) torch.Size([16, 32, 3])
torch.Size([16]) torch.Size([16])
torch.Size([16]) torch.Size([16])
torch.Size([16]) torch.Size([16])
torch.Size([16]) tor

In [40]:
net = Model().to(device=device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.ASGD(net.parameters())

In [25]:
p = Path.home().joinpath(
    "Documents",
    "human_func_state",
    "human_func_state",
    "models_dumps",
    "wesad",
    "steps",
    "derivative",
    "NetUpDownCoder3MP_30_ASGD_lr_1e-06",
    "NetUpDownCoder3MP_30_ASGD_lr_1e-06",
    "NetUpDownCoder3MP_30_ASGD_lr_1e-06_best.pkl",
)

In [26]:
with open(p, "rb") as f:
    d = pickle.load(f)

In [34]:
d.get("net_state_dict")

OrderedDict([('seq.0.conv.weight',
              tensor([[[-1.4543e-01,  2.9138e-02, -2.3321e-02],
                       [-2.8214e-01, -3.0476e-01, -9.3200e-01]],
              
                      [[ 7.8600e-03,  6.6469e-02, -5.7823e-01],
                       [-5.4734e-01,  1.0128e-01, -2.6428e-01]],
              
                      [[-4.5640e-01,  8.6116e-02,  2.7025e-01],
                       [-4.3448e-01,  2.7172e-02,  2.3120e-01]],
              
                      [[-4.8941e-05, -9.7109e-02,  3.9938e-01],
                       [-1.6571e-01,  2.5700e-01,  3.7950e-01]]], device='cuda:0')),
             ('seq.0.bn.weight',
              tensor([0.9719, 1.0134, 1.0203, 0.9118], device='cuda:0')),
             ('seq.0.bn.bias',
              tensor([ 0.2658, -0.1680,  0.2002, -0.0864], device='cuda:0')),
             ('seq.0.bn.running_mean',
              tensor([-103.5825, -362.1432,  -94.1253,  221.6618], device='cuda:0')),
             ('seq.0.bn.running_var',
     

In [33]:
mod_name = (
    f"_{net.__class__.__name__}"
    f"_{optimizer.__class__.__name__}"
    f"_lr_{optim_lr}"
)
subj_mod_name = f"subj_{subject_id}_{mod_name}"

write_path = base_path.joinpath(
    "models_dumps", "wesad", "subjects_related", mod_name, subj_mod_name
)
writer = SummaryWriter(log_dir=write_path / "hist")
dump_name = (write_path / f"{mod_name}_last").with_suffix(".pkl")
best_name = (write_path / f"{mod_name}_best").with_suffix(".pkl")

In [34]:
def val(model: Model, dl):
    with torch.no_grad():
        sum_ = 0
        for i, data in enumerate(dl):
            inputs, labels = data
            inputs = (
                torch.unsqueeze(inputs, 1).to(torch.float32).to(device=device)
            )

            outputs = model(inputs).cpu()
            eq = outputs.max(1).indices == labels
            sum_ += eq.sum()
    return sum_

In [35]:
v = val(net, dl_test)
v / len(ds_subj_test)

tensor(0.6618)

In [38]:
v.item()

45

In [14]:
def save_data(path, net_state_dict, epoch, rate, rate_subject=None):
    with open(path, "wb") as f:
        pickle.dump(
            {
                "net_state_dict": net_state_dict,
                "current_epoch": epoch,
                "rate": rate,
                "rate_subject": rate_subject,
                "optimizer": optimizer.__class__.__name__,
                "optimizer_params": optimizer.param_groups,
                "train_batch_size": train_batch_size,
                "test_batch_size": test_batch_size,
                "device": device,
            },
            f,
        )

In [15]:
def train(
    model,
    epoch_count=50,
    print_step=50,
    start_epoch=0,
    min_loss=(torch.tensor(torch.inf), 0),
    best_rate=(torch.tensor(0.0), 0),
    worst_rate=(torch.tensor(1.0), 0),
):
    for epoch in trange(
        start_epoch, epoch_count
    ):  # loop over the dataset multiple times
        epoch_loss = 0
        running_loss = 0.0
        model.train()
        for i, data in enumerate(dl_train):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = (
                torch.unsqueeze(inputs, 1).to(torch.float32).to(device=device)
            )

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs.cpu(), labels)

            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % print_step == print_step - 1:
                mean_loss = running_loss / print_step
                if mean_loss < min_loss[0]:
                    min_loss = (mean_loss, (epoch, i))
                print(f"[{epoch:3d}, {i:4d}] loss: {mean_loss:.3f}")
                epoch_loss += running_loss
                running_loss = 0.0
        epoch_loss += running_loss
        mean_loss = epoch_loss / (len(ds_subj_train) / ds_step_size)
        if mean_loss < min_loss[0]:
            min_loss = (mean_loss, epoch)
        model.eval()
        common_acc = val(model, dl_test) / len(ds_subj_test)
        writer.add_scalar("Loss/train", epoch_loss, epoch)
        writer.add_scalar("Accuracy/train", common_acc, epoch)
        if common_acc > best_rate[0]:
            best_rate = (common_acc, epoch)
            save_data(best_name, net.state_dict(), epoch, common_acc)
        if common_acc < worst_rate[0]:
            worst_rate = (common_acc, epoch)
        save_data(dump_name, net.state_dict(), epoch, common_acc)

        print(
            f"[{epoch:3d}] rate: {common_acc:.4f}; {best_rate = }, {worst_rate = }"
        )
    print("Finished Training. Min_loss:", min_loss)
    return worst_rate, best_rate, min_loss

In [16]:
train(net, epoch_count=epoch_count)

  0%|          | 0/30 [00:00<?, ?it/s]

[  0] rate: 0.6618; best_rate = (tensor(0.6618), 0), worst_rate = (tensor(0.6618), 0)
[  1] rate: 0.6765; best_rate = (tensor(0.6765), 1), worst_rate = (tensor(0.6618), 0)
[  2] rate: 0.7059; best_rate = (tensor(0.7059), 2), worst_rate = (tensor(0.6618), 0)
[  3] rate: 0.7059; best_rate = (tensor(0.7059), 2), worst_rate = (tensor(0.6618), 0)
[  4] rate: 0.7647; best_rate = (tensor(0.7647), 4), worst_rate = (tensor(0.6618), 0)
[  5] rate: 0.7206; best_rate = (tensor(0.7647), 4), worst_rate = (tensor(0.6618), 0)
[  6] rate: 0.7647; best_rate = (tensor(0.7647), 4), worst_rate = (tensor(0.6618), 0)
[  7] rate: 0.7794; best_rate = (tensor(0.7794), 7), worst_rate = (tensor(0.6618), 0)
[  8] rate: 0.8235; best_rate = (tensor(0.8235), 8), worst_rate = (tensor(0.6618), 0)
[  9] rate: 0.8088; best_rate = (tensor(0.8235), 8), worst_rate = (tensor(0.6618), 0)
[ 10] rate: 0.7941; best_rate = (tensor(0.8235), 8), worst_rate = (tensor(0.6618), 0)
[ 11] rate: 0.8971; best_rate = (tensor(0.8971), 11), 

((tensor(0.6618), 0), (tensor(0.9412), 22), (inf, 0))

In [17]:
v = val(net, dl_test)
v / len(ds_subj_test)

tensor(0.9412)