In [1]:
import numpy as np
import pandas as pd
import os
from scipy.fftpack import dct



import unicodedata
import string
import re
import random
import time
import math


from utils import token_regularizer
from utils import pair_files
from utils import load_encode
from torch.utils.data import Dataset, DataLoader, random_split

from collections import defaultdict

import json
import random

from dataset import Seq2SeqDataset, seq2seq_collate_fn


import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F

*The path of the data

In [2]:
the_path = '/afs/inf.ed.ac.uk/group/cstr/projects/galatea/d02'
data_path_motion = the_path + '/Recordings_October_2014/DOF-hiroshi/'
data_path_text = the_path + '/Recordings_October_2014/Transcriptions/transcriptions_phrase_tables/'

paired_file_paths = pair_files(data_path_text, '.TABLE', data_path_motion,'.qtn')

In [3]:
    def is_number(s: str) -> bool:
        try:
            float(s)
            return True
        except ValueError:
            return False

In [4]:
total_words = defaultdict(dict)
for i in range(len(paired_file_paths)):
    with open(paired_file_paths[i][1],'r') as f:
        for line in f.readlines():
            line = line.strip().split()
            for j in range(3,len(line)):
                words = token_regularizer(line[j])
                for w in words:
                    if(w not in total_words):
                        total_words[w] = 1
                    else:
                        total_words[w] += 1
                
                
# print(total_words)

##所有文件中的词个数

In [5]:
vocab_to_int = {}
value = 0

for word in total_words.keys():
    if '-' not in word:
        vocab_to_int[word] = value
        value += 1
codes = ["<UNK>","<PAD>","<EOS>","<GO>","<Stammer>","<Long>"]

for code in codes:
    vocab_to_int[code] = len(vocab_to_int)

int_to_vocab = {}
for word,value in vocab_to_int.items():
    int_to_vocab[value] = word


In [23]:
a = ['Soph_02_e_Soph,2.21,2.6,COOL,221,260']
a[0].split(',')

['Soph_02_e_Soph', '2.21', '2.6', 'COOL', '221', '260']

In [6]:
def convert_to_ints(text):
    ints = []
    ints.append(vocab_to_int["<GO>"])
    for sentense in text:
        sentense_ints = 0
        for word in sentense.split():
            if word in vocab_to_int:
                sentense_ints = vocab_to_int[word]
            elif word[len(word)-1] == '-':
                sentense_ints = vocab_to_int["<Stammer>"]
            elif '-' in word:
                sentense_ints = vocab_to_int["<Long>"]
            else:
                sentense_ints = vocab_to_int["<UNK>"]
#         if eos:
#             sentense_ints.append(vocab_to_int["<EOS>"])
        ints.append(sentense_ints)
    ints.append(vocab_to_int["<EOS>"])
    return ints
    

In [16]:
def seq2seq_preprocess(transcript_path: str, motion_path: str, 
                       dictionary: dict) -> (np.ndarray, np.ndarray):
    transcripts = []
    intervals = []

    prev_sentense = ''
    prev_interval = [0,0]
    with open(transcript_path, 'r') as f:
        for line in f.readlines():
            line = line.strip().split()
            if not is_number(line[1]) or not is_number(line[2]):
                continue
            start_time = int(float(line[1]) * 100)
            end_time = int(float(line[2])*100)
            text = line[3:]
            
#             if len(text) < 5:
#                 continue 
#                 去掉过短句子
            
            if(float(line[1]) - prev_interval[1]/100.0 <= 0.5):
                prev_sentense += text
                prev_interval[1] = end_time
            else:
                transcripts.append(convert_to_ints(prev_sentense))
                intervals.append(prev_interval)
                prev_sentense = text
                prev_interval = [start_time,end_time]

                
                

    motions = np.loadtxt(motion_path, usecols=range(4), skiprows=17, 
                         dtype='float')
    
    num_dof = 4
    targets = []
    for period in intervals[1:]:
        start_time = period[0]
        end_time = period[1]

        temp_motion = motions[start_time:end_time]
        temp_motion = np.array(temp_motion)

        targets.append(temp_motion)
        
    inputs =  np.array(transcripts[1:])
    print(inputs)
    print(targets)
#     print(len(inputs))
#     print(len(targets))
    return inputs, targets

# make_dataset(seq2seq_preprocess, paired_file_paths, 'n',
#                 './data/extro_seq2seq_dataset')
# print('extroverted data finished')

In [8]:
word2idx = vocab_to_int

In [17]:
def make_dataset(process_function, paired_file_paths: list, ptype: str,
                 save_path: str) -> None:
    """"Calls preprocess methods and save processed data.

    Args:
        process_function: the function for proprecessing, it can be
            'dct_preprocess' or 'seq2seq_preprocess'.
        paired_file_paths: the output from 'pair_files' function. Its format is
            [NAME_NUM_[ine], actual text path, actual motion path]
        ptype: type of personality, can be '(e)xtroverted', '(i)ntroverted' and
            '(n)atural'.
        save_path: path to save the dataset
    """
    assert ptype in {
        'e', 'i', 'n'
    }, ('personality type should be "e","i" or "n", not %r' % ptype)

    all_input = []
    all_target = []

    input_train = []
    input_valid = []
    input_test = []

    target_train = []
    target_valid = []
    target_test = []
    # apply 'preprocess' method to for each text/motion pair
    for pair in paired_file_paths:
        # the last letter of file name means speaker's personality

        if pair[0][-1] == ptype:
            inputs, targets = process_function(pair[1], pair[2], word2idx)
            all_input.append(inputs)
            all_target.append(targets)

    train_size = int(0.8 * len(all_input))
    valid_size = int(0.1 * len(all_input))
    test_size = len(all_input) - train_size - valid_size

    for i in range(0, train_size):
        for j in range(0, len(all_input[i])):
            input_train.append(all_input[i][j])


    for i in range(train_size, train_size + valid_size):
        for j in range(0, len(all_input[i])):
            input_valid.append(all_input[i][j])

    for i in range(train_size+valid_size, len(all_input)):
        for j in range(0, len(all_input[i])):
            input_test.append(all_input[i][j])



    for i in range(0, train_size):
        for j in range(0, len(all_target[i])):
            target_train.append(all_target[i][j])


    for i in range(train_size, train_size + valid_size):
        for j in range(0, len(all_target[i])):
            target_valid.append(all_target[i][j])

    for i in range(train_size+valid_size, len(all_target)):
        for j in range(0, len(all_target[i])):
            target_test.append(all_target[i][j])

    print(len(input_train))
    np.savez(save_path+"_train.npz", input=input_train, target=target_train)
    np.savez(save_path+"_valid.npz", input=input_valid, target=target_valid)
    np.savez(save_path+"_test.npz", input=input_test, target=target_test)
    
# make three kinds of dataset
make_dataset(seq2seq_preprocess, paired_file_paths, 'e',
                './data/extro_seq2seq_dataset')

[list([3973, 0, 1, 2, 3, 2, 4, 5, 6, 7, 8, 3972])
 list([3973, 9, 2, 3974, 2, 11, 12, 2, 11, 12, 13, 11, 14, 15, 16, 17, 3972])
 list([3973, 3970, 19, 3970, 22, 23, 24, 25, 3970, 20, 27, 2, 11, 12, 28, 15, 29, 7, 30, 16, 23, 31, 32, 23, 33, 3972])
 list([3973, 3970, 34, 35, 36, 3970, 38, 3970, 11, 40, 41, 42, 11, 43, 44, 3970, 3970, 46, 2, 47, 32, 11, 42, 48, 49, 32, 50, 51, 23, 52, 3970, 53, 54, 32, 55, 56, 32, 57, 32, 11, 58, 23, 31, 2, 19, 58, 23, 59, 31, 60, 32, 50, 51, 61, 62, 21, 35, 11, 5, 27, 3972])
 list([3973, 27, 63, 64, 2, 2, 11, 12, 14, 15, 16, 50, 65, 66, 3972])
 list([3973, 9, 64, 64, 3970, 65, 25, 12, 3970, 65, 25, 12, 3972])
 list([3973, 9, 3970, 5, 3970, 5, 3972])
 list([3973, 67, 23, 68, 3974, 9, 3970, 3970, 72, 73, 2, 74, 75, 23, 5, 76, 77, 78, 5, 79, 3972])
 list([3973, 3974, 81, 2, 82, 83, 3972])
 list([3973, 9, 5, 84, 3970, 85, 3974, 3970, 85, 81, 2, 81, 2, 86, 87, 88, 71, 2, 89, 51, 23, 90, 91, 51, 92, 3972])
 list([3973, 93, 63, 3970, 85, 58, 27, 3970, 85, 94, 

[array([[ 0.992979  , -0.0322815 ,  0.0246271 ,  0.111101  ],
       [ 0.992685  , -0.0322274 ,  0.0267859 ,  0.113224  ],
       [ 0.992564  , -0.0320446 ,  0.0282164 ,  0.113991  ],
       ...,
       [ 0.996237  ,  0.0655421 ,  0.00639988, -0.0563455 ],
       [ 0.996245  ,  0.0649604 ,  0.0067312 , -0.0568417 ],
       [ 0.996245  ,  0.0643916 ,  0.00703811, -0.0574505 ]]), array([[ 0.997679  , -0.00396967,  0.0255913 ,  0.0629802 ],
       [ 0.997576  , -0.00474889,  0.0250169 ,  0.0647527 ],
       [ 0.997483  , -0.00549194,  0.0244312 ,  0.066343  ],
       ...,
       [ 0.99004   , -0.129474  , -0.0502541 ,  0.0230771 ],
       [ 0.989889  , -0.130362  , -0.0502551 ,  0.024494  ],
       [ 0.989758  , -0.131091  , -0.0503125 ,  0.0257292 ]]), array([[ 0.986983  , -0.116026  , -0.0039279 ,  0.111298  ],
       [ 0.986962  , -0.115807  , -0.0030584 ,  0.111737  ],
       [ 0.986968  , -0.115366  , -0.00227497,  0.112161  ],
       [ 0.987025  , -0.114699  , -0.00173339,  0.112347

[array([[ 0.999151  ,  0.00443318, -0.0216118 ,  0.0347805 ],
       [ 0.999112  ,  0.00297189, -0.0222101 ,  0.0356767 ],
       [ 0.999443  ,  0.0136    , -0.0211011 ,  0.022004  ],
       ...,
       [ 0.998636  ,  0.0420457 ,  0.0287345 , -0.0115073 ],
       [ 0.998596  ,  0.0431509 ,  0.0281922 , -0.0121978 ],
       [ 0.998628  ,  0.0427413 ,  0.0278678 , -0.0118176 ]]), array([[ 9.98900e-01,  4.27109e-02,  1.37491e-02, -1.36004e-02],
       [ 9.98861e-01,  4.36102e-02,  1.28115e-02, -1.44876e-02],
       [ 9.98883e-01,  4.33092e-02,  1.24425e-02, -1.42311e-02],
       [ 9.98875e-01,  4.39625e-02,  1.31794e-02, -1.19030e-02],
       [ 9.98908e-01,  4.33007e-02,  1.24970e-02, -1.23484e-02],
       [ 9.98956e-01,  4.24011e-02,  1.12405e-02, -1.27318e-02],
       [ 9.99101e-01,  3.95392e-02,  1.12268e-02, -1.03626e-02],
       [ 9.99249e-01,  3.63123e-02,  1.09290e-02, -7.90608e-03],
       [ 9.99265e-01,  3.59116e-02,  1.00385e-02, -8.92902e-03],
       [ 9.99405e-01,  3.27266e-02

[list([3973, 196, 1, 71, 2, 3972])
 list([3973, 9, 63, 63, 3970, 523, 36, 20, 168, 76, 3970, 462, 20, 20, 188, 133, 270, 2, 3970, 39, 134, 27, 63, 168, 3972])
 list([3973, 63, 3970, 188, 76, 3972])
 list([3973, 495, 495, 164, 7, 1062, 474, 20, 188, 168, 3972])
 list([3973, 168, 3972])
 list([3973, 27, 3970, 85, 2, 2, 19, 3970, 34, 1063, 76, 168, 63, 3974, 168, 3972])
 list([3973, 27, 211, 21, 21, 11, 3970, 26, 72, 23, 1064, 1065, 3972])
 list([3973, 3974, 270, 3972])
 list([3973, 63, 63, 63, 3970, 151, 3970, 26, 65, 1066, 134, 3970, 58, 3972])
 list([3973, 2, 19, 2, 19, 302, 2, 194, 2, 194, 23, 1062, 35, 3970, 26, 3972])
 list([3973, 270, 3970, 3970, 58, 99, 115, 20, 58, 35, 134, 3972])
 list([3973, 3970, 85, 3970, 58, 95, 3970, 248, 2, 95, 3970, 248, 2, 3972])
 list([3973, 117, 3970, 930, 270, 3974, 117, 199, 2, 191, 115, 3972])
 list([3973, 27, 3972])
 list([3973, 3974, 270, 3970, 182, 72, 1067, 134, 168, 95, 95, 3970, 3970, 151, 3970, 11, 2, 19, 1064, 1063, 32, 3970, 194, 23, 139, 1

[list([3973, 1237, 3972])
 list([3973, 9, 133, 99, 283, 2, 9, 495, 164, 625, 202, 3970, 435, 20, 98, 11, 259, 652, 35, 1238, 124, 77, 3970, 435, 9, 11, 499, 202, 3974, 9, 470, 3972])
 list([3973, 9, 3974, 9, 208, 3970, 77, 1240, 32, 3970, 77, 1241, 51, 120, 32, 187, 2, 499, 9, 470, 3972])
 list([3973, 64, 257, 2, 11, 1034, 20, 1242, 3970, 11, 9, 64, 3970, 77, 23, 133, 811, 9, 3970, 77, 272, 1243, 3970, 452, 3970, 191, 34, 20, 3972])
 list([3973, 93, 81, 2, 19, 117, 3970, 224, 58, 99, 81, 23, 421, 1244, 302, 3970, 75, 3974, 66, 499, 202, 3970, 11, 3970, 248, 370, 117, 3970, 128, 370, 32, 11, 113, 370, 283, 95, 131, 436, 878, 3970, 11, 27, 27, 2, 49, 164, 35, 151, 174, 64, 3972])
 list([3973, 27, 156, 3972])
 list([3973, 27, 174, 64, 3970, 1113, 61, 76, 1245, 174, 156, 61, 21, 65, 1014, 155, 174, 156, 76, 64, 3970, 75, 23, 47, 103, 35, 283, 117, 2, 179, 75, 23, 1246, 75, 23, 1246, 1030, 51, 35, 3972])
 list([3973, 27, 27, 27, 133, 35, 2, 58, 20, 1247, 3972])
 list([3973, 9, 63, 63, 75, 7

[list([3973, 9, 365, 3970, 3970, 2, 98, 26, 1250, 7, 25, 3970, 180, 3970, 12, 204, 858, 3970, 11, 377, 15, 3970, 64, 854, 260, 3970, 77, 99, 3974, 58, 3970, 11, 26, 12, 101, 99, 1180, 29, 32, 3970, 250, 3970, 185, 1279, 20, 29, 3970, 11, 194, 23, 1280, 164, 20, 64, 458, 3970, 11, 388, 3972])
 list([3973, 3970, 11, 276, 3974, 3972])
 list([3973, 3970, 12, 23, 1281, 3972])
 list([3973, 3974, 1282, 1281, 3970, 11, 50, 3970, 276, 99, 81, 174, 458, 21, 11, 578, 32, 11, 1283, 365, 118, 1284, 3970, 11, 276, 99, 75, 23, 3970, 51, 7, 3970, 1287, 32, 11, 1288, 3970, 204, 1289, 174, 458, 3972])
 list([3973, 258, 3970, 5, 164, 11, 58, 1290, 99, 721, 1291, 3970, 11, 58, 3970, 471, 100, 1292, 16, 61, 3970, 228, 36, 1293, 793, 3970, 72, 11, 3970, 276, 99, 81, 61, 1294, 3972])
 list([3973, 258, 3972])
 list([3973, 3970, 219, 1282, 1281, 3970, 11, 11, 64, 249, 131, 618, 1295, 153, 342, 32, 3970, 11, 3970, 3970, 920, 23, 631, 3970, 3970, 618, 705, 3970, 3970, 618, 810, 3970, 661, 64, 249, 3970, 436, 81,

[array([[ 0.993248  , -0.00679037,  0.0151165 ,  0.11482   ],
       [ 0.993182  , -0.00769794,  0.0154666 ,  0.115289  ],
       [ 0.993063  , -0.00859377,  0.0160741 ,  0.116165  ],
       [ 0.992922  , -0.00953516,  0.0167812 ,  0.117188  ],
       [ 0.992773  , -0.0104171 ,  0.0177002 ,  0.11824   ],
       [ 0.992609  , -0.0112663 ,  0.0188747 ,  0.119347  ],
       [ 0.992453  , -0.0120596 ,  0.0201497 ,  0.120356  ],
       [ 0.992392  , -0.0128129 ,  0.0210327 ,  0.120633  ],
       [ 0.992378  , -0.0134301 ,  0.0218131 ,  0.120542  ],
       [ 0.992413  , -0.0138652 ,  0.0224569 ,  0.120083  ],
       [ 0.992559  , -0.0142339 ,  0.0225837 ,  0.1188    ],
       [ 0.992761  , -0.0145491 ,  0.0224965 ,  0.117077  ],
       [ 0.992918  , -0.0147941 ,  0.0226563 ,  0.115677  ],
       [ 0.993056  , -0.0149627 ,  0.0228754 ,  0.114425  ],
       [ 0.993191  , -0.0151905 ,  0.0229746 ,  0.113197  ],
       [ 0.993289  , -0.0154037 ,  0.0232291 ,  0.112254  ],
       [ 0.993358  , -0

[array([[ 0.9977    , -0.0127803 ,  0.0658021 , -0.0100999 ],
       [ 0.997542  , -0.0133107 ,  0.0677806 , -0.0117833 ],
       [ 0.997436  , -0.0113403 ,  0.0701027 , -0.00882465],
       ...,
       [ 0.999724  ,  0.00241301, -0.00465546,  0.0228812 ],
       [ 0.999687  ,  0.00252047, -0.00610516,  0.0241246 ],
       [ 0.999668  ,  0.00253028, -0.00663242,  0.0247498 ]]), array([[ 9.96407e-01, -1.04754e-02, -8.32130e-02,  1.18295e-02],
       [ 9.96755e-01, -8.02108e-03, -7.89695e-02,  1.33696e-02],
       [ 9.96931e-01, -6.49403e-03, -7.64186e-02,  1.57244e-02],
       [ 9.96898e-01, -6.73946e-03, -7.67373e-02,  1.61373e-02],
       [ 9.96773e-01, -9.42892e-03, -7.85324e-02,  1.36841e-02],
       [ 9.97106e-01, -9.09919e-03, -7.43868e-02,  1.28136e-02],
       [ 9.97115e-01, -8.75659e-03, -7.43337e-02,  1.26442e-02],
       [ 9.97110e-01, -8.71225e-03, -7.42811e-02,  1.33210e-02],
       [ 9.96951e-01, -9.33416e-03, -7.60022e-02,  1.50029e-02],
       [ 9.96958e-01, -6.53957e-03

[array([[ 0.991729  , -0.0144912 ,  0.00647217,  0.127364  ],
       [ 0.991691  , -0.0142846 ,  0.00677085,  0.127669  ],
       [ 0.991659  , -0.0142264 ,  0.00687161,  0.127921  ],
       [ 0.991574  , -0.0140852 ,  0.00711952,  0.12858   ],
       [ 0.991575  , -0.0140051 ,  0.00715379,  0.128572  ],
       [ 0.991547  , -0.0137979 ,  0.00690673,  0.128827  ],
       [ 0.991665  , -0.0137347 ,  0.00684873,  0.127924  ],
       [ 0.991743  , -0.0134731 ,  0.00672209,  0.127355  ],
       [ 0.991872  , -0.0131779 ,  0.00643369,  0.126391  ],
       [ 0.992014  , -0.0127448 ,  0.00635702,  0.125321  ],
       [ 0.992184  , -0.0122649 ,  0.00638596,  0.124016  ],
       [ 0.992309  , -0.0117105 ,  0.00645517,  0.123063  ],
       [ 0.992522  , -0.0112106 ,  0.00667486,  0.121367  ],
       [ 0.992804  , -0.0108368 ,  0.00702918,  0.119048  ],
       [ 0.993008  , -0.0104803 ,  0.00731424,  0.117354  ],
       [ 0.993227  , -0.0101755 ,  0.00743178,  0.115501  ],
       [ 0.993402  , -0

[array([[ 9.99915e-01, -6.60716e-03, -7.77823e-03, -8.08634e-03],
       [ 9.99914e-01, -6.65142e-03, -7.86420e-03, -8.09415e-03],
       [ 9.99914e-01, -6.64260e-03, -7.85233e-03, -8.13813e-03],
       ...,
       [ 9.99876e-01,  1.48980e-02,  4.60735e-03,  2.30722e-03],
       [ 9.99881e-01,  1.47656e-02,  4.34108e-03,  7.42823e-04],
       [ 9.99868e-01,  1.54965e-02,  4.61406e-03, -1.56346e-03]]), array([[ 9.99583e-01, -1.08591e-02, -9.51224e-04,  2.67290e-02],
       [ 9.99589e-01, -1.06145e-02, -9.97955e-04,  2.66009e-02],
       [ 9.99568e-01, -1.10583e-02, -1.08590e-03,  2.72214e-02],
       [ 9.99598e-01, -1.01712e-02, -7.96595e-04,  2.64577e-02],
       [ 9.99585e-01, -1.04220e-02, -7.26101e-04,  2.68620e-02],
       [ 9.99542e-01, -1.15091e-02, -8.39827e-04,  2.79620e-02],
       [ 9.99544e-01, -1.16306e-02, -7.47757e-04,  2.78708e-02],
       [ 9.99569e-01, -1.13130e-02, -4.55655e-04,  2.70986e-02],
       [ 9.99544e-01, -1.23042e-02, -4.58894e-04,  2.75841e-02],
       [ 9

[array([[ 0.999432  , -0.0125676 , -0.0175037 ,  0.0258954 ],
       [ 0.999419  , -0.012411  , -0.0175266 ,  0.0264563 ],
       [ 0.999387  , -0.0104052 , -0.0158399 ,  0.0294316 ],
       ...,
       [ 0.999814  , -0.0110693 , -0.00668057, -0.0143113 ],
       [ 0.999837  , -0.0111124 , -0.00671104, -0.0125402 ],
       [ 0.999845  , -0.0102281 , -0.0074048 , -0.0122794 ]]), array([[ 0.995491 ,  0.031054 , -0.0174164,  0.0879227],
       [ 0.995386 ,  0.031779 , -0.0177041,  0.0887935],
       [ 0.995201 ,  0.0330533, -0.0179465,  0.090332 ],
       ...,
       [ 0.999457 ,  0.0141463, -0.0163579,  0.0248441],
       [ 0.99948  ,  0.0130575, -0.0163241,  0.0245432],
       [ 0.999472 ,  0.0127878, -0.01589  ,  0.0253053]]), array([[ 0.998448  , -0.0164577 , -0.0259769 ,  0.0464323 ],
       [ 0.998532  , -0.0165238 , -0.0267171 ,  0.0441173 ],
       [ 0.998606  , -0.016166  , -0.0275929 ,  0.0419842 ],
       ...,
       [ 0.99913   , -0.02507   , -0.00925378,  0.0320235 ],
       

[list([3973, 64, 3970, 188, 254, 117, 3970, 100, 326, 16, 61, 76, 3970, 61, 245, 117, 81, 2, 179, 3970, 188, 254, 36, 11, 1012, 20, 1914, 239, 99, 7, 1123, 2, 19, 253, 72, 875, 11, 67, 3974, 93, 7, 3970, 40, 306, 461, 78, 120, 15, 3972])
 list([3973, 88, 3972])
 list([3973, 801, 89, 51, 219, 1616, 67, 239, 2245, 136, 15, 51, 7, 469, 117, 81, 2, 179, 3972])
 list([3973, 93, 2, 19, 2, 231, 223, 99, 100, 136, 153, 132, 93, 2, 262, 115, 117, 81, 2, 179, 3970, 1036, 3972])
 list([3973, 172, 3970, 188, 350, 3972])
 list([3973, 3970, 228, 36, 2246, 49, 16, 985, 3972])
 list([3973, 99, 100, 1882, 3972]) list([3973, 27, 3972])
 list([3973, 64, 3970, 3974, 3970, 11, 1024, 99, 179, 2, 19, 81, 3970, 1070, 985, 23, 851, 861, 81, 3970, 3972])
 list([3973, 2, 19, 566, 20, 23, 72, 2247, 105, 32, 187, 270, 2, 199, 239, 99, 115, 27, 3972])
 list([3973, 93, 3972])
 list([3973, 93, 3970, 66, 3970, 23, 814, 186, 164, 78, 2248, 3970, 85, 2, 81, 179, 51, 114, 330, 23, 814, 174, 3970, 2, 3972])
 list([3973, 2

[array([[ 0.988394 , -0.127391 , -0.0214585, -0.079932 ],
       [ 0.988634 , -0.12602  , -0.0260471, -0.0777363],
       [ 0.988804 , -0.124883 , -0.0304698, -0.075784 ],
       ...,
       [ 0.991624 ,  0.0306091,  0.0325986,  0.121174 ],
       [ 0.991512 ,  0.0316093,  0.0325138,  0.121848 ],
       [ 0.991398 ,  0.0325308,  0.0324535,  0.122551 ]]), array([[ 9.91452e-01, -5.43839e-04,  3.72205e-02,  1.25051e-01],
       [ 9.91463e-01,  2.46275e-04,  3.66563e-02,  1.25128e-01],
       [ 9.91481e-01,  1.08278e-03,  3.62475e-02,  1.25104e-01],
       ...,
       [ 9.82946e-01, -8.97368e-02, -1.60492e-01, -2.42635e-03],
       [ 9.82746e-01, -8.73987e-02, -1.63004e-01, -8.77719e-04],
       [ 9.82609e-01, -8.48450e-02, -1.65169e-01,  5.19713e-04]]), array([[ 0.993775  ,  0.0116686 , -0.0953626 ,  0.0563928 ],
       [ 0.993775  ,  0.0118417 , -0.0952728 ,  0.0565166 ],
       [ 0.993776  ,  0.0120667 , -0.0951356 ,  0.0566813 ],
       [ 0.993777  ,  0.0123272 , -0.0949708 ,  0.056890

[array([[ 0.999093 , -0.0125545, -0.0160133,  0.0374154],
       [ 0.999066 , -0.0124584, -0.0158437,  0.038212 ],
       [ 0.999037 , -0.0122703, -0.0156926,  0.0391036],
       ...,
       [ 0.994762 , -0.0356318, -0.0812014, -0.0508509],
       [ 0.994742 , -0.0341014, -0.0822874, -0.0505441],
       [ 0.994786 , -0.0322064, -0.0830732, -0.0496269]]), array([[ 0.999137  ,  0.018894  , -0.002727  , -0.0368846 ],
       [ 0.999141  ,  0.018846  , -0.00280039, -0.0368004 ],
       [ 0.999145  ,  0.0187909 , -0.00283288, -0.0367134 ],
       ...,
       [ 0.998389  ,  0.0105508 ,  0.0362142 , -0.0423963 ],
       [ 0.998391  ,  0.0106235 ,  0.0362058 , -0.0423174 ],
       [ 0.998395  ,  0.0106897 ,  0.0362212 , -0.0422084 ]]), array([[ 0.999399  ,  0.0112663 ,  0.0278038 , -0.0173693 ],
       [ 0.999466  ,  0.011646  ,  0.0265944 , -0.0150009 ],
       [ 0.999532  ,  0.0119659 ,  0.0252198 , -0.0124986 ],
       ...,
       [ 0.99876   , -0.0119848 , -0.0482433 ,  0.00260209],
       

[list([3973, 0, 3972]) list([3973, 3970, 93, 1, 188, 65, 3974, 3972])
 list([3973, 9, 72, 3972]) list([3973, 257, 3972])
 list([3973, 27, 125, 51, 7, 245, 3972]) list([3973, 64, 3972])
 list([3973, 11, 191, 2540, 20, 15, 39, 2541, 3974, 3974, 3970, 12, 899, 3970, 2543, 88, 2544, 168, 76, 64, 3972])
 list([3973, 63, 63, 63, 9, 365, 3972])
 list([3973, 3970, 1444, 9, 3970, 19, 32, 3970, 270, 35, 65, 1550, 77, 2545, 6, 7, 3970, 3970, 3974, 3970, 3970, 194, 20, 3970, 20, 3970, 1695, 116, 3972])
 list([3973, 3970, 20, 3972])
 list([3973, 27, 2547, 93, 39, 188, 35, 168, 156, 63, 63, 63, 63, 63, 3972])
 list([3973, 156, 64, 63, 63, 63, 63, 1504, 63, 1504, 3972])
 list([3973, 27, 2548, 2549, 2550, 3972])
 list([3973, 2551, 26, 1680, 9, 3974, 3970, 156, 93, 3970, 100, 721, 531, 3972])
 list([3973, 99, 100, 491, 51, 7, 525, 3972])
 list([3973, 29, 7, 1150, 51, 365, 2552, 3970, 3972])
 list([3973, 63, 63, 9, 63, 63, 63, 63, 3970, 539, 168, 3970, 21, 681, 3972])
 list([3973, 1242, 681, 3972])
 lis

[list([3973, 0, 9, 3970, 75, 7, 144, 753, 1918, 168, 3972])
 list([3973, 9, 499, 153, 71, 2, 1424, 3970, 7, 144, 2587, 531, 3970, 157, 255, 29, 365, 2588, 731, 3972])
 list([3973, 63, 63, 63, 63, 3970, 2589, 168, 63, 64, 2, 19, 1, 3970, 188, 125, 36, 2, 19, 119, 51, 1, 2587, 3970, 188, 32, 1, 3970, 1595, 279, 99, 303, 7, 2590, 35, 3970, 470, 2591, 32, 2, 19, 124, 199, 125, 36, 56, 88, 23, 2422, 88, 23, 2592, 88, 546, 2, 19, 119, 51, 212, 3972])
 list([3973, 1712, 330, 124, 747, 29, 22, 23, 340, 994, 134, 3972])
 list([3973, 168, 64, 3970, 188, 11, 254, 1, 2, 435, 36, 219, 2593, 3972])
 list([3973, 63, 63, 3974, 2594, 71, 119, 51, 58, 23, 2464, 2595, 51, 58, 23, 2596, 3972])
 list([3973, 3970, 216, 2597, 32, 131, 3974, 9, 3970, 3974, 3970, 11, 2600, 131, 3972])
 list([3973, 3970, 180, 61, 2601, 51, 2, 2602, 16, 219, 2281, 2593, 35, 7, 2593, 188, 2603, 134, 168, 27, 58, 124, 262, 3974, 81, 200, 1209, 2604, 32, 200, 1209, 2605, 32, 200, 1209, 2606, 124, 262, 566, 757, 214, 985, 3972])
 li

[array([[ 9.99319e-01, -2.71239e-03,  2.74006e-02, -2.45497e-02],
       [ 9.99391e-01, -2.55586e-03,  2.65169e-02, -2.25422e-02],
       [ 9.99459e-01, -2.40983e-03,  2.56170e-02, -2.04940e-02],
       [ 9.99519e-01, -2.28798e-03,  2.47290e-02, -1.85529e-02],
       [ 9.99574e-01, -2.15338e-03,  2.38785e-02, -1.66291e-02],
       [ 9.99627e-01, -1.92055e-03,  2.30653e-02, -1.44874e-02],
       [ 9.99673e-01, -1.80805e-03,  2.22686e-02, -1.24026e-02],
       [ 9.99714e-01, -1.76598e-03,  2.15481e-02, -1.01742e-02],
       [ 9.99749e-01, -1.84581e-03,  2.09032e-02, -7.82471e-03],
       [ 9.99777e-01, -2.00000e-03,  2.03437e-02, -5.34524e-03],
       [ 9.99797e-01, -2.25369e-03,  1.98618e-02, -2.67319e-03],
       [ 9.99808e-01, -2.49333e-03,  1.94375e-02,  5.35605e-05],
       [ 9.99810e-01, -2.77562e-03,  1.90748e-02,  2.80124e-03],
       [ 9.99803e-01, -3.09947e-03,  1.87815e-02,  5.65141e-03],
       [ 9.99792e-01, -3.41145e-03,  1.83730e-02,  8.14512e-03],
       [ 9.99774e-01, -3

[array([[ 0.999111  , -0.00594123,  0.0412929 ,  0.00603109],
       [ 0.999138  , -0.00553412,  0.0408597 ,  0.0048884 ],
       [ 0.999145  , -0.00554672,  0.0406412 ,  0.00509883],
       ...,
       [ 0.973705  , -0.0372377 , -0.0244909 , -0.223411  ],
       [ 0.97309   , -0.0373033 , -0.0233509 , -0.226185  ],
       [ 0.972473  , -0.0369973 , -0.0220015 , -0.229005  ]]), array([[ 0.999694  ,  0.00841188,  0.0227161 , -0.00507117],
       [ 0.999706  ,  0.00892385,  0.022083  , -0.00460474],
       [ 0.999716  ,  0.00923247,  0.021642  , -0.00374563],
       ...,
       [ 0.994984  , -0.0812448 ,  0.0152996 , -0.0563281 ],
       [ 0.994794  , -0.083646  ,  0.0159293 , -0.0559915 ],
       [ 0.994636  , -0.0854979 ,  0.0163411 , -0.055884  ]]), array([[ 9.96988e-01, -4.69736e-02,  4.18112e-02, -4.53947e-02],
       [ 9.97309e-01, -4.53451e-02,  4.20066e-02, -3.94276e-02],
       [ 9.97572e-01, -4.39204e-02,  4.24605e-02, -3.34251e-02],
       [ 9.97806e-01, -4.23084e-02,  4.27746

[list([3973, 3975, 3972])
 list([3973, 64, 117, 27, 64, 117, 81, 2, 179, 3972])
 list([3973, 27, 3972])
 list([3973, 124, 318, 851, 414, 16, 78, 1660, 257, 188, 22, 5, 105, 134, 3970, 3970, 11, 2747, 46, 3970, 191, 47, 3970, 7, 211, 3970, 85, 3972])
 list([3973, 27, 3972]) list([3973, 27, 3972]) list([3973, 27, 3972])
 list([3973, 3970, 145, 12, 39, 35, 524, 100, 132, 122, 3974, 124, 262, 47, 39, 88, 88, 3972])
 list([3973, 3970, 85, 3974, 75, 2, 77, 75, 2, 3974, 524, 194, 78, 1701, 72, 88, 2, 19, 3972])
 list([3973, 27, 27, 27, 3972])
 list([3973, 27, 3970, 77, 20, 1271, 124, 318, 851, 414, 3970, 3970, 85, 3970, 3970, 177, 75, 99, 259, 3972])
 list([3973, 2748, 49, 365, 2749, 168, 3972])
 list([3973, 27, 3970, 179, 131, 177, 920, 3972])
 list([3973, 27, 27, 27, 3970, 447, 3970, 262, 3974, 3970, 58, 99, 47, 23, 139, 2750, 474, 134, 27, 27, 524, 100, 122, 3972])
 list([3973, 3970, 26, 265, 164, 3974, 93, 851, 414, 124, 265, 20, 134, 27, 3972])
 list([3973, 32, 20, 2117, 3972])
 list([39

[list([3973, 499, 153, 152, 694, 458, 3970, 534, 2772, 49, 3972])
 list([3973, 2773, 458, 39, 124, 71, 458, 2, 578, 152, 32, 75, 78, 51, 35, 458, 102, 49, 1886, 230, 3970, 3970, 11, 128, 20, 3970, 72, 40, 32, 55, 3972])
 list([3973, 384, 133, 3974, 3974, 3972]) list([3973, 3970, 266, 3972])
 list([3973, 156, 93, 93, 102, 49, 71, 2, 239, 15, 1518, 2230, 3972])
 list([3973, 151, 3970, 133, 93, 3970, 3970, 239, 99, 100, 2022, 29, 230, 3970, 77, 224, 23, 395, 51, 252, 99, 81, 3972])
 list([3973, 9, 72, 93, 3970, 224, 40, 384, 224, 23, 395, 51, 2774, 3972])
 list([3973, 257, 9, 1575, 27, 3970, 456, 124, 3970, 75, 23, 611, 365, 883, 3972])
 list([3973, 93, 3970, 85, 2, 19, 2774, 50, 202, 7, 297, 32, 50, 7, 2775, 35, 131, 2776, 9, 365, 883, 32, 187, 2, 480, 364, 7, 611, 3970, 3974, 963, 3972])
 list([3973, 64, 456, 124, 661, 75, 23, 2777, 384, 3970, 249, 1452, 99, 1308, 16, 3972])
 list([3973, 458, 102, 49, 1886, 3972])
 list([3973, 95, 124, 180, 23, 2778, 93, 124, 3970, 75, 23, 2778, 64, 124

[list([3973, 9, 495, 27, 3970, 168, 3970, 222, 252, 76, 224, 779, 153, 61, 146, 3970, 168, 3972])
 list([3973, 168, 64, 3970, 1471, 7, 7, 2814, 2815, 61, 105, 3970, 76, 3970, 3970, 12, 2816, 2486, 2397, 554, 2817, 1061, 2324, 164, 1988, 555, 3970, 758, 12, 265, 168, 12, 265, 2818, 99, 7, 2814, 2819, 421, 139, 51, 23, 3970, 2821, 29, 39, 134, 76, 134, 26, 600, 249, 230, 2, 19, 117, 3970, 58, 64, 3970, 3970, 276, 99, 3970, 3974, 3970, 3970, 276, 99, 194, 600, 2822, 35, 188, 76, 168, 134, 3970, 75, 896, 7, 2814, 2, 19, 7, 7, 7, 2823, 2815, 164, 61, 64, 20, 168, 3970, 168, 3970, 3970, 3970, 3972])
 list([3973, 3970, 63, 63, 63, 63, 63, 63, 63, 3970, 23, 2247, 1115, 2, 19, 124, 480, 76, 3970, 3970, 3974, 3970, 133, 99, 283, 2, 2, 2260, 330, 93, 81, 20, 174, 230, 3970, 225, 365, 211, 21, 2, 19, 95, 3970, 191, 81, 212, 81, 20, 174, 64, 3970, 2824, 3970, 12, 326, 78, 2825, 164, 61, 146, 32, 3970, 3970, 265, 20, 2486, 32, 3970, 3970, 276, 99, 2826, 365, 1209, 2827, 532, 244, 230, 51, 155, 3970,

[list([3973, 196, 3970, 27, 63, 133, 133, 462, 3972])
 list([3973, 27, 27, 452, 117, 21, 20, 3972])
 list([3973, 851, 3970, 27, 27, 27, 27, 258, 3972])
 list([3973, 93, 27, 63, 3970, 180, 365, 795, 411, 3972])
 list([3973, 27, 63, 20, 188, 153, 153, 3970, 27, 908, 2, 19, 3972])
 list([3973, 27, 63, 221, 311, 969, 124, 180, 1627, 32, 187, 233, 32, 3970, 318, 15, 164, 23, 1284, 3972])
 list([3973, 27, 63, 27, 3970, 3974, 3970, 3970, 222, 282, 2, 1352, 2, 222, 194, 409, 174, 3972])
 list([3973, 122, 122, 122, 27, 3972]) list([3973, 9, 5, 122, 3972])
 list([3973, 27, 27, 63, 3970, 168, 3972])
 list([3973, 76, 3974, 365, 795, 95, 3970, 204, 1882, 95, 3974, 3970, 3970, 76, 3970, 2935, 7, 1240, 3972])
 list([3973, 27, 3972]) list([3973, 27, 3972])
 list([3973, 27, 3970, 19, 117, 3970, 496, 3972])
 list([3973, 27, 27, 3972]) list([3973, 9, 5, 122, 122, 3972])
 list([3973, 27, 32, 153, 23, 3970, 3970, 72, 133, 27, 3972])
 list([3973, 9, 27, 51, 155, 50, 409, 3970, 2937, 27, 27, 3972])
 list([39

[list([3973, 3970, 748, 1, 71, 2, 3972])
 list([3973, 9, 117, 21, 20, 47, 3974, 47, 153, 248, 114, 248, 114, 1968, 1968, 3972])
 list([3973, 3970, 1789, 153, 16, 20, 3974, 3974, 3972])
 list([3973, 3970, 1634, 3972]) list([3973, 2, 75, 26, 3972])
 list([3973, 9, 1740, 2, 64, 249, 3970, 3970, 515, 99, 47, 219, 283, 20, 524, 124, 47, 458, 3972])
 list([3973, 156, 3970, 47, 283, 20, 156, 3972]) list([3973, 258, 3972])
 list([3973, 168, 3972]) list([3973, 3970, 495, 3972])
 list([3973, 495, 72, 249, 164, 168, 3972])
 list([3973, 495, 216, 249, 164, 7, 2976, 3972])
 list([3973, 1740, 2, 164, 3972])
 list([3973, 50, 7, 779, 252, 35, 3970, 452, 318, 322, 61, 134, 3972])
 list([3973, 1175, 3972])
 list([3973, 3970, 3970, 72, 179, 35, 3970, 114, 3972])
 list([3973, 168, 3972]) list([3973, 93, 7, 211, 21, 3972])
 list([3973, 3970, 179, 35, 3970, 89, 51, 23, 119, 51, 2, 1019, 58, 3970, 225, 249, 89, 51, 23, 2898, 3970, 328, 168, 3972])
 list([3973, 63, 3970, 2977, 3972])
 list([3973, 9, 63, 63, 4

[list([3973, 3090, 1, 2, 326, 3972])
 list([3973, 258, 3970, 222, 1993, 3972])
 list([3973, 139, 1338, 27, 3972]) list([3973, 117, 81, 2, 276, 3972])
 list([3973, 23, 3091, 3972])
 list([3973, 21, 35, 117, 2, 2162, 20, 911, 3972]) list([3973, 9, 3972])
 list([3973, 258, 3972]) list([3973, 258, 3972])
 list([3973, 32, 46, 2, 239, 99, 3974, 1313, 202, 21, 221, 191, 100, 29, 65, 1657, 3972])
 list([3973, 27, 3972]) list([3973, 29, 65, 1104, 3972])
 list([3973, 27, 27, 32, 2, 2162, 20, 23, 3091, 71, 2, 191, 100, 29, 65, 3092, 3972])
 list([3973, 76, 174, 117, 81, 2, 179, 3970, 524, 115, 3972])
 list([3973, 258, 3972]) list([3973, 280, 3972])
 list([3973, 3970, 250, 280, 280, 524, 3970, 115, 384, 3972])
 list([3973, 75, 2, 77, 83, 2935, 164, 3972])
 list([3973, 75, 2, 77, 83, 1006, 75, 2, 77, 3093, 184, 2, 100, 1012, 3094, 32, 740, 3095, 490, 3972])
 list([3973, 384, 3972]) list([3973, 258, 258, 174, 3972])
 list([3973, 32, 2, 19, 2, 19, 117, 3970, 100, 326, 3970, 85, 2, 3970, 920, 546, 135

[list([3973, 196, 3970, 3972]) list([3973, 196, 3970, 20, 1789, 3972])
 list([3973, 3970, 1789, 151, 3970, 239, 151, 71, 2, 1206, 152, 3972])
 list([3973, 3970, 122, 3970, 122, 35, 3970, 5, 35, 3970, 84, 168, 3972])
 list([3973, 3970, 188, 11, 254, 76, 75, 2, 374, 75, 2, 374, 3970, 1913, 3972])
 list([3973, 27, 27, 27, 27, 27, 27, 27, 27, 3970, 3970, 365, 366, 3970, 27, 3972])
 list([3973, 27, 27, 58, 221, 188, 221, 221, 250, 35, 221, 188, 239, 969, 99, 283, 168, 2, 19, 3970, 233, 3084, 168, 58, 2, 19, 99, 58, 942, 36, 2, 19, 7, 7, 3111, 168, 3972])
 list([3973, 134, 76, 221, 250, 35, 3970, 661, 100, 39, 58, 3970, 781, 58, 3972])
 list([3973, 219, 557, 88, 64, 32, 187, 221, 3970, 499, 319, 7, 309, 644, 168, 3972])
 list([3973, 32, 76, 3970, 188, 11, 76, 254, 58, 3972])
 list([3973, 230, 58, 3970, 179, 23, 7, 211, 21, 58, 3972])
 list([3973, 3970, 3974, 99, 99, 802, 1001, 32, 3112, 1523, 2, 32, 114, 2, 19, 230, 3970, 23, 133, 814, 51, 1825, 2, 19, 3970, 2, 19, 58, 124, 47, 124, 47, 8, 3

[list([3973, 64, 436, 2, 248, 114, 46, 2, 199, 153, 7, 644, 51, 7, 3138, 3972])
 list([3973, 258, 3972])
 list([3973, 174, 32, 32, 65, 3975, 188, 39, 16, 2, 3972])
 list([3973, 436, 3139, 3972])
 list([3973, 156, 32, 2, 199, 39, 164, 7, 309, 716, 7, 3972])
 list([3973, 174, 3972])
 list([3973, 258, 81, 2, 75, 546, 436, 2, 58, 115, 2538, 656, 35, 2, 3012, 32, 295, 98, 100, 72, 3140, 3972])
 list([3973, 174, 3972]) list([3973, 64, 411, 3972])
 list([3973, 554, 156, 3972]) list([3973, 3970, 32, 2, 295, 3972])
 list([3973, 1119, 7, 644, 174, 150, 168, 124, 75, 2, 153, 3970, 350, 103, 2283, 3970, 13, 7, 76, 13, 7, 1099, 35, 188, 3141, 3972])
 list([3973, 168, 81, 2, 3972])
 list([3973, 75, 277, 99, 115, 36, 35, 546, 850, 280, 280, 2, 3970, 2081, 15, 99, 7, 3142, 187, 88, 76, 32, 1667, 36, 20, 3972])
 list([3973, 27, 20, 188, 2283, 20, 188, 7, 7, 3143, 153, 7, 574, 7, 702, 1423, 557, 3972])
 list([3973, 174, 258, 3972]) list([3973, 174, 3972])
 list([3973, 156, 3972]) list([3973, 156, 3972])

[list([3973, 196, 27, 3970, 239, 3970, 239, 3970, 132, 3170, 2, 19, 3970, 3970, 137, 3970, 137, 72, 498, 2, 19, 3970, 137, 72, 58, 72, 168, 3972])
 list([3973, 72, 3970, 878, 58, 195, 58, 760, 49, 3970, 294, 58, 3970, 3970, 3970, 58, 3970, 191, 2503, 20, 58, 3970, 11, 278, 58, 3970, 259, 23, 139, 3171, 3970, 259, 499, 23, 139, 139, 1574, 27, 27, 27, 168, 3972])
 list([3973, 63, 3970, 26, 239, 1790, 153, 2, 470, 63, 63, 63, 63, 3970, 151, 3972])
 list([3973, 3970, 27, 27, 27, 27, 27, 27, 27, 27, 23, 59, 644, 2949, 27, 35, 524, 100, 72, 133, 71, 2, 191, 58, 811, 3172, 32, 58, 1120, 152, 99, 58, 3970, 3970, 1732, 2, 19, 637, 2, 47, 15, 2914, 352, 3970, 77, 23, 969, 350, 177, 524, 81, 3970, 3970, 1014, 32, 187, 3972])
 list([3973, 3970, 781, 3972])
 list([3973, 27, 134, 325, 29, 3970, 225, 325, 29, 3970, 470, 3972])
 list([3973, 27, 257, 470, 58, 851, 105, 2, 199, 58, 3176, 3970, 188, 58, 3177, 32, 187, 7, 3178, 188, 58, 151, 1775, 656, 27, 257, 470, 117, 117, 23, 3179, 325, 117, 23, 3179,

[list([3973, 156, 257, 3970, 3972]) list([3973, 174, 452, 27, 3970, 3972])
 list([3973, 9, 365, 3970, 3972]) list([3973, 3970, 858, 3972])
 list([3973, 222, 2, 81, 61, 83, 222, 2, 1481, 279, 99, 81, 61, 3972])
 list([3973, 9, 365, 3970, 3972]) list([3973, 3974, 3972])
 list([3973, 3970, 216, 1347, 216, 216, 1347, 3972])
 list([3973, 27, 3970, 85, 3970, 58, 687, 134, 61, 21, 687, 61, 21, 72, 687, 61, 21, 3972])
 list([3973, 63, 3970, 3974, 3972])
 list([3973, 3970, 3970, 115, 3970, 81, 3970, 270, 3970, 72, 3970, 115, 3970, 58, 61, 1019, 3972])
 list([3973, 151, 270, 3970, 63, 920, 99, 194, 3244, 58, 3972])
 list([3973, 151, 93, 2, 19, 3970, 135, 500, 35, 2, 2421, 114, 35, 3970, 3970, 135, 500, 134, 3970, 85, 2, 19, 7, 1019, 21, 23, 139, 1050, 3970, 11, 496, 58, 3970, 2, 23, 3245, 88, 212, 3970, 2, 3972])
 list([3973, 27, 3972])
 list([3973, 3970, 763, 146, 51, 65, 3246, 1164, 3970, 85, 3970, 85, 3970, 63, 920, 99, 194, 3244, 88, 277, 3970, 11, 496, 3970, 3970, 58, 7, 1019, 3970, 85, 2, 

[array([[ 0.99122   , -0.0369983 ,  0.00263872,  0.126912  ],
       [ 0.991281  , -0.0369878 ,  0.00238599,  0.126441  ],
       [ 0.991302  , -0.0371138 ,  0.00224159,  0.126242  ],
       ...,
       [ 0.998488  ,  0.00800217,  0.0488901 ,  0.0238188 ],
       [ 0.998535  ,  0.00757327,  0.0483462 ,  0.0230902 ],
       [ 0.998599  ,  0.0069746 ,  0.0474411 ,  0.0223756 ]]), array([[ 0.99932   , -0.0224495 , -0.00163798,  0.0292144 ],
       [ 0.999311  , -0.0224909 , -0.00171415,  0.0294711 ],
       [ 0.999299  , -0.0225406 , -0.00189098,  0.0298339 ],
       ...,
       [ 0.981099  ,  0.00624886,  0.181356  , -0.067206  ],
       [ 0.981272  ,  0.00594292,  0.180488  , -0.0670382 ],
       [ 0.981437  ,  0.00563427,  0.179668  , -0.0668579 ]]), array([[ 9.79829e-01, -4.13593e-04,  1.87481e-01, -6.91753e-02],
       [ 9.79687e-01, -2.32115e-04,  1.88137e-01, -6.94125e-02],
       [ 9.79531e-01, -2.33224e-04,  1.88865e-01, -6.96395e-02],
       ...,
       [ 9.97761e-01, -6.59848e-

[list([3973, 156, 64, 76, 222, 2, 117, 81, 2, 179, 117, 81, 2, 179, 36, 117, 3970, 250, 99, 2, 3381, 3970, 85, 3970, 49, 164, 20, 3972])
 list([3973, 222, 2, 47, 164, 20, 3972]) list([3973, 76, 3972])
 list([3973, 76, 3972])
 list([3973, 27, 3970, 19, 3970, 23, 3382, 134, 2, 19, 3970, 103, 727, 23, 816, 51, 23, 3383, 50, 39, 11, 3384, 3972])
 list([3973, 3970, 85, 3970, 3970, 77, 382, 29, 297, 507, 860, 71, 3970, 3972])
 list([3973, 3970, 3972])
 list([3973, 134, 2, 19, 76, 2, 19, 3974, 76, 2, 19, 2, 920, 7, 757, 330, 93, 3974, 11, 58, 3970, 503, 32, 3385, 3974, 3972])
 list([3973, 27, 134, 2, 19, 3970, 76, 2, 3970, 75, 99, 81, 249, 2, 11, 75, 99, 100, 39, 3974, 3970, 100, 3156, 882, 2, 184, 75, 23, 1875, 153, 3970, 253, 99, 523, 36, 3972])
 list([3973, 3970, 75, 23, 1875, 2, 3970, 509, 75, 99, 115, 277, 3974, 3970, 85, 11, 179, 36, 20, 3974, 330, 23, 816, 3383, 32, 179, 36, 35, 3387, 35, 1423, 800, 499, 153, 3972])
 list([3973, 134, 3970, 3972])
 list([3973, 93, 2, 19, 61, 21, 280, 20

[array([[ 0.999661  ,  0.00369607, -0.00674049,  0.0248819 ],
       [ 0.999602  ,  0.00373388, -0.00637344,  0.0272352 ],
       [ 0.999528  ,  0.00396116, -0.00584869,  0.0299009 ],
       ...,
       [ 0.999822  ,  0.00479925,  0.00516317, -0.017525  ],
       [ 0.999801  ,  0.00541783,  0.00369196, -0.0188626 ],
       [ 0.999769  ,  0.00601327,  0.0022696 , -0.0205271 ]]), array([[ 9.99708e-01,  1.23837e-02,  1.27942e-03, -2.07136e-02],
       [ 9.99711e-01,  1.22646e-02,  1.11288e-03, -2.06369e-02],
       [ 9.99715e-01,  1.21656e-02,  9.00254e-04, -2.05346e-02],
       ...,
       [ 9.97526e-01, -1.25235e-02,  1.90384e-02, -6.65040e-02],
       [ 9.97580e-01, -1.23948e-02,  1.88298e-02, -6.57739e-02],
       [ 9.97617e-01, -1.23294e-02,  1.85417e-02, -6.53094e-02]]), array([[ 9.97973e-01, -1.94150e-03,  4.66257e-02,  4.32625e-02],
       [ 9.98133e-01, -2.01549e-03,  4.49272e-02,  4.13341e-02],
       [ 9.98273e-01, -1.90673e-03,  4.30752e-02,  3.98975e-02],
       [ 9.98380e-01

[list([3973, 9, 766, 4, 103, 61, 297, 3970, 205, 4, 103, 7, 129, 3732, 3972])
 list([3973, 9, 4, 103, 7, 2243, 153, 7, 1342, 32, 50, 7, 50, 7, 3733, 61, 21, 208, 3970, 34, 20, 126, 3972])
 list([3973, 9, 93, 3974, 81, 46, 81, 2, 464, 578, 3972])
 list([3973, 117, 36, 9, 9, 4, 103, 409, 409, 3735, 153, 35, 3296, 3970, 578, 39, 3970, 3736, 9, 3970, 11, 208, 3970, 34, 20, 3972])
 list([3973, 9, 133, 75, 23, 1450, 3972])
 list([3973, 32, 3970, 7, 476, 2298, 365, 883, 32, 4, 103, 7, 2539, 9, 365, 883, 3970, 145, 180, 23, 180, 23, 23, 166, 64, 746, 288, 61, 21, 208, 3970, 145, 72, 12, 29, 23, 3737, 90, 288, 3970, 64, 1841, 3970, 3738, 3972])
 list([3973, 9, 134, 3970, 34, 121, 3970, 3970, 72, 81, 3970, 2, 58, 121, 3972])
 list([3973, 117, 211, 3972])
 list([3973, 9, 3970, 452, 3970, 11, 29, 65, 3739, 3970, 452, 3970, 452, 3970, 3970, 72, 58, 20, 3970, 85, 3970, 219, 2526, 61, 21, 23, 844, 211, 3970, 326, 239, 99, 23, 121, 90, 75, 124, 157, 12, 99, 146, 288, 61, 3974, 3972])
 list([3973, 9, 3

[list([3973, 196, 3970, 1105, 3972])
 list([3973, 3974, 3974, 117, 81, 2, 85, 117, 81, 3970, 179, 3972])
 list([3973, 3974, 63, 20, 3970, 515, 75, 2, 75, 2, 3268, 2201, 3972])
 list([3973, 388, 3972]) list([3973, 93, 3970, 85, 3972])
 list([3973, 3970, 26, 133, 3972]) list([3973, 168, 3972])
 list([3973, 3970, 85, 3970, 835, 3972])
 list([3973, 156, 58, 3974, 388, 3972])
 list([3973, 394, 98, 1355, 2018, 2291, 330, 23, 1019, 164, 23, 840, 1657, 3970, 85, 61, 21, 61, 21, 3783, 3972])
 list([3973, 122, 1019, 3970, 394, 250, 35, 3972])
 list([3973, 63, 3970, 26, 117, 124, 180, 288, 188, 188, 188, 1813, 280, 3974, 280, 81, 2, 920, 99, 1590, 20, 3972])
 list([3973, 600, 1347, 3970, 334, 164, 3970, 334, 164, 2018, 2291, 3970, 58, 840, 2201, 23, 3784, 61, 21, 3783, 3972])
 list([3973, 501, 71, 2, 1151, 61, 21, 3974, 61, 21, 3783, 3972])
 list([3973, 3970, 3785, 61, 1657, 21, 1328, 1, 249, 222, 2, 372, 153, 1354, 61, 21, 61, 21, 3974, 58, 61, 21, 3783, 3972])
 list([3973, 63, 3970, 2905, 35, 3

[list([3973, 196, 952, 196, 470, 3970, 20, 239, 3972])
 list([3973, 27, 452, 196, 3974, 470, 81, 2, 464, 47, 164, 23, 47, 164, 23, 529, 88, 212, 470, 3972])
 list([3973, 151, 151, 3972])
 list([3973, 27, 452, 3974, 3974, 3974, 288, 2, 81, 11, 3974, 436, 3970, 1812, 702, 699, 3970, 77, 58, 3970, 480, 1481, 279, 319, 32, 3970, 11, 435, 2, 19, 124, 262, 3974, 3972])
 list([3973, 3970, 3974, 1481, 3972]) list([3973, 27, 452, 470, 3972])
 list([3973, 9, 1247, 27, 289, 146, 63, 3970, 11, 3901, 134, 27, 47, 153, 3974, 3972])
 list([3973, 27, 27, 3972]) list([3973, 9, 1247, 3972])
 list([3973, 27, 27, 3972])
 list([3973, 3974, 3974, 117, 2, 496, 470, 3972])
 list([3973, 27, 27, 3972])
 list([3973, 940, 699, 3974, 279, 3903, 940, 699, 214, 2, 1640, 3972])
 list([3973, 3970, 781, 3974, 515, 2, 496, 279, 3903, 940, 699, 230, 3970, 3974, 3970, 47, 99, 23, 2207, 174, 458, 3970, 465, 99, 7, 92, 93, 3974, 93, 3970, 11, 465, 16, 7, 92, 283, 3972])
 list([3973, 93, 3970, 2487, 2, 319, 35, 940, 699, 35,

In [10]:
#%%
# this code is to make dataset for seq2seq baseline model
the_path = '/afs/inf.ed.ac.uk/group/cstr/projects/galatea/d02'
data_path_motion = the_path + '/Recordings_October_2014/DOF-hiroshi/'
data_path_text = the_path + '/Recordings_October_2014/Transcriptions/transcriptions_phrase_tables/'

paired_file_paths = pair_files(data_path_text, '.TABLE', data_path_motion,
                                '.qtn')




# make three kinds of dataset
make_dataset(seq2seq_preprocess, paired_file_paths, 'e',
                './data/extro_seq2seq_dataset')
print('extroverted data finished')
make_dataset(seq2seq_preprocess, paired_file_paths, 'i',
                './data/intro_seq2seq_dataset')
print('introverted data finished')
make_dataset(seq2seq_preprocess, paired_file_paths, 'n',
                './data/natural_seq2seq_dataset')
print('natural data finished')

22
22
50
50
46
46
47
47
34
34
56
56
50
50
49
49
48
48
33
33


KeyboardInterrupt: 

## Building the models

In [None]:
MAX_LENGTH = 100


USE_CUDA = True



In [None]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, n_layers=1):
        super(EncoderRNN, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers)
        
    def forward(self, word_inputs, hidden):
        # Note: we run this all at once (over the whole input sequence)
        seq_len = len(word_inputs)
        embedded = self.embedding(word_inputs).view(seq_len, 1, -1)
        output, hidden = self.gru(embedded, hidden)
        return output, hidden

    def init_hidden(self):
        hidden = Variable(torch.zeros(self.n_layers, 1, self.hidden_size))
        if USE_CUDA: hidden = hidden.cuda()
        return hidden

In [None]:
class BahdanauAttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, n_layers=1, dropout_p=0.1):
        super(AttnDecoderRNN, self).__init__()
        
        # Define parameters
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout_p = dropout_p
        self.max_length = max_length
        
        # Define layers
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.dropout = nn.Dropout(dropout_p)
        self.attn = GeneralAttn(hidden_size)
        self.gru = nn.GRU(hidden_size * 2, hidden_size, n_layers, dropout=dropout_p)
        self.out = nn.Linear(hidden_size, output_size)
    
    def forward(self, word_input, last_hidden, encoder_outputs):
        # Note that we will only be running forward for a single decoder time step, but will use all encoder outputs
        
        # Get the embedding of the current input word (last output word)
        word_embedded = self.embedding(word_input).view(1, 1, -1) # S=1 x B x N
        word_embedded = self.dropout(word_embedded)
        
        # Calculate attention weights and apply to encoder outputs
        attn_weights = self.attn(last_hidden[-1], encoder_outputs)
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) # B x 1 x N
        
        # Combine embedded input word and attended context, run through RNN
        rnn_input = torch.cat((word_embedded, context), 2)
        output, hidden = self.gru(rnn_input, last_hidden)
        
        # Final output layer
        output = output.squeeze(0) # B x N
        output = F.log_softmax(self.out(torch.cat((output, context), 1)))
        
        # Return final output, hidden state, and attention weights (for visualization)
        return output, hidden, attn_weights

In [None]:
class Attn(nn.Module):
    def __init__(self, method, hidden_size, max_length=MAX_LENGTH):
        super(Attn, self).__init__()
        
        self.method = method
        self.hidden_size = hidden_size
        
        if self.method == 'general':
            self.attn = nn.Linear(self.hidden_size, hidden_size)

        elif self.method == 'concat':
            self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
            self.other = nn.Parameter(torch.FloatTensor(1, hidden_size))

    def forward(self, hidden, encoder_outputs):
        seq_len = len(encoder_outputs)

        # Create variable to store attention energies
        attn_energies = Variable(torch.zeros(seq_len)) # B x 1 x S
        if USE_CUDA: attn_energies = attn_energies.cuda()

        # Calculate energies for each encoder output
        for i in range(seq_len):
            attn_energies[i] = self.score(hidden, encoder_outputs[i])

        # Normalize energies to weights in range 0 to 1, resize to 1 x 1 x seq_len
        return F.softmax(attn_energies).unsqueeze(0).unsqueeze(0)
    
    def score(self, hidden, encoder_output):
        
        if self.method == 'dot':
            energy = hidden.dot(encoder_output)
            return energy
        
        elif self.method == 'general':
            energy = self.attn(encoder_output)
            energy = hidden.dot(energy)
            return energy
        
        elif self.method == 'concat':
            energy = self.attn(torch.cat((hidden, encoder_output), 1))
            energy = self.other.dot(energy)
            return energy

In [None]:
class AttnDecoderRNN(nn.Module):
    def __init__(self, attn_model, hidden_size, output_size, n_layers=1, dropout_p=0.1):
        super(AttnDecoderRNN, self).__init__()
        
        # Keep parameters for reference
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout_p = dropout_p
        
        # Define layers
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size * 2, hidden_size, n_layers, dropout=dropout_p)
        self.out = nn.Linear(hidden_size * 2, output_size)
        
        # Choose attention model
        if attn_model != 'none':
            self.attn = Attn(attn_model, hidden_size)
    
    def forward(self, word_input, last_context, last_hidden, encoder_outputs):
        # Note: we run this one step at a time
        
        # Get the embedding of the current input word (last output word)
        word_embedded = self.embedding(word_input).view(1, 1, -1) # S=1 x B x N
        
        # Combine embedded input word and last context, run through RNN
        rnn_input = torch.cat((word_embedded, last_context.unsqueeze(0)), 2)
        rnn_output, hidden = self.gru(rnn_input, last_hidden)

        # Calculate attention from current RNN state and all encoder outputs; apply to encoder outputs
        attn_weights = self.attn(rnn_output.squeeze(0), encoder_outputs)
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) # B x 1 x N
        
        # Final output layer (next word prediction) using the RNN hidden state and context vector
        rnn_output = rnn_output.squeeze(0) # S=1 x B x N -> B x N
        context = context.squeeze(1)       # B x S=1 x N -> B x N
        output = F.log_softmax(self.out(torch.cat((rnn_output, context), 1)))
        
        # Return final output, hidden state, and attention weights (for visualization)
        return output, context, hidden, attn_weights

In [None]:
# encoder_test = EncoderRNN(10, 10, 2)
# decoder_test = AttnDecoderRNN('general', 10, 10, 2)
# print(encoder_test)
# print(decoder_test)

# encoder_hidden = encoder_test.init_hidden()
# word_input = Variable(torch.LongTensor([1, 2, 3]))
# if USE_CUDA:
#     encoder_test.cuda()
#     word_input = word_input.cuda()
# encoder_outputs, encoder_hidden = encoder_test(word_input, encoder_hidden)

# word_inputs = Variable(torch.LongTensor([1, 2, 3]))
# decoder_attns = torch.zeros(1, 3, 3)
# decoder_hidden = encoder_hidden
# decoder_context = Variable(torch.zeros(1, decoder_test.hidden_size))

# if USE_CUDA:
#     decoder_test.cuda()
#     word_inputs = word_inputs.cuda()
#     decoder_context = decoder_context.cuda()

# for i in range(3):
#     decoder_output, decoder_context, decoder_hidden, decoder_attn = decoder_test(word_inputs[i], decoder_context, decoder_hidden, encoder_outputs)
#     print(decoder_output.size(), decoder_hidden.size(), decoder_attn.size())
#     decoder_attns[0, i] = decoder_attn.squeeze(0).cpu().data


## Train

In [None]:
batch_size = 50

extro_data_train_path = './data/extro_seq2seq_dataset_train.npz'
extro_data_valid_path = './data/extro_seq2seq_dataset_valid.npz'
extro_data_test_path = './data/extro_seq2seq_dataset_test.npz'

intro_data_train_path = './data/intro_seq2seq_dataset_train.npz'
intro_data_valid_path = './data/intro_seq2seq_dataset_valid.npz'
intro_data_test_path = './data/intro_seq2seq_dataset_test.npz'

natural_data_train_path = './data/natural_seq2seq_dataset_train.npz'
natural_data_valid_path = './data/natural_seq2seq_dataset_valid.npz'
natural_data_test_path = './data/natural_seq2seq_dataset_test.npz'
# data_path = './data/extro_seq2seq_dataset.npz'
word2idx = vocab_to_int  # load word map
#
# # load and split dataset
# seq2seq_dataset = Seq2SeqDataset(data_path, word2idx)
# train_size = int(0.8 * len(seq2seq_dataset))
# valid_size = int(0.1 * len(seq2seq_dataset))
# test_size = len(seq2seq_dataset) - train_size - valid_size

# train_set, valid_set, test_set = random_split(
    # seq2seq_dataset, [train_size, valid_size, test_size])

train_set = Seq2SeqDataset(extro_data_train_path, word2idx)
valid_set = Seq2SeqDataset(extro_data_valid_path, word2idx)
test_set = Seq2SeqDataset(extro_data_test_path, word2idx)
# print(len(valid_set))
# print(test_set[0])

input_test_set = []
target_test_set = []
for i in range(0, len(test_set)):
    input_test_set.append(test_set[i][0])
    target_test_set.append(test_set[i][1])

np.savez("test_data.npz", input=input_test_set, target=target_test_set)


train_dataloader = DataLoader(train_set,
                              batch_size,
                              shuffle=True,
                              collate_fn=seq2seq_collate_fn)
valid_dataloader = DataLoader(valid_set,
                              batch_size,
                              shuffle=True,
                              collate_fn=seq2seq_collate_fn)
test_dataloader = DataLoader(test_set,
                             batch_size=1,
                             collate_fn=seq2seq_collate_fn)



In [None]:
teacher_forcing_ratio = 0.5
clip = 5.0

def train(input_variable, target_variable, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):

    # Zero gradients of both optimizers
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    loss = 0 # Added onto for each word

    # Get size of input and target sentences
    input_length = input_variable.size()[0]
    target_length = target_variable.size()[0]

    # Run words through encoder
    encoder_hidden = encoder.init_hidden()
    encoder_outputs, encoder_hidden = encoder(input_variable, encoder_hidden)
    
    # Prepare input and output variables
    decoder_input = Variable(torch.LongTensor([[SOS_token]]))
    decoder_context = Variable(torch.zeros(1, decoder.hidden_size))
    decoder_hidden = encoder_hidden # Use last hidden state from encoder to start decoder
    if USE_CUDA:
        decoder_input = decoder_input.cuda()
        decoder_context = decoder_context.cuda()

    # Choose whether to use teacher forcing
    use_teacher_forcing = random.random() < teacher_forcing_ratio
    if use_teacher_forcing:
        
        # Teacher forcing: Use the ground-truth target as the next input
        for di in range(target_length):
            decoder_output, decoder_context, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_context, decoder_hidden, encoder_outputs)
            loss += criterion(decoder_output, target_variable[di])
            decoder_input = target_variable[di] # Next target is next input

    else:
        # Without teacher forcing: use network's own prediction as the next input
        for di in range(target_length):
            decoder_output, decoder_context, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_context, decoder_hidden, encoder_outputs)
            loss += criterion(decoder_output, target_variable[di])
            
            # Get most likely word index (highest value) from output
            topv, topi = decoder_output.data.topk(1)
            ni = topi[0][0]
            
            decoder_input = Variable(torch.LongTensor([[ni]])) # Chosen word is next input
            if USE_CUDA: decoder_input = decoder_input.cuda()

            # Stop at end of sentence (not necessary when using known targets)
            if ni == EOS_token: break

    # Backpropagation
    loss.backward()
    torch.nn.utils.clip_grad_norm(encoder.parameters(), clip)
    torch.nn.utils.clip_grad_norm(decoder.parameters(), clip)
    encoder_optimizer.step()
    decoder_optimizer.step()
    
    return loss.data[0] / target_length

In [None]:
def as_minutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def time_since(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (as_minutes(s), as_minutes(rs))

In [None]:
attn_model = 'general'

hidden_size = 500
n_layers = 2
dropout_p = 0.05
embed_dim = 100

# Initialize models
encoder = EncoderRNN(embed_dim, hidden_size, n_layers)
decoder = AttnDecoderRNN(attn_model, hidden_size, embed_dim, n_layers, dropout_p=dropout_p)

# Move models to GPU
if USE_CUDA:
    encoder.cuda()
    decoder.cuda()

# Initialize optimizers and criterion
learning_rate = 0.0001
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
criterion = nn.NLLLoss()


In [None]:


# Configuring training
n_epochs = 50000
plot_every = 200
print_every = 1000

# Keep track of time elapsed and running averages
start = time.time()
plot_losses = []
print_loss_total = 0 # Reset every print_every
plot_loss_total = 0 # Reset every plot_every



In [None]:
for idx, (in_seq, tgt_seq, target, in_len, tgt_len) in enumerate(train_dataloader):

#         in_seq = in_seq.to(device)
#         tgt_seq = tgt_seq.to(device)
#         target = target.to(device)
#         in_len = in_len.to(device)
#         tgt_len = tgt_len.to(device)

In [None]:
# Begin!
for epoch in range(1, n_epochs + 1):
    
    # Get training data for this cycle
    training_pair = train_dataloader
    input_variable = training_pair[0]
    target_variable = training_pair[1]

    # Run the train function
    loss = train(input_variable, target_variable, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)

    # Keep track of loss
    print_loss_total += loss
    plot_loss_total += loss

    if epoch == 0: continue

    if epoch % print_every == 0:
        print_loss_avg = print_loss_total / print_every
        print_loss_total = 0
        print_summary = '%s (%d %d%%) %.4f' % (time_since(start, epoch / n_epochs), epoch, epoch / n_epochs * 100, print_loss_avg)
        print(print_summary)

    if epoch % plot_every == 0:
        plot_loss_avg = plot_loss_total / plot_every
        plot_losses.append(plot_loss_avg)
        plot_loss_total = 0