In [1]:
from transformers import VivitConfig

from model.vivit import VivitPose
from preprocess_data import get_video_datasets, getSMPLXParams
from torch.utils.data import DataLoader
from euler_to_rot import euler_to_rotation_matrix_zyz_tensor
import torch
import time
device = 'cuda'
configuration = VivitConfig()
vivit_num_frames = 3
configuration.num_labels = 69
configuration.num_frames=vivit_num_frames
model = VivitPose(configuration).to(device)
train_data, test_data = get_video_datasets(slice_len=vivit_num_frames)
train_loader = DataLoader(train_data, batch_size=10, shuffle=True)


In [2]:
#clear gc, cuda cache
import gc
gc.collect()
torch.cuda.empty_cache()

In [3]:
history = []
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
model.train()
epochs = 3
start_time = time.time()
last_time = time.time()
len_iter = len(train_loader)
checkpoint_dir = 'checkpoints'
for epoch in range(epochs):
    for i, v in enumerate(train_loader):
        # cur = time.time()
        frames, labels, _ = v
        frames = frames.to(device)
        labels = labels.to(device)
        #print('Time for data transfer: %d seconds' % (time.time() - cur))
        optimizer.zero_grad()
        # cur = time.time()
        outputs = model(frames)
        batch_size = outputs.shape[0]
        global_orient, body_pose, transl = getSMPLXParams(outputs)
        global_orient = euler_to_rotation_matrix_zyz_tensor(global_orient)
        body_pose = euler_to_rotation_matrix_zyz_tensor(body_pose)
        global_orient = global_orient.reshape(batch_size, -1)
        body_pose = body_pose.reshape(batch_size, -1)
        outputs_rot = torch.cat((global_orient, body_pose, transl), dim=1)
        #print('Time for forward pass: %d seconds' % (time.time() - cur))
        # cur = time.time()
        loss = criterion(outputs_rot, labels)
        loss.backward()
        # print('Time for backward pass: %d seconds' % (time.time() - cur))
        # cur = time.time()
        optimizer.step()
        # print('Time for optimizer step: %d seconds' % (time.time() - cur))
        if i % 10 == 0:
            print('Epoch: %d, Iteration: %d/%d, Loss: %f' % (epoch, i,len_iter, loss.item()))
            history.append(loss.item())
        if i % 100 == 0:
            time_elapsed = time.time() - last_time
            print('Time elapsed: %d seconds' % time_elapsed)
            last_time = time.time()
            #estimate time remaining
            time_remaining = time_elapsed * (len_iter - i)/100
            print('Estimated time remaining for epoch: %d seconds' % time_remaining)
        if i % 2000 == 0: # vivit_epoch_iteration_rot_MSE.pth
            torch.save(model.state_dict(), '%s/vivit_epoch_%d_iteration_%d_rot_MSE.pth' % (checkpoint_dir, epoch, i))

Epoch: 0, Iteration: 0/122857, Loss: 0.225386
Time elapsed: 0 seconds
Estimated time remaining for epoch: 621 seconds
Epoch: 0, Iteration: 10/122857, Loss: 0.063345
Epoch: 0, Iteration: 20/122857, Loss: 0.047326
Epoch: 0, Iteration: 30/122857, Loss: 0.055381
Epoch: 0, Iteration: 40/122857, Loss: 0.048625
Epoch: 0, Iteration: 50/122857, Loss: 0.031556
Epoch: 0, Iteration: 60/122857, Loss: 0.063781
Epoch: 0, Iteration: 70/122857, Loss: 0.044184
Epoch: 0, Iteration: 80/122857, Loss: 0.044803
Epoch: 0, Iteration: 90/122857, Loss: 0.030572
Epoch: 0, Iteration: 100/122857, Loss: 0.047886
Time elapsed: 17 seconds
Estimated time remaining for epoch: 21217 seconds
Epoch: 0, Iteration: 110/122857, Loss: 0.044007
Epoch: 0, Iteration: 120/122857, Loss: 0.050441
Epoch: 0, Iteration: 130/122857, Loss: 0.049348
Epoch: 0, Iteration: 140/122857, Loss: 0.038448
Epoch: 0, Iteration: 150/122857, Loss: 0.055801
Epoch: 0, Iteration: 160/122857, Loss: 0.040762
Epoch: 0, Iteration: 170/122857, Loss: 0.034749
