In [10]:
import os
import importlib
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from LaFan import LaFan1
from models import StateEncoder, OffsetEncoder, TargetEncoder, LSTM, Decoder, ShortMotionDiscriminator, LongMotionDiscriminator
from skeleton import Skeleton
from functions import gen_ztta
import config

In [20]:
importlib.reload(config)
from config import *

# Initializing models

In [35]:
# --- Generator ---
state_encoder = StateEncoder(in_dim=model["state_input_dim"]) # 95: 22 * 4 (quaternions) + 4 (contact) + 3 (root position)
state_encoder = state_encoder.to(device)

offset_encoder = OffsetEncoder(in_dim=model["offset_input_dim"]) # 91: 22 * 4 (quaternions) + 3 (root position)
offset_encoder = offset_encoder.to(device)

target_encoder = TargetEncoder(in_dim=model["target_input_dim"]) # 88: 22 * 4 (quaternions)
target_encoder = target_encoder.to(device)

lstm = LSTM(in_dim=model["lstm_dim"], hidden_dim=model["lstm_dim"] * 2)
lstm = lstm.to(device)

decoder = Decoder(in_dim=model["lstm_dim"] * 2, out_dim=model["decoder_output_dim"]) # 95
decoder = decoder.to(device)

In [36]:
# --- Discriminators ---
short_discriminator = ShortMotionDiscriminator(in_dim=model['num_joints'] * 3 * 2)
short_discriminator.to(device)

long_discriminator = LongMotionDiscriminator(in_dim=model['num_joints'] * 3 * 2)
long_discriminator.to(device)

LongMotionDiscriminator(
  (fc0): Conv1d(132, 512, kernel_size=(10,), stride=(1,))
  (fc1): Conv1d(512, 256, kernel_size=(1,), stride=(1,))
  (fc2): Conv1d(256, 1, kernel_size=(1,), stride=(1,))
)

In [31]:
# --- Skeleton ---
skeleton_mocap = Skeleton(offsets=data["offsets"], parents=data["parents"])
skeleton_mocap.to(device)
skeleton_mocap.remove_joints(data["joints_to_remove"])

# Loading data

In [23]:
lafan = LaFan1(data["path_small_flipped"], seq_len=data["seq_length"], offset=data["offset"], train=True, debug=False)
x_mean = lafan.x_mean.to(device)
x_std = lafan.x_std.to(device).view(1, 1, 22, 3)

Building the data set... ['subject1', 'subject2', 'subject3', 'subject4']
Processing file dance2_subject2.bvh
Processing file dance2_subject4.bvh
Processing file dance2_subject1.bvh
Processing file dance2_subject3.bvh
Nb of sequences : 448



In [25]:
lafan_loader = DataLoader(lafan, batch_size=train["batch_size"], shuffle=True, num_workers=data["num_workers"])

# Optimizers

In [37]:
# --- Optimizer ---
optimizer_g = optim.Adam(lr=train["lr"], params=list(state_encoder.parameters()) +\
                                            list(offset_encoder.parameters()) +\
                                            list(target_encoder.parameters()) +\
                                            list(lstm.parameters()) +\
                                            list(decoder.parameters()), \
                                            betas=(train['beta1'], train['beta2']), \
                                            weight_decay=train['weight_decay'])

In [38]:
optimizer_d = optim.Adam(lr=train['lr'], params=list(short_discriminator.parameters()) +\
                                             list(long_discriminator.parameters()), \
                                             betas = (train['beta1'], train['beta2']), \
                                             weight_decay = train['weight_decay'])

# Training

In [34]:
experiment = "small_dataset_test_00"
writer = SummaryWriter(f"logs/{experiment}")
log_i = 0

In [39]:
for epoch in range(2):
    print(f"\n--- EPOCH {epoch} ---")

    state_encoder.train()
    offset_encoder.train()
    target_encoder.train()
    lstm.train()
    decoder.train()

    # Z-time to arrival
    ztta = gen_ztta(length=data["seq_length"]).to(device)

    for i_batch, sampled_batch in enumerate(lafan_loader):
        # Loss
        loss_pos = 0
        loss_quat = 0
        loss_contact = 0
        loss_root = 0

        # State inputs
        local_q = sampled_batch["local_q"].to(device)                           # batch_sample, t, joint, quaternion
        root_v = sampled_batch["root_v"].to(device)                             # batch_sample, t+1, velocity
        contact = sampled_batch["contact"].to(device)                           # batch_sample, t, contact

        # Offset inputs
        root_p_offset = sampled_batch["root_p_offset"].to(device)               # batch_sample, root postion on last frame
        local_q_offset = sampled_batch["local_q_offset"].to(device)             # batch_sample, quaternions of all joints on last frame
        local_q_offset = local_q_offset.view(local_q_offset.size(0), -1)        # Flatten with joint x quaternions

        # Target inputs
        target = sampled_batch["target"].to(device)                             # batch_sample, quaternions of all joints on last frame
        target = target.view(target.size(0), -1)                                # Flatten with joint x quaternions

        # Root position
        root_p = sampled_batch["root_p"].to(device)                             # batch_sample, t, root_position

        # X
        X = sampled_batch["X"].to(device)                                       # batch_sample, t, joint, position

        lstm.init_hidden(local_q.size(0))
        pred_list = []
        pred_list.append(X[:, 0])                                               # First frame quaternions for all joints

        root_pred = None
        local_q_pred = None
        contact_pred = None
        root_v_pred = None

        for t in range(lafan.cur_seq_length - 1):
            if t == 0:
                root_p_t = root_p[:, t]
                local_q_t = local_q[:,t]
                local_q_t = local_q_t.view(local_q_t.size(0), -1)
                contact_t = contact[:, t]
                root_v_t = root_v[:, t]
            else:
                root_p_t = root_pred[0]
                local_q_t = local_q_pred[0]
                contact_t = contact_pred[0]
                root_v_t = root_v_pred[0]

            # State vector
            state_input = torch.cat([local_q_t, root_v_t, contact_t], -1)

            # Offset vector
            root_p_offset_t = root_p_offset - root_p_t
            local_q_offset_t = local_q_offset - local_q_t
            offset_input = torch.cat([root_p_offset_t, local_q_offset_t], -1)

            # Target vector
            target_input = target

            # Passing vectors through encoders
            h_state = state_encoder(state_input)
            h_offset = offset_encoder(offset_input)
            h_target = target_encoder(target_input)

            h_state += ztta[:, t]
            h_offset += ztta[:, t]
            h_target += ztta[:, t]
            
            # Scheduled target noise
            tta = lafan.cur_seq_length - 2 - t
            if tta < 5:
                lambda_target = 0.0
            elif tta >= 5 and tta < 30:
                lambda_target = (tta - 5) / 25.0
            else:
                lambda_target = 1.0
            h_offset += 0.5 * lambda_target * torch.FloatTensor(h_offset.size()).normal_().to(device)
            h_target += 0.5 * lambda_target * torch.FloatTensor(h_target.size()).normal_().to(device)

            # Passing encoder outputs to LSTM
            lstm_input = torch.cat([h_state, h_offset, h_target], -1).unsqueeze(0)
            h_out = lstm(lstm_input)

            # Passing LSTM output to decoder
            h_pred, contact_pred = decoder(h_out)
            local_q_v_pred = h_pred[:, :, :model["target_input_dim"]]
            local_q_pred = local_q_v_pred + local_q_t

            local_q_pred_ = local_q_pred.view(local_q_pred.size(0), local_q_pred.size(1), -1, 4)
            local_q_pred_ = local_q_pred_ / torch.norm(local_q_pred_, dim = -1, keepdim = True)

            root_v_pred = h_pred[:, :, model["target_input_dim"]:]
            root_pred = root_v_pred + root_p_t
            pos_pred = skeleton_mocap.forward_kinematics(local_q_pred_, root_pred)

            pos_next = X[:,t+1]
            local_q_next = local_q[:,t+1]
            local_q_next = local_q_next.view(local_q_next.size(0), -1)
            root_p_next = root_p[:,t+1]
            contact_next = contact[:,t+1]

            loss_pos += torch.mean(torch.abs(pos_pred[0] - pos_next) / x_std) / lafan.cur_seq_length
            loss_quat += torch.mean(torch.abs(local_q_pred[0] - local_q_next)) / lafan.cur_seq_length
            loss_root += torch.mean(torch.abs(root_pred[0] - root_p_next) / x_std[:,:,0]) / lafan.cur_seq_length
            loss_contact += torch.mean(torch.abs(contact_pred[0] - contact_next)) / lafan.cur_seq_length
            pred_list.append(pos_pred[0])
            
        # Training Discriminator
        fake_input = torch.cat([x.reshape(x.size(0), -1).unsqueeze(-1) for x in pred_list], -1)
        fake_v_input = torch.cat([fake_input[:,:,1:] - fake_input[:,:,:-1], torch.zeros_like(fake_input[:,:,0:1]).to(device)], -1)
        fake_input = torch.cat([fake_input, fake_v_input], 1)

        real_input = torch.cat([X[:, i].view(X.size(0), -1).unsqueeze(-1) for i in range(lafan.cur_seq_length)], -1)
        real_v_input = torch.cat([real_input[:,:,1:] - real_input[:,:,:-1], torch.zeros_like(real_input[:,:,0:1]).to(device)], -1)
        real_input = torch.cat([real_input, real_v_input], 1)

        optimizer_d.zero_grad()
        short_fake_logits = torch.mean(short_discriminator(fake_input.detach())[:, 0], 1)
        short_real_logits = torch.mean(short_discriminator(real_input)[:, 0], 1)
        short_d_fake_loss = torch.mean((short_fake_logits) ** 2)
        short_d_real_loss = torch.mean((short_real_logits -  1) ** 2)
        short_d_loss = (short_d_fake_loss + short_d_real_loss) / 2.0
                
        long_fake_logits = torch.mean(long_discriminator(fake_input.detach())[:,0], 1)
        long_real_logits = torch.mean(long_discriminator(real_input)[:,0], 1)
        long_d_fake_loss = torch.mean((long_fake_logits) ** 2)
        long_d_real_loss = torch.mean((long_real_logits -  1) ** 2)
        long_d_loss = (long_d_fake_loss + long_d_real_loss) / 2.0
                    
        total_d_loss = train['loss_adv_weight'] * long_d_loss + \
                       train['loss_adv_weight'] * short_d_loss
        total_d_loss.backward()
        optimizer_d.step()
            
        # Backprop
        optimizer_g.zero_grad()
        pred_pos = torch.cat([x.reshape(x.size(0), -1).unsqueeze(-1) for x in pred_list], -1)
        pred_vel = (pred_pos[:, data["foot_index"], 1:] - pred_pos[:, data["foot_index"], :-1])
        pred_vel = pred_vel.view(pred_vel.size(0), 4, 3, pred_vel.size(-1))
        loss_slide = torch.mean(torch.abs(pred_vel * contact[:,:-1].permute(0, 2, 1).unsqueeze(2)))
        loss_total = train["loss_pos_weight"] * loss_pos + \
                    train["loss_quat_weight"] * loss_quat + \
                    train["loss_root_weight"] * loss_root + \
                    train["loss_slide_weight"] * loss_slide + \
                    train["loss_contact_weight"] * loss_contact
        
        short_fake_logits = torch.mean(short_discriminator(fake_input)[:, 0], 1)
        short_g_loss = torch.mean((short_fake_logits - 1) ** 2)
        long_fake_logits = torch.mean(long_discriminator(fake_input)[:, 0], 1)
        long_g_loss = torch.mean((long_fake_logits -1) ** 2)
        total_g_loss = train['loss_adv_weight'] * long_g_loss + \
                       train['loss_adv_weight'] * short_g_loss
        loss_total += total_g_loss

        loss_total.backward()
        torch.nn.utils.clip_grad_norm_(state_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(offset_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(target_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(lstm.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), 0.5)
        optimizer_g.step()
        
        # Loggign Loss
        if i_batch % 1 == 0: print(f"EPOCH {epoch} BATCH {i_batch} LOSS TOTAL {loss_total.item()}")
        writer.add_scalar("Loss/Pos", loss_pos.item(), log_i)
        writer.add_scalar("Loss/Quat", loss_quat.item(), log_i)
        writer.add_scalar("Loss/Root", loss_root.item(), log_i)
        writer.add_scalar("Loss/Slide", loss_slide.item(), log_i)
        writer.add_scalar("Loss/Contact", loss_contact.item(), log_i)
        writer.add_scalar("Loss/Total", loss_total.item(), log_i)
        writer.add_scalar("Adversarial Loss/Short Generator", short_g_loss.item())
        writer.add_scalar("Adversarial Loss/Long Generator", long_g_loss.item())
        writer.add_scalar("Adversarial Loss/Short Discriminator Real", short_d_real_loss.item())
        writer.add_scalar("Adversarial Loss/Short Discriminator Fake", short_d_fake_loss.item())
        writer.add_scalar("Adversarial Loss/Long Discriminator Real", long_d_real_loss.item())
        writer.add_scalar("Adversarial Loss/Long Discriminator Fake", long_d_fake_loss.item())
        log_i += 1

    # Saving models
    if (epoch != 0 and epoch % 1 == 0):
        folder_name = f"./models/{experiment}/epoch_{epoch}"
        os.makedirs(folder_name, exist_ok=True)
        torch.save(state_encoder.state_dict(), f"{folder_name}/state_encoder.pkl")
        torch.save(target_encoder.state_dict(), f"{folder_name}/target_encoder.pkl")
        torch.save(offset_encoder.state_dict(), f"{folder_name}/offset_encoder.pkl")
        torch.save(lstm.state_dict(), f"{folder_name}/lstm.pkl")
        torch.save(decoder.state_dict(), f"{folder_name}/decoder.pkl")
        torch.save(optimizer_g.state_dict(), f"{folder_name}/optimizer_g.pkl")
        
        torch.save(short_discriminator.state_dict(), f"{folder_name}/short_discriminator.pkl")
        torch.save(long_discriminator.state_dict(), f"{folder_name}/long_discriminator.pkl")
        torch.save(optimizer_d.state_dict(), f"{folder_name}/optimizer_d.pkl")


--- EPOCH 0 ---
EPOCH 0 BATCH 0 LOSS TOTAL 3.8840823482018854
EPOCH 0 BATCH 1 LOSS TOTAL 3.9107837293412846
EPOCH 0 BATCH 2 LOSS TOTAL 3.7407431861661125
EPOCH 0 BATCH 3 LOSS TOTAL 3.5848414633618315

--- EPOCH 1 ---
EPOCH 1 BATCH 0 LOSS TOTAL 3.505505016690905
EPOCH 1 BATCH 1 LOSS TOTAL 3.2336309432831696
EPOCH 1 BATCH 2 LOSS TOTAL 2.9134361854355633
EPOCH 1 BATCH 3 LOSS TOTAL 2.9686267368547856
