In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm

from torch.autograd import Variable

from speech2gesture import *
from pats_master.data import Data
from funcs import pos_to_motion

MODEL_PATH = './save'
lr = 10e-4
n_epochs = 10
lambda_d = 1.
lambda_gan = 1.

In [4]:
common_kwargs = dict(path2data = '../pats/data',
                     speaker = ['lec_cosmic'],
                     modalities = ['pose/data', 'audio/log_mel_512'],
                     fs_new = [15, 15],
                     batch_size = 4,
                     window_hop = 5)

dataloader = Data(**common_kwargs)

100%|██████████| 65/65 [00:33<00:00,  1.93it/s]
100%|██████████| 9/9 [00:05<00:00,  1.58it/s]
100%|██████████| 7/7 [00:04<00:00,  1.61it/s]


In [3]:
for batch in data.train:
    break

for key in batch.keys():
    if key != 'meta':
        print('{}: {}'.format(key, batch[key].shape))

pose = batch['pose/data']
pose = pose.reshape(pose.shape[0], pose.shape[1], 2, -1)
print(pose.shape)

print('first pose: ', pose[0, 0, :, :])

NameError: name 'data' is not defined

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cuda = True if torch.cuda.is_available() else False

# Loss function
motion_regloss = torch.nn.L1Loss()
g_loss = torch.nn.MSELoss()
d_loss1 = torch.nn.MSELoss()
d_loss2 = torch.nn.MSELoss()

# Initialize generator and discriminator
generator = Speech2Gesture_G()
discriminator = Speech2Gesture_D()
# generator.to(device)
# discriminator.to(device)

if cuda:
    generator.cuda()
    discriminator.cuda()
    motion_regloss.cuda()
    g_loss.cuda()
    d_loss1.cuda()
    d_loss2.cuda()
    

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr)

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [None]:
# ----------
#  Training
# ----------

for epoch in range(n_epochs):
    for i, batch in enumerate(dataloader.train):

        # Configure input
        # audio = Variable(batch['audio/log_mel_512'].type(Tensor))
        # real_pose = Variable(batch['pose/data'].type(Tensor))
        audio = batch['audio/log_mel_512']
        audio = audio.to(device)
        audio = audio.type(torch.cuda.FloatTensor)
        real_pose = batch['pose/data']
        real_pose = real_pose.to(device)
        real_pose = real_pose.type(torch.cuda.FloatTensor)

        # Adversarial ground truths
        valid = Variable(Tensor(real_pose.size(0), 12).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(real_pose.size(0), 12).fill_(0.0), requires_grad=False)


        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Using audio as generator input
        fake_pose, _ = generator(audio)

        # Generate motions
        real_motion = pos_to_motion(real_pose)
        fake_motion = pos_to_motion(fake_pose)

        # discriminator
        fake_d, _ = discriminator(fake_pose)
        print('fake_d size: ', fake_d.size())

        # Loss measures generator's ability to fool the discriminator
        G_loss = motion_regloss(real_motion, fake_motion) + lambda_gan * g_loss(fake_d, valid)
        # G_loss = motion_regloss(real_pose, fake_pose) + lambda_gan * g_loss(discriminator(fake_pose), valid)

        G_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        real_d, _ = discriminator(real_pose)
        fake_d, _ = discriminator(fake_pose.detach())

        # Measure discriminator's ability to classify real from generated samples
        real_loss = d_loss1(real_d, valid)
        fake_loss = d_loss2(fake_d, fake)
        D_loss = real_loss + lambda_d * fake_loss

        D_loss.backward()
        optimizer_D.step()

        if i % 200 == 199:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, n_epochs, i, len(dataloader), D_loss.item(), G_loss.item())
            )

print('saving generators')
torch.save(generator.state_dict(), MODEL_PATH)
print('saving discriminators')
torch.save(discriminator.state_dict(), MODEL_PATH)