In [None]:
import os
import logging
from datetime import datetime

import torch
from torch.utils.data import DataLoader

from utils.dataset import BabyMotionDataset
from models.diffusion import DiffusionModel, GaussianDiffusion

# Train

In [None]:
# Params

num_epochs = 1000
batch_size=64
lr=2e-4

seq_len = 100
save_interval = 10

origin_dir = "./data_origin"

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

label2idx = {
    'crawl': 0, 'walk': 1,
    'sit-floor': 2, 'sit-high-chair': 3, 'sit-low-chair': 4, 'stand': 5, 
    'hold-horizontal': 6, 'hold-vertical': 7, 'piggyback': 8, 
    'baby-food': 9, 'bottle': 10, 'breast': 11, 
    'face-down': 12, 'face-side': 13, 'face-up':14, 'roll-over': 15
}

# logger
now = datetime.now().strftime("%Y%m%d_%H%M%S")
log_root = f"./logs/aug_Diffusion/{now}"
checkpoints_dir = os.path.join(log_root, "checkpoints")
os.makedirs(log_root, exist_ok=True)
os.makedirs(checkpoints_dir, exist_ok=True)
logging.basicConfig(
    filename=os.path.join(log_root, "train.log"),
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)

In [None]:
class IndexedMotionDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, action2idx):
        self.base_dataset = base_dataset
        self.action2idx = action2idx

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        sequence, action, _ = self.base_dataset[idx]
        class_idx = self.action2idx[action]
        return sequence, class_idx
    
# dataset & Dataloader
origin_dataset = BabyMotionDataset(
    origin_dir=origin_dir,
    max_len=seq_len,
    min_len=10,
    is_train=True
)
dataset = IndexedMotionDataset(origin_dataset, label2idx)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Models
model = DiffusionModel(input_dim=3, hidden_dim=128, num_classes=len(label2idx), n_layers=2).to(device)
diff = GaussianDiffusion(model, timesteps=100)

# optimizers
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
for epoch in range(1, num_epochs + 1):
    model.train()
    total_loss = 0

    for real_seq, class_idx in tqdm(dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        real_seq = real_seq.to(device)  # shape: [B, T, 3]
        class_idx = class_idx.to(device)  # shape: [B]

        loss = diff.p_losses(model, real_seq, class_idx)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)

    log_msg = f"[Epoch {epoch+1}/{num_epochs}] Loss: {avg_loss.item():.9f}"
    print(log_msg)
    logging.info(log_msg)

    if (epoch + 1) % 10 == 0:
        epoch_dir = os.path.join(checkpoints_dir, f"epoch{epoch+1}")
        os.makedirs(epoch_dir, exist_ok=True)
        torch.save(diff.state_dict(), os.path.join(epoch_dir, f"generator_epoch{epoch+1}.pt"))


# Gen