In [1]:
import os
import logging
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from models.gan import Generator, Discriminator
from utils.dataset import BabyMotionDataset

# Train GAN

In [2]:
# Params
batch_size = 64
num_epochs = 1000
lr_G = 2e-4
lr_D = 5e-5

latent_dim = 96
embedding_dim = 32
seq_len = 100

critic_iters = 10
lambda_gp = 20

origin_dir = "./data_origin"

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

label2idx = {
    'crawl': 0, 'walk': 1,
    'sit-floor': 2, 'sit-high-chair': 3, 'sit-low-chair': 4, 'stand': 5, 
    'hold-horizontal': 6, 'hold-vertical': 7, 'piggyback': 8, 
    'baby-food': 9, 'bottle': 10, 'breast': 11, 
    'face-down': 12, 'face-side': 13, 'face-up':14, 'roll-over': 15
}

# logger
now = datetime.now().strftime("%Y%m%d_%H%M%S")
log_root = f"./logs/aug_GAN/{now}"
checkpoints_dir = os.path.join(log_root, "checkpoints")
os.makedirs(log_root, exist_ok=True)
os.makedirs(checkpoints_dir, exist_ok=True)
logging.basicConfig(
    filename=os.path.join(log_root, "train.log"),
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)

In [4]:
# Dataset and Dataloader

class GANWrapperDataset(Dataset):
    def __init__(self, base_dataset, label2idx, seq_len=100):
        self.dataset = base_dataset
        self.label2idx = label2idx
        self.seq_len = seq_len

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sequence, action, _ = self.dataset[idx]
        T, D = sequence.shape
        if T >= self.seq_len:
            sequence = sequence[:self.seq_len]
        else:
            padding = torch.zeros(self.seq_len - T, D)
            sequence = torch.cat([sequence, padding], dim=0)

        return sequence, self.label2idx[action]


origin_dataset = BabyMotionDataset(
    origin_dir=origin_dir,
    max_len=seq_len,
    min_len=10,
    is_train=True
)
dataset = GANWrapperDataset(origin_dataset, label2idx)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [5]:
# Models
G = Generator(latent_dim=latent_dim, embedding_dim=embedding_dim, num_classes=len(label2idx)).to(device)
D = Discriminator(embedding_dim=embedding_dim, num_classes=len(label2idx)).to(device)

# optimizers
optimizer_G = optim.Adam(G.parameters(), lr=lr_G, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=lr_D, betas=(0.5, 0.999))

# Loss
def compute_gradient_penalty(D, real_samples, fake_samples, labels):
    alpha = torch.rand(real_samples.size(0), 1, 1, device=device)
    alpha = alpha.expand_as(real_samples)
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)

    with torch.backends.cudnn.flags(enabled=False):
        d_interpolates = D(interpolates, labels)
    fake = torch.ones(d_interpolates.size(), device=device, requires_grad=False)

    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    gradients = gradients.reshape(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [6]:
for epoch in range(num_epochs):
    for i, (real_seq, class_idx) in enumerate(dataloader):
        real_seq = real_seq.to(device)
        class_idx = class_idx.to(device)
        batch_size = real_seq.size(0)

        # ================= Train Discriminator =================
        for _ in range(critic_iters):
            noise = torch.randn(batch_size, latent_dim, device=device)
            fake_seq = G(noise, class_idx)

            real_validity = D(real_seq, class_idx)
            fake_validity = D(fake_seq.detach(), class_idx)
            gradient_penalty = compute_gradient_penalty(D, real_seq, fake_seq.detach(), class_idx)

            d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty

            optimizer_D.zero_grad()
            d_loss.backward()
            optimizer_D.step()

        # ================= Train Generator =================
        noise = torch.randn(batch_size, latent_dim, device=device)
        fake_seq = G(noise, class_idx)
        g_loss = -torch.mean(D(fake_seq, class_idx))

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

    log_msg = f"[Epoch {epoch+1}/{num_epochs}] Discriminator_Loss: {d_loss.item():.9f}, Generator_Loss: {g_loss.item():.9f}"
    print(log_msg)
    logging.info(log_msg)

    # Save checkpoint
    if (epoch + 1) % 10 == 0:
        epoch_dir = os.path.join(checkpoints_dir, f"epoch{epoch+1}")
        os.makedirs(epoch_dir, exist_ok=True)
        torch.save(G.state_dict(), os.path.join(epoch_dir, f"generator_epoch{epoch+1}.pt"))
        torch.save(D.state_dict(), os.path.join(epoch_dir, f"discriminator_epoch{epoch+1}.pt"))

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


[Epoch 1/1000] Discriminator_Loss: 18.513252258, Generator_Loss: -0.078471169
[Epoch 2/1000] Discriminator_Loss: 14.820059776, Generator_Loss: -0.773604870
[Epoch 3/1000] Discriminator_Loss: 17.043272018, Generator_Loss: -2.369840860
[Epoch 4/1000] Discriminator_Loss: 7.478017807, Generator_Loss: -1.204199910
[Epoch 5/1000] Discriminator_Loss: 8.684754372, Generator_Loss: -0.711095989
[Epoch 6/1000] Discriminator_Loss: 3.765950680, Generator_Loss: -0.035151586
[Epoch 7/1000] Discriminator_Loss: 4.395219803, Generator_Loss: 0.173131868
[Epoch 8/1000] Discriminator_Loss: 4.539766312, Generator_Loss: -0.180454448
[Epoch 9/1000] Discriminator_Loss: 6.794099808, Generator_Loss: -0.647410393
[Epoch 10/1000] Discriminator_Loss: 6.636375427, Generator_Loss: -0.916483700
[Epoch 11/1000] Discriminator_Loss: 7.073828697, Generator_Loss: -0.633870363
[Epoch 12/1000] Discriminator_Loss: 6.931130409, Generator_Loss: -0.862657666
[Epoch 13/1000] Discriminator_Loss: 6.619785309, Generator_Loss: -0.785

# Generation