In [None]:
import sys

import wandb

print(sys.executable)

import torch
print(torch.__version__)

torch.cuda.is_available()

## Generate Dataset
(No need to run this, dataset is under `./dataset`)

In [None]:
import json
import os
import numpy as np

mmd_path = '/mnt/c/research/avdc/mmd_processed2'
sliding_window_size = 10
val_ratio = 0.05

train_data = []
val_data = []

def process(folder_path):
    # Iterate through all json files in the folder
    for file in os.listdir(folder_path):
        if file.endswith('.json'):
            with open(folder_path + '/' + file, 'r') as f:
                shots = json.load(f)
    
                duration = 0
                for shot in shots:
                    duration += shot['duration']
            
                array = np.zeros((len(shots), 1 + 3 * 4 + 4 * 4))
                for i, shot in enumerate(shots):
                    array[i, 0] = shot['duration']
                    array[i, 1:4] = shot['character_pos']['x'], shot['character_pos']['y'], shot['character_pos']['z']
                    array[i, 4:7] = shot['character_pos_offset']['x'], shot['character_pos_offset']['y'], shot['character_pos_offset']['z']
                    array[i, 7] = np.deg2rad(shot['character_rot_y'])
                    array[i, 8] = shot['character_rot_y_offset']
                    
                    array[i, 9:12] = shot['camera_pos']['x'], shot['camera_pos']['y'], shot['camera_pos']['z']
                    array[i, 12:15] = shot['camera_pos_offset']['x'], shot['camera_pos_offset']['y'], shot['camera_pos_offset']['z']
                    array[i, 15:18] = np.deg2rad(shot['camera_rot']['x']), np.deg2rad(shot['camera_rot']['y']), np.deg2rad(shot['camera_rot']['z'])
                    array[i, 18:21] = np.deg2rad(shot['camera_rot_offset']['x']), np.deg2rad(shot['camera_rot_offset']['y']), np.deg2rad(shot['camera_rot_offset']['z'])
                    array[i, 21] = shot['camera_distance']
                    array[i, 22] = shot['camera_distance_offset']
                    array[i, 23] = shot['camera_fov']
                    array[i, 24] = shot['camera_fov_offset']
                    
                for i in range(len(array) - sliding_window_size + 1):
                    data = train_data if np.random.rand() > val_ratio else val_data
                    data.append(array[i:i+sliding_window_size])
    
    return True

processed = 0
for folder in os.listdir(mmd_path):
    processed += 1 if process(mmd_path + '/' + folder) else 0

print("Generated ", len(train_data), "training samples and ", len(val_data), "validation samples from ", processed, "folders.")

np_data = np.array(train_data)
np.save('./dataset/train.npy', np_data)
np_data = np.array(val_data)
np.save('./dataset/val.npy', np_data)

## Dataset Definition

In [None]:
import torch.nn as nn
import torch.optim as optim

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

class AVDCDataset(Dataset):
    def __init__(self, dataset_path):
        self.shots = []
        self.load_dataset(dataset_path)

    def load_dataset(self, dataset_path):
        self.shots = np.load(dataset_path)
        self.shots = torch.from_numpy(self.shots).float()
        
        #for i in range(0, len(self.shots[0])):
            # Base camera coordinate on character position TODO: Really necessary?
            # self.shots[:, i, 7:10] -= self.shots[:, i, 1:4]
            
            # Find max and min from 1:4 TODO: Really necessary?
            # max_pos = self.shots[:, i, 1:4].max()
            # min_pos = self.shots[:, i, 1:4].min()
            # scale = max_pos - min_pos
            # self.shots[:, i, 1:4] = (self.shots[:, i, 1:4] - min_pos) / scale
            # self.shots[:, i, 7:10] = (self.shots[:, i, 7:10] - min_pos) / scale
            # self.shots[:, i, 4:7] = self.shots[:, i, 4:7] / scale
            # self.shots[:, i, 10:13] = self.shots[:, i, 10:13] / scale
            
            # Rotation needs no changes

    def __len__(self):
        return self.shots.shape[0]

    def __getitem__(self, idx):
        shots = self.shots[idx]
        inputs = shots[:, 0:9]
        labels = shots[:, 9:25]
        return inputs, labels

# Model, loss, and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

# Data loading
train_dataset = AVDCDataset("./dataset/train.npy")
val_dataset = AVDCDataset("./dataset/val.npy")

In [None]:
train_dataset[0]
print(len(train_dataset))
print(len(val_dataset))

In [8]:
def delta_pos(rot, d):
    # Calculate cosines and sines for each rotation angle in the batch
    cos_rot = torch.cos(rot)
    sin_rot = torch.sin(rot)
    
    # Prepare zeros and ones for the batch
    zeros = torch.zeros_like(d)
    ones = torch.ones_like(d)

    # Define the rotation matrices for each axis with batched operations
    # Rz matrices for the whole batch
    Rz_batch = torch.stack([
        torch.stack([cos_rot[..., 2], -sin_rot[..., 2], zeros], dim=-1),
        torch.stack([sin_rot[..., 2], cos_rot[..., 2], zeros], dim=-1),
        torch.stack([zeros, zeros, ones], dim=-1)
    ], dim=-2)  # No need for transpose as we build it correctly
    
    # Ry matrices for the whole batch
    Ry_batch = torch.stack([
        torch.stack([cos_rot[..., 1], zeros, sin_rot[..., 1]], dim=-1),
        torch.stack([zeros, ones, zeros], dim=-1),
        torch.stack([-sin_rot[..., 1], zeros, cos_rot[..., 1]], dim=-1)
    ], dim=-2)
    
    # Rx matrices for the whole batch
    Rx_batch = torch.stack([
        torch.stack([ones, zeros, zeros], dim=-1),
        torch.stack([zeros, cos_rot[..., 0], -sin_rot[..., 0]], dim=-1),
        torch.stack([zeros, sin_rot[..., 0], cos_rot[..., 0]], dim=-1)
    ], dim=-2)
    
    # Compute the combined rotation matrix for each set of rotations in the batch
    # The order of multiplication might need to be adjusted based on your specific rotation order requirements
    R_ZXY_batch = torch.einsum('...ij,...jk->...ik', Rz_batch, Ry_batch)
    R_ZXY_batch = torch.einsum('...ij,...jk->...ik', R_ZXY_batch, Rx_batch)
    
    # Apply the rotation to the distance vector for each item in the batch
    d_vec = torch.stack([zeros, zeros, d], axis=-1)  # Create a batch of distance vectors
    camera_positions = torch.einsum('...ij,...j->...i', R_ZXY_batch, d_vec)
    
    return camera_positions

batched_camera_rot = torch.Tensor(np.deg2rad(np.array([
    [[-22.69, 0.00, 0.00], [-22.69, 0.00, 0.00]],
    [[0.00, 0.00, -5.41], [-22.69, 0.00, 0.00]],
    [[-22.69, 0.00, 0.00], [0.00, 0.00, 0.00]]
])))
batched_camera_distance = torch.Tensor(np.array([
    [4.559982, 4.559982],
    [2.159997, 4.559982],
    [4.559982, 1.199999]
]))

print(batched_camera_rot.shape, batched_camera_distance.shape)

batched_result = delta_pos(batched_camera_rot, batched_camera_distance)
batched_result

# Expected
# array([[[0.        , 1.75899036, 4.20706415],
#         [0.        , 1.75899036, 4.20706415]],
# 
#        [[0.        , 0.        , 2.159997  ],
#         [0.        , 1.75899036, 4.20706415]],
# 
#        [[0.        , 1.75899036, 4.20706415],
#         [0.        , 0.        , 1.199999  ]]])

character_chest_height = 1.115

def calculate_angle(camera_pos, camera_rot, character_pos):
    chest_pos = character_pos.clone()
    chest_pos[..., 1] = 1.115

    dir_vec = chest_pos - camera_pos
    dir_norm = torch.norm(dir_vec, dim=-1, keepdim=True)
    dir_vec = dir_vec / (dir_norm + 1e-6)  # Add epsilon to avoid division by zero

    cos_rot = torch.cos(camera_rot)
    sin_rot = torch.sin(camera_rot)

    Rz = torch.stack([
        torch.stack([cos_rot[..., 2], -sin_rot[..., 2], torch.zeros_like(cos_rot[..., 2])], dim=-1),
        torch.stack([sin_rot[..., 2], cos_rot[..., 2], torch.zeros_like(sin_rot[..., 2])], dim=-1),
        torch.stack([torch.zeros_like(cos_rot[..., 2]), torch.zeros_like(sin_rot[..., 2]), torch.ones_like(sin_rot[..., 2])], dim=-1)
    ], dim=-2)

    Ry = torch.stack([
        torch.stack([cos_rot[..., 1], torch.zeros_like(cos_rot[..., 1]), sin_rot[..., 1]], dim=-1),
        torch.stack([torch.zeros_like(cos_rot[..., 1]), torch.ones_like(cos_rot[..., 1]), torch.zeros_like(sin_rot[..., 1])], dim=-1),
        torch.stack([-sin_rot[..., 1], torch.zeros_like(sin_rot[..., 1]), cos_rot[..., 1]], dim=-1)
    ], dim=-2)

    Rx = torch.stack([
        torch.stack([torch.ones_like(cos_rot[..., 0]), torch.zeros_like(cos_rot[..., 0]), torch.zeros_like(sin_rot[..., 0])], dim=-1),
        torch.stack([torch.zeros_like(cos_rot[..., 0]), cos_rot[..., 0], -sin_rot[..., 0]], dim=-1),
        torch.stack([torch.zeros_like(cos_rot[..., 0]), sin_rot[..., 0], cos_rot[..., 0]], dim=-1)
    ], dim=-2)

    R = torch.einsum('...ij,...jk->...ik', Rz, Ry)
    R = torch.einsum('...ij,...jk->...ik', R, Rx)

    forward_vec = torch.einsum('...ij,...j->...i', R, torch.tensor([0.0, 0.0, 1.0], device=dir_vec.device, dtype=torch.float32))
    forward_norm = torch.norm(forward_vec, dim=-1, keepdim=True)
    forward_vec = forward_vec / (forward_norm + 1e-6)  # Add epsilon to avoid division by zero

    cos_angle = torch.sum(forward_vec * dir_vec, dim=-1)
    cos_angle = torch.clamp(cos_angle, -0.9999, 0.9999)  # Clamp to avoid invalid values for acos
    angle = torch.acos(cos_angle) # * (180 / torch.pi)  # Convert radians to degrees

    return angle
    

# Define the test inputs
camera_pos_test = torch.tensor([[-3.11, 1.15, 2.01], [-1.77, 1.34, 3.77]], dtype=torch.float32)
camera_rot_test = torch.tensor(np.deg2rad([[2.30, 112.65, 0.00], [354.39, 37.95, 0.00]]), dtype=torch.float32)  # No rotation and 90 degrees yaw
character_pos_test = torch.tensor([[0.00, 1.11, 0.06], [0.00, 1.11, 0.06]], dtype=torch.float32)

# Run the test
calculate_angle(camera_pos_test, camera_rot_test, character_pos_test)

# Expected: 0.1675, 2.0369

torch.Size([3, 2, 3]) torch.Size([3, 2])


tensor([0.1675, 2.0369])

RuntimeError: stack expects each tensor to be equal size, but got [2] at entry 0 and [2, 3] at entry 2

In [10]:
 # "camera_pos": {
 #    "x": -0.193416685,
 #    "y": 1.41547632,
 #    "z": 0.180194452
 #  },
 #  "camera_pos_offset": {
 #    "x": -0.303891361,
 #    "y": -0.9445602,
 #    "z": -0.524982154
 #  },
 #  "camera_rot": {
 #    "x": 87.1031952,
 #    "y": 288.31192,
 #    "z": 383.871033
 #  },
 #  "camera_rot_offset": {
 #    "x": -359.080658,
 #    "y": -4.89764547,
 #    "z": 300.791779
 #  },
 #  "camera_distance": 5.56078148,
 #  "camera_distance_offset": -2.1778245,
 
camera_rot = torch.tensor(np.deg2rad([[10, 20, 45]]), dtype=torch.float32)
camera_rot_offset = torch.tensor(np.deg2rad([[20, 10, -10]]), dtype=torch.float32)
camera_distance = torch.tensor([2], dtype=torch.float32)
camera_distance_offset = torch.tensor([4], dtype=torch.float32)

print(delta_pos(camera_rot, camera_distance))
print(delta_pos(camera_rot + camera_rot_offset, camera_distance + camera_distance_offset))

camera_pos = torch.tensor([[-0.193416685, 1.41547632, 0.180194452]], dtype=torch.float32)
camera_pos_offset = torch.tensor([[-0.303891361, -0.9445602, -0.524982154]], dtype=torch.float32)
camera_rot = torch.tensor(np.deg2rad([[87.1031952, 288.31192, 383.871033]]), dtype=torch.float32)
camera_rot_offset = torch.tensor(np.deg2rad([[-359.080658, -4.89764547, 300.791779]]), dtype=torch.float32)
camera_distance = torch.tensor([5.56078148], dtype=torch.float32)
camera_distance_offset = torch.tensor([-2.1778245], dtype=torch.float32)

print(delta_pos(camera_rot, camera_distance))
print(delta_pos(camera_rot + camera_rot_offset, camera_distance + camera_distance_offset))

tensor([[0.7219, 0.2308, 1.8508]])
tensor([[ 3.8489, -0.9673,  4.5000]])
tensor([[ 2.0035, -5.1866,  0.0883]])
tensor([[-2.0481, -2.6924,  0.0271]])


19## Actual Training

In [None]:
import math
import matplotlib.pyplot as plt
import json

class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000) -> None:
        super().__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:, : x.size(1)]
        return self.dropout(x)

class AVDCEncoderOnlyModel(nn.Module):
    def __init__(self, input_dim, output_dim, d_model, dim_feedforward, num_heads, num_encoder_layers, dropout=0.1):
        super(AVDCEncoderOnlyModel, self).__init__()

        self.input_embedding = nn.Linear(input_dim, d_model)
        self.input_positional_encoding = PositionalEncoding(d_model)
        encoder_layers = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=num_heads, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers)
        self.out = nn.Linear(d_model, output_dim)
        
        self.init_weights()

    def init_weights(self):
        nn.init.xavier_uniform_(self.input_embedding.weight)
        nn.init.xavier_uniform_(self.out.weight)
        self.input_embedding.bias.data.fill_(0.01)
        self.out.bias.data.fill_(0.01)

    def forward(self, x, labels, inference):
        x = self.input_embedding(x)
        x = self.input_positional_encoding(x)
        x = self.transformer_encoder(x)
        x = self.out(x)
        return x

def compute_loss(inputs, outputs, labels):
    var_loss_weight = 0.05
    angle_weight = 10
    y_weight = 5
    
    # ensures predicted position delta is similar to true position delta (L2 loss)
    pos_delta_loss = (torch.norm(outputs[:, :, 3:6], dim=2) - torch.norm(labels[:, :, 3:6], dim=2)) ** 2
    
    # ensures predicted start camera position has similar distance to character
    predicted_start = outputs[:, :, 0:3] + delta_pos(outputs[:, :, 6:9], outputs[:, :, 12])
    true_start = labels[:, :, 0:3] + delta_pos(labels[:, :, 6:9], labels[:, :, 12])
    predicted_start_dist_to_char = torch.norm(predicted_start - inputs[:, :, 1:4], dim=2)
    true_start_dist_to_char = torch.norm(true_start - inputs[:, :, 1:4], dim=2)
    start_distance_loss = (predicted_start_dist_to_char - true_start_dist_to_char) ** 2
    
    # ensures predicted middle camera position has similar distance to character
    predicted_middle = (outputs[:, :, 0:3] + outputs[:, :, 3:6] * 0.5) + delta_pos(outputs[:, :, 6:9] + outputs[:, :, 9:12] * 0.5, outputs[:, :, 12] + outputs[:, :, 13] * 0.5)
    true_middle = (labels[:, :, 0:3] + labels[:, :, 3:6] * 0.5) + delta_pos(labels[:, :, 6:9] + labels[:, :, 9:12] * 0.5, labels[:, :, 12] + labels[:, :, 13] * 0.5)
    predicted_middle_dist_to_char = torch.norm(predicted_middle - (inputs[:, :, 1:4] + inputs[:, :, 4:7] * 0.5), dim=2)
    true_middle_dist_to_char = torch.norm(true_middle - (inputs[:, :, 1:4] + inputs[:, :, 4:7] * 0.5), dim=2)
    middle_distance_loss = (predicted_middle_dist_to_char - true_middle_dist_to_char) ** 2
    
    # ensures predicted end camera position has similar distance to character
    predicted_end = (outputs[:, :, 0:3] + outputs[:, :, 3:6]) + delta_pos(outputs[:, :, 6:9] + outputs[:, :, 9:12], outputs[:, :, 12] + outputs[:, :, 13])
    true_end = (labels[:, :, 0:3] + labels[:, :, 3:6]) + delta_pos(labels[:, :, 6:9] + labels[:, :, 9:12], labels[:, :, 12] + labels[:, :, 13])
    predicted_end_dist_to_char = torch.norm(predicted_end - (inputs[:, :, 1:4] + inputs[:, :, 4:7]), dim=2)
    true_end_dist_to_char = torch.norm(true_end - (inputs[:, :, 1:4] + inputs[:, :, 4:7]), dim=2)
    end_distance_loss = (predicted_end_dist_to_char - true_end_dist_to_char) ** 2
    
    # ensures predicted start camera rotation has similar angle
    predicted_start_angle = calculate_angle(predicted_start, outputs[:, :, 6:9], inputs[:, :, 1:4])
    true_start_angle = calculate_angle(true_start, labels[:, :, 6:9], inputs[:, :, 1:4])
    start_angle_loss = (predicted_start_angle - true_start_angle) ** 2
    
    # ensures predicted middle camera rotation has similar angle
    predicted_middle_angle = calculate_angle(predicted_middle, outputs[:, :, 6:9] + outputs[:, :, 9:12] * 0.5, inputs[:, :, 1:4] + inputs[:, :, 4:7] * 0.5)
    true_middle_angle = calculate_angle(true_middle, labels[:, :, 6:9] + labels[:, :, 9:12] * 0.5, inputs[:, :, 1:4] + inputs[:, :, 4:7] * 0.5)
    middle_angle_loss = (predicted_middle_angle - true_middle_angle) ** 2
    
    # ensures predicted end camera rotation has similar angle
    predicted_end_angle = calculate_angle(predicted_end, outputs[:, :, 6:9] + outputs[:, :, 9:12], inputs[:, :, 1:4] + inputs[:, :, 4:7])
    true_end_angle = calculate_angle(true_end, labels[:, :, 6:9] + labels[:, :, 9:12], inputs[:, :, 1:4] + inputs[:, :, 4:7])
    end_angle_loss = (predicted_end_angle - true_end_angle) ** 2
    
    # ensures camera start y is similar
    predicted_start_y = outputs[:, :, 1]
    true_start_y = labels[:, :, 1]
    start_y_loss = (predicted_start_y - true_start_y) ** 2
    
    # ensures camera end y is similar
    predicted_end_y = outputs[:, :, 1] + outputs[:, :, 4]
    true_end_y = labels[:, :, 1] + labels[:, :, 4]
    end_y_loss = (predicted_end_y - true_end_y) ** 2
    
    # # ensures camera start FOV is similar
    # predicted_start_fov = outputs[:, :, 14]
    # true_start_fov = labels[:, :, 14]
    # start_fov_loss = (predicted_start_fov - true_start_fov) ** 2
    # 
    # # ensures camera end FOV is similar
    # predicted_end_fov = outputs[:, :, 14] + outputs[:, :, 15]
    # true_end_fov = labels[:, :, 14] + labels[:, :, 15]
    # end_fov_loss = (predicted_end_fov - true_end_fov) ** 2
    
    # ensures predicted start position has enough variance
    predicted_variance = torch.var(predicted_start, dim=1, unbiased=False)
    predicted_variance = torch.mean(predicted_variance, dim=-1)
    true_variance = torch.var(true_start, dim=1, unbiased=False)
    true_variance = torch.mean(true_variance, dim=-1)
    variance_loss = (predicted_variance - true_variance) ** 2
    
    # ensures start delta has similar variance to true start delta
    predicted_start_delta = outputs[:, :, 3:6]
    true_start_delta = labels[:, :, 3:6]
    predicted_start_delta_variance = torch.var(predicted_start_delta, dim=1, unbiased=False)
    predicted_start_delta_variance = torch.mean(predicted_start_delta_variance, dim=-1)
    true_start_delta_variance = torch.var(true_start_delta, dim=1, unbiased=False)
    true_start_delta_variance = torch.mean(true_start_delta_variance, dim=-1)
    variance_loss += (predicted_start_delta_variance - true_start_delta_variance) ** 2
    
    # print("Pos delta loss", pos_delta_loss.mean())
    # print("Start distance loss", start_distance_loss.mean())
    # print("Middle distance loss", middle_distance_loss.mean())
    # print("End distance loss", end_distance_loss.mean())
    # print("Angle loss", angle_weight * (start_angle_loss.mean() + middle_angle_loss.mean() + end_angle_loss.mean()))
    # print("Y loss", y_weight * (start_y_loss.mean() + end_y_loss.mean()))
    # print("Variance loss", var_loss_weight * variance_loss.mean())
    return (pos_delta_loss.mean() + 
            start_distance_loss.mean() +
            middle_distance_loss.mean() +
            end_distance_loss.mean() + 
            angle_weight * (start_angle_loss.mean() + middle_angle_loss.mean() + end_angle_loss.mean()) +
            y_weight * (start_y_loss.mean() + end_y_loss.mean()) +
            var_loss_weight * variance_loss.mean())

def compute_mse_loss(inputs, outputs, labels):
    return nn.MSELoss()(outputs[0:12], labels[0:12])

def visualize(inputs, outputs, labels):
    rand_idx = np.random.randint(0, outputs.shape[0])
    inputs = inputs[rand_idx]
    outputs = outputs[rand_idx]
    labels = labels[rand_idx]
    
    # Draw camera path on XZ plane
    fig, ax = plt.subplots()
    plt.ioff()
    
    character_pos_x_from = inputs[:, 1].cpu().detach().numpy()
    character_pos_z_from = inputs[:, 3].cpu().detach().numpy()
    character_pos_x_to = character_pos_x_from + inputs[:, 4].cpu().detach().numpy()
    character_pos_z_to = character_pos_z_from + inputs[:, 6].cpu().detach().numpy()
    
    output_camera_pos = outputs[:, 0:3]
    output_camera_pos_offset = outputs[:, 3:6]
    output_camera_rot = outputs[:, 6:9]
    output_camera_rot_offset = outputs[:, 9:12]
    output_camera_d = outputs[:, 12]
    output_camera_d_offset = outputs[:, 13]
    
    # output_camera_pos_x_from = output_camera_pos_from[:, 0].cpu().detach().numpy()
    # output_camera_pos_z_from = output_camera_pos_from[:, 2].cpu().detach().numpy()
    # output_camera_pos_x_to = output_camera_pos_to[:, 0].cpu().detach().numpy()
    # output_camera_pos_z_to = output_camera_pos_to[:, 2].cpu().detach().numpy()
    
    labels_camera_pos = labels[:, 0:3]
    labels_camera_pos_offset = labels[:, 3:6]
    labels_camera_rot = labels[:, 6:9]
    labels_camera_rot_offset = labels[:, 9:12]
    labels_camera_d = labels[:, 12]
    labels_camera_d_offset = labels[:, 13]
    
    num_paths = outputs.shape[0]
    for i in range(num_paths):
        ax.plot([character_pos_x_from[i], character_pos_x_to[i]], [character_pos_z_from[i], character_pos_z_to[i]], label="Character path", color=np.array([0, 1, 0, i / num_paths]))
        
        # sample camera curve path
        def sample(pos, pos_offset, rot, rot_offset, d, d_offset, color):
            # for j in range(0, 100):
            #     t = j / 100
            #     p = pos + t * pos_offset
            #     r = rot + t * rot_offset
            #     d = d + t * d_offset
            #     adjusted_pos = (p + delta_pos(r, d)).cpu().detach().numpy()
            #     ax.scatter(adjusted_pos[0], adjusted_pos[2], color=color)
        
            # batched version
            t = torch.linspace(0, 1, steps=100).unsqueeze(1).to(device)
            p = pos + t * pos_offset
            r = rot + t * rot_offset
            d = (d + t * d_offset).squeeze(-1)
            adjusted_pos = (p + delta_pos(r, d)).cpu().detach().numpy()
            ax.plot(adjusted_pos[:, 0], adjusted_pos[:, 2], color=color)
        
        # ax.plot([output_camera_pos_x_from[i], output_camera_pos_x_to[i]], [output_camera_pos_z_from[i], output_camera_pos_z_to[i]], label="Predicted camera path", color=np.array([1, 0, 0, i / num_paths]))
        # ax.plot([labels_camera_pos_x_from[i], labels_camera_pos_x_to[i]], [labels_camera_pos_z_from[i], labels_camera_pos_z_to[i]], label="True camera path", color=np.array([0, 0, 1, i / num_paths]))
        sample(output_camera_pos[i], output_camera_pos_offset[i], output_camera_rot[i], output_camera_rot_offset[i], output_camera_d[i], output_camera_d_offset[i], np.array([1, 0, 0, i / num_paths]))
        sample(labels_camera_pos[i], labels_camera_pos_offset[i], labels_camera_rot[i], labels_camera_rot_offset[i], labels_camera_d[i], labels_camera_d_offset[i], np.array([0, 0, 1, i / num_paths]))
    
    ax.set_xlabel('X')
    ax.set_ylabel('Z')
    ax.grid(True)
    
    plt.show()
    
    if write_test_data:
        import time
        # timestamp as filename
        filename = f"test_data/test_{(time.time() * 1000):.0f}"
        fig.savefig(filename + ".png")
        
        json_obj = {
            "inputs": inputs.cpu().detach().numpy().tolist(),
            "outputs": outputs.cpu().detach().numpy().tolist(),
            "labels": labels.cpu().detach().numpy().tolist()
        }
        with open(f"{filename}.json", "w") as f:
            json.dump(json_obj, f)
    
    return fig

def evaluate(epoch, model, dataloader, optimizer):
    if optimizer is not None:
        model.train()
    else:
        model.eval()

    avg_loss = 0

    with torch.set_grad_enabled(optimizer is not None):
        i = 0
        length = len(dataloader)
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs, labels, optimizer is None)
            
            if optimizer is None and (epoch > 0 or not will_train) and epoch % visualize_epochs == 0:
                fig = visualize(inputs, outputs, labels)
                if use_wandb:
                    wandb.log({"epoch": epoch, "path": wandb.Image(fig)})
                plt.close(fig)

            loss = compute_loss(inputs, outputs, labels)
            l = loss.item()
            avg_loss += l
            
            if optimizer is not None:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            
            i += 1
            print(f"Epoch {epoch + 1}/{epochs}, Step {i}/{length}, Loss: {l}")
            if use_wandb and optimizer is not None:
                wandb.log({"step": epoch * length + i, "loss": l})

    avg_loss /= len(dataloader)

    return avg_loss

torch.cuda.empty_cache()
torch.set_printoptions(sci_mode=False, threshold=10000)

import os
import glob

# Hyperparameters
d_model = 512
dim_feedforward = 1024
num_heads = 2
num_encoder_layers = 6
num_decoder_layers = 1
dropout = 0.2
batch_size = 1024
epochs = 100000
learning_rate = 5e-5

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=32, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=1024, shuffle=False, num_workers=32, pin_memory=True)

test_input, test_label = next(iter(train_dataloader))
input_dim = test_input.shape[2]
output_dim = test_label.shape[2]

model = AVDCEncoderOnlyModel(input_dim=input_dim,
                              output_dim=output_dim,
                              d_model=d_model,
                              dim_feedforward=dim_feedforward,
                              num_heads=num_heads,
                              num_encoder_layers=num_encoder_layers,
                              dropout=dropout).to(device)

print("Number of parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=100000)

will_train = False
will_validate = True
will_load_checkpoint = True
write_test_data = True
visualize_epochs = 25

if write_test_data:
    visualize_epochs = 1

use_wandb = True
if not will_train:
    use_wandb = False
if use_wandb:
    wandb.init(
        # set the wandb project where this run will be logged
        project="avdc",
        # track hyperparameters and run metadata
        config={
            "input_dim": input_dim,
            "d_model": d_model,
            "dim_feedforward": dim_feedforward,
            "num_heads": num_heads,
            "num_encoder_layers": num_encoder_layers,
            "num_decoder_layers": num_decoder_layers,
            "dropout": dropout,
            "batch_size": batch_size,
            "learning_rate": learning_rate
        }
    )

# Load checkpoint
starting_epoch = 0
if will_load_checkpoint:
    # Find the latest checkpoint in ./checkpoints by modified time
    list_of_files = glob.glob('./checkpoints/*')
    if len(list_of_files) == 0:
        print("No checkpoints found.")
    else:
        latest_file = max(list_of_files, key=os.path.getctime)
        print(f"Loading checkpoint: {latest_file}")
        starting_epoch = int(latest_file.split('_')[-1].split('.')[0])
        model.load_state_dict(torch.load(latest_file))

# Training loop
for epoch in range(epochs):
    if epoch < starting_epoch:
        continue

    avg_train_loss = 0
    if will_train:
        avg_train_loss = evaluate(epoch, model, train_dataloader, optimizer)
   
    avg_val_loss = 0
    if will_validate:
        avg_val_loss = evaluate(epoch, model, val_dataloader, None)

    scheduler.step(avg_val_loss)

    print(f'Epoch [{epoch + 1}/{epochs}], LR: {scheduler.get_last_lr()}, '
          f'Train Loss: {avg_train_loss:.8f}, Val Loss: {avg_val_loss:.8f}')

    # Save checkpoint
    if epoch % 10 == 0:
        torch.save(model.state_dict(), f'checkpoints/checkpoint_{epoch}.pt')

    # if epoch % 50 == 0:
    #     pdb.set_trace()

    if use_wandb:
        wandb.log({"epoch": epoch,
                   "train_loss": avg_train_loss, "val_loss": avg_val_loss})

print("Training complete.")