## Imports

In [None]:
import wandb
import roma
import torch

## Generate Dataset
Note: you don't (and probably can't) run this; a generated dataset is under `./dataset`.

We convert all euler angles to quaternions which are more stable for training.

In [None]:
import json
import os
import numpy as np

mmd_path = '/mnt/c/research/avdc/mmd_processed2'
sliding_window_size = 10
val_ratio = 0.05

train_data = []
val_data = []

will_generate = False

def process(folder_path):
    # Iterate through all json files in the folder
    for file in os.listdir(folder_path):
        if file.endswith('.json'):
            with open(folder_path + '/' + file, 'r') as f:
                shots = json.load(f)
    
                duration = 0
                for shot in shots:
                    duration += shot['duration']
            
                array = np.zeros((len(shots), 1 + 3 * 4 + 4 * 4))
                for i, shot in enumerate(shots):
                    array[i, 0] = shot['duration']
                    array[i, 1:4] = shot['character_pos']['x'], shot['character_pos']['y'], shot['character_pos']['z']
                    array[i, 4:7] = shot['character_pos_offset']['x'], shot['character_pos_offset']['y'], shot['character_pos_offset']['z']
                    array[i, 7] = torch.deg2rad(torch.tensor(shot['character_rot_y']))
                    array[i, 8] = torch.deg2rad(torch.tensor(shot['character_rot_y_offset']))
                    
                    array[i, 9:12] = shot['camera_pos']['x'], shot['camera_pos']['y'], shot['camera_pos']['z']
                    array[i, 12:15] = shot['camera_pos_offset']['x'], shot['camera_pos_offset']['y'], shot['camera_pos_offset']['z']
                    array[i, 15:18] = torch.deg2rad(torch.tensor(shot['camera_rot']['x'])), torch.deg2rad(torch.tensor(shot['camera_rot']['y'])), torch.deg2rad(torch.tensor(shot['camera_rot']['z']))
                    array[i, 18:21] = torch.deg2rad(torch.tensor(shot['camera_rot_offset']['x'])), torch.deg2rad(torch.tensor(shot['camera_rot_offset']['y'])), torch.deg2rad(torch.tensor(shot['camera_rot_offset']['z']))
                    array[i, 21] = shot['camera_distance']
                    array[i, 22] = shot['camera_distance_offset']
                    array[i, 23] = shot['camera_fov']
                    array[i, 24] = shot['camera_fov_offset']
                    
                for i in range(len(array) - sliding_window_size + 1):
                    data = train_data if np.random.rand() > val_ratio else val_data
                    data.append(array[i:i+sliding_window_size])
    
    return True

if will_generate:
    processed = 0
    for folder in os.listdir(mmd_path):
        processed += 1 if process(mmd_path + '/' + folder) else 0

    print("Generated ", len(train_data), "training samples and ", len(val_data), "validation samples from ", processed, "folders.")

    np_data = np.array(train_data)
    np.save('./dataset/train.npy', np_data)
    np_data = np.array(val_data)
    np.save('./dataset/val.npy', np_data)
else:
    print("Skipped generation.")

## Dataset Definition

The data format is as follows:
- Inputs: shot duration, character position, character rotation Y
- Outputs: camera look at position, camera look at position delta, camera local rotation, camera local rotation delta, camera distance from look at, camera distance from look at delta

The output parameters can then be used to generate the camera trajectory. Please refer to the paper for more details.

Note camera FoV is included but not used in this paper.

In [None]:
import torch.nn as nn
import torch.optim as optim

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

class AVDCDataset(Dataset):
    def __init__(self, dataset_path):
        self.shots = []
        self.load_dataset(dataset_path)

    def load_dataset(self, dataset_path):
        self.shots = np.load(dataset_path)
        self.shots = torch.from_numpy(self.shots).float()

    def __len__(self):
        return self.shots.shape[0]

    def __getitem__(self, idx):
        shots = self.shots[idx]
        inputs = shots[:, 0:9]
        labels = shots[:, 9:25]
        return inputs, labels

# Model, loss, and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

# Data loading
train_dataset = AVDCDataset("./dataset/train.npy")
val_dataset = AVDCDataset("./dataset/val.npy")

In [None]:
# Print data for inspection
train_dataset[0]
print(len(train_dataset))
print(len(val_dataset))

## Utilities

Following we implement mathematic utilities, especially the `batch_get_camera_position_and_rotation` function, for computing the final camera transform based on output parameters. For more details, please refer to the paper.

In [None]:
import math

def angle_normalize(angle):
    return (angle + math.pi) % (2 * math.pi) - math.pi

def batch_euler_xyz_to_zxy(euler):
    return torch.stack([euler[..., 2], euler[..., 0], euler[..., 1]], dim=-1)

def batch_euler_zxy_to_xyz(euler):
    return torch.stack([euler[..., 1], euler[..., 2], euler[..., 0]], dim=-1)

def batch_euler_normalize(euler):
    return torch.stack([angle_normalize(euler[..., 0]), angle_normalize(euler[..., 1]), angle_normalize(euler[..., 2])], dim=-1)

def batch_get_camera_position_and_rotation(character_pos, character_rot_y, camera_pos, camera_rot, distance):
    convention = 'zxy'
    zeros = torch.zeros_like(character_rot_y)
    pis = torch.ones_like(character_rot_y) * math.pi
    character_rot = torch.stack([zeros, character_rot_y, zeros], dim=-1)
    camera_rot[..., 1] = -camera_rot[..., 1]
    
    character_rotation = roma.euler_to_unitquat(convention, batch_euler_xyz_to_zxy(character_rot))
    local_camera_rotation = roma.euler_to_unitquat(convention, batch_euler_xyz_to_zxy(camera_rot))
    extra_rotation = roma.euler_to_unitquat(convention, batch_euler_xyz_to_zxy(torch.stack([zeros, pis, zeros], dim=-1)))
    
    combined_rotation = roma.quat_product(roma.quat_product(character_rotation, local_camera_rotation), extra_rotation)
    
    world_camera_pos = character_pos + roma.quat_action(character_rotation, camera_pos)
    camera_backward = roma.quat_action(combined_rotation, torch.stack([zeros, zeros, -distance], dim=-1))
    final_camera_position = world_camera_pos + camera_backward
    
    final_camera_rotation = batch_euler_zxy_to_xyz(roma.unitquat_to_euler(convention, combined_rotation))
    
    return final_camera_position, final_camera_rotation

Following we implement the `calculate_angle` function which computes the angle difference between the vector `character - cam` and `forward_vector(cam) - cam`. Essentially, how much the camera has to rotate to center the character in the view.

In [None]:
character_chest_height = 1.115

def angle_v2(a, b, dim=-1, eps=1e-5):
    a_norm = a.norm(dim=dim, keepdim=True)
    b_norm = b.norm(dim=dim, keepdim=True)

    a_prime = (a * b_norm - a_norm * b).norm(dim=dim)
    b_prime = (a * b_norm + a_norm * b).norm(dim=dim)
    
    mask = (a_prime < eps) & (b_prime < eps)
    a_adjusted = torch.where(mask, a_prime + eps, a_prime)
    b_adjusted = torch.where(mask, b_prime + eps, b_prime)

    return 2 * torch.atan2(a_adjusted, b_adjusted)

def calculate_angle(camera_pos, camera_rot, character_pos):
    chest_pos = character_pos.clone()
    chest_pos[..., 1] = character_chest_height
    
    zeros = torch.zeros(character_pos.shape[:-1], device=character_pos.device)
    ones = torch.ones(character_pos.shape[:-1], device=character_pos.device)
    
    dir_vec = chest_pos - camera_pos
    dir_norm = torch.norm(dir_vec, dim=-1, keepdim=True)
    dir_vec = dir_vec / (dir_norm + 1e-5)  # Add epsilon to avoid division by zero
    camera_dir = roma.quat_action(roma.euler_to_unitquat('zxy', batch_euler_xyz_to_zxy(camera_rot)), torch.stack([zeros, zeros, ones], dim=-1))
    angle = angle_v2(camera_dir, dir_vec, dim=-1)
    return angle
      

# Define the test inputs
camera_pos_test = torch.tensor([[-3.11, 1.15, 2.01], [-1.77, 1.34, 3.77]], dtype=torch.float32)
camera_rot_test = torch.deg2rad(torch.tensor([[2.30, 112.65, 0.00], [354.39, 37.95, 0.00]], dtype=torch.float32))  # No rotation and 90 degrees yaw
character_pos_test = torch.tensor([[0.00, 1.11, 0.06], [0.00, 1.11, 0.06]], dtype=torch.float32)
size = 2
camera_pos_test = camera_pos_test.unsqueeze(0).expand(size, -1, -1).contiguous()
camera_rot_test = camera_rot_test.unsqueeze(0).expand(size, -1, -1).contiguous()
character_pos_test = character_pos_test.unsqueeze(0).expand(size, -1, -1).contiguous()

# Run the test
calculate_angle(camera_pos_test, camera_rot_test, character_pos_test)

# Expected: 0.1675, 2.0369

Following we implement the `calculate_look_angle` function which computes the Y rotation difference between the camera orientation and the character orientation.

In [None]:
def calculate_look_angle(camera_rot, character_rot_y):
    zeros = torch.zeros_like(character_rot_y)
    ones = torch.ones_like(character_rot_y)
    
    character_rot = torch.stack([zeros, character_rot_y, zeros], dim=-1)
    
    camera_dir = roma.quat_action(roma.euler_to_unitquat('zxy', batch_euler_xyz_to_zxy(camera_rot)), torch.stack([zeros, zeros, ones], dim=-1))
    character_dir = roma.quat_action(roma.euler_to_unitquat('zxy', batch_euler_xyz_to_zxy(character_rot)), torch.stack([zeros, zeros, ones], dim=-1))
    angle = angle_v2(camera_dir, character_dir, dim=-1)
    return angle

torch.rad2deg(calculate_look_angle(torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=torch.float32), torch.deg2rad(torch.tensor([190.0, 170.0], dtype=torch.float32))))

## Model Definition

In [None]:
import math
import matplotlib.pyplot as plt
import json

class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000) -> None:
        super().__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:, : x.size(1)]
        return self.dropout(x)

class AVDCEncoderOnlyModel(nn.Module):
    def __init__(self, input_dim, output_dim, d_model, dim_feedforward, num_heads, num_encoder_layers, dropout=0.1):
        super(AVDCEncoderOnlyModel, self).__init__()

        self.input_embedding = nn.Linear(input_dim, d_model)
        self.input_positional_encoding = PositionalEncoding(d_model)
        encoder_layers = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=num_heads, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers)
        self.out = nn.Linear(d_model, output_dim)
        
        self.init_weights()

    def init_weights(self):
        nn.init.xavier_uniform_(self.input_embedding.weight)
        nn.init.xavier_uniform_(self.out.weight)
        self.input_embedding.bias.data.fill_(0.01)
        self.out.bias.data.fill_(0.01)
        
        # Init transformer weights
        for p in self.transformer_encoder.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, x):
        x = self.input_embedding(x)
        x = self.input_positional_encoding(x)
        x = self.transformer_encoder(x)
        x = self.out(x)
        return x

## Loss Function

We design a heuristic loss function that penalizes on stylistic features of the camera trajectory. For more details, please refer to the paper.

In [None]:

def compute_loss(inputs, outputs, labels, step, total_step, inference=False):
    pos_weight = 2
    y_weight = 4
    rot_weight = 1
    distance_weight = 2
    d_weight = 2
    angle_weight = 1
    y_angle_weight = 1
    bb_weight = 2
    
    input_character_pos = inputs[:, :, 1:4]
    input_character_pos_offset = inputs[:, :, 4:7]
    input_character_rot_y = inputs[:, :, 7]
    input_character_rot_y_offset = inputs[:, :, 8]
    
    input_character_pos0 = input_character_pos
    input_character_pos1 = input_character_pos + input_character_pos_offset
    
    input_character_rot_y0 = input_character_rot_y
    input_character_rot_y1 = input_character_rot_y + input_character_rot_y_offset
    
    outputs_camera_pos = outputs[:, :, 0:3]
    outputs_camera_pos_offset = outputs[:, :, 3:6]
    outputs_camera_pos_0 = outputs_camera_pos
    outputs_camera_pos_1 = outputs_camera_pos + outputs_camera_pos_offset
    
    outputs_camera_rot = outputs[:, :, 6:9]
    outputs_camera_rot_offset = outputs[:, :, 9:12]
    outputs_camera_rot_0 = outputs_camera_rot
    outputs_camera_rot_1 = outputs_camera_rot + outputs_camera_rot_offset
    
    outputs_camera_d = outputs[:, :, 12]
    outputs_camera_d_offset = outputs[:, :, 13]
    outputs_camera_d_0 = outputs_camera_d
    outputs_camera_d_1 = outputs_camera_d + outputs_camera_d_offset
    
    labels_camera_pos = labels[:, :, 0:3]
    labels_camera_pos_offset = labels[:, :, 3:6]
    labels_camera_pos0 = labels_camera_pos
    labels_camera_pos1 = labels_camera_pos + labels_camera_pos_offset
    
    labels_camera_rot = labels[:, :, 6:9]
    labels_camera_rot_offset = labels[:, :, 9:12]
    labels_camera_rot0 = labels_camera_rot
    labels_camera_rot1 = labels_camera_rot + labels_camera_rot_offset
    
    labels_camera_d = labels[:, :, 12]
    labels_camera_d_offset = labels[:, :, 13]
    labels_camera_d0 = labels_camera_d
    labels_camera_d1 = labels_camera_d + labels_camera_d_offset
    
    outputs_start_camera_pos = outputs_camera_pos_0
    outputs_start_camera_rot = outputs_camera_rot_0
    
    labels_start_camera_pos = labels_camera_pos0
    labels_start_camera_rot = labels_camera_rot0
    
    outputs_end_camera_pos = outputs_camera_pos_1
    outputs_end_camera_rot = outputs_camera_rot_1
    
    labels_end_camera_pos = labels_camera_pos1
    labels_end_camera_rot = labels_camera_rot1
    
    outputs_start_camera_pos, outputs_start_camera_rot = batch_get_camera_position_and_rotation(input_character_pos0, input_character_rot_y0, outputs_camera_pos_0, outputs_camera_rot_0, outputs_camera_d_0)
    
    labels_start_camera_pos, labels_start_camera_rot = batch_get_camera_position_and_rotation(input_character_pos0, input_character_rot_y0, labels_camera_pos0, labels_camera_rot0, labels_camera_d0)
    
    outputs_end_camera_pos, outputs_end_camera_rot = batch_get_camera_position_and_rotation(input_character_pos1, input_character_rot_y1, outputs_camera_pos_1, outputs_camera_rot_1, outputs_camera_d_1)
    
    labels_end_camera_pos, labels_end_camera_rot = batch_get_camera_position_and_rotation(input_character_pos1, input_character_rot_y1, labels_camera_pos1, labels_camera_rot1, labels_camera_d1)
    
    # ensure actual position delta is similar to true position delta
    actual_pos_delta_loss = torch.nn.functional.mse_loss((outputs_end_camera_pos - outputs_start_camera_pos).norm(dim=2).mean(dim=1), (labels_end_camera_pos - labels_start_camera_pos).norm(dim=2).mean(dim=1), reduction='mean')

    # ensures predicted position delta is similar to true position delta (L2 loss)
    pos_delta_loss = torch.nn.functional.mse_loss(outputs_camera_pos_offset.norm(dim=2).mean(dim=1), labels_camera_pos_offset.norm(dim=2).mean(dim=1), reduction='mean')
    
    # ensures predicted rotation delta is similar to true rotation delta (L2 loss)
    rot_delta_loss = torch.nn.functional.mse_loss(outputs_camera_rot_offset.norm(dim=2).mean(dim=1), labels_camera_rot_offset.norm(dim=2).mean(dim=1), reduction='mean')
    
    # ensures predicted start camera position has similar distance to character
    predicted_start_dist_to_char = outputs_start_camera_pos - input_character_pos0
    true_start_dist_to_char = labels_start_camera_pos - input_character_pos0
    start_distance_loss = torch.nn.functional.mse_loss(predicted_start_dist_to_char, true_start_dist_to_char, reduction='mean')

    predicted_end_dist_to_char = outputs_end_camera_pos - input_character_pos1
    true_end_dist_to_char = labels_end_camera_pos - input_character_pos1
    end_distance_loss = torch.nn.functional.mse_loss(predicted_end_dist_to_char, true_end_dist_to_char, reduction='mean')

    # ensures predicted d is similar to true d
    d_loss = torch.nn.functional.mse_loss(outputs_camera_d.mean(), labels_camera_d.mean(), reduction='mean') + torch.nn.functional.mse_loss(outputs_camera_d_offset.mean(), labels_camera_d_offset.mean(), reduction='mean')
    
    # ensures predicted start camera rotation has similar angle
    predicted_start_angle = calculate_angle(outputs_start_camera_pos, outputs_start_camera_rot, input_character_pos)
    true_start_angle = calculate_angle(labels_start_camera_pos, labels_start_camera_rot, input_character_pos)
    start_angle_loss = torch.nn.functional.mse_loss(predicted_start_angle, true_start_angle, reduction='mean')
    
    # ensures predicted end camera rotation has similar angle
    predicted_end_angle = calculate_angle(outputs_end_camera_pos, outputs_end_camera_rot, input_character_pos1)
    true_end_angle = calculate_angle(labels_end_camera_pos, labels_end_camera_rot, input_character_pos1)
    end_angle_loss = torch.nn.functional.mse_loss(predicted_end_angle, true_end_angle, reduction='mean')
    
    # ensures camera start y is overall similar
    predicted_start_y = outputs_start_camera_pos[:, :, 1].mean(dim=1)
    true_start_y = labels_start_camera_pos[:, :, 1].mean(dim=1)
    start_y_loss = torch.nn.functional.mse_loss(predicted_start_y, true_start_y, reduction='mean')
    
    # ensures camera end y is similar
    predicted_end_y = outputs_end_camera_pos[:, :, 1].mean(dim=1)
    true_end_y = labels_end_camera_pos[:, :, 1].mean(dim=1)
    end_y_loss = torch.nn.functional.mse_loss(predicted_end_y, true_end_y, reduction='mean')
    
    # ensures camera Y rotation and player Y rotation start difference is similar
    predicted_start_y_rot_diff = calculate_look_angle(outputs_start_camera_rot, input_character_rot_y0)
    true_start_y_rot_diff = calculate_look_angle(labels_start_camera_rot, input_character_rot_y0)
    start_y_rot_diff_loss = (predicted_start_y_rot_diff - true_start_y_rot_diff) ** 2

    # ensures camera Y rotation and player Y rotation end difference is similar
    predicted_end_y_rot_diff = calculate_look_angle(outputs_end_camera_rot, input_character_rot_y1)
    true_end_y_rot_diff = calculate_look_angle(labels_end_camera_rot, input_character_rot_y1)
    end_y_rot_diff_loss = (predicted_end_y_rot_diff - true_end_y_rot_diff) ** 2
    
    def bounding_box_size(tensor):
        min = tensor.min(dim=1).values
        max = tensor.max(dim=1).values
        return max - min    
    
    # ensures predicted start position has similar bounding box size
    predicted_start_bounding_box_size = bounding_box_size(outputs_start_camera_pos)
    true_start_bounding_box_size = bounding_box_size(labels_start_camera_pos)
    start_bounding_box_size_loss = torch.nn.functional.mse_loss(predicted_start_bounding_box_size.norm(dim=1), true_start_bounding_box_size.norm(dim=1), reduction='mean')
    
    # ensures predicted end position has similar bounding box size
    predicted_end_bounding_box_size = bounding_box_size(outputs_end_camera_pos)
    true_end_bounding_box_size = bounding_box_size(labels_end_camera_pos)
    end_bounding_box_size_loss = torch.nn.functional.mse_loss(predicted_end_bounding_box_size.norm(dim=1), true_end_bounding_box_size.norm(dim=1), reduction='mean')

    pos_loss = pos_weight * pos_delta_loss.mean()
    rot_loss = rot_weight * rot_delta_loss.mean()
    d_loss = d_weight * d_loss.mean()
    distance_loss = distance_weight * (start_distance_loss + end_distance_loss)
    angle_loss = angle_weight * (start_angle_loss.mean() + end_angle_loss.mean())
    y_angle_loss = y_angle_weight * (start_y_rot_diff_loss.mean() + end_y_rot_diff_loss.mean())
    bb_loss = bb_weight * (start_bounding_box_size_loss + end_bounding_box_size_loss)
    y_loss = y_weight * (start_y_loss.mean() + end_y_loss.mean())

    loss = pos_loss + rot_loss + d_loss + distance_loss + angle_loss + y_angle_loss + bb_loss + y_loss

    if use_wandb and not inference:
        wandb.log({"step": total_step, "loss": loss, "pos_loss": pos_loss, "rot_loss": rot_loss, "d_loss": d_loss, "distance_loss": distance_loss, "angle_loss": angle_loss, "y_angle_loss": y_angle_loss, "bb_loss": bb_loss,
        "y_loss": y_loss})
    
    return loss

## Training Code & Visualization

In [None]:
def visualize(inputs, outputs, labels):
    rand_idx = np.random.randint(0, outputs.shape[0])
    inputs = inputs[rand_idx]
    outputs = outputs[rand_idx]
    labels = labels[rand_idx]
    
    # Draw camera path on XZ plane
    fig, ax = plt.subplots()
    plt.ioff()
    
    character_pos_x_from = inputs[:, 1].cpu().detach().numpy()
    character_pos_z_from = inputs[:, 3].cpu().detach().numpy()
    character_pos_x_to = character_pos_x_from + inputs[:, 4].cpu().detach().numpy()
    character_pos_z_to = character_pos_z_from + inputs[:, 6].cpu().detach().numpy()
    
    character_pos = inputs[:, 1:4]
    character_pos_offset = inputs[:, 4:7]
    character_rot_y = inputs[:, 7]
    character_rot_y_offset = inputs[:, 8]
    
    output_camera_pos = outputs[:, 0:3]
    output_camera_pos_offset = outputs[:, 3:6]
    output_camera_rot = outputs[:, 6:9]
    output_camera_rot_offset = outputs[:, 9:12]
    output_camera_d = outputs[:, 12]
    output_camera_d_offset = outputs[:, 13]
    
    labels_camera_pos = labels[:, 0:3]
    labels_camera_pos_offset = labels[:, 3:6]
    labels_camera_rot = labels[:, 6:9]
    labels_camera_rot_offset = labels[:, 9:12]
    labels_camera_d = labels[:, 12]
    labels_camera_d_offset = labels[:, 13]
    
    num_paths = outputs.shape[0]
    for i in range(num_paths):
        ax.plot([character_pos_x_from[i], character_pos_x_to[i]], [character_pos_z_from[i], character_pos_z_to[i]], label="Character path", color=np.array([0, 1, 0, i / num_paths]))
        
        # sample camera curve path
        def sample(character_pos, character_pos_offset, character_rot_y, character_rot_y_offset, pos, pos_offset, rot, rot_offset, d, d_offset, color):
            def sample_camera_transform(character_pos, character_pos_offset, character_rot_y, character_rot_y_offset, pos, pos_offset, rot, rot_offset, d, d_offset, t):
                cp = character_pos + t * character_pos_offset
                cry = (character_rot_y + t * character_rot_y_offset).squeeze(-1)
                p = pos + t * pos_offset
                r = rot + t * rot_offset
                d = (d + t * d_offset).squeeze(-1)
                out_pos, out_rot = batch_get_camera_position_and_rotation(cp, cry, p, r, d)
                return out_pos, out_rot
        
            t = torch.linspace(0, 1, steps=100).unsqueeze(1).to(device)
            out_pos, out_rot = sample_camera_transform(character_pos, character_pos_offset, character_rot_y, character_rot_y_offset, pos, pos_offset, rot, rot_offset, d, d_offset, t)
            
            out_pos = out_pos.cpu().detach().numpy()
            ax.plot(out_pos[:, 0], out_pos[:, 2], color=color)
        
        sample(character_pos[i], character_pos_offset[i], character_rot_y[i], character_rot_y_offset[i],
            output_camera_pos[i], output_camera_pos_offset[i], output_camera_rot[i], output_camera_rot_offset[i], output_camera_d[i], output_camera_d_offset[i], np.array([1, 0, 0, i / num_paths]))
        sample(character_pos[i], character_pos_offset[i], character_rot_y[i], character_rot_y_offset[i],
            labels_camera_pos[i], labels_camera_pos_offset[i], labels_camera_rot[i], labels_camera_rot_offset[i], labels_camera_d[i], labels_camera_d_offset[i], np.array([0, 0, 1, i / num_paths]))
    
    ax.set_xlabel('X')
    ax.set_ylabel('Z')
    ax.grid(True)
    
    plt.show()
    
    if write_test_data:
        import time
        # timestamp as filename
        filename = f"test_data/test_{(time.time() * 1000):.0f}"
        fig.savefig(filename + ".png")
        
        json_obj = {
            "inputs": inputs.cpu().detach().numpy().tolist(),
            "outputs": outputs.cpu().detach().numpy().tolist(),
            "labels": labels.cpu().detach().numpy().tolist()
        }
        with open(f"{filename}.json", "w") as f:
            json.dump(json_obj, f)
    
    return fig

def evaluate(epoch, model, dataloader, optimizer):
    avg_loss = 0

    i = 0
    length = len(dataloader)
    for inputs, labels in dataloader:
        step = epoch * length + i
        inputs, labels = inputs.to(device), labels.to(device)
        
        outputs = model(inputs)

        loss = compute_loss(inputs, outputs, labels, i, step, optimizer is None)
        l = loss.item()
        avg_loss += l
        
        if (optimizer is None and epoch % visualize_epochs == 0) or (optimizer is not None and (i % 100 == 0 and i > 0) or (epoch % 10 == 0 and i == 0)) or write_test_data:
            fig = visualize(inputs, outputs, labels)
            if use_wandb:
                wandb.log({"epoch": epoch, "path": wandb.Image(fig)})
            plt.close(fig)
        
        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        i += 1
        print(f"Epoch {epoch + 1}/{epochs}, Step {i}/{length}, Loss: {l}")

    avg_loss /= len(dataloader)

    return avg_loss

torch.cuda.empty_cache()
torch.set_printoptions(sci_mode=False, threshold=10000)

import os
import glob

# Hyperparameters
d_model = 512
dim_feedforward = 1024
num_heads = 2
num_encoder_layers = 6
num_decoder_layers = 1
dropout = 0.2
batch_size = 1024
epochs = 100000
learning_rate = 5e-5

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=32, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=1024, shuffle=False, num_workers=32, pin_memory=True)

test_input, test_label = next(iter(train_dataloader))
input_dim = test_input.shape[2]
output_dim = test_label.shape[2]

model = AVDCEncoderOnlyModel(input_dim=input_dim,
                              output_dim=output_dim,
                              d_model=d_model,
                              dim_feedforward=dim_feedforward,
                              num_heads=num_heads,
                              num_encoder_layers=num_encoder_layers,
                              dropout=dropout).to(device)

print("Number of parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=100000)

will_train = True
will_validate = True
will_load_checkpoint = False
write_test_data = False
visualize_epochs = 20
validate_epochs = 10

if write_test_data or not will_train:
    visualize_epochs = 1

use_wandb = True
if not will_train:
    use_wandb = False
if use_wandb:
    wandb.init(
        # set the wandb project where this run will be logged
        project="avdc",
        # track hyperparameters and run metadata
        config={
            "input_dim": input_dim,
            "d_model": d_model,
            "dim_feedforward": dim_feedforward,
            "num_heads": num_heads,
            "num_encoder_layers": num_encoder_layers,
            "num_decoder_layers": num_decoder_layers,
            "dropout": dropout,
            "batch_size": batch_size,
            "learning_rate": learning_rate
        }
    )

# Load checkpoint
starting_epoch = 0
if will_load_checkpoint:
    # Find the latest checkpoint in ./checkpoints by modified time
    list_of_files = glob.glob('./checkpoints/*')
    if len(list_of_files) == 0:
        print("No checkpoints found.")
    else:
        latest_file = max(list_of_files, key=os.path.getctime)
        print(f"Loading checkpoint: {latest_file}")
        starting_epoch = int(latest_file.split('_')[-1].split('.')[0])
        model.load_state_dict(torch.load(latest_file))

# Training loop
for epoch in range(epochs):
    if epoch < starting_epoch:
        continue

    avg_train_loss = 0
    if will_train:
        avg_train_loss = evaluate(epoch, model, train_dataloader, optimizer)
   
    avg_val_loss = 0
    if will_validate and epoch % validate_epochs == 0 and epoch > 0: 
        avg_val_loss = evaluate(epoch, model, val_dataloader, None)

    scheduler.step(avg_val_loss)

    print(f'Epoch [{epoch + 1}/{epochs}], LR: {scheduler.get_last_lr()}, '
          f'Train Loss: {avg_train_loss:.8f}, Val Loss: {avg_val_loss:.8f}')

    # Save checkpoint
    if epoch % 1 == 0:
        torch.save(model.state_dict(), f'checkpoints/checkpoint_{epoch}.pt')

    if use_wandb:
        wandb.log({"epoch": epoch,
                   "train_loss": avg_train_loss, "val_loss": avg_val_loss}
                   if will_validate else {"epoch": epoch, "train_loss": avg_train_loss})

print("Training complete.")