# Transformer Training on Pose Data

In [1]:
!nvidia-smi -L

GPU 0: NVIDIA GeForce RTX 4090 (UUID: GPU-59ba7a4d-461d-6c44-7eea-a4200c322183)


In [17]:
import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.distributions as D
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler


## Loading the Data

In [3]:
def find_line(lines, prefix):
  for line in lines:
    if line.startswith('Frames:'):
      return line

In [4]:
def read_bvh_file(file_path):
    # Read file contents
    with open(file_path, 'r') as f:
        file_contents = f.read()
        
    # Split file contents by newline characters
    lines = file_contents.split('\n')

    # Find the channel names
    channel_names = []
    joint_name = None
    for line in lines:
      line = line.strip()
      if line.startswith('JOINT') or line.startswith('ROOT'):
        # Joint line looks like this:
        # JOINT Spine2
        joint_data = line.split(' ')
        joint_name = joint_data[1]
      if line.startswith('CHANNELS'):
        # Channels line looks like this:
        # CHANNELS 3 Yrotation Xrotation Zrotation
        channel_data = line.split(' ')
        for channel in channel_data[2:]:
          channel_names.append(f'{joint_name}_{channel}')
    
    # Find the number of frames and the start of the motion data
    num_frames_line = find_line(lines, 'Frames:')
    num_frames = int(num_frames_line.split(' ')[1])
    motion_data_index = lines.index('MOTION') + 3
    header = '\n'.join(lines[:motion_data_index])
    
    # Find the number of channels in the file
    first_frame_data = lines[motion_data_index].strip().split(' ')
    num_channels = len(first_frame_data)
    print('Channels:', num_channels)

    # Extract the motion data as a string
    motion_data_str = ''.join(lines[motion_data_index:motion_data_index+num_frames])
    
    # Convert the motion data to a numpy array
    motion_data = np.fromstring(motion_data_str, sep=' ')
    motion_data = motion_data.reshape((num_frames, -1))
    
    # Convert the numpy array to a PyTorch tensor
    motion_tensor = torch.tensor(motion_data, dtype=torch.float32)
    
    return motion_tensor, header, channel_names

In [5]:
raw_data, header, channel_names = read_bvh_file('train_data/flute2.bvh')
print(raw_data.shape)
#print(channel_names)

Channels: 183
torch.Size([19298, 183])


## Building a PyTorch Dataset

In [6]:
class BVHDataset(Dataset):
    def __init__(self, file_path, input_size, output_size, seq_length, future_delta):
        self.file_path = file_path
        self.input_size = input_size
        self.output_size = output_size
        self.seq_length = seq_length
        self.future_delta = future_delta

        # Read BVH file
        self.motion_tensor, self.header, self.channel_names = read_bvh_file(file_path)

        # Compute the total number of sequences in the file
        self.total_sequences = len(self.motion_tensor) - self.seq_length - self.future_delta

        # Compute input_mean and input_std
        self.input_mean = torch.mean(self.motion_tensor, dim=(0,))
        self.input_std = torch.std(self.motion_tensor, dim=(0,))
        self.input_std = torch.where(self.input_std == 0, torch.tensor(1e-7), self.input_std)

    def __len__(self):
        return self.total_sequences

    def __getitem__(self, idx):
        # Compute the sequence index for the given index
        seq_idx = idx + self.seq_length

        # Get the sequence of length seq_length as input x
        input_tensor = self.motion_tensor[seq_idx - self.seq_length:seq_idx, :self.input_size]

        # Get the frame future_delta frames into the future as output y
        output_tensor = self.motion_tensor[seq_idx + self.future_delta, :self.output_size]

        # Normalize input and output tensors
        input_tensor = (input_tensor - self.input_mean) / self.input_std
        output_tensor = (output_tensor - self.input_mean[:self.output_size]) / self.input_std[:self.output_size]

        return input_tensor, output_tensor


In [15]:
# Create data loaders

num_channels = raw_data.shape[-1]
input_size = num_channels
output_size = num_channels
seq_length = 100
future_delta = 200
batch_size = 32


dataset = BVHDataset('train_data/flute2.bvh', input_size, output_size, seq_length, future_delta)

# Split the dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

Channels: 183


## Defining the Model

In [29]:
import torch
import torch.nn as nn

class PoseTransformer(nn.Module):
    def __init__(self, input_dim, model_dim, num_heads, num_layers, output_dim, sequence_length):
        super(PoseTransformer, self).__init__()

        self.model_dim = model_dim
        self.sequence_length = sequence_length

        self.embedding = nn.Linear(input_dim, model_dim)
        self.position_encoding = nn.Parameter(torch.randn(sequence_length, 1, model_dim))

        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=model_dim,
            nhead=num_heads,
            dim_feedforward=model_dim * 4,
            dropout=0.1
        )
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)

        self.fc_out = nn.Linear(model_dim, output_dim)

    def forward(self, x):
        # x shape: (batch_size, sequence_length, input_dim)
        x = self.embedding(x)
        # x shape: (sequence_length, batch_size, model_dim)
        x = x.permute(1, 0, 2)
        x = x + self.position_encoding
        x = self.transformer_encoder(x)
        # x shape: (batch_size, sequence_length, model_dim)
        x = x.permute(1, 0, 2)

        # Use only the last frame for prediction
        x = x[:, -1, :]
        x = self.fc_out(x)

        return x

# Hyperparameters
input_dim = 183
output_dim = 183
model_dim = 512
num_heads = 8
num_layers = 6
sequence_length = 100

# Create the model
model = PoseTransformer(input_dim, model_dim, num_heads, num_layers, output_dim, sequence_length)


In [30]:
# Loss function
criterion = nn.MSELoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Learning rate scheduler (optional)
scheduler = StepLR(optimizer, step_size=10, gamma=0.9)

# Training parameters
num_epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Training on {device}')

# Move the model to the device
model.to(device)

# Main training loop
for epoch in range(num_epochs):
    # Training
    model.train()
    train_loss = 0
    for batch_idx, (input_seq, target_seq) in enumerate(tqdm(train_loader)):
        input_seq, target_seq = input_seq.to(device), target_seq.to(device)

        optimizer.zero_grad()

        output_seq = model(input_seq)
        loss = criterion(output_seq, target_seq)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.6f}")

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch_idx, (input_seq, target_seq) in enumerate(val_loader):
            input_seq, target_seq = input_seq.to(device), target_seq.to(device)

            output_seq = model(input_seq)
            loss = criterion(output_seq, target_seq)

            val_loss += loss.item()

        val_loss /= len(val_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {val_loss:.6f}")

    # Update learning rate
    scheduler.step()


Training on cuda


100%|██████████| 475/475 [00:08<00:00, 54.17it/s]


Epoch 1/50, Train Loss: 0.765806
Epoch 1/50, Validation Loss: 0.775133


100%|██████████| 475/475 [00:08<00:00, 53.71it/s]


Epoch 2/50, Train Loss: 0.775260
Epoch 2/50, Validation Loss: 0.773287


100%|██████████| 475/475 [00:08<00:00, 53.44it/s]


Epoch 3/50, Train Loss: 0.772943
Epoch 3/50, Validation Loss: 0.771553


100%|██████████| 475/475 [00:08<00:00, 54.04it/s]


Epoch 4/50, Train Loss: 0.772199
Epoch 4/50, Validation Loss: 0.771681


100%|██████████| 475/475 [00:08<00:00, 54.49it/s]


Epoch 5/50, Train Loss: 0.771967
Epoch 5/50, Validation Loss: 0.770930


100%|██████████| 475/475 [00:08<00:00, 53.78it/s]


Epoch 6/50, Train Loss: 0.771752
Epoch 6/50, Validation Loss: 0.771028


100%|██████████| 475/475 [00:08<00:00, 53.78it/s]


Epoch 7/50, Train Loss: 0.771739
Epoch 7/50, Validation Loss: 0.770969


100%|██████████| 475/475 [00:08<00:00, 54.22it/s]


Epoch 8/50, Train Loss: 0.771655
Epoch 8/50, Validation Loss: 0.770765


100%|██████████| 475/475 [00:08<00:00, 54.18it/s]


Epoch 9/50, Train Loss: 0.771561
Epoch 9/50, Validation Loss: 0.770582


100%|██████████| 475/475 [00:08<00:00, 54.32it/s]


Epoch 10/50, Train Loss: 0.771546
Epoch 10/50, Validation Loss: 0.770415


100%|██████████| 475/475 [00:08<00:00, 54.36it/s]


Epoch 11/50, Train Loss: 0.771775
Epoch 11/50, Validation Loss: 0.770447


100%|██████████| 475/475 [00:08<00:00, 53.76it/s]


Epoch 12/50, Train Loss: 0.771458
Epoch 12/50, Validation Loss: 0.770595


100%|██████████| 475/475 [00:08<00:00, 53.67it/s]


Epoch 13/50, Train Loss: 0.771534
Epoch 13/50, Validation Loss: 0.770394


100%|██████████| 475/475 [00:08<00:00, 54.20it/s]


Epoch 14/50, Train Loss: 0.771367
Epoch 14/50, Validation Loss: 0.770696


100%|██████████| 475/475 [00:08<00:00, 53.75it/s]


Epoch 15/50, Train Loss: 0.771528
Epoch 15/50, Validation Loss: 0.770463


100%|██████████| 475/475 [00:08<00:00, 54.34it/s]


Epoch 16/50, Train Loss: 0.771353
Epoch 16/50, Validation Loss: 0.770816


100%|██████████| 475/475 [00:08<00:00, 54.09it/s]


Epoch 17/50, Train Loss: 0.771398
Epoch 17/50, Validation Loss: 0.770471


100%|██████████| 475/475 [00:08<00:00, 54.06it/s]


Epoch 18/50, Train Loss: 0.771301
Epoch 18/50, Validation Loss: 0.770826


100%|██████████| 475/475 [00:09<00:00, 52.70it/s]


Epoch 19/50, Train Loss: 0.771382
Epoch 19/50, Validation Loss: 0.770512


100%|██████████| 475/475 [00:08<00:00, 53.32it/s]


Epoch 20/50, Train Loss: 0.771239
Epoch 20/50, Validation Loss: 0.770423


100%|██████████| 475/475 [00:08<00:00, 54.17it/s]


Epoch 21/50, Train Loss: 0.771299
Epoch 21/50, Validation Loss: 0.770426


100%|██████████| 475/475 [00:08<00:00, 54.01it/s]


Epoch 22/50, Train Loss: 0.771380
Epoch 22/50, Validation Loss: 0.770493


100%|██████████| 475/475 [00:08<00:00, 53.14it/s]


Epoch 23/50, Train Loss: 0.771241
Epoch 23/50, Validation Loss: 0.770660


100%|██████████| 475/475 [00:08<00:00, 53.76it/s]


Epoch 24/50, Train Loss: 0.771133
Epoch 24/50, Validation Loss: 0.771266


100%|██████████| 475/475 [00:08<00:00, 54.33it/s]


Epoch 25/50, Train Loss: 0.771144
Epoch 25/50, Validation Loss: 0.770687


100%|██████████| 475/475 [00:08<00:00, 55.40it/s]


Epoch 26/50, Train Loss: 0.771245
Epoch 26/50, Validation Loss: 0.770379


100%|██████████| 475/475 [00:08<00:00, 54.26it/s]


Epoch 27/50, Train Loss: 0.771190
Epoch 27/50, Validation Loss: 0.771007


100%|██████████| 475/475 [00:08<00:00, 53.58it/s]


Epoch 28/50, Train Loss: 0.771243
Epoch 28/50, Validation Loss: 0.770407


100%|██████████| 475/475 [00:08<00:00, 55.19it/s]


Epoch 29/50, Train Loss: 0.771167
Epoch 29/50, Validation Loss: 0.770765


100%|██████████| 475/475 [00:08<00:00, 55.69it/s]


Epoch 30/50, Train Loss: 0.771187
Epoch 30/50, Validation Loss: 0.770550


100%|██████████| 475/475 [00:08<00:00, 55.14it/s]


Epoch 31/50, Train Loss: 0.771150
Epoch 31/50, Validation Loss: 0.770327


100%|██████████| 475/475 [00:08<00:00, 54.65it/s]


Epoch 32/50, Train Loss: 0.771098
Epoch 32/50, Validation Loss: 0.770446


100%|██████████| 475/475 [00:08<00:00, 55.24it/s]


Epoch 33/50, Train Loss: 0.771105
Epoch 33/50, Validation Loss: 0.770378


100%|██████████| 475/475 [00:08<00:00, 55.33it/s]


Epoch 34/50, Train Loss: 0.771209
Epoch 34/50, Validation Loss: 0.770746


100%|██████████| 475/475 [00:08<00:00, 55.40it/s]


Epoch 35/50, Train Loss: 0.771161
Epoch 35/50, Validation Loss: 0.770714


100%|██████████| 475/475 [00:08<00:00, 55.33it/s]


Epoch 36/50, Train Loss: 0.771111
Epoch 36/50, Validation Loss: 0.770425


100%|██████████| 475/475 [00:08<00:00, 55.25it/s]


Epoch 37/50, Train Loss: 0.771076
Epoch 37/50, Validation Loss: 0.770495


100%|██████████| 475/475 [00:08<00:00, 55.24it/s]


Epoch 38/50, Train Loss: 0.771122
Epoch 38/50, Validation Loss: 0.770496


100%|██████████| 475/475 [00:08<00:00, 55.30it/s]


Epoch 39/50, Train Loss: 0.771139
Epoch 39/50, Validation Loss: 0.770400


100%|██████████| 475/475 [00:08<00:00, 55.15it/s]


Epoch 40/50, Train Loss: 0.771155
Epoch 40/50, Validation Loss: 0.770451


100%|██████████| 475/475 [00:08<00:00, 55.26it/s]


Epoch 41/50, Train Loss: 0.771119
Epoch 41/50, Validation Loss: 0.770496


100%|██████████| 475/475 [00:08<00:00, 55.31it/s]


Epoch 42/50, Train Loss: 0.771122
Epoch 42/50, Validation Loss: 0.770558


100%|██████████| 475/475 [00:08<00:00, 55.29it/s]


Epoch 43/50, Train Loss: 0.771108
Epoch 43/50, Validation Loss: 0.770438


100%|██████████| 475/475 [00:08<00:00, 55.17it/s]


Epoch 44/50, Train Loss: 0.771097
Epoch 44/50, Validation Loss: 0.770557


100%|██████████| 475/475 [00:08<00:00, 55.03it/s]


Epoch 45/50, Train Loss: 0.771117
Epoch 45/50, Validation Loss: 0.770545


100%|██████████| 475/475 [00:08<00:00, 55.20it/s]


Epoch 46/50, Train Loss: 0.771047
Epoch 46/50, Validation Loss: 0.770353


100%|██████████| 475/475 [00:08<00:00, 55.18it/s]


Epoch 47/50, Train Loss: 0.771355
Epoch 47/50, Validation Loss: 0.770442


100%|██████████| 475/475 [00:08<00:00, 54.96it/s]


Epoch 48/50, Train Loss: 0.771160
Epoch 48/50, Validation Loss: 0.770539


100%|██████████| 475/475 [00:08<00:00, 55.06it/s]


Epoch 49/50, Train Loss: 0.771128
Epoch 49/50, Validation Loss: 0.770350


100%|██████████| 475/475 [00:08<00:00, 55.27it/s]


Epoch 50/50, Train Loss: 0.771068
Epoch 50/50, Validation Loss: 0.770419


## Inference

In [32]:
def generate_sequence(model, seed_sequence, num_frames_to_generate):
    model.eval()
    generated_sequence = seed_sequence.clone()

    with torch.no_grad():
        for _ in range(num_frames_to_generate):
            # Get the last sequence_length frames from the generated_sequence
            input_seq = generated_sequence[:, -sequence_length:, :]

            # Predict the next frame
            next_frame = model(input_seq)

            # Reshape the predicted frame to have the same dimensions as input_seq
            next_frame = next_frame.view(1, 1, -1)

            # Append the predicted frame to the generated_sequence
            generated_sequence = torch.cat((generated_sequence, next_frame), dim=1)

    return generated_sequence


In [58]:
#seed_sequence = torch.randn(1, sequence_length, input_dim).to(device)  # Random seed sequence

# Use the first sequence_length frames from the dataset as the seed sequence
seed_sequence = raw_data[:sequence_length]
seed_sequence = (seed_sequence - dataset.input_mean) / dataset.input_std 
seed_sequence = seed_sequence.unsqueeze(0).to(device)  
#print(raw_data.shape)
num_frames_to_generate = 1000

generated_sequence = generate_sequence(model, seed_sequence, num_frames_to_generate)
generated_sequence = generated_sequence.squeeze(0)
generated_sequence = generated_sequence.cpu() * dataset.input_std + dataset.input_mean
generated_sequence = generated_sequence.numpy()

In [59]:
generated_sequence.shape

(1100, 183)

In [60]:
def write_bvh_file(file_name, header, predicted_motion):
    # Open the file for writing
    with open(file_name, 'w') as f:
        # Write the header
        f.write(header)
        f.write('\n')
        # Write the motion data
        for frame in predicted_motion:
            frame_str = ' '.join(str(x) for x in frame)
            f.write(frame_str + '\n')

In [61]:
write_bvh_file('out_06.bvh', dataset.header, generated_sequence)
