In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import skeletonDefMixamo as skd
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

from PhaseFunctionedNetwork import PhaseFunctionedNetwork

# set seeds for reproduceability
torch.manual_seed(42)
np.random.seed(42)
rng = np.random.RandomState(42)

In [2]:
# load data
# X = np.float32(np.loadtxt('./data/Input.txt'))
# Y = np.float32(np.loadtxt('./data/Output.txt'))
# P = np.float32(np.loadtxt('./data/Phases.txt'))
# print(X.shape, Y.shape, P.shape)

(124610, 342) (124610, 311) (124610,)


In [4]:
# C:\Users\Ana\Desktop\dev\pfnn-dev\Export
X = np.float32(np.loadtxt('C:/Users/Ana/Desktop/dev/pfnn-dev/Export/Input.txt'))
Y = np.float32(np.loadtxt('C:/Users/Ana/Desktop/dev/pfnn-dev/Export/Output.txt'))
print(X.shape, Y.shape)

(281, 864) (281, 819)


In [5]:
# calculate mean and std
Xmean, Xstd = X.mean(axis=0), X.std(axis=0)
Ymean, Ystd = Y.mean(axis=0), Y.std(axis=0)

j = skd.JOINT_NUM
w = ((60*2)//10)

Xstd[w*0:w* 1] = Xstd[w*0:w* 1].mean() # Trajectory Past Positions
Xstd[w*1:w* 2] = Xstd[w*1:w* 2].mean() # Trajectory Future Positions
Xstd[w*2:w* 3] = Xstd[w*2:w* 3].mean() # Trajectory Past Directions
Xstd[w*3:w* 4] = Xstd[w*3:w* 4].mean() # Trajectory Future Directions
Xstd[w*4:w*10] = Xstd[w*4:w*10].mean() # Trajectory Gait

# mask out unused joints in input
joint_weights = np.array(skd.JOINT_WEIGHTS).repeat(3)

Xstd[w*10+j*3*0:w*10+j*3*1] = Xstd[w*10+j*3*0:w*10+j*3*1].mean() / (joint_weights * 0.1) # Pos
Xstd[w*10+j*3*1:w*10+j*3*2] = Xstd[w*10+j*3*1:w*10+j*3*2].mean() / (joint_weights * 0.1) # Vel
Xstd[w*10+j*3*2:          ] = Xstd[w*10+j*3*2:          ].mean() # Terrain

Ystd[0:2] = Ystd[0:2].mean() # Translational Velocity
Ystd[2:3] = Ystd[2:3].mean() # Rotational Velocity
Ystd[3:4] = Ystd[3:4].mean() # Change in Phase
Ystd[4:8] = Ystd[4:8].mean() # Contacts

Ystd[8+w*0:8+w*1] = Ystd[8+w*0:8+w*1].mean() # Trajectory Future Positions
Ystd[8+w*1:8+w*2] = Ystd[8+w*1:8+w*2].mean() # Trajectory Future Directions

Ystd[8+w*2+j*3*0:8+w*2+j*3*1] = Ystd[8+w*2+j*3*0:8+w*2+j*3*1].mean() # Pos
Ystd[8+w*2+j*3*1:8+w*2+j*3*2] = Ystd[8+w*2+j*3*1:8+w*2+j*3*2].mean() # Vel
Ystd[8+w*2+j*3*2:8+w*2+j*3*3] = Ystd[8+w*2+j*3*2:8+w*2+j*3*3].mean() # Rot

# save mean / std / min / max

Xmean.astype(np.float32).tofile('./weights/Xmean.bin')
Ymean.astype(np.float32).tofile('./weights/Ymean.bin')
Xstd.astype(np.float32).tofile('./weights/Xstd.bin')
Ystd.astype(np.float32).tofile('./weights/Ystd.bin')

# normalize data
X = (X - Xmean) / Xstd
Y = (Y - Ymean) / Ystd

  Y = (Y - Ymean) / Ystd


In [4]:
# append phase as additional feature
input = torch.tensor(np.concatenate([X, P [..., np.newaxis]], axis=-1))
target = torch.tensor(Y)

dataset = TensorDataset(input[:100000], target[:100000])

BATCH_SIZE = 32
train_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [5]:
# ensure GPU is available
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")
device = "cuda"

# custom loss function
def loss_func(output, target, model):
    loss = torch.mean((output - target)**2) + model.cost()
    return loss

In [6]:
model = PhaseFunctionedNetwork(input_shape=input.shape[1], output_shape=311)
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)

epochs=1

for epoch in range(epochs):
    model.train()
    loss_list = []
    for i, batch in enumerate(tqdm(train_dataloader)):
        input, target = batch
        input, target = input.to(device), target.to(device)

        # forward pass
        output = model(input)
        loss = loss_func(output, target, model)
        loss_list.append(loss.item())

        # backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f'Epoch [{epoch+1}/{epochs}], Loss: {np.average(loss_list)}')


100%|██████████| 32/32 [00:01<00:00, 24.06it/s]

Epoch [1/1], Loss: 2.8743961080908775





In [7]:
# save weights
model.precompute_and_save_weights()