In [1]:
# @author Simon Stepputtis <sstepput@asu.edu>, Interactive Robotics Lab, Arizona State University

from __future__ import absolute_import, division, print_function, unicode_literals

from model_src.modelTorch import PolicyTranslationModelTorch
from utils.networkTorch import NetworkTorch
import hashids
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from utils.convertTFDataToPytorchData import TorchDataset
from prettytable import PrettyTable
import sys



# Learning rate for the adam optimizer
LEARNING_RATE   = 0.0001
# Weight for the attention loss
WEIGHT_ATTN     = 1.0
# Weight for the motion primitive weight loss
WEIGHT_W        = 50.0
# Weight for the trajectroy generation loss
WEIGHT_TRJ      = 5.0
# Weight for the time progression loss
WEIGHT_DT       = 14.0
# Weight for the phase prediction loss
WEIGHT_PHS      = 1.0
# Number of epochs to train
TRAIN_EPOCHS    = 1000


def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

def init_weights(network):
    nw_paras = network.named_parameters()
    for para_name, para in nw_paras:
        if 'bias' in para_name:
            para.data.fill_(0.01)

        elif 'weight' in para_name:
            torch.nn.init.orthogonal_(para)

def setupModel(device = 'cuda', batch_size = 16, path_dict = None, logname = None, model_path=None):
    print("  --> Running with default settings")
    model   = PolicyTranslationModelTorch(od_path="", glove_path=path_dict['GLOVE_PATH'], use_LSTM=False).to(device)
    train_data = TorchDataset(path = path_dict['TRAIN_DATA_TORCH'], device=device, on_device=False)
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)


    eval_data = TorchDataset(path = path_dict['VAL_DATA_TORCH'], device=device)
    eval_loader = DataLoader(eval_data, batch_size=batch_size, shuffle=True)
    network = NetworkTorch(model, data_path=path_dict['DATA_PATH'],logname=logname, lr=LEARNING_RATE, lw_atn=WEIGHT_ATTN, lw_w=WEIGHT_W, lw_trj=WEIGHT_TRJ, lw_dt=WEIGHT_DT, lw_phs=WEIGHT_PHS, gamma_sl = 1, device=device)
    network.setDatasets(train_loader=train_loader, val_loader=eval_loader)
    network.setup_model()
    if model_path is not None:
        model.load_state_dict(torch.load(model_path))
    #init_weights(network)
    count_parameters(network)

    #print(f'number of param,eters in net: {len(list(network.parameters()))} and number of applied: {i}')
    #network.load_state_dict(torch.load(MODEL_PATH), strict=True)
    #network.train(epochs=TRAIN_EPOCHS)
    return model
import os
args = ['-path', '/home/hendrik/Documents/master_project/LokalData', '-model', '/home/hendrik/Documents/master_project/LokalData/Data/Model/r0xGg7EJvnk/best/policy_translation_h']


if '-path' not in args:
    print('no path given, not executing code')
else:    
    data_path = args[args.index('-path') + 1]
    path_dict = {
    'TRAIN_DATA_TORCH' : os.path.join(data_path, 'TorchDataset/train_data_torch.txt'),
    'VAL_DATA_TORCH' : os.path.join(data_path, 'TorchDataset/val_data_torch.txt'),
    'MODEL_PATH' : os.path.join(data_path, 'TorchDataset/test_model.pth'),
    'TRAIN_DATA' : os.path.join(data_path, 'GDrive/train.tfrecord'),
    'VAL_DATA' : os.path.join(data_path, 'GDrive/validate.tfrecord'),
    'GLOVE_PATH' : os.path.join(data_path, 'GDrive/glove.6B.50d.txt'),
    'DATA_PATH' : data_path
    }

    model_path = None
    if '-model' in args:
        model_path = args[args.index('-model') + 1]

    hid             = hashids.Hashids()
    logname         = hid.encode(int(time.time() * 1000000))
    #network = setupModel(device='cuda', batch_size = 1000, path_dict = path_dict, logname=logname, model_path=model_path)
    model = setupModel(device='cuda', batch_size = 1000, path_dict = path_dict, logname=logname, model_path=model_path)
    #print(f'end saving: {path_dict["MODEL_PATH"]}')
    #torch.save(network.state_dict(), path_dict['MODEL_PATH'])



  --> Running with default settings
log dir: /home/hendrik/Documents/master_project/LokalData/gboard/nZqADWwGABE/train/
+--------------------------------------------+------------+
|                  Modules                   | Parameters |
+--------------------------------------------+------------+
|        model.attention.w1.0.weight         |    2368    |
|         model.attention.w1.0.bias          |     64     |
|        model.attention.w2.0.weight         |    2368    |
|         model.attention.w2.0.bias          |     64     |
|         model.attention.wt.weight          |     64     |
| model.controller.Cell.robot_gru.weight_ih  |    672     |
| model.controller.Cell.robot_gru.weight_hh  |    3072    |
|  model.controller.Cell.robot_gru.bias_ih   |     96     |
|  model.controller.Cell.robot_gru.bias_hh   |     96     |
|  model.controller.Cell.kin_model.0.weight  |    5852    |
|   model.controller.Cell.kin_model.0.bias   |     77     |
|  model.controller.Cell.kin_model.2.wei

In [6]:
import pickle
with open('paras_tf.pkl', 'rb') as f:
        paras_tf = pickle.load(f)

In [38]:
for para in paras_tf:
    print(f'{para}, {paras_tf[para].shape}')
    print(paras_tf[para])

gru/kernel:0, (50, 96)
[[-0.17791894 -0.70764434 -0.08299291 ... -0.05358362 -0.00458363
  -0.1709208 ]
 [ 0.07102172  0.09653635 -0.16095    ... -0.10288757  0.13903658
  -0.09596796]
 [ 0.07963923 -0.24571653 -0.2802402  ... -0.16149086 -0.06638538
  -0.12625666]
 ...
 [ 0.04019414 -0.18111318  0.18511467 ...  0.19710933  0.23814878
  -0.21418212]
 [ 0.3339912   0.29496026  0.3878791  ...  0.17702861  0.20176657
  -0.02789401]
 [ 0.5140473   0.0228003   0.18072219 ...  0.01925989 -0.04115836
  -0.196385  ]]
gru/recurrent_kernel:0, (32, 96)
[[ 0.5205785  -0.05283815  0.04602512 ...  0.12861852  0.80711615
   0.47307903]
 [-0.38277477  0.11572853  0.48367453 ... -0.14653924 -0.12454208
  -0.14414895]
 [ 0.18111008  0.08894205 -0.06654326 ...  0.02363874 -0.20021297
  -0.15873484]
 ...
 [-0.04770469 -0.2069003  -0.2604183  ...  0.1593047  -0.1393731
  -0.01547904]
 [ 0.19437614 -0.30055627 -0.37357354 ...  0.09628218  0.25283226
   0.38039055]
 [ 0.25481355  0.10311014 -0.15027852 ...  

In [17]:
torch_state_dict = model.state_dict()

In [29]:
import pickle
def load_tf_statedict(model):
    with open('paras_tf.pkl', 'rb') as f:
        paras_tf = pickle.load(f)
    torch_state_dict = model.state_dict()
    #Attention
    torch_state_dict['attention.w1.0.weight'] = torch.tensor(paras_tf['attention/time_distributed/kernel:0']).T
    torch_state_dict['attention.w1.0.bias'] = torch.tensor(paras_tf['attention/time_distributed/bias:0']).T
    torch_state_dict['attention.w2.0.weight'] = torch.tensor(paras_tf['attention/time_distributed_1/kernel:0']).T
    torch_state_dict['attention.w1.0.bias'] = torch.tensor(paras_tf['attention/time_distributed_1/bias:0']).T
    torch_state_dict['attention.wt.weight'] = torch.tensor(paras_tf['attention/time_distributed_2/kernel:0']).T

    #LanguageGRU
    torch_state_dict['lng_gru.weight_ih_l0'] = torch.tensor(paras_tf['gru/kernel:0']).T
    torch_state_dict['lng_gru.weight_hh_l0'] = torch.tensor(paras_tf['gru/recurrent_kernel:0']).T
    torch_state_dict['lng_gru.bias_ih_l0'] = torch.tensor(paras_tf['gru/bias:0'][0]).T
    torch_state_dict['lng_gru.bias_hh_l0'] = torch.tensor(paras_tf['gru/bias:0'][1]).T

    #Controller GRU
    torch_state_dict['controller.Cell.robot_gru.weight_ih'] = torch.tensor(paras_tf['rnn/gru_cell/kernel:0']).T
    torch_state_dict['controller.Cell.robot_gru.weight_hh'] = torch.tensor(paras_tf['rnn/gru_cell/recurrent_kernel:0']).T
    torch_state_dict['controller.Cell.robot_gru.bias_ih'] = torch.tensor(paras_tf['rnn/gru_cell/bias:0'][0]).T
    torch_state_dict['controller.Cell.robot_gru.bias_hh'] = torch.tensor(paras_tf['rnn/gru_cell/bias:0'][1]).T

    #ControllerKinModel
    torch_state_dict['controller.Cell.kin_model.0.weight'] = torch.tensor(paras_tf['rnn/dense_3/kernel:0']).T
    torch_state_dict['controller.Cell.kin_model.0.bias'] = torch.tensor(paras_tf['rnn/dense_3/bias:0']).T
    torch_state_dict['controller.Cell.kin_model.2.weight'] = torch.tensor(paras_tf['rnn/dense_4/kernel:0']).T
    torch_state_dict['controller.Cell.kin_model.2.bias'] = torch.tensor(paras_tf['rnn/dense_4/bias:0']).T
    torch_state_dict['controller.Cell.kin_model.4.weight'] = torch.tensor(paras_tf['rnn/dense_5/kernel:0']).T
    torch_state_dict['controller.Cell.kin_model.4.bias'] = torch.tensor(paras_tf['rnn/dense_5/bias:0']).T

    #ControllerPhaseModel
    torch_state_dict['controller.Cell.phase_model.0.weight'] = torch.tensor(paras_tf['rnn/dense_6/kernel:0']).T
    torch_state_dict['controller.Cell.phase_model.0.bias'] = torch.tensor(paras_tf['rnn/dense_6/bias:0']).T
    torch_state_dict['controller.Cell.phase_model.2.weight'] = torch.tensor(paras_tf['rnn/dense_7/kernel:0']).T
    torch_state_dict['controller.Cell.phase_model.2.bias'] = torch.tensor(paras_tf['rnn/dense_7/bias:0']).T

    #dmpdt model
    torch_state_dict['dmp_dt_model_seq.0.weight'] = torch.tensor(paras_tf['dense/kernel:0']).T
    torch_state_dict['dmp_dt_model_seq.0.bias'] = torch.tensor(paras_tf['dense/bias:0']).T
    torch_state_dict['dmp_dt_model_seq.3.weight'] = torch.tensor(paras_tf['dense_1/kernel:0']).T
    torch_state_dict['dmp_dt_model_seq.3.bias'] = torch.tensor(paras_tf['dense_1/bias:0']).T
    torch_state_dict['dmp_dt_model_seq.5.weight'] = torch.tensor(paras_tf['dense_2/kernel:0']).T
    torch_state_dict['dmp_dt_model_seq.5.bias'] = torch.tensor(paras_tf['dense_2/bias:0']).T


    model.load_state_dict(torch_state_dict, strict=True)
    return model

In [31]:
model = load_tf_statedict(model)

In [37]:
for name in torch_state_dict:
    print(f'{name}, {torch_state_dict[name].shape}')
    print(torch_state_dict[name].T)
    

attention.w1.0.weight, torch.Size([64, 37])
tensor([[ 0.1265, -0.1342,  0.0061,  ..., -0.0392, -0.4519, -0.2788],
        [-0.0390,  0.4663,  0.0977,  ...,  0.3503,  0.3941,  0.3791],
        [ 0.3282, -0.4493,  0.1372,  ..., -0.0041, -0.4718,  0.0838],
        ...,
        [ 0.1637,  0.1312,  0.0923,  ..., -0.0059, -0.1302,  0.0372],
        [ 0.2239,  0.1140, -0.1892,  ...,  0.0350,  0.1322,  0.2279],
        [-0.0310, -0.1359,  0.0577,  ..., -0.0342,  0.1836, -0.1649]])
attention.w1.0.bias, torch.Size([64])
tensor([-0.4603,  0.0729,  0.6140, -0.4336, -0.1155,  0.7804,  0.7307,  0.0689,
         0.4769,  0.0824,  0.1243,  0.6022, -0.4338, -0.1350, -0.8306,  0.5305,
         0.4693,  0.4393,  0.4598, -0.0303,  0.5425, -0.4905, -0.4256,  0.2886,
         0.1306,  0.0497,  0.3747, -0.3712, -0.2596,  0.7701,  0.7730, -0.1385,
         1.2864, -0.1683,  0.0558,  0.5368, -0.2944,  0.4716,  0.0742, -0.4483,
         0.5224,  0.6562, -0.6833, -0.2498, -0.5926, -0.4330, -0.3485, -0.0227,
    

In [43]:
grucell = nn.GRUCell(input_size=7, hidden_size=32, bias= True)

In [44]:
inpt = torch.ones((2, 7))