# LSTM

In [1]:
import os
import logging
from datetime import datetime

import torch
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score

from utils.dataset import BabyMotionDataset
from models.lstm import LSTMClassifier
from utils.utils_func import collate_fn

In [2]:
# params
num_epoch = 10000
lr = 5e-4

use_aug = True
aug_method = 'ChatGPT-o4-instructed' 
val_method = 'lstm'

In [3]:
origin_dir = "./data_origin"
aug_dir = os.path.join("./data_aug", aug_method)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

label2idx = {
    'crawl': 0, 'walk': 1,
    'sit-floor': 2, 'sit-high-chair': 3, 'sit-low-chair': 4, 'stand': 5, 
    'hold-horizontal': 6, 'hold-vertical': 7, 'piggyback': 8, 
    'baby-food': 9, 'bottle': 10, 'breast': 11, 
    'face-down': 12, 'face-side': 13, 'face-up':14, 'roll-over': 15
}

# logger
now = datetime.now().strftime("%Y%m%d_%H%M%S")
aug_dir_name = aug_method if use_aug else "origin"
log_root = f"./logs/{val_method}/{aug_dir_name}/{now}"
checkpoints_root = os.path.join(log_root, "checkpoints")
os.makedirs(log_root, exist_ok=True)
os.makedirs(checkpoints_root, exist_ok=True)
logging.basicConfig(
    filename=os.path.join(log_root, "train.log"),
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)

In [4]:
train_dataset = BabyMotionDataset(
    origin_dir=origin_dir,
    aug_dirs=[aug_dir] if use_aug else [],
    max_len=100,
    min_len=10,
    is_train=True
)
val_dataset = BabyMotionDataset(
    origin_dir=origin_dir,
    aug_dirs=[aug_dir] if use_aug else [],
    max_len=100,
    min_len=10,
    is_train=False
)

# DataLoader
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=lambda b: collate_fn(b, label2idx)
)

val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    collate_fn=lambda b: collate_fn(b, label2idx)
)

In [5]:
# 验证函数
def evaluate(model, dataloader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for x, lengths, y in dataloader:
            x, y = x.to(device), y.to(device)
            out = model(x, lengths)
            preds = out.argmax(dim=1)
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(y.cpu().tolist())
    acc = accuracy_score(all_labels, all_preds)
    return acc

In [6]:
model = LSTMClassifier(input_dim=3, hidden_dim=64, num_classes=len(label2idx)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()

best_val_acc = 0.0
for epoch in range(num_epoch):
    model.train()
    all_preds, all_labels = [], []

    for batch in train_loader:
        x, lengths, y = batch
        x, y = x.to(device), y.to(device)

        output = model(x, lengths)
        loss = criterion(output, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = output.argmax(dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(y.cpu().numpy())

    train_acc = accuracy_score(all_labels, all_preds)
    val_acc = evaluate(model, val_loader)

    log_msg = f"[Epoch {epoch}] Loss: {loss.item():.9f} | Train Acc: {train_acc:.9f} | Val Acc: {val_acc:.9f}"
    print(log_msg)
    logging.info(log_msg)

    # save model at best val acc.
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), os.path.join(checkpoints_root, f"best_model_epoch{epoch}.pt"))
        logging.info(f"[Epoch {epoch}] Saved new best model with Val Acc: {val_acc:.9f}")

[Epoch 0] Loss: 2.736076355 | Train Acc: 0.156089194 | Val Acc: 0.179487179
[Epoch 1] Loss: 2.701134443 | Train Acc: 0.159519726 | Val Acc: 0.179487179
[Epoch 2] Loss: 2.687011242 | Train Acc: 0.178387650 | Val Acc: 0.290598291
[Epoch 3] Loss: 2.592462778 | Train Acc: 0.279588336 | Val Acc: 0.358974359
[Epoch 4] Loss: 1.994645357 | Train Acc: 0.313893654 | Val Acc: 0.401709402
[Epoch 5] Loss: 2.185788155 | Train Acc: 0.367066895 | Val Acc: 0.418803419
[Epoch 6] Loss: 2.183508396 | Train Acc: 0.391080617 | Val Acc: 0.470085470
[Epoch 7] Loss: 2.273966551 | Train Acc: 0.399656947 | Val Acc: 0.444444444
[Epoch 8] Loss: 1.844173431 | Train Acc: 0.409948542 | Val Acc: 0.427350427
[Epoch 9] Loss: 1.183115005 | Train Acc: 0.432246998 | Val Acc: 0.495726496
[Epoch 10] Loss: 1.921888590 | Train Acc: 0.468267581 | Val Acc: 0.470085470
[Epoch 11] Loss: 1.436494708 | Train Acc: 0.449399657 | Val Acc: 0.487179487
[Epoch 12] Loss: 1.779504180 | Train Acc: 0.433962264 | Val Acc: 0.504273504
[Epoch 13