In [5]:
from transformers import VivitConfig

from model.vivit import VivitPose
from preprocess_data import get_video_datasets
from torch.utils.data import DataLoader

import torch
import time
device = 'cuda'
configuration = VivitConfig()
vivit_num_frames = 3
configuration.num_labels = 69
configuration.num_frames=vivit_num_frames
model = VivitPose(configuration).to(device)
train_data, test_data = get_video_datasets(slice_len=vivit_num_frames)
train_loader = DataLoader(train_data, batch_size=12, shuffle=True)


In [6]:
#clear gc, cuda cache
import gc
gc.collect()
torch.cuda.empty_cache()

In [7]:
history = []
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
model.train()
epochs = 1
start_time = time.time()
last_time = time.time()
len_iter = len(train_loader)
checkpoint_dir = 'checkpoints'
for epoch in range(epochs):
    for i, v in enumerate(train_loader):
        # cur = time.time()
        frames, labels, _ = v
        frames = frames.to(device)
        labels = labels.to(device)
        #print('Time for data transfer: %d seconds' % (time.time() - cur))
        optimizer.zero_grad()
        # cur = time.time()
        outputs = model(frames)
        #print('Time for forward pass: %d seconds' % (time.time() - cur))
        # cur = time.time()
        loss = criterion(outputs, labels)
        loss.backward()
        # print('Time for backward pass: %d seconds' % (time.time() - cur))
        # cur = time.time()
        optimizer.step()
        # print('Time for optimizer step: %d seconds' % (time.time() - cur))
        if i % 10 == 0:
            print('Epoch: %d, Iteration: %d/%d, Loss: %f' % (epoch, i,len_iter, loss.item()))
            history.append(loss.item())
        if i % 100 == 0:
            time_elapsed = time.time() - last_time
            print('Time elapsed: %d seconds' % time_elapsed)
            last_time = time.time()
            #estimate time remaining
            time_remaining = time_elapsed * (len_iter - i)/100
            print('Estimated time remaining for epoch: %d seconds' % time_remaining)
        if i % 2000 == 0:
            torch.save(model.state_dict(), '%s/vivit_%d.pth' % (checkpoint_dir, i))

Epoch: 0, Iteration: 0/102381, Loss: 0.504296
Time elapsed: 0 seconds
Estimated time remaining for epoch: 550 seconds
Epoch: 0, Iteration: 10/102381, Loss: 0.216484
Epoch: 0, Iteration: 20/102381, Loss: 0.205415
Epoch: 0, Iteration: 30/102381, Loss: 0.221483
Epoch: 0, Iteration: 40/102381, Loss: 0.273071
Epoch: 0, Iteration: 50/102381, Loss: 0.203067
Epoch: 0, Iteration: 60/102381, Loss: 0.220044
Epoch: 0, Iteration: 70/102381, Loss: 0.180068
Epoch: 0, Iteration: 80/102381, Loss: 0.176031
Epoch: 0, Iteration: 90/102381, Loss: 0.171300
Epoch: 0, Iteration: 100/102381, Loss: 0.227132
Time elapsed: 19 seconds
Estimated time remaining for epoch: 19515 seconds
Epoch: 0, Iteration: 110/102381, Loss: 0.240493
Epoch: 0, Iteration: 120/102381, Loss: 0.166484
Epoch: 0, Iteration: 130/102381, Loss: 0.201197
Epoch: 0, Iteration: 140/102381, Loss: 0.213970
Epoch: 0, Iteration: 150/102381, Loss: 0.244240
Epoch: 0, Iteration: 160/102381, Loss: 0.190421
Epoch: 0, Iteration: 170/102381, Loss: 0.268864
