In [1]:
import torch
import torch.nn as nn
import copy

from utils import config
from utils import utils
from utils.models import Regressor

# from utils import mmfit_data as mmfit
import utils.mhad_data as data
if config.dataset == config.Dataset.MMFIT:
    window_length = config.sensor_window_length
elif config.dataset == config.Dataset.MHAD:
    window_length = int(config.mhad_window_length * config.mhad_sampling_rate)

In [2]:
# # --- TEST data --- #
# d = mmfit_data.train_datasets.datasets[0]
# pose, acc, label, sf, ef = d.pose, d.acc, d.labels , d.start_frames[0], d.end_frames[0]

# print(f'activity window start: {sf}, end: {ef}')
# start_i = torch.where(pose[0, :, 0] == sf)[0][0]
# end_i = torch.where(pose[0, :, 0] == ef)[0][0]
# print(f'start index: {start_i}, end index: {end_i}')
# print(pose[0, start_i, 0], pose[0, end_i, 0])

# print(len(end_i) != 0)
# activity_window = ef - sf
# for i in range(0, 200000, 500):
#     sp, sa, sl = d.__getitem__(i)
#     print(sl)
#     break

In [3]:
# >>> pose2imu models <<< #
# >>> CNN <<< #
pose2imu_model = Regressor(
    in_ch=config.in_ch,
    num_joints=config.num_joints,
    window_length=window_length,
).to(config.device)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(pose2imu_model.parameters(), lr=config.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.1, patience=10
)

In [4]:
# # --- test model --- #
# with torch.no_grad():
#         for pose, acc, label in data.train_loader:
#                 """
#                 pose: (batch_size, ch, num_joints, sensor_window_length)
#                 acc: (batch_size, ch, sensor_window_length)
#                 label: (batch_size)
#                 """
#                 pred_acc = pose2imu_model(pose)
#                 print(pred_acc.shape)
#                 loss = criterion(pred_acc, acc)
#                 break


In [None]:
# >>> Training <<< #
epochs = config.epochs
train_loss_history, val_loss_history = [], []
best_val_loss = float("inf")
best_model_state = None
best_epoch = float('inf')
patience = config.patience
epochs_no_improve = 0
log = ''

for epoch in range(epochs):
    # - Train
    total_train_loss = 0

    pose2imu_model.train()
    for pose, acc, label in data.train_loader:
        """
        pose: (batch_size, ch, num_joints, sensor_window_length)
        acc: (batch_size, ch, sensor_window_length)
        label: (batch_size)
        """
        # -- Move to GPU
        pose = pose.to(config.device, non_blocking=True)
        acc = acc.to(config.device, non_blocking=True)
        labels = label.to(config.device, non_blocking=True)

        # -- Forward pass
        pred_acc = pose2imu_model(pose) # (batch_size, 3, sensor_window_length)

        # -- Calculate loss
        loss = criterion(pred_acc, acc)
        total_train_loss += loss.item()

        # -- Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    average_train_loss = total_train_loss / len(data.train_loader)
    train_loss_history.append(average_train_loss)

    # - Validation
    total_val_loss = 0

    pose2imu_model.eval()
    with torch.no_grad():
        for pose, acc, label in data.val_loader:
            # -- Move to GPU
            pose = pose.to(config.device, non_blocking=True)
            acc = acc.to(config.device, non_blocking=True)
            labels = label.to(config.device, non_blocking=True)
            # forward pass
            pred_acc = pose2imu_model(pose)

            # calculate loss
            loss = criterion(pred_acc, acc)
            total_val_loss += loss.item()

    average_val_loss = total_val_loss / len(data.val_loader)
    val_loss_history.append(average_val_loss)

    out = f"Epoch {epoch+1}/{epochs}, Train Loss: {average_train_loss:.4f}, Val Loss: {average_val_loss:.4f}" + \
            f'\n----------------------------------------------------\n'

    print(out)

    if average_val_loss < best_val_loss:
        best_val_loss = average_val_loss
        best_model_state = copy.deepcopy(pose2imu_model.state_dict())
        best_epoch = epoch
        epochs_no_improve = 0

        log = f"Seed [{utils.args.seed}]" \
            f"\nVal Best Loss: {best_val_loss:.4f}" + \
            f'\n-----------------------------------------------------------\n'
    else:
        epochs_no_improve += 1

    if epochs_no_improve == patience:
        print("Early stopping triggered.")
        break

    scheduler.step(average_val_loss)

In [None]:
# name = '20.01_0_pose2imu[1](regression).pth'
# latest_model = utils.find_latest_model(name)
# print(f"loading: {latest_model}")
# best_model_state = torch.load(latest_model, map_location=config.device)
# log = ''

loading: ../train_out/29.02/20.01_0_pose2imu[1](regression).pth


In [None]:
# >>> TEST best model <<< #
# - Load best model
pose2imu_model.load_state_dict(best_model_state)

# - Test
total_loss = 0

pose2imu_model.eval()
with torch.no_grad():
    for pose, acc, label in data.test_loader:
        """
        pose: (batch_size, ch, num_joints, sensor_window_length)
        acc: (batch_size, ch, sensor_window_length)
        label: (batch_size)
        """
        # -- Move to GPU
        pose = pose.to(config.device, non_blocking=True)
        acc = acc.to(config.device, non_blocking=True)
        labels = label.to(config.device, non_blocking=True)

        pred_acc = pose2imu_model(pose)
        loss = criterion(pred_acc, acc)
        
        total_loss += loss.item()
    
    average_test_loss = total_loss / len(data.test_loader)

    log += f"Test Loss: {average_test_loss:.4f}" + \
        f'\n----------------------------------------------------\n'

    print(log)

Test Loss: 0.4746
----------------------------------------------------



In [None]:
# >>> SAVE <<< #
prefix = config.pose2imu_model_name + "[" + str(utils.args.seed) + "]"
# - Best Weights
file_name = '0_' + prefix + "(regression)"
utils.save_model(best_model_state, file_name)  # saving the best model

# SAVE loss plot
metric = "Loss"
file_name = '1_' + prefix + "(" + metric + ")"
utils.save_plot(
    epochs=epoch,
    best_epoch=best_epoch,
    train_metric_history=train_loss_history,
    val_metric_history=val_loss_history,
    metric=metric,
    file_name=file_name,
)

# Save Test Log
log += prefix
file_name = '2_' + prefix + "(Log)"
utils.save_log(log=log, file_name=file_name)