# LSTM

In [1]:
import os
import logging
from datetime import datetime

import torch
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score

from utils.dataset import BabyMotionDataset
from models.lstm import LSTMClassifier
from utils.utils_func import collate_fn

In [2]:
# params
num_epoch = 5000
lr = 5e-4

use_aug = True
aug_method = 'CS' 
val_method = 'lstm'

In [3]:
origin_dir = "./data_origin"
aug_dir = os.path.join("./data_aug", aug_method) if use_aug else None

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

label2idx = {
    'crawl': 0, 'walk': 1,
    'sit-floor': 2, 'sit-high-chair': 3, 'sit-low-chair': 4, 'stand': 5, 
    'hold-horizontal': 6, 'hold-vertical': 7, 'piggyback': 8, 
    'baby-food': 9, 'bottle': 10, 'breast': 11, 
    'face-down': 12, 'face-side': 13, 'face-up':14, 'roll-over': 15
}

# logger
now = datetime.now().strftime("%Y%m%d_%H%M%S")
aug_dir_name = aug_method if use_aug else "origin"
log_root = f"./logs/{val_method}/{aug_dir_name}/{now}"
checkpoints_root = os.path.join(log_root, "checkpoints")
os.makedirs(log_root, exist_ok=True)
os.makedirs(checkpoints_root, exist_ok=True)
logging.basicConfig(
    filename=os.path.join(log_root, "train.log"),
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)

In [4]:
train_dataset = BabyMotionDataset(
    origin_dir=origin_dir,
    aug_dirs=[aug_dir],
    max_len=100,
    min_len=10,
    is_train=True
)
val_dataset = BabyMotionDataset(
    origin_dir=origin_dir,
    aug_dirs=[aug_dir],
    max_len=100,
    min_len=10,
    is_train=False
)

# DataLoader
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=lambda b: collate_fn(b, label2idx)
)

val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    collate_fn=lambda b: collate_fn(b, label2idx)
)

In [5]:
# 验证函数
def evaluate(model, dataloader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for x, lengths, y in dataloader:
            x, y = x.to(device), y.to(device)
            out = model(x, lengths)
            preds = out.argmax(dim=1)
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(y.cpu().tolist())
    acc = accuracy_score(all_labels, all_preds)
    return acc

In [6]:
model = LSTMClassifier(input_dim=3, hidden_dim=64, num_classes=len(label2idx)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()

best_val_acc = 0.0
for epoch in range(num_epoch):
    model.train()
    all_preds, all_labels = [], []

    for batch in train_loader:
        x, lengths, y = batch
        x, y = x.to(device), y.to(device)

        output = model(x, lengths)
        loss = criterion(output, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = output.argmax(dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(y.cpu().numpy())

    train_acc = accuracy_score(all_labels, all_preds)
    val_acc = evaluate(model, val_loader)

    log_msg = f"[Epoch {epoch}] Loss: {loss.item():.9f} | Train Acc: {train_acc:.9f} | Val Acc: {val_acc:.9f}"
    print(log_msg)
    logging.info(log_msg)

    # save model at best val acc.
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), os.path.join(checkpoints_root, f"best_model_epoch{epoch}.pt"))
        logging.info(f"[Epoch {epoch}] Saved new best model with Val Acc: {val_acc:.9f}")

[Epoch 0] Loss: 2.641304493 | Train Acc: 0.092769441 | Val Acc: 0.179487179
[Epoch 1] Loss: 2.487344742 | Train Acc: 0.140518417 | Val Acc: 0.188034188
[Epoch 2] Loss: 2.038975954 | Train Acc: 0.231923602 | Val Acc: 0.333333333
[Epoch 3] Loss: 1.695190668 | Train Acc: 0.286493861 | Val Acc: 0.393162393
[Epoch 4] Loss: 1.744086981 | Train Acc: 0.351978172 | Val Acc: 0.410256410
[Epoch 5] Loss: 1.810969234 | Train Acc: 0.384720327 | Val Acc: 0.478632479
[Epoch 6] Loss: 1.822838902 | Train Acc: 0.416098226 | Val Acc: 0.495726496
[Epoch 7] Loss: 1.556724191 | Train Acc: 0.442019100 | Val Acc: 0.478632479
[Epoch 8] Loss: 1.568375707 | Train Acc: 0.458390177 | Val Acc: 0.487179487
[Epoch 9] Loss: 1.658686638 | Train Acc: 0.466575716 | Val Acc: 0.495726496
[Epoch 10] Loss: 1.712798834 | Train Acc: 0.458390177 | Val Acc: 0.547008547
[Epoch 11] Loss: 1.351862669 | Train Acc: 0.463847203 | Val Acc: 0.538461538
[Epoch 12] Loss: 1.415567636 | Train Acc: 0.492496589 | Val Acc: 0.538461538
[Epoch 13