In [1]:
import os
import sys
import time
import numpy as np
import torch
import random
import wandb
from datetime import datetime
from torch import optim
from torch.optim import lr_scheduler
sys.path.append(os.getcwd())
from torch.utils.data import DataLoader
from data.dataloader_dance import DanceDataset, seq_collate
from model.GroupNet_dance import GroupNet
import math

In [2]:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

class Args:
    def __init__(self):
        # Training parameters
        self.seed = 1
        self.dataset = 'dance'
        self.wandb_enabled = True  # Set to True to enable WandB logging
        self.wandb_project = 'dance-generation'  # WandB project name
        self.wandb_entity = 'vikhyat3'  # Your WandB username
        self.batch_size = 64 # batch size
        self.past_length = 8 # number of frames to condition on
        self.future_length = 10 # number of frames to predict into the future 
        self.traj_scale = 1 # scale factor applied to trajectory coordinates
        self.learn_prior = False # whether to learn prior distribution vs using fixed prior
        self.lr = 3e-4 # learning rate
        self.weight_decay = 0.0001
        self.sample_k = 20 # number of samples to generate during testing for diverse predictions
        self.num_epochs = 5 # MAKE 100
        self.decay_step = 10 # number of epochs before applying learning rate decay
        self.decay_gamma = 0.9 # learning rate decay factor
        self.print_every_it = 18 # print training stats every N iterations
        # self.test_every_it = 27 # test model every N iterations

        # Model parameters
        self.ztype = 'gaussian' # type of latent distribution: 'gaussian' or 'vmf'
        self.zdim = 32 # dimension of latent variable
        self.hidden_dim = 64 # dimension of hidden layers
        self.hyper_scales = [15,53] # scales for hyperprior ([5,11] for nba)
        self.num_decompose = 2 # number of decomposed distributions
        self.min_clip = 2.0

        # Save/load parameters
        self.model_save_dir = 'saved_models/dance'
        self.model_save_epoch = 5 # save model every N epochs
        self.epoch_continue = 0 # epoch to continue training from, 0 if training from scratch
    
    def __str__(self):
        return str(self.__dict__)

args = Args()

In [3]:
datetime_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

In [4]:
datetime_str

'2025-04-22_20-44-27'

In [5]:
""" setup """
np.random.seed(args.seed)
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.set_default_dtype(torch.float32)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('device:',device)
print(args)

# Initialize WandB
run = None
if args.wandb_enabled:
    try:
        run = wandb.init(
            project=args.wandb_project,
            entity=args.wandb_entity,
            config=args.__dict__,
        )
    except Exception as e:
        print(f"Warning: Could not initialize WandB: {e}")
        print("Training will continue without WandB logging")
        args.wandb_enabled = False

if not os.path.isdir(args.model_save_dir):
    os.makedirs(args.model_save_dir)

device: cuda
{'seed': 1, 'dataset': 'dance', 'wandb_enabled': True, 'wandb_project': 'dance-generation', 'wandb_entity': 'vikhyat3', 'batch_size': 64, 'past_length': 8, 'future_length': 10, 'traj_scale': 1, 'learn_prior': False, 'lr': 0.0003, 'weight_decay': 0.0001, 'sample_k': 20, 'num_epochs': 5, 'decay_step': 10, 'decay_gamma': 0.9, 'print_every_it': 18, 'ztype': 'gaussian', 'zdim': 32, 'hidden_dim': 64, 'hyper_scales': [15, 53], 'num_decompose': 2, 'min_clip': 2.0, 'model_save_dir': 'saved_models/dance', 'model_save_epoch': 5, 'epoch_continue': 0}


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mvikhyat3[0m. Use [1m`wandb login --relogin`[0m to force relogin


Problem at: <ipython-input-5-b1f881d4516b> 18 <module>
Training will continue without WandB logging


In [6]:
def train(train_loader, epoch):
    model.train()
    total_iter_num = len(train_loader)
    iter_num = 0
    epoch_start_time = time.time()
    epoch_loss = 0

    for data in train_loader:
        total_loss, loss_pred, loss_recover, loss_kl, loss_diverse = model(data)
        epoch_loss += total_loss.item()

        """ optimize """
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        if iter_num % args.print_every_it == 0:
            print('Epochs: {:02d}/{:02d}| It: {:04d}/{:04d} | Train loss: {:6.3f} (pred: {:6.3f}| recover: {:6.3f}| kl: {:6.3f}| diverse: {:6.3f})'
                    .format(epoch,args.num_epochs,iter_num,total_iter_num,total_loss.item(),loss_pred,loss_recover,loss_kl,loss_diverse))
            # Log metrics to WandB if enabled
            if args.wandb_enabled:
                wandb.log({
                    'train_loss': total_loss.item(),
                    'pred_loss': loss_pred,
                    'recover_loss': loss_recover,
                    'kl_loss': loss_kl,
                    'diverse_loss': loss_diverse,
                    'epoch': epoch,
                    'iteration': iter_num
                })

        iter_num += 1

    epoch_time = time.time() - epoch_start_time
    avg_loss = epoch_loss / total_iter_num
    print(f'Epoch {epoch} completed in {epoch_time:.2f} seconds. Average loss: {avg_loss:.3f}')
    
    scheduler.step()
    model.step_annealer()

In [7]:
foo = np.load('datasets/dance/train.npy') 
foo.shape

(1702, 18, 53, 6)

In [8]:
# for ct, data_test in enumerate(test_loader):
#     print(model(data_test))

In [10]:
""" model & optimizer """
model = GroupNet(args,device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.decay_step, gamma=args.decay_gamma)

train_size = 10

""" dataloader """
train_set = DanceDataset(
    obs_len=args.past_length,
    pred_len=args.future_length,
    n_samples=train_size,
    training=True)

train_loader = DataLoader(
    train_set,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=4,
    collate_fn=seq_collate,
    pin_memory=True)

test_set = DanceDataset(
    obs_len=args.past_length,
    pred_len=args.future_length,
    n_samples=5,
    training=False)

test_loader = DataLoader(
    test_set,
    batch_size=len(test_set),
    shuffle=True,
    num_workers=4,
    collate_fn=seq_collate,
    pin_memory=True)

Loaded Train data with shape: (10, 18, 53, 3)
10
Loaded Test data with shape: (5, 18, 53, 3)
5


In [11]:
train_sample = train_set[0]
train_sample[0].shape, train_sample[1].shape

(torch.Size([53, 8, 3]), torch.Size([53, 10, 3]))

In [12]:
""" Loading if needed """
if args.epoch_continue > 0:
    checkpoint_path = os.path.join(args.model_save_dir,str(args.epoch_continue)+'.p')
    print('load model from: {checkpoint_path}')
    model_load = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(model_load['model_dict'])
    if 'optimizer' in model_load:
        optimizer.load_state_dict(model_load['optimizer'])
    if 'scheduler' in model_load:
        scheduler.load_state_dict(model_load['scheduler'])

In [13]:
# show all parameter shapes for GroupNet, and count total number of parameters
total_params = 0
for name, param in model.named_parameters():
	print(name, param.data.shape)
	total_params += np.prod(param.data.shape)
print('Total number of parameters:', total_params)


past_encoder.input_fc.weight torch.Size([64, 6])
past_encoder.input_fc.bias torch.Size([64])
past_encoder.input_fc2.weight torch.Size([64, 512])
past_encoder.input_fc2.bias torch.Size([64])
past_encoder.interaction.nmp_mlp_start.MLP_distribution.layers.0.weight torch.Size([128, 64])
past_encoder.interaction.nmp_mlp_start.MLP_distribution.layers.0.bias torch.Size([128])
past_encoder.interaction.nmp_mlp_start.MLP_distribution.layers.1.weight torch.Size([6, 128])
past_encoder.interaction.nmp_mlp_start.MLP_distribution.layers.1.bias torch.Size([6])
past_encoder.interaction.nmp_mlp_start.MLP_factor.layers.0.weight torch.Size([128, 64])
past_encoder.interaction.nmp_mlp_start.MLP_factor.layers.0.bias torch.Size([128])
past_encoder.interaction.nmp_mlp_start.MLP_factor.layers.1.weight torch.Size([1, 128])
past_encoder.interaction.nmp_mlp_start.MLP_factor.layers.1.bias torch.Size([1])
past_encoder.interaction.nmp_mlp_start.init_MLP.layers.0.weight torch.Size([128, 64])
past_encoder.interaction.n

In [14]:
""" start training """
model.set_device(device)
for epoch in range(args.epoch_continue, args.num_epochs):
    train(train_loader,epoch)
    """ save model """
    if  (epoch + 1) % args.model_save_epoch == 0:
        model_saved = {'model_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'epoch': epoch + 1,'model_cfg': args}
        saved_path = os.path.join(args.model_save_dir,str(epoch+1)+'.p')
        torch.save(model_saved, saved_path)

  soft_max_1d = F.softmax(trans_input)


Epochs: 00/05| It: 0000/0001 | Train loss: 19.629 (pred:  1.573| recover: 15.814| kl:  2.000| diverse:  0.243)
Epoch 0 completed in 0.57 seconds. Average loss: 19.629
Epochs: 01/05| It: 0000/0001 | Train loss: 16.309 (pred:  0.991| recover: 13.167| kl:  2.000| diverse:  0.151)
Epoch 1 completed in 0.33 seconds. Average loss: 16.309
Epochs: 02/05| It: 0000/0001 | Train loss: 13.782 (pred:  0.743| recover: 10.927| kl:  2.000| diverse:  0.112)
Epoch 2 completed in 0.36 seconds. Average loss: 13.782
Epochs: 03/05| It: 0000/0001 | Train loss: 11.807 (pred:  0.701| recover:  9.000| kl:  2.000| diverse:  0.106)
Epoch 3 completed in 0.34 seconds. Average loss: 11.807
Epochs: 04/05| It: 0000/0001 | Train loss: 10.083 (pred:  0.770| recover:  7.199| kl:  2.000| diverse:  0.113)
Epoch 4 completed in 0.34 seconds. Average loss: 10.083


In [20]:
all_num = 0
l2error_overall = 0
l2error_dest = 0
l2error_avg_04s = 0
l2error_dest_04s = 0
l2error_avg_08s = 0
l2error_dest_08s = 0
l2error_avg_12s = 0
l2error_dest_12s = 0
l2error_avg_16s = 0
l2error_dest_16s = 0
l2error_avg_20s = 0
l2error_dest_20s = 0
l2error_avg_24s = 0
l2error_dest_24s = 0
l2error_avg_28s = 0
l2error_dest_28s = 0
l2error_avg_32s = 0
l2error_dest_32s = 0
l2error_avg_36s = 0
l2error_dest_36s = 0

for data in test_loader:
	future_traj = np.array(data['future_traj']) * args.traj_scale # B,N,T,2
	with torch.no_grad():
		prediction = model.inference(data)
	prediction = prediction * args.traj_scale
	prediction = np.array(prediction.cpu()) #(BN,20,T,2)
	batch = future_traj.shape[0]
	actor_num = future_traj.shape[1]

	y = np.reshape(future_traj,(batch*actor_num,args.future_length, 3))
	y = y[None].repeat(20,axis=0)
	l2error_avg_04s += np.mean(np.min(np.mean(np.linalg.norm(y[:,:,:1,:] - prediction[:,:,:1,:], axis = 3),axis=2),axis=0))*batch
	l2error_dest_04s += np.mean(np.min(np.mean(np.linalg.norm(y[:,:,0:1,:] - prediction[:,:,0:1,:], axis = 3),axis=2),axis=0))*batch
	l2error_avg_08s += np.mean(np.min(np.mean(np.linalg.norm(y[:,:,:2,:] - prediction[:,:,:2,:], axis = 3),axis=2),axis=0))*batch
	l2error_dest_08s += np.mean(np.min(np.mean(np.linalg.norm(y[:,:,1:2,:] - prediction[:,:,1:2,:], axis = 3),axis=2),axis=0))*batch
	l2error_avg_12s += np.mean(np.min(np.mean(np.linalg.norm(y[:,:,:3,:] - prediction[:,:,:3,:], axis = 3),axis=2),axis=0))*batch
	l2error_dest_12s += np.mean(np.min(np.mean(np.linalg.norm(y[:,:,2:3,:] - prediction[:,:,2:3,:], axis = 3),axis=2),axis=0))*batch
	l2error_avg_16s += np.mean(np.min(np.mean(np.linalg.norm(y[:,:,:4,:] - prediction[:,:,:4,:], axis = 3),axis=2),axis=0))*batch
	l2error_dest_16s += np.mean(np.min(np.mean(np.linalg.norm(y[:,:,3:4,:] - prediction[:,:,3:4,:], axis = 3),axis=2),axis=0))*batch
	l2error_avg_20s += np.mean(np.min(np.mean(np.linalg.norm(y[:,:,:5,:] - prediction[:,:,:5,:], axis = 3),axis=2),axis=0))*batch
	l2error_dest_20s += np.mean(np.min(np.mean(np.linalg.norm(y[:,:,4:5,:] - prediction[:,:,4:5,:], axis = 3),axis=2),axis=0))*batch
	l2error_avg_24s += np.mean(np.min(np.mean(np.linalg.norm(y[:,:,:6,:] - prediction[:,:,:6,:], axis = 3),axis=2),axis=0))*batch
	l2error_dest_24s += np.mean(np.min(np.mean(np.linalg.norm(y[:,:,5:6,:] - prediction[:,:,5:6,:], axis = 3),axis=2),axis=0))*batch
	l2error_avg_28s += np.mean(np.min(np.mean(np.linalg.norm(y[:,:,:7,:] - prediction[:,:,:7,:], axis = 3),axis=2),axis=0))*batch
	l2error_dest_28s += np.mean(np.min(np.mean(np.linalg.norm(y[:,:,6:7,:] - prediction[:,:,6:7,:], axis = 3),axis=2),axis=0))*batch
	l2error_avg_32s += np.mean(np.min(np.mean(np.linalg.norm(y[:,:,:8,:] - prediction[:,:,:8,:], axis = 3),axis=2),axis=0))*batch
	l2error_dest_32s += np.mean(np.min(np.mean(np.linalg.norm(y[:,:,7:8,:] - prediction[:,:,7:8,:], axis = 3),axis=2),axis=0))*batch
	l2error_avg_36s += np.mean(np.min(np.mean(np.linalg.norm(y[:,:,:9,:] - prediction[:,:,:9,:], axis = 3),axis=2),axis=0))*batch
	l2error_dest_36s += np.mean(np.min(np.mean(np.linalg.norm(y[:,:,8:9,:] - prediction[:,:,8:9,:], axis = 3),axis=2),axis=0))*batch
	l2error_overall += np.mean(np.min(np.mean(np.linalg.norm(y[:,:,:10,:] - prediction[:,:,:10,:], axis = 3),axis=2),axis=0))*batch
	l2error_dest += np.mean(np.min(np.mean(np.linalg.norm(y[:,:,9:10,:] - prediction[:,:,9:10,:], axis = 3),axis=2),axis=0))*batch
	all_num += batch

print(all_num)
l2error_overall /= all_num
l2error_dest /= all_num

l2error_avg_04s /= all_num
l2error_dest_04s /= all_num
l2error_avg_08s /= all_num
l2error_dest_08s /= all_num
l2error_avg_12s /= all_num
l2error_dest_12s /= all_num
l2error_avg_16s /= all_num
l2error_dest_16s /= all_num
l2error_avg_20s /= all_num
l2error_dest_20s /= all_num
l2error_avg_24s /= all_num
l2error_dest_24s /= all_num
l2error_avg_28s /= all_num
l2error_dest_28s /= all_num
l2error_avg_32s /= all_num
l2error_dest_32s /= all_num
l2error_avg_36s /= all_num
l2error_dest_36s /= all_num
print('##################')
print('ADE 1.0s:',(l2error_avg_08s+l2error_avg_12s)/2)
print('ADE 2.0s:',l2error_avg_20s)
print('ADE 3.0s:',(l2error_avg_32s+l2error_avg_28s)/2)
print('ADE 4.0s:',l2error_overall)

print('FDE 1.0s:',(l2error_dest_08s+l2error_dest_12s)/2)
print('FDE 2.0s:',l2error_dest_20s)
print('FDE 3.0s:',(l2error_dest_28s+l2error_dest_32s)/2)
print('FDE 4.0s:',l2error_dest)
print('##################')

5
##################
ADE 1.0s: 0.032786876894533634
ADE 2.0s: 0.05441300570964813
ADE 3.0s: 0.06626186519861221
ADE 4.0s: 0.07543409615755081
FDE 1.0s: 0.025778746232390404
FDE 2.0s: 0.05313604697585106
FDE 3.0s: 0.06723129749298096
FDE 4.0s: 0.0723644569516182
##################


In [17]:
265*10*2*3/2

7950.0