In [None]:
!nvidia-smi

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import joblib
import plotly.graph_objects as go
import numpy as np
import random
import torch
import transformers
# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

In [None]:
# use to save results in .bvh file

import sys
sys.path.append('utils/PyMO/pymo/')

import sys
sys.path.append('utils/Poses/PyMO/')

In [None]:
data = np.load('data/ubisoft.rotation_matrix.300.npy', allow_pickle=True)

In [None]:
data_train = np.array([i for i in data if i['type'] == 'train'])

In [None]:
data_normalization = {}

for i in range(198, 201):
    tmp = np.concatenate([j['pose'] for j in data], 0)[:, i]
    tmp_mean = tmp.mean()
    tmp_std = tmp.std()
    tmp_min = tmp.min()
    tmp_max = tmp.max()

    if data_normalization.get('pose', None) is None:
        data_normalization['pose'] = {}
    if data_normalization['pose'].get(i, None) is None:
        data_normalization['pose'][i] = {}
    data_normalization['pose'][i]['mean'] = tmp_mean
    data_normalization['pose'][i]['std'] = tmp_std
    data_normalization['pose'][i]['min'] = tmp_min
    data_normalization['pose'][i]['max'] = tmp_max

In [None]:
def norm_train_poses(sample):
    for i in range(198, 201):
        sample[:, :, i] = (sample[:, :, i] - data_normalization['pose'][i]['min']) / (data_normalization['pose'][i]['max'] - data_normalization['pose'][i]['min'])

    return sample

def norm_train_music(sample):
    for i in range(35):
        sample[:, :, i] = (sample[:, :, i] - data_normalization['music'][i]['min']) / (data_normalization['music'][i]['max'] - data_normalization['music'][i]['min'])
    
    return sample

In [None]:
def unnorm_train_poses(sample):
    for i in range(198, 201):
            sample[:, :, i] = sample[:, :, i] * (data_normalization['pose'][i]['max'] - data_normalization['pose'][i]['min']) + data_normalization['pose'][i]['min']
            
    return sample

In [None]:
torch.cuda.ipc_collect()
torch.cuda.empty_cache()

In [None]:
device = 'cuda:0'

In [None]:
bert_config = transformers.ConvBertConfig().to_dict()
config = transformers.PretrainedConfig().to_dict()

bert_config['num_hidden_layers'] = 10
bert_config['num_attention_heads'] = 8
bert_config['output_hidden_states'] = True
bert_config['hidden_size'] = 512
bert_config['embedding_size'] = 512
bert_config['max_position_embeddings'] = 300
# bert_config['type_vocab_size'] = 24

bert_config = transformers.ConvBertConfig().from_dict(bert_config)

In [None]:
class PoseBERT(torch.nn.Module):
    def __init__(self, bert_config):
        super(PoseBERT, self).__init__()
        self.bert = transformers.ConvBertModel(bert_config)

        self.conv1 = torch.nn.Conv1d(512, 256, 3, padding=(1,))
        self.conv2 = torch.nn.Conv1d(256, 201, 1)
        self.norm = torch.nn.LayerNorm((300, 201))
#         self.linear = torch.nn.Linear(512, 201)
        self.activation = torch.nn.Tanh()
        
    def single_pass(self, xp, attention_mask):
        xp = torch.nn.functional.pad(xp, (0, 512 - xp.shape[-1]))
        
        poses_output_seed = self.bert(inputs_embeds=xp, attention_mask=attention_mask)['last_hidden_state']
        poses_output_seed = poses_output_seed.transpose(1, -1)
        decoder_output_pose = self.conv2(self.conv1(poses_output_seed))
        decoder_output_pose = self.norm(decoder_output_pose.transpose(1, -1))

        return None, decoder_output_pose, None
    
    def forward(self, input_xp, attention_mask, device='cpu', noise=None, train=True):
        if train:
            _, predict, _ = self.single_pass(input_xp, attention_mask)
                
            return predict

In [None]:
poseBert = PoseBERT(bert_config).to(device)
# discriminate = DiscriminatorSingle(context_window=5).to(device)

In [None]:
# poseBert.load_state_dict(torch.load('/ess_storage/storage/home/skhn2/Poses/models/G.40/generator.step36427.G0.066.pt'))

### loss

In [None]:
# history arrays
G_losses = []
D_losses = []

iters = 0
epochs = 1000
batch_size = 128

In [None]:
# criterion = FrechetDistance()
criterion = torch.nn.MSELoss()

lrG = 1e-4
beta1 = 0.9

# fillers
real_label = 1.
fake_label = 0.

optimizerG = torch.optim.Adam(poseBert.parameters(), lr=lrG, betas=(beta1, 0.999))
# optimizerD = torch.optim.AdamW(discriminate.parameters(), lr=lrD, betas=(beta1, 0.999))

# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizerG, max_lr=1e-5, steps_per_epoch=len(data_cls)//batch_size, epochs=500, final_div_factor=100)

In [None]:
data_train = data_train[:len(data_train) // batch_size * batch_size]

In [None]:
frames_seed = 300
frames_music = 300
frames_min_length = 120

masked_probability = 0.25
zero_probability = 0.80
replace_probability = 0.10
same_probability = 0.10

# frames_seed = 300
# frames_music = 300
# frames_min_length = 120

# masked_probability = 0.001
# zero_probability = 1.0
# replace_probability = 0.0
# same_probability = 0.0

In [None]:
def generate_masked_batch(pose_batch):
    masked_pose_batch = torch.tensor(pose_batch)
    original_pose_batch = torch.tensor(pose_batch)

    samples_lengths = np.random.randint(frames_min_length, frames_seed, (pose_batch.shape[0]))
    masks = np.ones((pose_batch.shape[0], frames_seed))
    attention_mask = np.ones((pose_batch.shape[0], frames_seed))

    for j, sample_length in enumerate(samples_lengths):
        attention_mask[j][sample_length:] *= 0
        mask = np.random.choice([1, 0], p=[1-masked_probability, masked_probability], size=(sample_length))
        masks[j][:len(mask)] *= mask
        mask_type = np.random.choice([0, 1, 2], p=[zero_probability, replace_probability, same_probability], size=(mask == 0).sum())
        
        masked_pose_batch[j][sample_length:] *= 0
        for i, index in enumerate(np.where(mask == 0)[0]):
            if mask_type[i] == 0:
                masked_pose_batch[j][index] *= 0

            elif mask_type[i] == 1:
                select_sample_to_replace = np.random.randint(pose_batch.shape[0])
                lst = np.arange(samples_lengths[select_sample_to_replace])
                lst = lst[np.where(lst != index)]
                random_replace = np.random.choice(lst)

                masked_pose_batch[j][index] = original_pose_batch[select_sample_to_replace][random_replace]
                
    return masked_pose_batch, masks[:, :, None].repeat(pose_batch.shape[-1], -1), attention_mask

In [None]:
def generate_masked_batch_patched(pose_batch, pattern=0):
    masked_pose_batch = torch.tensor(pose_batch)
    original_pose_batch = torch.tensor(pose_batch)

    samples_lengths = np.array([100 for _ in range(pose_batch.shape[0])])# np.random.randint(frames_min_length, frames_seed, (pose_batch.shape[0]))
    masks = np.ones((pose_batch.shape[0], frames_seed))
    attention_mask = np.ones((pose_batch.shape[0], frames_seed))

    for j, sample_length in enumerate(samples_lengths):
        attention_mask[j][sample_length:] *= 0
        if pattern == 0:
            mask = np.ones((sample_length))
            mask[0] = 0
        elif pattern == 1:
            tmp = [1] * 5 + [0] * 10
            mask = np.array((tmp * 30)[:sample_length])
        elif pattern == 2:
            mask = np.zeros((sample_length))
            mask[::10] = 1
        elif pattern == 3:
            mask = np.ones((sample_length))
            mask[88:98] = 0
        
#         mask = np.ones((sample_length))
        
#         mask[30:35] = 0
#         mask[0] = 0
#         mask = np.random.choice([1, 0], p=[1-masked_probability, masked_probability], size=(sample_length))
        masks[j][:len(mask)] *= mask
        mask_type = np.random.choice([0, 1, 2], p=[zero_probability, replace_probability, same_probability], size=(mask == 0).sum())
        
        masked_pose_batch[j][sample_length:] *= 0
        for i, index in enumerate(np.where(mask == 0)[0]):
            if mask_type[i] == 0:
                masked_pose_batch[j][index] *= 0

            elif mask_type[i] == 1:
                select_sample_to_replace = np.random.randint(pose_batch.shape[0])
                lst = np.arange(samples_lengths[select_sample_to_replace])
                lst = lst[np.where(lst != index)]
                random_replace = np.random.choice(lst)

                masked_pose_batch[j][index] = original_pose_batch[select_sample_to_replace][random_replace]
                
    return masked_pose_batch, masks[:, :, None].repeat(pose_batch.shape[-1], -1), attention_mask

In [None]:
cls = torch.zeros(batch_size, 1, 512).to(device)

In [None]:
for epoch in range(epochs):
    # reset network
    # seed remain the same for the epoch
    indexes = np.array(list(range(len(data_train))))
    np.random.shuffle(indexes)
    _data_train = data_train[indexes]

    data_batches = np.array_split(_data_train, len(_data_train)//batch_size)
    
    # take first batch and prepare
    for i, batch in enumerate(data_batches):
        pose_batch = norm_train_poses(torch.tensor(np.concatenate([i['pose'][None] for i in batch], 0))).to(device).float()
        
        masked_pose_batch, mask, attention_mask = generate_masked_batch(pose_batch.clone().detach())
        
        fake = poseBert(masked_pose_batch, torch.tensor(attention_mask).to(device), device=device, train=True)

        errG = criterion(fake.masked_fill(torch.tensor(mask).to(device).bool(), 0), pose_batch.masked_fill(torch.tensor(mask).to(device).bool(), 0))
        optimizerG.zero_grad()
        errG.backward()
        iters += 1
        G_losses.append(errG.item())

        optimizerG.step()
        
        if iters % 10 == 0:
            print('[%d/%d][%d/%d]\tLoss_G: %.5f'
                  % (epoch, epochs, i, len(data_batches),
                     errG.item()))

In [None]:
torch.save(poseBert.state_dict(), '/ess_storage/storage/home/skhn2/Poses/models/WORK.MaskedBERT.Ubidata.v2.generator.step{}.G{}.pt'.format(iters, round(errG.item(), 6)))

In [None]:
# poseBert.load_state_dict(torch.load('/ess_storage/storage/home/skhn2/Poses/models/WORK.MaskedBERT.Ubidata.generator.step10000.G0.000257.pt'))

In [None]:
lrG = 1e-4

optimizerG = torch.optim.Adam(poseBert.parameters(), lr=lrG, betas=(beta1, 0.999))

In [None]:
for epoch in range(epochs):
    # reset network
    # seed remain the same for the epoch
    indexes = np.array(list(range(len(data_train))))
    np.random.shuffle(indexes)
    _data_train = data_train[indexes]

    data_batches = np.array_split(_data_train, len(_data_train)//batch_size)
    
    # take first batch and prepare
    for i, batch in enumerate(data_batches):
        pose_batch = norm_train_poses(torch.tensor(np.concatenate([i['pose'][None] for i in batch], 0))).to(device).float()
        
        masked_pose_batch, mask, attention_mask = generate_masked_batch(pose_batch.clone().detach())
        
        fake = poseBert(masked_pose_batch, torch.tensor(attention_mask).to(device), device=device, train=True)

        errG = criterion(fake.masked_fill(torch.tensor(mask).to(device).bool(), 0), pose_batch.masked_fill(torch.tensor(mask).to(device).bool(), 0))
        optimizerG.zero_grad()
        errG.backward()
        iters += 1
        G_losses.append(errG.item())

        optimizerG.step()
        
        if iters % 10 == 0:
            print('[%d/%d][%d/%d]\tLoss_G: %.5f'
                  % (epoch, epochs, i, len(data_batches),
                     errG.item()))

In [None]:
lrG = 1e-5

optimizerG = torch.optim.Adam(poseBert.parameters(), lr=lrG, betas=(beta1, 0.999))

In [None]:
for epoch in range(epochs):
    # reset network
    # seed remain the same for the epoch
    indexes = np.array(list(range(len(data_train))))
    np.random.shuffle(indexes)
    _data_train = data_train[indexes]

    data_batches = np.array_split(_data_train, len(_data_train)//batch_size)
    
    # take first batch and prepare
    for i, batch in enumerate(data_batches):
        pose_batch = norm_train_poses(torch.tensor(np.concatenate([i['pose'][None] for i in batch], 0))).to(device).float()
        
        masked_pose_batch, mask, attention_mask = generate_masked_batch(pose_batch.clone().detach())
        
        fake = poseBert(masked_pose_batch, torch.tensor(attention_mask).to(device), device=device, train=True)

        errG = criterion(fake.masked_fill(torch.tensor(mask).to(device).bool(), 0), pose_batch.masked_fill(torch.tensor(mask).to(device).bool(), 0))
        optimizerG.zero_grad()
        errG.backward()
        iters += 1
        G_losses.append(errG.item())

        optimizerG.step()
        
        if iters % 10 == 0:
            print('[%d/%d][%d/%d]\tLoss_G: %.5f'
                  % (epoch, epochs, i, len(data_batches),
                     errG.item()))

In [None]:
for epoch in range(epochs):
    # reset network
    # seed remain the same for the epoch
    indexes = np.array(list(range(len(data_train))))
    np.random.shuffle(indexes)
    _data_train = data_train[indexes]

    data_batches = np.array_split(_data_train, len(_data_train)//batch_size)
    
    # take first batch and prepare
    for i, batch in enumerate(data_batches):
        pose_batch = norm_train_poses(torch.tensor(np.concatenate([i['pose'][None] for i in batch], 0))).to(device).float()
        
        masked_pose_batch, mask, attention_mask = generate_masked_batch(pose_batch.clone().detach())
        
        fake = poseBert(masked_pose_batch, torch.tensor(attention_mask).to(device), device=device, train=True)

        errG = criterion(fake.masked_fill(torch.tensor(mask).to(device).bool(), 0), pose_batch.masked_fill(torch.tensor(mask).to(device).bool(), 0))
        optimizerG.zero_grad()
        errG.backward()
        iters += 1
        G_losses.append(errG.item())

        optimizerG.step()
        
        if iters % 10 == 0:
            print('[%d/%d][%d/%d]\tLoss_G: %.5f'
                  % (epoch, epochs, i, len(data_batches),
                     errG.item()))

In [None]:
# Generate prediction and save to .bvh

In [None]:
from writers import BVHWriter
from pymo.parsers import BVHParser
from scipy.spatial.transform import Rotation as R

In [None]:
p = BVHParser()
# load some existing bvh file
placeholder = p.parse('data/placeholders/placeholder.bvh')

In [None]:
# prepare batch from test data
pose_batch = norm_train_poses(torch.tensor(np.concatenate([i['pose'][None] for i in [i for i in data if i['type'] == 'test'][:10]], 0))).to(device).float()

In [None]:
# apply some masking
_masked_pose_batch, _mask, _attention_mask = generate_masked_batch_patched(pose_batch, 2)

In [None]:
with torch.no_grad():
    fake = poseBert(_masked_pose_batch.clone().detach()[7:8], torch.tensor(_attention_mask).clone().detach()[7:8].to(device), device=device, train=True)

_fake = torch.cat([unnorm_train_poses(pose_batch[:, :100])[7], unnorm_train_poses(fake)[:, :100][0]]).cpu().numpy()

In [None]:
# separate global coordinates and motions
trace = _fake[:, -3:]
pose = _fake[:, :-3].reshape(_fake.shape[0], 22, 3, 3)

In [None]:
# convert motion from rotation matrix into euler angles
pose_euler = np.array([R.from_matrix(i).as_euler('zyx', True) for i in pose]).reshape(pose.shape[0], 66)

In [None]:
for i, frame in enumerate(pose_euler):
    for j, coord in enumerate(frame):
        placeholder.values.values[i][j+3] = coord

In [None]:
for i, frame in enumerate(trace):
    for j, coord in enumerate(frame):
        placeholder.values.values[i][j] = coord

In [None]:
p2 = BVHWriter()

In [None]:
with open('generated/test.bvh', 'w') as f:
    p2.write(placeholder, f)