In [1]:
from src.main_utils import Configuration, ModelConfiguration
from src.motion_visualization_tools import test_scaling_method,test_generation_method, check_original_reconstructed_generated, compare_motion_data_plots
from src.cvae_networks import CVAE_CNN
from src.train_utils import train_model
from src.thirdHand_data_loader import get_min_max_from_dataset
import torch
import numpy as np

In [2]:
def create_the_model(device='cuda', 
                     csv_folder_path=None, 
                     tresh_l= 0.289, 
                     tresh_h_normal= 0.4, 
                     tresh_h_riz= 0.27, 
                     dist= 15, 
                     peak_dist= 30, 
                     motion_fixed_length= 20, 
                     data_item="X_centered_scaled",
                     batch_size=128, 
                     kernel_size=5, # 3 or 5 both work!
                     first_filter_size =9, # between 10 and 5
                     depth = 2, # depth should be 2, 3, 4
                     dropout = 0.1,
                     epochs = 100, 
                     latent_dim = 8,
                     rec_loss= "L1",
                     reduction= "sum",
                     kld_weight = 1e-1,
                     model_name_to_save="c_vae_model"):
    """creates the configuration objects and inits the model
    For arguemtns documentations, refer to the src/main_utils.py
    Returns:
        touple: model, project_config, model_config
    """
    project_config = Configuration(device, 
                                   csv_folder_path,
                                   tresh_l, 
                                   tresh_h_normal, 
                                   tresh_h_riz, 
                                   dist, 
                                   peak_dist, 
                                   motion_fixed_length, 
                                   data_item, 
                                   batch_size)

    model_config = ModelConfiguration(
                                    project_config.device, 
                                    kernel_size, 
                                    first_filter_size, 
                                    depth,
                                    dropout,
                                    epochs, 
                                    latent_dim,
                                    rec_loss,
                                    reduction,
                                    kld_weight,
                                    model_name_to_save)
    
    model = CVAE_CNN(project_config, model_config)
    return model, project_config, model_config

In [3]:
train_model = False
model, project_config, model_config = create_the_model()

if train_model:
    print("---------------------------------------------------------------------------------")
    print("Trainig the model from scratch, the model will be saved in ./models/{}.pt".format(model_config.model_name_to_save))
    train_model (model, project_config, model_config.epochs, model_config.model_name_to_save) 
else:
    print("---------------------------------------------------------------------------------")
    print("Loading trained model from: ./models/{}.pt".format(model_config.model_name_to_save))
    try:
        model = torch.load("./models/{}.pt".format(model_config.model_name_to_save))
    except:
        print("A trained model does not exist in the provided path, please train the model first.")

Data loaded from 11 filse, stored in a dataframe with shape (248486, 10)
Dataframe headers are: ['px', 'py', 'pz', 'v1x', 'v1y', 'v1z', 'v2x', 'v2y', 'v2z', 'hand']
Data loaded from 2 filse, stored in a dataframe with shape (54402, 10)
Dataframe headers are: ['px', 'py', 'pz', 'v1x', 'v1y', 'v1z', 'v2x', 'v2y', 'v2z', 'hand']
---------------------------------------------------------------------------------
Loading trained model from: ./models/c_vae_model.pt


In [4]:
test_generation_method(model, project_config, generation_size= 16) 

In [5]:
check_original_reconstructed_generated(model, project_config)

In [6]:
counter = 0
for d in project_config.train_iterator:
    if counter > 6:
        X = d["X"][:3, :, :]
        
        min_val, max_val = get_min_max_from_dataset(X)
            
        X_scaled = torch.zeros_like(X)
        X_centered = torch.zeros_like(X)
        X_centered_scaled = torch.zeros_like(X)

        # # finding the touching point of each motion and centering the motion on that
        center_points = torch.zeros_like(X[:, 0:1, :])        
        center_points[:, 0, :3] = X[:, 9, :3]
        X_centered =  X - center_points

        # # scaling data between 0 and 1
        X_scaled =(X - min_val) / (max_val - min_val)
        
        print (torch.sum(X -(X_scaled*(max_val - min_val)+min_val)) < 0.001)
        
        # # scaling the centered data between 0 and 1
        centered_min_val, centered_max_val = get_min_max_from_dataset(X_centered)
        
        x_samples = d[project_config.data_item]
        y_samples = d["Y"]
        
        x_rec, __, __ = model(x_samples, y_samples)
        
        centered_min_val= d["centered_min_val"][0]
        centered_max_val= d["centered_max_val"][0]
        
        X_centered_scaled = (X_centered - centered_min_val) / (centered_max_val - centered_min_val)
        X_back_to_orig = (X_centered_scaled*(centered_max_val - centered_min_val)+centered_min_val)+center_points
            

        fig_01 = compare_motion_data_plots([X, X_back_to_orig], 0) 
        break
    
    counter +=1

tensor(True, device='cuda:0')
