In [7]:
import os
import logging
from datetime import datetime

import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

from utils.dataset import BabyMotionDataset
from models.diffusion import DiffusionModel, GaussianDiffusion

# Train

In [8]:
# Params

num_epochs = 2000
batch_size=64
lr=2e-4

seq_len = 100
save_interval = 10

origin_dir = "./data_origin"

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

label2idx = {
    'crawl': 0, 'walk': 1,
    'sit-floor': 2, 'sit-high-chair': 3, 'sit-low-chair': 4, 'stand': 5, 
    'hold-horizontal': 6, 'hold-vertical': 7, 'piggyback': 8, 
    'baby-food': 9, 'bottle': 10, 'breast': 11, 
    'face-down': 12, 'face-side': 13, 'face-up':14, 'roll-over': 15
}

# logger
now = datetime.now().strftime("%Y%m%d_%H%M%S")
log_root = f"./logs/aug_Diffusion/{now}"
checkpoints_dir = os.path.join(log_root, "checkpoints")
os.makedirs(log_root, exist_ok=True)
os.makedirs(checkpoints_dir, exist_ok=True)
logging.basicConfig(
    filename=os.path.join(log_root, "train.log"),
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)

In [10]:
class IndexedMotionDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, action2idx):
        self.base_dataset = base_dataset
        self.action2idx = action2idx

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        sequence, action, _ = self.base_dataset[idx]
        class_idx = self.action2idx[action]
        return sequence, class_idx
    
def my_collate_fn(batch):
    sequences, labels = zip(*batch)
    sequences = pad_sequence(sequences, batch_first=True)  # 自动 padding
    labels = torch.tensor(labels)
    return sequences, labels
    
# dataset & Dataloader
origin_dataset = BabyMotionDataset(
    origin_dir=origin_dir,
    max_len=seq_len,
    min_len=10,
    is_train=True
)
dataset = IndexedMotionDataset(origin_dataset, label2idx)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=my_collate_fn)

In [11]:
# Models
model = DiffusionModel(input_dim=3, hidden_dim=128, num_classes=len(label2idx), seq_len=100 ,num_layers=2).to(device)
diff = GaussianDiffusion(model, timesteps=100).to(device)

# optimizers
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [12]:
for epoch in range(1, num_epochs + 1):
    model.train()
    total_loss = 0

    for real_seq, class_idx in dataloader:
        real_seq = real_seq.to(device)  # shape: [B, T, 3]
        class_idx = class_idx.to(device)  # shape: [B]

        t = torch.randint(0, diff.timesteps, (real_seq.size(0),), device=device)
        loss = diff.p_losses(real_seq, t, class_idx)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)

    log_msg = f"[Epoch {epoch}/{num_epochs}] Loss: {avg_loss:.9f}"
    print(log_msg)
    logging.info(log_msg)

    if epoch % save_interval == 0:
        epoch_dir = os.path.join(checkpoints_dir, f"epoch{epoch+1}")
        os.makedirs(epoch_dir, exist_ok=True)
        torch.save(model.state_dict(), os.path.join(epoch_dir, f"generator_epoch{epoch+1}.pt"))


[Epoch 1/2000] Loss: 1.008886090
[Epoch 2/2000] Loss: 1.001794568
[Epoch 3/2000] Loss: 1.000373789
[Epoch 4/2000] Loss: 0.995101741
[Epoch 5/2000] Loss: 1.003138883
[Epoch 6/2000] Loss: 0.988924801
[Epoch 7/2000] Loss: 0.985116678
[Epoch 8/2000] Loss: 0.986889737
[Epoch 9/2000] Loss: 0.990123153
[Epoch 10/2000] Loss: 0.984946268
[Epoch 11/2000] Loss: 0.982514109
[Epoch 12/2000] Loss: 0.976451022
[Epoch 13/2000] Loss: 0.974366384
[Epoch 14/2000] Loss: 0.976397123
[Epoch 15/2000] Loss: 0.957790775
[Epoch 16/2000] Loss: 0.935042850
[Epoch 17/2000] Loss: 0.926011707
[Epoch 18/2000] Loss: 0.900855090
[Epoch 19/2000] Loss: 0.871506018
[Epoch 20/2000] Loss: 0.839336327
[Epoch 21/2000] Loss: 0.789447486
[Epoch 22/2000] Loss: 0.759034455
[Epoch 23/2000] Loss: 0.712820939
[Epoch 24/2000] Loss: 0.647992492
[Epoch 25/2000] Loss: 0.597133126
[Epoch 26/2000] Loss: 0.573228819
[Epoch 27/2000] Loss: 0.507528160
[Epoch 28/2000] Loss: 0.476589275
[Epoch 29/2000] Loss: 0.399872461
[Epoch 30/2000] Loss: 0

# Gen