From d5eb36d5076ef0debd645002c1605564e231cec6 Mon Sep 17 00:00:00 2001 From: Tomasz Kornuta Date: Fri, 3 May 2019 15:10:33 -0700 Subject: [PATCH 1/3] Reordered dimensions of hidden states passed between RNN components (RNN, AttDecGRU): batch first, c4 working in DataParallel --- .../mnist/mnist_classification_softmax.yml | 5 --- .../c4_enc_attndec_resnet152_ewm_cat_is.yml | 12 ++++--- ptp/components/models/attn_decoder_rnn.py | 18 ++++++++--- .../models/recurrent_neural_network.py | 32 +++++++++++++++---- ptp/components/models/sentence_embeddings.py | 4 +-- ptp/components/utils/word_mappings.py | 5 +-- 6 files changed, 51 insertions(+), 25 deletions(-) diff --git a/configs/mnist/mnist_classification_softmax.yml b/configs/mnist/mnist_classification_softmax.yml index adf6da7..eaf749a 100644 --- a/configs/mnist/mnist_classification_softmax.yml +++ b/configs/mnist/mnist_classification_softmax.yml @@ -1,11 +1,6 @@ # Load config defining MNIST problems for training, validation and testing. default_configs: mnist/default_mnist.yml -training: - optimizer: - name: SGD - - pipeline: name: mnist_softmax_classifier diff --git a/configs/vqa_med_2019/c4_classification/c4_enc_attndec_resnet152_ewm_cat_is.yml b/configs/vqa_med_2019/c4_classification/c4_enc_attndec_resnet152_ewm_cat_is.yml index 6101f2e..a7804cf 100644 --- a/configs/vqa_med_2019/c4_classification/c4_enc_attndec_resnet152_ewm_cat_is.yml +++ b/configs/vqa_med_2019/c4_classification/c4_enc_attndec_resnet152_ewm_cat_is.yml @@ -49,7 +49,7 @@ pipeline: pretrained_embeddings_file: glove.6B.100d.txt data_folder: ~/data/vqa-med word_mappings_file: questions.all.word.mappings.csv - fixed_padding: 10 + fixed_padding: 10 # The longest question! additional_tokens: , streams: inputs: questions @@ -66,7 +66,7 @@ pipeline: export_pad_mapping_to_globals: True additional_tokens: , eos_token: True - fixed_padding: 10 + fixed_padding: 10 # The longest question! streams: inputs: answers outputs: indexed_answers @@ -88,9 +88,11 @@ pipeline: # Single layer GRU Encoder encoder: + priority: 3 type: RecurrentNeuralNetwork + # Do not wrap that model with DataDictParallel! + parallelize: False cell_type: GRU - priority: 3 initial_state: Trainable hidden_size: 100 num_layers: 1 @@ -110,7 +112,7 @@ pipeline: reshaper_1: priority: 3.01 type: ReshapeTensor - input_dims: [1, -1, 100] + input_dims: [-1, 1, 100] output_dims: [-1, 100] streams: inputs: s2s_state_output @@ -148,7 +150,7 @@ pipeline: priority: 3.3 type: ReshapeTensor input_dims: [-1, 100] - output_dims: [1, -1, 100] + output_dims: [-1, 1, 100] streams: inputs: question_image_activations outputs: question_image_activations_reshaped diff --git a/ptp/components/models/attn_decoder_rnn.py b/ptp/components/models/attn_decoder_rnn.py index 4d558ed..3cf12af 100644 --- a/ptp/components/models/attn_decoder_rnn.py +++ b/ptp/components/models/attn_decoder_rnn.py @@ -82,7 +82,7 @@ def __init__(self, name, config): # Create dropout layer. self.dropout = torch.nn.Dropout(dropout_rate) - # Create rnn cell. + # Create rnn cell: hardcoded one layer GRU. self.rnn_cell = getattr(torch.nn, "GRU")(self.input_size, self.hidden_size, 1, dropout=dropout_rate, batch_first=True) # Create layers for the attention @@ -149,7 +149,7 @@ def input_data_definitions(self): d[self.key_inputs] = DataDefinition([-1, -1, self.hidden_size], [torch.Tensor], "Batch of encoder outputs [BATCH_SIZE x SEQ_LEN x INPUT_SIZE]") # Input hidden state - d[self.key_input_state] = DataDefinition([1, -1, self.hidden_size], [torch.Tensor], "Batch of RNN last states") + d[self.key_input_state] = DataDefinition([-1, 1, self.hidden_size], [torch.Tensor], "Batch of RNN last hidden states passed from another RNN that will be used as initial [BATCH_SIZE x NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]") return d @@ -167,9 +167,9 @@ def output_data_definitions(self): # Only last prediction. d[self.key_predictions] = DataDefinition([-1, self.prediction_size], [torch.Tensor], "Batch of predictions, each represented as probability distribution over classes [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE]") - # Output hidden state stream + # Output hidden state stream TODO: why do we need that? if self.output_last_state: - d[self.key_output_state] = DataDefinition([1, -1, self.hidden_size], [torch.Tensor], "Batch of RNN last states") + d[self.key_output_state] = DataDefinition([-1, 1, self.hidden_size], [torch.Tensor], "Batch of RNN final hidden states [BATCH_SIZE x NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]") return d @@ -185,9 +185,14 @@ def forward(self, data_dict): inputs = data_dict[self.key_inputs] batch_size = inputs.shape[0] + print("{}: input shape: {}, device: {}\n".format(self.name, inputs.shape, inputs.device)) - # Initialize hidden state. + # Initialize hidden state from inputs - as last hidden state from external component. hidden = data_dict[self.key_input_state] + # For RNNs (aside of LSTM): [BATCH_SIZE x NUM_LAYERS x HIDDEN_SIZE] -> [NUM_LAYERS x BATCH_SIZE x HIDDEN_SIZE] + hidden = hidden.transpose(0,1) + print("{}: hidden shape: {}, device: {}\n".format(self.name, hidden.shape, hidden.device)) + # List that will contain the output sequence activations = [] @@ -232,4 +237,7 @@ def forward(self, data_dict): # Output last hidden state, if requested if self.output_last_state: + # For others: [NUM_LAYERS x BATCH_SIZE x HIDDEN_SIZE] -> [BATCH_SIZE x NUM_LAYERS x HIDDEN_SIZE] + hidden = hidden.transpose(0,1) + # Export last hidden state. data_dict.extend({self.key_output_state: hidden}) diff --git a/ptp/components/models/recurrent_neural_network.py b/ptp/components/models/recurrent_neural_network.py index 8ef2dc3..03bc404 100644 --- a/ptp/components/models/recurrent_neural_network.py +++ b/ptp/components/models/recurrent_neural_network.py @@ -118,7 +118,7 @@ def __init__(self, name, config): except KeyError: raise ConfigurationError( "Invalid RNN type, available options for 'cell_type' are ['LSTM', 'GRU', 'RNN_TANH', 'RNN_RELU'] (currently '{}')".format(self.cell_type)) - # Parameters - for a single sample. + # Parameters - for a single sample 2 x [NUM_LAYERS x BATCH_SIZE x HIDDEN_SIZE] h0 = torch.zeros(self.num_layers, 1, self.hidden_size) c0 = torch.zeros(self.num_layers, 1, self.hidden_size) @@ -228,9 +228,9 @@ def input_data_definitions(self): # Input hidden state if self.initial_state == "Input": if self.cell_type == "LSTM": - d[self.key_input_state] = DataDefinition([2, self.num_layers, -1, self.hidden_size], [torch.Tensor], "Batch of LSTM initial hidden states (h0/c0) passed from another LSTM [2 x NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]") + d[self.key_input_state] = DataDefinition([-1, 2, self.num_layers, self.hidden_size], [torch.Tensor], "Batch of LSTM last hidden states (h0/c0) passed from another LSTM that will be used as initial [BATCH_SIZE x 2 x NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]") else: - d[self.key_input_state] = DataDefinition([self.num_layers, -1, self.hidden_size], [torch.Tensor], "Batch of RNN initial hidden states passed from another RNN [NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]") + d[self.key_input_state] = DataDefinition([-1, self.num_layers, self.hidden_size], [torch.Tensor], "Batch of RNN last hidden states passed from another RNN that will be used as initial [BATCH_SIZE x NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]") return d @@ -253,9 +253,9 @@ def output_data_definitions(self): # Output: hidden state stream. if self.output_last_state: if self.cell_type == "LSTM": - d[self.key_output_state] = DataDefinition([2, self.num_layers, -1, self.hidden_size], [torch.Tensor], "Batch of LSTM final hidden states (h0/c0) [2 x NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]") + d[self.key_output_state] = DataDefinition([-1, 2, self.num_layers, self.hidden_size], [torch.Tensor], "Batch of LSTM final hidden states (h0/c0) [BATCH_SIZE x 2 x NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]") else: - d[self.key_output_state] = DataDefinition([self.num_layers, -1, self.hidden_size], [torch.Tensor], "Batch of RNN final hidden states [NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]") + d[self.key_output_state] = DataDefinition([-1, self.num_layers, self.hidden_size], [torch.Tensor], "Batch of RNN final hidden states [BATCH_SIZE x NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]") return d @@ -285,15 +285,27 @@ def forward(self, data_dict): if inputs.dim() == 2: inputs = inputs.unsqueeze(1) batch_size = inputs.shape[0] + + print("{}: input shape: {}, device: {}\n".format(self.name, inputs.shape, inputs.device)) + # Get initial state, depending on the settings. if self.initial_state == "Input": - # Initialize hidden state. + # Initialize hidden state from inputs - as last hidden state from external component. hidden = data_dict[self.key_input_state] + # Flip batch and num_layer dims so batch will be third/second! + if self.cell_type == 'LSTM': + # For LSTM: [BATCH_SIZE x NUM_LAYERS x 2 x HIDDEN_SIZE] -> [2 x NUM_LAYERS x BATCH_SIZE x HIDDEN_SIZE] + hidden = hidden.transpose(0,2) + else: + # For others: [BATCH_SIZE x NUM_LAYERS x HIDDEN_SIZE] -> [NUM_LAYERS x BATCH_SIZE x HIDDEN_SIZE] + hidden = hidden.transpose(0,1) else: hidden = self.initialize_hiddens_state(batch_size) + print("{}: hidden shape: {}, device: {}\n".format(self.name, hidden.shape, hidden.device)) + activations = [] # Check out operation mode. @@ -355,4 +367,12 @@ def forward(self, data_dict): pass if self.output_last_state: + # Flip batch and num_layer dims so batch will be first! + if self.cell_type == 'LSTM': + # For LSTM: [2 x NUM_LAYERS x BATCH_SIZE x HIDDEN_SIZE] -> [BATCH_SIZE x NUM_LAYERS x 2 x HIDDEN_SIZE] + hidden = hidden.transpose(0,2) + else: + # For others: [NUM_LAYERS x BATCH_SIZE x HIDDEN_SIZE] -> [BATCH_SIZE x NUM_LAYERS x HIDDEN_SIZE] + hidden = hidden.transpose(0,1) + # Export last hidden state. data_dict.extend({self.key_output_state: hidden}) diff --git a/ptp/components/models/sentence_embeddings.py b/ptp/components/models/sentence_embeddings.py index 466ecb8..869a2dc 100644 --- a/ptp/components/models/sentence_embeddings.py +++ b/ptp/components/models/sentence_embeddings.py @@ -111,7 +111,7 @@ def forward(self, data_dict): # Unpack DataDict. inputs = data_dict[self.key_inputs] - #print("{}: input len: {}, device: {}\n".format(self.name, len(inputs), "-")) + print("{}: input len: {}, device: {}\n".format(self.name, len(inputs), "-")) # Get index of padding. pad_index = self.word_to_ix[''] @@ -163,7 +163,7 @@ def forward(self, data_dict): padded_indices = torch.nn.utils.rnn.pad_sequence(indices_list, batch_first=True, padding_value=pad_index) embedds = self.embeddings(padded_indices) - #print("{}: embedds shape: {}, device: {}\n".format(self.name, embedds.shape, embedds.device)) + print("{}: embedds shape: {}, device: {}\n".format(self.name, embedds.shape, embedds.device)) # Add embeddings to datadict. data_dict.extend({self.key_outputs: embedds}) diff --git a/ptp/components/utils/word_mappings.py b/ptp/components/utils/word_mappings.py index 61f5e52..d74c810 100644 --- a/ptp/components/utils/word_mappings.py +++ b/ptp/components/utils/word_mappings.py @@ -146,13 +146,14 @@ def pad_trunc_list(l: list, length: int, padding_value = 0, eos_value = None): :return: None """ - if len(l) < length: if eos_value is not None: l.append(eos_value) l.extend([padding_value]*(length-len(l))) elif len(l) > length: + #print("pad_trunc_list to cat!: {}".format(len(l))) del l[length:] if eos_value is not None: - l[length-1] = eos_value \ No newline at end of file + l[length-1] = eos_value + #exit(1) From 0a55cf22cbc11bfe2b8291f539b2ad85acbc50eb Mon Sep 17 00:00:00 2001 From: Tomasz Kornuta Date: Fri, 3 May 2019 15:23:38 -0700 Subject: [PATCH 2/3] Cleanups, got max question length(1(, but then encoder has also to generate 19 symbols, which makes it really slow --- .../c4_enc_attndec_resnet152_ewm_cat_is.yml | 8 +++--- ptp/components/models/attn_decoder_rnn.py | 5 ++-- .../models/recurrent_neural_network.py | 8 ++---- ptp/components/models/sentence_embeddings.py | 25 ++----------------- ptp/components/utils/word_mappings.py | 2 +- 5 files changed, 11 insertions(+), 37 deletions(-) diff --git a/configs/vqa_med_2019/c4_classification/c4_enc_attndec_resnet152_ewm_cat_is.yml b/configs/vqa_med_2019/c4_classification/c4_enc_attndec_resnet152_ewm_cat_is.yml index a7804cf..69312f7 100644 --- a/configs/vqa_med_2019/c4_classification/c4_enc_attndec_resnet152_ewm_cat_is.yml +++ b/configs/vqa_med_2019/c4_classification/c4_enc_attndec_resnet152_ewm_cat_is.yml @@ -49,7 +49,7 @@ pipeline: pretrained_embeddings_file: glove.6B.100d.txt data_folder: ~/data/vqa-med word_mappings_file: questions.all.word.mappings.csv - fixed_padding: 10 # The longest question! + fixed_padding: 10 # The longest question! max is 19! additional_tokens: , streams: inputs: questions @@ -66,7 +66,7 @@ pipeline: export_pad_mapping_to_globals: True additional_tokens: , eos_token: True - fixed_padding: 10 # The longest question! + fixed_padding: 10 # The longest question! max is 19! streams: inputs: answers outputs: indexed_answers @@ -163,7 +163,7 @@ pipeline: priority: 4 hidden_size: 100 use_logsoftmax: False - autoregression_length: 10 + autoregression_length: 10 # Current implementation requires this value to be equal to fixed_padding in SentenceEmbeddings/Indexer... prediction_mode: Dense dropout_rate: 0.1 streams: @@ -200,8 +200,8 @@ pipeline: # Prediction decoding. prediction_decoder: - type: SentenceIndexer priority: 10 + type: SentenceIndexer # Reverse mode. reverse: True # Use distributions as inputs. diff --git a/ptp/components/models/attn_decoder_rnn.py b/ptp/components/models/attn_decoder_rnn.py index 3cf12af..32a2c14 100644 --- a/ptp/components/models/attn_decoder_rnn.py +++ b/ptp/components/models/attn_decoder_rnn.py @@ -185,14 +185,13 @@ def forward(self, data_dict): inputs = data_dict[self.key_inputs] batch_size = inputs.shape[0] - print("{}: input shape: {}, device: {}\n".format(self.name, inputs.shape, inputs.device)) + #print("{}: input shape: {}, device: {}\n".format(self.name, inputs.shape, inputs.device)) # Initialize hidden state from inputs - as last hidden state from external component. hidden = data_dict[self.key_input_state] # For RNNs (aside of LSTM): [BATCH_SIZE x NUM_LAYERS x HIDDEN_SIZE] -> [NUM_LAYERS x BATCH_SIZE x HIDDEN_SIZE] hidden = hidden.transpose(0,1) - print("{}: hidden shape: {}, device: {}\n".format(self.name, hidden.shape, hidden.device)) - + #print("{}: hidden shape: {}, device: {}\n".format(self.name, hidden.shape, hidden.device)) # List that will contain the output sequence activations = [] diff --git a/ptp/components/models/recurrent_neural_network.py b/ptp/components/models/recurrent_neural_network.py index 03bc404..26105b6 100644 --- a/ptp/components/models/recurrent_neural_network.py +++ b/ptp/components/models/recurrent_neural_network.py @@ -285,10 +285,7 @@ def forward(self, data_dict): if inputs.dim() == 2: inputs = inputs.unsqueeze(1) batch_size = inputs.shape[0] - - print("{}: input shape: {}, device: {}\n".format(self.name, inputs.shape, inputs.device)) - - + #print("{}: input shape: {}, device: {}\n".format(self.name, inputs.shape, inputs.device)) # Get initial state, depending on the settings. if self.initial_state == "Input": @@ -303,8 +300,7 @@ def forward(self, data_dict): hidden = hidden.transpose(0,1) else: hidden = self.initialize_hiddens_state(batch_size) - - print("{}: hidden shape: {}, device: {}\n".format(self.name, hidden.shape, hidden.device)) + #print("{}: hidden shape: {}, device: {}\n".format(self.name, hidden.shape, hidden.device)) activations = [] diff --git a/ptp/components/models/sentence_embeddings.py b/ptp/components/models/sentence_embeddings.py index 869a2dc..caaf96a 100644 --- a/ptp/components/models/sentence_embeddings.py +++ b/ptp/components/models/sentence_embeddings.py @@ -110,8 +110,7 @@ def forward(self, data_dict): # Unpack DataDict. inputs = data_dict[self.key_inputs] - - print("{}: input len: {}, device: {}\n".format(self.name, len(inputs), "-")) + #print("{}: input len: {}, device: {}\n".format(self.name, len(inputs), "-")) # Get index of padding. pad_index = self.word_to_ix[''] @@ -137,33 +136,13 @@ def forward(self, data_dict): #indices_list.append(self.app_state.FloatTensor(output_sample)) indices_list.append(self.app_state.LongTensor(output_sample)) - # Create list of (index,len) tuples. - #order_len = [(i, len(indices_list[i])) for i in range(len(indices_list))] - - # Sort by seq_length. - #sorted_order_len = sorted(order_len, key=lambda x: x[1], reverse=True) - - # Now sort indices list. - #sorted_indices_list = [indices_list[sorted_order_len[i][0]] for i in range(len(indices_list))] - - # Pad the indices list. - #padded_indices = torch.nn.utils.rnn.pad_sequence(sorted_indices_list, batch_first=True) - - # Revert to the original batch order!!! - #new_old_order = [(i, sorted_order_len[i][0]) for i in range(len(indices_list))] - #sorted_new_old_order = sorted(new_old_order, key=lambda x: x[1]) - #unsorted_padded_indices = [padded_indices[sorted_new_old_order[i][0]] for i in range(len(indices_list))] - - # Change to tensor. - #unsorted_padded_indices_tensor = torch.stack( [self.app_state.FloatTensor(lst) for lst in unsorted_padded_indices] ) - # Embedd indices. #embedds = self.embeddings(unsorted_padded_indices_tensor) padded_indices = torch.nn.utils.rnn.pad_sequence(indices_list, batch_first=True, padding_value=pad_index) embedds = self.embeddings(padded_indices) - print("{}: embedds shape: {}, device: {}\n".format(self.name, embedds.shape, embedds.device)) + #print("{}: embedds shape: {}, device: {}\n".format(self.name, embedds.shape, embedds.device)) # Add embeddings to datadict. data_dict.extend({self.key_outputs: embedds}) diff --git a/ptp/components/utils/word_mappings.py b/ptp/components/utils/word_mappings.py index d74c810..014c7a9 100644 --- a/ptp/components/utils/word_mappings.py +++ b/ptp/components/utils/word_mappings.py @@ -153,7 +153,7 @@ def pad_trunc_list(l: list, length: int, padding_value = 0, eos_value = None): elif len(l) > length: #print("pad_trunc_list to cat!: {}".format(len(l))) + #exit(1) del l[length:] if eos_value is not None: l[length-1] = eos_value - #exit(1) From 6233df5f4fd9c58164aee019f767d0d469308dab Mon Sep 17 00:00:00 2001 From: Tomasz Kornuta Date: Fri, 3 May 2019 17:09:33 -0700 Subject: [PATCH 3/3] cleanups, hyperparams for c4 enc_dec --- .../c4_enc_attndec_resnet152_ewm_cat_is.yml | 6 ++---- ptp/components/models/sentence_embeddings.py | 14 +++++++------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/configs/vqa_med_2019/c4_classification/c4_enc_attndec_resnet152_ewm_cat_is.yml b/configs/vqa_med_2019/c4_classification/c4_enc_attndec_resnet152_ewm_cat_is.yml index 69312f7..2b603a1 100644 --- a/configs/vqa_med_2019/c4_classification/c4_enc_attndec_resnet152_ewm_cat_is.yml +++ b/configs/vqa_med_2019/c4_classification/c4_enc_attndec_resnet152_ewm_cat_is.yml @@ -4,12 +4,11 @@ default_configs: vqa_med_2019/default_vqa_med_2019.yml # Training parameters: training: problem: - batch_size: 64 + batch_size: 200 # requires to use 4 GPUs! categories: C4 question_preprocessing: lowercase, remove_punctuation, tokenize #, random_remove_stop_words #,random_shuffle_words answer_preprocessing: lowercase, remove_punctuation, tokenize export_sample_weights: ~/data/vqa-med/answers.c4.weights.csv - batch_size: 32 sampler: weights: ~/data/vqa-med/answers.c4.weights.csv dataloader: @@ -23,11 +22,10 @@ training: # Validation parameters: validation: problem: - batch_size: 64 + batch_size: 200 categories: C4 question_preprocessing: lowercase, remove_punctuation, tokenize answer_preprocessing: lowercase, remove_punctuation, tokenize - batch_size: 32 dataloader: num_workers: 4 diff --git a/ptp/components/models/sentence_embeddings.py b/ptp/components/models/sentence_embeddings.py index caaf96a..29671c6 100644 --- a/ptp/components/models/sentence_embeddings.py +++ b/ptp/components/models/sentence_embeddings.py @@ -73,6 +73,10 @@ def __init__(self, name, config): emb_vectors = emb.load_pretrained_embeddings(self.logger, self.data_folder, self.config["pretrained_embeddings_file"], self.word_to_ix, self.embeddings_size) self.embeddings.weight = torch.nn.Parameter(emb_vectors) + # Get index of from vocabulary. + self.pad_index = self.word_to_ix[''] + + def input_data_definitions(self): """ @@ -112,9 +116,6 @@ def forward(self, data_dict): inputs = data_dict[self.key_inputs] #print("{}: input len: {}, device: {}\n".format(self.name, len(inputs), "-")) - # Get index of padding. - pad_index = self.word_to_ix[''] - indices_list = [] # Process samples 1 by one. for sample in inputs: @@ -131,15 +132,14 @@ def forward(self, data_dict): # Apply fixed padding to all sequences if requested # Otherwise let torch.nn.utils.rnn.pad_sequence handle it and choose a dynamic padding if self.fixed_padding > 0: - pad_trunc_list(output_sample, self.fixed_padding, padding_value=pad_index) + pad_trunc_list(output_sample, self.fixed_padding, padding_value=self.pad_index) #indices_list.append(self.app_state.FloatTensor(output_sample)) indices_list.append(self.app_state.LongTensor(output_sample)) + # Padd indices using pad index retrieved from vocabulary. + padded_indices = torch.nn.utils.rnn.pad_sequence(indices_list, batch_first=True, padding_value=self.pad_index) # Embedd indices. - #embedds = self.embeddings(unsorted_padded_indices_tensor) - - padded_indices = torch.nn.utils.rnn.pad_sequence(indices_list, batch_first=True, padding_value=pad_index) embedds = self.embeddings(padded_indices) #print("{}: embedds shape: {}, device: {}\n".format(self.name, embedds.shape, embedds.device))