In [1]:
import os
import logging
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from models.gan import Generator, Discriminator
from utils.dataset import BabyMotionDataset

# Train GAN

In [None]:
# Params
batch_size = 64
num_epochs = 2000
lr_G = 2e-4
lr_D = 1e-4

latent_dim = 96
embedding_dim = 32
seq_len = 100

origin_dir = "./data_origin"

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

label2idx = {
    'crawl': 0, 'walk': 1,
    'sit-floor': 2, 'sit-high-chair': 3, 'sit-low-chair': 4, 'stand': 5, 
    'hold-horizontal': 6, 'hold-vertical': 7, 'piggyback': 8, 
    'baby-food': 9, 'bottle': 10, 'breast': 11, 
    'face-down': 12, 'face-side': 13, 'face-up':14, 'roll-over': 15
}

# logger
now = datetime.now().strftime("%Y%m%d_%H%M%S")
log_root = f"./logs/aug_GAN/{now}"
checkpoints_dir = os.path.join(log_root, "checkpoints")
os.makedirs(log_root, exist_ok=True)
os.makedirs(checkpoints_dir, exist_ok=True)
logging.basicConfig(
    filename=os.path.join(log_root, "train.log"),
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)

In [4]:
# Dataset and Dataloader

class GANWrapperDataset(Dataset):
    def __init__(self, base_dataset, label2idx, seq_len=100):
        self.dataset = base_dataset
        self.label2idx = label2idx
        self.seq_len = seq_len

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sequence, action, _ = self.dataset[idx]
        T, D = sequence.shape
        if T >= self.seq_len:
            sequence = sequence[:self.seq_len]
        else:
            padding = torch.zeros(self.seq_len - T, D)
            sequence = torch.cat([sequence, padding], dim=0)

        return sequence, self.label2idx[action]


origin_dataset = BabyMotionDataset(
    origin_dir=origin_dir,
    max_len=seq_len,
    min_len=10,
    is_train=True
)
dataset = GANWrapperDataset(origin_dataset, label2idx)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [5]:
# Models
G = Generator(latent_dim=latent_dim, embedding_dim=embedding_dim, num_classes=len(label2idx)).to(device)
D = Discriminator(embedding_dim=embedding_dim, num_classes=len(label2idx)).to(device)

In [None]:
# Loss and optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=lr_G, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=lr_D, betas=(0.5, 0.999))

In [7]:
real_label = 1.
fake_label = 0.

for epoch in range(num_epochs):
    for i, (real_seq, class_idx) in enumerate(dataloader):
        real_seq = real_seq.to(device)
        class_idx = class_idx.to(device)
        batch_size = real_seq.size(0)

        # ================= Train Discriminator =================
        optimizer_D.zero_grad()

        # Real data
        label_real = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
        output_real = D(real_seq, class_idx).view(-1)
        loss_D_real = criterion(output_real, label_real)

        # Fake data
        noise = torch.randn(batch_size, latent_dim, device=device)
        fake_seq = G(noise, class_idx)
        label_fake = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)
        output_fake = D(fake_seq.detach(), class_idx).view(-1)
        loss_D_fake = criterion(output_fake, label_fake)

        loss_D = loss_D_real + loss_D_fake
        loss_D.backward()
        optimizer_D.step()

        # ================= Train Generator =================
        optimizer_G.zero_grad()
        label_gen = torch.full((batch_size,), real_label, dtype=torch.float, device=device)  # 试图欺骗判别器
        output_gen = D(fake_seq, class_idx).view(-1)
        loss_G = criterion(output_gen, label_gen)
        loss_G.backward()
        optimizer_G.step()

    log_msg = f"[Epoch {epoch+1}] Loss_D: {loss_D.item():.9f}, Loss_G: {loss_G.item():.9f}"
    print(log_msg)
    logging.info(log_msg)

    # Save checkpoint
    if (epoch + 1) % 10 == 0:
        epoch_dir = os.path.join(checkpoints_dir, f"epoch{epoch+1}")
        os.makedirs(epoch_dir, exist_ok=True)
        torch.save(G.state_dict(), os.path.join(epoch_dir, f"generator_epoch{epoch+1}.pt"))
        torch.save(D.state_dict(), os.path.join(epoch_dir, f"discriminator_epoch{epoch+1}.pt"))

[Epoch 1] Loss_D: 1.274890423, Loss_G: 0.680369318
[Epoch 2] Loss_D: 1.250694633, Loss_G: 0.659478009
[Epoch 3] Loss_D: 1.285191536, Loss_G: 0.783086658
[Epoch 4] Loss_D: 1.278347135, Loss_G: 0.818741918
[Epoch 5] Loss_D: 1.180890560, Loss_G: 0.861262560
[Epoch 6] Loss_D: 1.036827326, Loss_G: 0.978592038
[Epoch 7] Loss_D: 0.955897927, Loss_G: 1.006511331
[Epoch 8] Loss_D: 0.854010880, Loss_G: 1.156296730
[Epoch 9] Loss_D: 0.826131344, Loss_G: 1.221798658
[Epoch 10] Loss_D: 0.965777397, Loss_G: 1.247191906
[Epoch 11] Loss_D: 0.849196494, Loss_G: 1.404745579
[Epoch 12] Loss_D: 0.872894287, Loss_G: 1.587828517
[Epoch 13] Loss_D: 0.630417585, Loss_G: 1.703842640
[Epoch 14] Loss_D: 0.749438941, Loss_G: 1.746307373
[Epoch 15] Loss_D: 0.876568198, Loss_G: 1.797334909
[Epoch 16] Loss_D: 0.607800961, Loss_G: 2.011391640
[Epoch 17] Loss_D: 0.689593792, Loss_G: 1.869740605
[Epoch 18] Loss_D: 0.620718837, Loss_G: 1.821591735
[Epoch 19] Loss_D: 0.602943003, Loss_G: 2.102064848
[Epoch 20] Loss_D: 0.

# Generation