In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="5"

import joblib
import numpy as np
import random
import torch
import transformers

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

In [None]:
data = np.load('data/rotation_matrix.300.npy', allow_pickle=True)

In [None]:
data_train = np.array([i for i in data if i['type'] == 'train'])

# Poses and music normalization

In [None]:
# global coordinates are not scaled to 0-1

In [None]:
data_normalization = {}

for i in range(216, 219):
    tmp = np.concatenate([j['pose'] for j in data], 0)[:, i]
    tmp_mean = tmp.mean()
    tmp_std = tmp.std()
    tmp_min = tmp.min()
    tmp_max = tmp.max()

    if data_normalization.get('pose', None) is None:
        data_normalization['pose'] = {}
    if data_normalization['pose'].get(i, None) is None:
        data_normalization['pose'][i] = {}
    data_normalization['pose'][i]['mean'] = tmp_mean
    data_normalization['pose'][i]['std'] = tmp_std
    data_normalization['pose'][i]['min'] = tmp_min
    data_normalization['pose'][i]['max'] = tmp_max

In [None]:

if data_normalization.get('music', None) is None:
    data_normalization['music'] = {}

for i in range(35):
    tmp = np.concatenate([i['music'] for i in data], 0)[:, i]
    data_normalization['music'][i] = {}
    data_normalization['music'][i]['mean'] = tmp.mean()
    data_normalization['music'][i]['std'] = tmp.std()
    data_normalization['music'][i]['min'] = tmp.min()
    data_normalization['music'][i]['max'] = tmp.max()

In [None]:
def norm_train_poses(sample):
    for i in range(216, 219):
        sample[:, :, i] = (sample[:, :, i] - data_normalization['pose'][i]['min']) / (data_normalization['pose'][i]['max'] - data_normalization['pose'][i]['min'])

    return sample

def norm_train_music(sample):
    for i in range(35):
        sample[:, :, i] = (sample[:, :, i] - data_normalization['music'][i]['min']) / (data_normalization['music'][i]['max'] - data_normalization['music'][i]['min'])
    
    return sample

In [None]:
def unnorm_train_poses(sample):
    for i in range(216, 219):
            sample[:, :, i] = sample[:, :, i] * (data_normalization['pose'][i]['max'] - data_normalization['pose'][i]['min']) + data_normalization['pose'][i]['min']
            
    return sample

In [None]:
torch.cuda.ipc_collect()
torch.cuda.empty_cache()

In [None]:
device = 'cuda:0'

# Configure models

In [None]:
# bert for poses processing

bert_config = transformers.ConvBertConfig().to_dict()
config = transformers.PretrainedConfig().to_dict()

bert_config['num_hidden_layers'] = 5
bert_config['num_attention_heads'] = 10
bert_config['output_hidden_states'] = True
bert_config['hidden_size'] = 800
bert_config['embedding_size'] = 800
bert_config['max_position_embeddings'] = 300
# bert_config['type_vocab_size'] = 24

bert_config = transformers.ConvBertConfig().from_dict(bert_config)

In [None]:
# bert for music processing

bert_config2 = transformers.ConvBertConfig().to_dict()
config = transformers.PretrainedConfig().to_dict()

bert_config2['num_hidden_layers'] = 5
bert_config2['num_attention_heads'] = 10
bert_config2['output_hidden_states'] = True
bert_config2['hidden_size'] = 800
bert_config2['embedding_size'] = 800
bert_config2['max_position_embeddings'] = 300
# bert_config['type_vocab_size'] = 24

bert_config2 = transformers.ConvBertConfig().from_dict(bert_config2)

In [None]:
# cross-modal bert

bert_config3 = transformers.ConvBertConfig().to_dict()
config = transformers.PretrainedConfig().to_dict()

bert_config3['num_hidden_layers'] = 5
bert_config3['num_attention_heads'] = 10
bert_config3['output_hidden_states'] = True
bert_config3['hidden_size'] = 1600
bert_config3['embedding_size'] = 1600
bert_config3['max_position_embeddings'] = 300
# bert_config['type_vocab_size'] = 24

bert_config3 = transformers.ConvBertConfig().from_dict(bert_config3)

In [None]:
class PoseBERT(torch.nn.Module):
    def __init__(self, bert_config, bert_config2, bert_config3):
        super(PoseBERT, self).__init__()
        self.bert = transformers.ConvBertModel(bert_config)
        self.bert2 = transformers.ConvBertModel(bert_config2)
        self.bert3 = transformers.ConvBertModel(bert_config3)
        
        self.linear = torch.nn.Linear(800 * 2, 219)
        self.activation = torch.nn.Tanh()
        
    def single_pass(self, xp, xm, attention_mask):
        # simply pad input data to bert hidden size
        xp = torch.nn.functional.pad(xp, (0, 800 - xp.shape[-1]))
        xm = torch.nn.functional.pad(xm, (0, 800 - xm.shape[-1]))
        
        poses_output_seed = self.bert(inputs_embeds=xp, attention_mask=attention_mask)['last_hidden_state']
        music_output_seed = self.bert2(inputs_embeds=xm, attention_mask=attention_mask)['last_hidden_state']
        
        # concatenate processed music and poses together
        predict = torch.cat([poses_output_seed, music_output_seed], -1)

        _predict = self.bert3(inputs_embeds=predict, attention_mask=attention_mask)['last_hidden_state']
        
        decoder_output_pose = self.activation(self.linear(_predict))

        return decoder_output_pose
    
    def forward(self, input_xp, input_xm, attention_mask, device='cpu'):
            predict = self.single_pass(input_xp, input_xm, attention_mask)

            return predict

In [None]:
poseBert = PoseBERT(bert_config, bert_config2, bert_config3).to(device)

### loss

In [None]:
# history arrays
G_losses = []

iters = 0
epochs = 950
batch_size = 32

In [None]:
criterion = torch.nn.MSELoss()

lrG = 1e-4
beta1 = 0.9

optimizerG = torch.optim.Adam(poseBert.parameters(), lr=lrG, betas=(beta1, 0.999))

In [None]:
frames_seed = 300
frames_music = 300
frames_min_length = 120

# probabilities according to BERT paper
masked_probability = 0.15
zero_probability = 0.80
replace_probability = 0.10
same_probability = 0.10

In [None]:
def generate_masked_batch(pose_batch):
    # clone input tensors
    masked_pose_batch = torch.tensor(pose_batch)
    original_pose_batch = torch.tensor(pose_batch)
    
    # assume all sequences should be different
    samples_lengths = np.random.randint(frames_min_length, frames_seed, (pose_batch.shape[0]))
    masks = np.ones((pose_batch.shape[0], frames_seed))
    attention_mask = np.ones((pose_batch.shape[0], frames_seed))

    # cut sequences and apply masking
    for j, sample_length in enumerate(samples_lengths):
        attention_mask[j][sample_length:] *= 0
        
        # mask or not
        mask = np.random.choice([1, 0], p=[1-masked_probability, masked_probability], size=(sample_length))
        masks[j][:len(mask)] *= mask
        
        # replace with zero, random other element or dont change
        mask_type = np.random.choice([0, 1, 2], p=[zero_probability, replace_probability, same_probability], size=(mask == 0).sum())
        
        masked_pose_batch[j][sample_length:] *= 0
        for i, index in enumerate(np.where(mask == 0)[0]):
            if mask_type[i] == 0:
                masked_pose_batch[j][index] *= 0

            elif mask_type[i] == 1:
                select_sample_to_replace = np.random.randint(pose_batch.shape[0])
                lst = np.arange(samples_lengths[select_sample_to_replace])
                lst = lst[np.where(lst != index)]
                random_replace = np.random.choice(lst)

                masked_pose_batch[j][index] = original_pose_batch[select_sample_to_replace][random_replace]

    return masked_pose_batch, masks[:, :, None].repeat(pose_batch.shape[-1], -1), attention_mask

In [None]:
# used for evaluation with fixed setting

def generate_masked_batch_patched(pose_batch, pattern=0):
    masked_pose_batch = torch.tensor(pose_batch)
    original_pose_batch = torch.tensor(pose_batch)

    samples_lengths = np.array([100 for _ in range(pose_batch.shape[0])])
    masks = np.ones((pose_batch.shape[0], frames_seed))
    attention_mask = np.ones((pose_batch.shape[0], frames_seed))

    for j, sample_length in enumerate(samples_lengths):
        attention_mask[j][sample_length:] *= 0
        
        # recover first pose of the sequence
        if pattern == 0:
            mask = np.ones((sample_length))
            mask[0] = 0
        # 5 over 5
        elif pattern == 1:
            tmp = [1] * 5 + [0] * 5
            mask = np.array((tmp * 30)[:sample_length])
        # recover using only 1 per 5
        elif pattern == 2:
            mask = np.zeros((sample_length))
            mask[::5] = 1
        # recover set of blank poses
        elif pattern == 3:
            mask = np.ones((sample_length))
            mask[88:98] = 0

        masks[j][:len(mask)] *= mask
        mask_type = np.random.choice([0, 1, 2], p=[zero_probability, replace_probability, same_probability], size=(mask == 0).sum())
        
        masked_pose_batch[j][sample_length:] *= 0
        for i, index in enumerate(np.where(mask == 0)[0]):
            if mask_type[i] == 0:
                masked_pose_batch[j][index] *= 0

            elif mask_type[i] == 1:
                select_sample_to_replace = np.random.randint(pose_batch.shape[0])
                lst = np.arange(samples_lengths[select_sample_to_replace])
                lst = lst[np.where(lst != index)]
                random_replace = np.random.choice(lst)

                masked_pose_batch[j][index] = original_pose_batch[select_sample_to_replace][random_replace]
                
    return masked_pose_batch, masks[:, :, None].repeat(pose_batch.shape[-1], -1), attention_mask

# Train

In [None]:
epoch = 0

for epoch in range(epochs):
    # shuffle training data
    indexes = np.array(list(range(len(data_train))))
    np.random.shuffle(indexes)
    _data_train = data_train[indexes]

    data_batches = np.array_split(_data_train[:len(data_train) // batch_size * batch_size], 
                                  len(_data_train[:len(data_train) // batch_size * batch_size])//batch_size)

    for i, batch in enumerate(data_batches):
        pose_batch = norm_train_poses(torch.tensor(np.concatenate([i['pose'][None] for i in batch], 0))).to(device).float()
        music_batch = norm_train_music(torch.tensor(np.concatenate([i['music'][None] for i in batch], 0))).to(device).float()
        
        masked_pose_batch, mask, attention_mask = generate_masked_batch(pose_batch.clone().detach())
        
        fake = poseBert(masked_pose_batch, music_batch, torch.tensor(attention_mask).to(device), device=device, train=True)
        
        # apply mask to real and generated sequences in order to assess loss only for masked frames
        errG = criterion(fake.masked_fill(torch.tensor(mask).to(device).bool(), 0), pose_batch.masked_fill(torch.tensor(mask).to(device).bool(), 0))
        optimizerG.zero_grad()
        errG.backward()
        iters += 1
        G_losses.append(errG.item())

        optimizerG.step()
        
        if iters % 10 == 0:
            print('[%d/%d][%d/%d]\tLoss_G: %.5f'
                  % (epoch, epochs, i, len(data_batches),
                     errG.item()))

In [None]:
iters

# Test

In [None]:
pose_batch = norm_train_poses(torch.tensor(np.concatenate([i['pose'][None] for i in [i for i in data if i['type'] == 'test'][:10]], 0))).to(device).float()
music_batch = norm_train_music(torch.tensor(np.concatenate([i['music'][None] for i in [i for i in data if i['type'] == 'test'][:10]], 0))).to(device).float()

In [None]:
_masked_pose_batch, _mask, _attention_mask = generate_masked_batch_patched(pose_batch, 2)

In [None]:
# sequence 7 is a good demostration of complex motion
with torch.no_grad():
    fake = poseBert(_masked_pose_batch.clone().detach()[7:8], music_batch.clone().detach()[7:8], torch.tensor(_attention_mask).clone().detach()[7:8].to(device), device=device, train=True)

# [:, :100] - select only first 100 frames
# [7] - select sequence 7
# [:, :-3] - remove global coordinates since our current visualization method does not support them
_fake = torch.cat([unnorm_train_poses(pose_batch[:, :100])[7][:, :-3], unnorm_train_poses(fake)[:, :100][0][:, :-3]])