In [None]:
# IMPORTS
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import torch.optim as optim
import matplotlib
import matplotlib.pyplot as plt
from timeit import default_timer as timer

In [None]:
# Do the imports from other files in the project
from Model.DepthLSTM import DepthLSTM
from Train.hyperparameters import *
from Train.train_epoch_J import train_epoch_J
from Train.test_epoch_J import test_epoch_J

#### Train with preprocessed data

In [None]:
path_videos = './Data/Preprocessed_J/'
list_videos = os.listdir(path_videos)
videos_data = []

In [None]:
for video in list_videos:
    path_video = path_videos + video
    videos_data.append(np.load(path_video))

In [None]:
videos_data[15].shape

### Define test set and train set

In [None]:
testset_idx = [0, 4, 9, 12, 21]
trainset_idx = list(set(range(24)) - set(testset_idx))

In [None]:
train_videos = [videos_data[i] for i in trainset_idx]
test_videos = [videos_data[i] for i in testset_idx]

### Construct train set and test set

In [None]:
def joinVideos(videos_data):
    span_videos = []
    for video in videos_data:
        span_vid = np.ones(len(video))
        span_vid[0] = 0
        span_videos.append(span_vid)
    span_videos = np.concatenate(span_videos)
    videos_data = np.concatenate(videos_data)
    return videos_data, span_videos


def reshapeBatches(vid_data, BATCH_SIZE, SEQ_LEN):
    num_batches = vid_data.shape[0] // (BATCH_SIZE*SEQ_LEN)
    cut_vid = vid_data[:num_batches*BATCH_SIZE*SEQ_LEN,]
    reshaped_vid = cut_vid.reshape(BATCH_SIZE, num_batches, SEQ_LEN, -1)
    reordered_vid = np.transpose(reshaped_vid, (1, 2, 0, 3))
    
    return reordered_vid


def splitInputOutput(vid_data):
    ground_idx = list(range(2,201,3))
    input_idx = list(set(range(201)) - set(ground_idx))
    
    input_vid = vid_data[:,:,:,input_idx]
    output_vid = vid_data[:,:,:,ground_idx]
    
    return input_vid, output_vid

In [None]:
def create_stateful_dataset(videos, shuffle = True):
    if shuffle:
        random.shuffle(videos)
        
    videos_set, span_videos = joinVideos(videos)
    
    dataset = reshapeBatches(videos_set, BATCH_SIZE, SEQ_LEN)
    span_videos = reshapeBatches(span_videos, BATCH_SIZE, SEQ_LEN)
    
    dataset, grounddataset = splitInputOutput(dataset)
    
    dataset = dataset.astype(np.float32)
    grounddataset = grounddataset.astype(np.float32)
    
    return dataset, grounddataset, span_videos

In [None]:
videos_trainset, groundtrainset, span_videos_train = create_stateful_dataset(train_videos, shuffle = True)
videos_testset, groundtestset, span_videos_test = create_stateful_dataset(test_videos, shuffle = False)

In [None]:
videos_trainset.shape
# [NUM BATCHES, SEQ_LEN, BATCH_SIZE, NUM_KEYPOINTS]

### Training begins

In [None]:
# Load testset. Trainset is loaded at every epoch.
testset, groundtestset, span_videos_test = create_stateful_dataset(train_videos, shuffle = False)

In [None]:
# TRAIN AND TEST THE MODEL
# Initialize the model
model = DepthLSTM(HIDDEN_SIZE, NUM_LAYERS, SEQ_LEN, SEQ_LEN_TRAIN, BATCH_SIZE, NUM_JOINTS)
model.to(device)
initial_epoch = 1

In [None]:
load_model = True

In [None]:
if load_model:
    epoch = -1
    model_name = f"J_{epoch}.pt"
    model_path = "./Output/" + model_name
    model = torch.load(model_path)
    model.to(device)
    initial_epoch = epoch+1
    NUM_EPOCHS = 1000

In [None]:
# Define the loss function and the optimizer
loss_function = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
tr_loss = []
tst_loss = []
state = None
timer_beg = timer()
# Train the model for NUM_EPOCHS epochs
for epoch in range(initial_epoch, NUM_EPOCHS):
    print('Starting epoch: ', epoch)
    trainset, groundtrainset, span_videos_train = create_stateful_dataset(train_videos, shuffle = True)
    train_epoch_loss, state = train_epoch_J(model, trainset, groundtrainset, span_videos_train, optimizer, loss_function)
    test_epoch_loss, predY = test_epoch_J(model, testset, groundtestset, span_videos_test, loss_function)
    timer_end = timer()  
    if (epoch) % 10 == 0:
        # Print the training loss of this epoch
        # It is calculated as the average of losses of every window
        print('Training loss in epoch {} is: {}'.format(epoch, sum(train_epoch_loss)/len(train_epoch_loss) ))
        print('Test loss in epoch {} is: {}'.format(epoch, sum(test_epoch_loss)/len(test_epoch_loss) ))
        
        # Save model
        
    name = f"J_{epoch}.pt"
    PATH = "./Output/" + name
    torch.save(model, PATH)
    
    
    with open("./Output/j_train_loss.txt", 'a') as f:
        f.write(str(sum(train_epoch_loss)/len(train_epoch_loss)) + '\n')
        
    with open("./Output/j_test_loss.txt", 'a') as f:
        f.write(str(sum(test_epoch_loss)/len(test_epoch_loss)) + '\n')
    
    if epoch > 1:
        previous_model = f"J_{epoch-1}.pt"
        os.system(f"rm ./Output/{previous_model}")

    tr_loss.append(train_epoch_loss)
    tst_loss.append(test_epoch_loss)
    timer_beg = timer()

In [None]:
plt.plot(np.array(tr_loss).mean(axis = 1))
plt.plot(np.array(tst_loss).mean(axis = 1))

In [None]:
model

In [None]:
# name = "J_300.pt"
# PATH = "./Output/" + name
# torch.save(model, PATH)

-------------------