From c307e64f6a8d109bb3ac3075e2ae778a7fdc2ff4 Mon Sep 17 00:00:00 2001 From: tkornut Date: Wed, 10 Apr 2019 11:09:04 -0700 Subject: [PATCH 01/11] initial version of pipeline for testing VFLenet5, masks not working yet --- .../models/variational_flow_lenet5.yml | 46 +++++ .../mnist/mnist_classification_vf_lenet5.yml | 68 +++++++ ptp/components/models/__init__.py | 2 + .../models/variational_flow_lenet5.py | 170 ++++++++++++++++++ 4 files changed, 286 insertions(+) create mode 100644 configs/default/components/models/variational_flow_lenet5.yml create mode 100644 configs/mnist/mnist_classification_vf_lenet5.yml create mode 100644 ptp/components/models/variational_flow_lenet5.py diff --git a/configs/default/components/models/variational_flow_lenet5.yml b/configs/default/components/models/variational_flow_lenet5.yml new file mode 100644 index 0000000..11b31a5 --- /dev/null +++ b/configs/default/components/models/variational_flow_lenet5.yml @@ -0,0 +1,46 @@ +# This file defines the default values for the Variational Flow LeNet5 model. + +#################################################################### +# 1. CONFIGURATION PARAMETERS that will be LOADED by the component. +#################################################################### + +streams: + #################################################################### + # 2. Keymappings associated with INPUT and OUTPUT streams. + #################################################################### + + # Stream containing batch of images (INPUT) + inputs: inputs + + # Stream containing batch of targets (used for masks) (INPUT) + targets: targets + + # Streams containing predictions (OUTPUT) + flow1_predictions: flow1_predictions + flow2_predictions: flow2_predictions + + # Streams containing predictions (OUTPUT) + flow1_masks: flow1_masks + flow2_masks: flow2_masks + +globals: + #################################################################### + # 3. Keymappings of variables that will be RETRIEVED from GLOBALS. + #################################################################### + + # Size of the prediction (RETRIEVED) + flow1_prediction_size: flow1_prediction_size + flow2_prediction_size: flow2_prediction_size + + # Word mappings used for filtering. + flow1_word_mappings: flow1_word_mappings + flow2_word_mappings: flow2_word_mappings + + #################################################################### + # 4. Keymappings associated with GLOBAL variables that will be SET. + #################################################################### + + #################################################################### + # 5. Keymappings associated with statistics that will be ADDED. + #################################################################### + diff --git a/configs/mnist/mnist_classification_vf_lenet5.yml b/configs/mnist/mnist_classification_vf_lenet5.yml new file mode 100644 index 0000000..ddb7841 --- /dev/null +++ b/configs/mnist/mnist_classification_vf_lenet5.yml @@ -0,0 +1,68 @@ +# Load config defining MNIST problems for training, validation and testing. +default_configs: mnist/default_mnist.yml + +# Training parameters - overwrite defaults: +training: + problem: + resize_image: [32, 32] + +# Validation parameters - overwrite defaults: +validation: + problem: + resize_image: [32, 32] + +# Testing parameters - overwrite defaults: +testing: + problem: + resize_image: [32, 32] + +# Definition of the pipeline. +pipeline: + name: mnist_variational_flow_lenet5 + + # Add global variables. + global_publisher: + type: GlobalVariablePublisher + priority: 0 + keys: [num_classes1, num_classes2, word_to_ix1, word_to_ix2] + values: [3, 7, {"Zero": 0, "One": 1, "Two": 2}, {"Three": 3, "Four": 4, "Five": 5, "Six": 6, "Seven": 7, "Eight": 8, "Nine": 9}] + + # Image classifier. + image_classifier: + type: VariationalFlowLeNet5 + priority: 1 + globals: + flow1_prediction_size: num_classes1 + flow2_prediction_size: num_classes2 + flow1_word_mappings: word_to_ix1 + flow2_word_mappings: word_to_ix2 + + # Masked loss. + nllloss: + type: NLLLoss + use_masks: True + streams: + predictions: flow1_predictions + masks: flow1_masks + + # Statistics. + batch_size: + type: BatchSizeStatistics + streams: + predictions: flow1_predictions + + accuracy: + type: AccuracyStatistics + streams: + predictions: flow1_predictions + + precision_recall: + type: PrecisionRecallStatistics + priority: 100.3 + use_word_mappings: True + show_class_scores: True + streams: + predictions: flow1_predictions + globals: + word_mappings: label_word_mappings +#: pipeline diff --git a/ptp/components/models/__init__.py b/ptp/components/models/__init__.py index 32e95b0..8551125 100644 --- a/ptp/components/models/__init__.py +++ b/ptp/components/models/__init__.py @@ -6,6 +6,7 @@ from .model import Model from .recurrent_neural_network import RecurrentNeuralNetwork from .sentence_embeddings import SentenceEmbeddings +from .variational_flow_lenet5 import VariationalFlowLeNet5 __all__ = [ 'ConvNetEncoder', @@ -16,4 +17,5 @@ 'Model', 'RecurrentNeuralNetwork', 'SentenceEmbeddings', + 'VariationalFlowLeNet5', ] diff --git a/ptp/components/models/variational_flow_lenet5.py b/ptp/components/models/variational_flow_lenet5.py new file mode 100644 index 0000000..e46926b --- /dev/null +++ b/ptp/components/models/variational_flow_lenet5.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (C) IBM Corporation 2018 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__author__ = "Tomasz Kornuta & Vincent Marois" + + +import torch + +from ptp.components.models.model import Model +from ptp.data_types.data_definition import DataDefinition + + +class VariationalFlowLeNet5(Model): + """ + A proof of concept of variational flow with LeNet-5 model for MNIST digits classification. + + Uses masks (depending on targets here) to control the where each particular sample "flows through during backpropagation". + + For that purpose it has two flows, first for first subset of classes (0-2) and second for the remainder (3-9). + + """ + def __init__(self, name, config): + """ + Initializes the model, retrieves key mappings, creates two flows. + + :param name: Name of the model (taken from the configuration file). + + :param config: Parameters read from configuration file. + :type config: ``ptp.configuration.ConfigInterface`` + + """ + super(VariationalFlowLeNet5, self).__init__(name, VariationalFlowLeNet5, config) + + # Get key mappings. + self.key_inputs = self.stream_keys["inputs"] + self.key_targets = self.stream_keys["targets"] + + self.key_flow1_predictions = self.stream_keys["flow1_predictions"] + self.key_flow1_masks = self.stream_keys["flow1_masks"] + self.key_flow2_predictions = self.stream_keys["flow2_predictions"] + self.key_flow2_masks = self.stream_keys["floww_masks"] + + # Retrieve prediction sizes from globals. + self.flow1_prediction_size = self.globals["flow1_prediction_size"] + self.flow2_prediction_size = self.globals["flow2_prediction_size"] + + # Retrieve word mappings from globals. + self.flow1_word_mappings = self.globals["flow1_word_mappings"] + self.flow2_word_mappings = self.globals["flow2_word_mappings"] + + + # Create flow 1. + self.flow1_image_encoder = torch.nn.Sequential( + torch.nn.Conv2d(1, 6, kernel_size=(5, 5)), + torch.nn.ReLU(inplace=True), + torch.nn.MaxPool2d(kernel_size=(2, 2), stride=2), + + torch.nn.Conv2d(6, 16, kernel_size=(5, 5)), + torch.nn.ReLU(inplace=True), + torch.nn.MaxPool2d(kernel_size=(2, 2), stride=2), + + torch.nn.Conv2d(16, 120, kernel_size=(5, 5)), + torch.nn.ReLU(inplace=True) + ) + + self.flow1_classifier = torch.nn.Sequential( + torch.nn.Linear(120, 84), + torch.nn.ReLU(inplace=True), + torch.nn.Linear(84, 10), # FOR NOW + torch.nn.LogSoftmax(dim=1) + ) + + # Create flow 2. + self.flow2_image_encoder = torch.nn.Sequential( + torch.nn.Conv2d(1, 6, kernel_size=(5, 5)), + torch.nn.ReLU(inplace=True), + torch.nn.MaxPool2d(kernel_size=(2, 2), stride=2), + + torch.nn.Conv2d(6, 16, kernel_size=(5, 5)), + torch.nn.ReLU(inplace=True), + torch.nn.MaxPool2d(kernel_size=(2, 2), stride=2), + + torch.nn.Conv2d(16, 120, kernel_size=(5, 5)), + torch.nn.ReLU(inplace=True) + ) + + self.flow2_classifier = torch.nn.Sequential( + torch.nn.Linear(120, 84), + torch.nn.ReLU(inplace=True), + torch.nn.Linear(84, self.flow2_prediction_size), + torch.nn.LogSoftmax(dim=1) + ) + + + def input_data_definitions(self): + """ + Function returns a dictionary with definitions of input data that are required by the component. + + :return: dictionary containing input data definitions (each of type :py:class:`ptp.utils.DataDefinition`). + """ + return { + self.key_inputs: DataDefinition([-1, 1, 32, 32], [torch.Tensor], "Batch of images [BATCH_SIZE x IMAGE_DEPTH x IMAGE_HEIGHT x IMAGE WIDTH]"), + self.key_targets: DataDefinition([-1], [torch.Tensor], "Batch of targets [BATCH_SIZE]"), + } + + + def output_data_definitions(self): + """ + Function returns a dictionary with definitions of output data produced the component. + + :return: dictionary containing output data definitions (each of type :py:class:`ptp.utils.DataDefinition`). + """ + return { # 10 for now! + self.key_flow1_predictions: DataDefinition([-1, 10], [torch.Tensor], "Batch of flow 1predictions, each represented as probability distribution over classes [BATCH_SIZE x FLOW1_PREDICTION_SIZE]"), + self.key_flow1_masks: DataDefinition([-1], [torch.Tensor], "Batch of masks for flow 1 [BATCH_SIZE]"), + self.key_flow2_predictions: DataDefinition([-1, self.flow2_prediction_size], [torch.Tensor], "Batch of flow 2 predictions, each represented as probability distribution over classes [BATCH_SIZE x FLOW2_PREDICTION_SIZE]"), + self.key_flow2_masks: DataDefinition([-1], [torch.Tensor], "Batch of masks for flow 2 [BATCH_SIZE]"), + } + + + def forward(self, data_dict): + """ + Main forward pass of the model. + In fact performs two passes, using masks generated on the fly using targets. + + :param data_dict: DataDict({'images',**}), where: + + - images: [batch_size, num_channels, width, height] + + :type data_dict: ``miprometheus.utils.DataDict`` + + :return: Predictions [batch_size, num_classes] + + """ + # Produce masks. + # TODO. + + # Get images. + img = data_dict[self.key_inputs] + + # Pass inputs through flow 1. + x1 = self.flow1_image_encoder(img) + x1 = x1.view(-1, 120) + x1 = self.flow1_classifier(x1) + + # Pass inputs through flow 2. + x2 = self.flow2_image_encoder(img) + x2 = x2.view(-1, 120) + x2 = self.flow2_classifier(x2) + + + # Add predictions to datadict. + data_dict.extend({ + self.key_flow1_predictions: x1, + self.key_flow2_predictions: x2, + }) From c347296911b1c49595e5b5f94276bbfe9663eab9 Mon Sep 17 00:00:00 2001 From: tkornut Date: Wed, 10 Apr 2019 14:26:37 -0700 Subject: [PATCH 02/11] variational flow 1 --- .../default/components/losses/nll_loss.yml | 6 ++++ .../mnist/mnist_classification_vf_lenet5.yml | 8 +++-- ptp/components/losses/nll_loss.py | 31 ++++++++++++++++--- .../models/variational_flow_lenet5.py | 19 ++++++++++-- 4 files changed, 56 insertions(+), 8 deletions(-) diff --git a/configs/default/components/losses/nll_loss.yml b/configs/default/components/losses/nll_loss.yml index 12cb4d9..858d86b 100644 --- a/configs/default/components/losses/nll_loss.yml +++ b/configs/default/components/losses/nll_loss.yml @@ -12,6 +12,9 @@ num_targets_dims: 1 # Options: NLLLoss | CrossEntropyLoss (NOT OPERATIONAL YET!) # loss_function: NLLLoss +# When set to True, performs masking of selected predictions (LOADED) +use_masking: False + streams: #################################################################### # 2. Keymappings associated with INPUT and OUTPUT streams. @@ -23,6 +26,9 @@ streams: # Stream containing batch of predictions (INPUT) predictions: predictions + # Stream containing masks used for masking of selected predictions (INPUT) + masks: masks + # Stream containing loss (OUTPUT) loss: loss diff --git a/configs/mnist/mnist_classification_vf_lenet5.yml b/configs/mnist/mnist_classification_vf_lenet5.yml index ddb7841..40877f5 100644 --- a/configs/mnist/mnist_classification_vf_lenet5.yml +++ b/configs/mnist/mnist_classification_vf_lenet5.yml @@ -5,6 +5,10 @@ default_configs: mnist/default_mnist.yml training: problem: resize_image: [32, 32] + #batch_size: 32 + optimizer: + #name: Adam + lr: 0.001 # Validation parameters - overwrite defaults: validation: @@ -25,7 +29,7 @@ pipeline: type: GlobalVariablePublisher priority: 0 keys: [num_classes1, num_classes2, word_to_ix1, word_to_ix2] - values: [3, 7, {"Zero": 0, "One": 1, "Two": 2}, {"Three": 3, "Four": 4, "Five": 5, "Six": 6, "Seven": 7, "Eight": 8, "Nine": 9}] + values: [9, 1, {"Zero": 0, "One": 1, "Two": 2, "Three": 3, "Four": 4, "Five": 5, "Six": 6, "Seven": 7, "Eight": 8}, {"Nine": 9}] # Image classifier. image_classifier: @@ -40,7 +44,7 @@ pipeline: # Masked loss. nllloss: type: NLLLoss - use_masks: True + use_masking: True streams: predictions: flow1_predictions masks: flow1_masks diff --git a/ptp/components/losses/nll_loss.py b/ptp/components/losses/nll_loss.py index 25c3ec4..07281f8 100644 --- a/ptp/components/losses/nll_loss.py +++ b/ptp/components/losses/nll_loss.py @@ -36,12 +36,18 @@ def __init__(self, name, config): # Call constructors of parent classes. Loss.__init__(self, name, NLLLoss, config) + self.key_masks = self.stream_keys["masks"] + # Set loss. self.loss_function = nn.NLLLoss() # Get number of targets dimensions. self.num_targets_dims = self.config["num_targets_dims"] + # Get masking flag. + self.use_masking = self.config["use_masking"] + + def input_data_definitions(self): """ @@ -49,10 +55,13 @@ def input_data_definitions(self): :return: dictionary containing input data definitions (each of type :py:class:`ptp.utils.DataDefinition`). """ - return { + input_defs = { self.key_targets: DataDefinition([-1]*self.num_targets_dims, [torch.Tensor], "Batch of targets (indices) [DIM 1 x DIM 2 x ... ]"), self.key_predictions: DataDefinition([-1]*(self.num_targets_dims+1), [torch.Tensor], "Batch of predictions, represented as tensor with probability distribution over classes [DIM 1 x DIM x ... x NUM_CLASSES]") } + if self.use_masking: + input_defs[self.key_masks] = DataDefinition([-1], [torch.Tensor], "Batch of masks [BATCH_SIZE]") + return input_defs def output_data_definitions(self): """ @@ -81,10 +90,24 @@ def __call__(self, data_dict): targets = data_dict[self.key_targets] predictions = data_dict[self.key_predictions] - if isinstance(targets, (list,)): - # Change to long tensor, as expected by nllloss. - targets = torch.LongTensor(targets) + #print("targets = ",targets) + #print("predictions = ",predictions) + + #if isinstance(targets, (list,)): + # # Change to long tensor, as expected by nllloss. + # targets = torch.LongTensor(targets) + + # Mask predictions if option set. + + if self.use_masking: + masks = data_dict[self.key_masks] + targets = targets * masks.type(self.app_state.LongTensor) + #print("unsqueezed masks = ", masks.unsqueeze(1)) + predictions = predictions * masks.unsqueeze(1).type(self.app_state.FloatTensor) + #print("masked targets = ",targets) + #print("masked predictions = ",predictions) + # reshape. last_dim = predictions.size(-1) diff --git a/ptp/components/models/variational_flow_lenet5.py b/ptp/components/models/variational_flow_lenet5.py index e46926b..a059812 100644 --- a/ptp/components/models/variational_flow_lenet5.py +++ b/ptp/components/models/variational_flow_lenet5.py @@ -146,8 +146,21 @@ def forward(self, data_dict): :return: Predictions [batch_size, num_classes] """ - # Produce masks. - # TODO. + targets = data_dict[self.key_targets] + #print("targets = \n", targets) + + # Produce masks - for flow 1. + flow1_masks = torch.zeros(targets.size(0), requires_grad=False).type(self.app_state.ByteTensor) + for _, val in self.flow1_word_mappings.items(): + flow1_masks = flow1_masks + (targets == val) + #print("flow1_masks = \n", flow1_masks) + + # Produce masks - for flow 2. + flow2_masks = torch.zeros(targets.size(0), requires_grad=False).type(self.app_state.ByteTensor) + for _, val in self.flow2_word_mappings.items(): + flow2_masks = flow2_masks + (targets == val) + #print("flow2_masks = \n", flow2_masks) + #exit(1) # Get images. img = data_dict[self.key_inputs] @@ -167,4 +180,6 @@ def forward(self, data_dict): data_dict.extend({ self.key_flow1_predictions: x1, self.key_flow2_predictions: x2, + self.key_flow1_masks: flow1_masks, + self.key_flow2_masks: flow2_masks, }) From 148f76d9d5a02a64270703780840e87428d4b989 Mon Sep 17 00:00:00 2001 From: tkornut Date: Wed, 10 Apr 2019 15:58:30 -0700 Subject: [PATCH 03/11] VF Lenet5: two flows --- .../components/masking/string_to_mask.yml | 40 +++++ configs/mnist/default_mnist.yml | 4 +- .../mnist/mnist_classification_vf_2lenet5.yml | 156 ++++++++++++++++++ .../mnist/mnist_classification_vf_lenet5.yml | 18 +- ptp/__init__.py | 2 + ptp/application/pipeline_manager.py | 2 +- ptp/components/masking/__init__.py | 5 + ptp/components/masking/string_to_mask.py | 118 +++++++++++++ .../problems/image_to_class/mnist.py | 9 +- .../publishers/precision_recall_statistics.py | 7 +- 10 files changed, 346 insertions(+), 15 deletions(-) create mode 100644 configs/default/components/masking/string_to_mask.yml create mode 100644 configs/mnist/mnist_classification_vf_2lenet5.yml create mode 100644 ptp/components/masking/__init__.py create mode 100644 ptp/components/masking/string_to_mask.py diff --git a/configs/default/components/masking/string_to_mask.yml b/configs/default/components/masking/string_to_mask.yml new file mode 100644 index 0000000..b42b4d9 --- /dev/null +++ b/configs/default/components/masking/string_to_mask.yml @@ -0,0 +1,40 @@ +# This file defines the default values for the String To Mask component. + +#################################################################### +# 1. CONFIGURATION PARAMETERS that will be LOADED by the component. +#################################################################### + +# Value that will be used when word is out of vocavbulary (LOADED) +# (Mask for that element will be 0 as well) +out_of_vocabulary_value: -1 + +streams: + #################################################################### + # 2. Keymappings associated with INPUT and OUTPUT streams. + #################################################################### + + # Stream containing input strings (INPUT) + strings: strings + + # Stream containing output masks (OUTPUT) + masks: masks + + # Stream containing output indices (OUTPUT) + string_indices: string_indices + +globals: + #################################################################### + # 3. Keymappings of variables that will be RETRIEVED from GLOBALS. + #################################################################### + + #################################################################### + # 4. Keymappings associated with GLOBAL variables that will be SET. + #################################################################### + + # Vocabulary used to produce masks and indices (GET) + word_mappings: word_mappings + + #################################################################### + # 5. Keymappings associated with statistics that will be ADDED. + #################################################################### + diff --git a/configs/mnist/default_mnist.yml b/configs/mnist/default_mnist.yml index f601a73..c0fef34 100644 --- a/configs/mnist/default_mnist.yml +++ b/configs/mnist/default_mnist.yml @@ -45,12 +45,12 @@ pipeline: # Loss nllloss: type: NLLLoss - priority: 10 + priority: 10.0 # Statistics. batch_size: type: BatchSizeStatistics - priority: 100.1 + priority: 100.0 accuracy: type: AccuracyStatistics diff --git a/configs/mnist/mnist_classification_vf_2lenet5.yml b/configs/mnist/mnist_classification_vf_2lenet5.yml new file mode 100644 index 0000000..0767142 --- /dev/null +++ b/configs/mnist/mnist_classification_vf_2lenet5.yml @@ -0,0 +1,156 @@ +# Load config defining MNIST problems for training, validation and testing. +default_configs: mnist/default_mnist.yml + +# Training parameters - overwrite defaults: +training: + problem: + resize_image: [32, 32] + #batch_size: 32 + optimizer: + #name: Adam + lr: 0.001 + +# Validation parameters - overwrite defaults: +validation: + problem: + resize_image: [32, 32] + +# Testing parameters - overwrite defaults: +testing: + problem: + resize_image: [32, 32] + +# Definition of the pipeline. +pipeline: + name: mnist_variational_flow_2lenet5 + # Disable components for "default" flow. + disable: nllloss, accuracy, precision_recall + + # Add global variables. + global_publisher: + type: GlobalVariablePublisher + priority: 0 + keys: [num_classes1, num_classes2, word_to_ix1, word_to_ix2] + values: [4, 6, {"Zero": 0, "One": 1, "Two": 2, "Three": 3}, {"Four": 0, "Five": 1, "Six": 2, "Seven": 3, "Eight": 4, "Nine": 5}] + + ######### Flow 1 ################# + flow1_string_to_mask: + type: StringToMask + priority: 1.1 + globals: + word_mappings: word_to_ix1 + streams: + strings: labels + string_indices: flow1_targets + masks: flow1_masks + + # Image classifier. + flow1_image_classifier: + type: LeNet5 + priority: 1.2 + globals: + prediction_size: num_classes1 + word_mappings: word_to_ix1 + streams: + inputs: inputs + predictions: flow1_predictions + + # Masked loss. + flow1_nllloss: + type: NLLLoss + priority: 1.31 + use_masking: True + streams: + targets: flow1_targets + predictions: flow1_predictions + masks: flow1_masks + loss: flow1_loss + + # Statistics. + flow1_accuracy: + type: AccuracyStatistics + priority: 1.32 + streams: + predictions: flow1_predictions + targets: flow1_targets + statistics: + accuracy: flow1_accuracy + + flow1_precision_recall: + type: PrecisionRecallStatistics + priority: 1.33 + use_word_mappings: True + show_class_scores: True + show_confusion_matrix: True + globals: + word_mappings: word_to_ix1 + num_classes: num_classes1 + streams: + targets: flow1_targets + predictions: flow1_predictions + statistics: + precision: flow1_precision + recall: flow1_recall + f1score: flow1_f1score + + ######### Flow 2 ################# + flow2_string_to_mask: + type: StringToMask + priority: 2.1 + globals: + word_mappings: word_to_ix2 + streams: + strings: labels + string_indices: flow2_targets + masks: flow2_masks + + # Image classifier. + flow2_image_classifier: + type: LeNet5 + priority: 2.2 + globals: + prediction_size: num_classes2 + word_mappings: word_to_ix2 + streams: + inputs: inputs + predictions: flow2_predictions + + # Masked loss. + flow2_nllloss: + type: NLLLoss + priority: 2.31 + use_masking: True + streams: + targets: flow2_targets + predictions: flow2_predictions + masks: flow2_masks + loss: flow2_loss + + # Statistics. + flow2_accuracy: + type: AccuracyStatistics + priority: 2.32 + streams: + targets: flow2_targets + predictions: flow2_predictions + statistics: + accuracy: flow2_accuracy + + flow2_precision_recall: + type: PrecisionRecallStatistics + priority: 2.33 + use_word_mappings: True + show_class_scores: True + show_confusion_matrix: True + globals: + word_mappings: word_to_ix2 + num_classes: num_classes2 + streams: + targets: flow2_targets + predictions: flow2_predictions + statistics: + precision: flow2_precision + recall: flow2_recall + f1score: flow2_f1score + +#: pipeline diff --git a/configs/mnist/mnist_classification_vf_lenet5.yml b/configs/mnist/mnist_classification_vf_lenet5.yml index 40877f5..233fbfa 100644 --- a/configs/mnist/mnist_classification_vf_lenet5.yml +++ b/configs/mnist/mnist_classification_vf_lenet5.yml @@ -23,13 +23,15 @@ testing: # Definition of the pipeline. pipeline: name: mnist_variational_flow_lenet5 + # Disable components for "default" flow. + disable: nllloss, accuracy, precision_recall # Add global variables. global_publisher: type: GlobalVariablePublisher - priority: 0 + priority: 0.1 keys: [num_classes1, num_classes2, word_to_ix1, word_to_ix2] - values: [9, 1, {"Zero": 0, "One": 1, "Two": 2, "Three": 3, "Four": 4, "Five": 5, "Six": 6, "Seven": 7, "Eight": 8}, {"Nine": 9}] + values: [3, 7, {"Zero": 0, "One": 1, "Two": 2, "Three": 3}, {"Four": 4, "Five": 5, "Six": 6, "Seven": 7, "Eight": 8, "Nine": 9}] # Image classifier. image_classifier: @@ -42,25 +44,23 @@ pipeline: flow2_word_mappings: word_to_ix2 # Masked loss. - nllloss: + nllloss_flow1: type: NLLLoss + priority: 10.1 use_masking: True streams: predictions: flow1_predictions masks: flow1_masks # Statistics. - batch_size: - type: BatchSizeStatistics - streams: - predictions: flow1_predictions - accuracy: + accuracy_flow1: type: AccuracyStatistics + priority: 100.2 streams: predictions: flow1_predictions - precision_recall: + precision_recall_flow1: type: PrecisionRecallStatistics priority: 100.3 use_word_mappings: True diff --git a/ptp/__init__.py b/ptp/__init__.py index 2c7575b..a9c69b2 100644 --- a/ptp/__init__.py +++ b/ptp/__init__.py @@ -7,6 +7,8 @@ from .components.losses import * +from .components.masking import * + from .components.models import * from .components.problems.problem import Problem diff --git a/ptp/application/pipeline_manager.py b/ptp/application/pipeline_manager.py index 83705e7..a91b766 100644 --- a/ptp/application/pipeline_manager.py +++ b/ptp/application/pipeline_manager.py @@ -81,7 +81,7 @@ def build(self, use_logger=True): disabled_components = '' # Add components to disable by the ones from configuration file. if "disable" in self.config: - disabled_components = [*disabled_components, *self.config["disable"].split(",")] + disabled_components = [*disabled_components, *self.config["disable"].replace(" ","").split(",")] # Add components to disable by the ones from command line arguments. if (self.app_state.args is not None) and (self.app_state.args.disable != ''): disabled_components = [*disabled_components, *self.app_state.args.disable.split(",")] diff --git a/ptp/components/masking/__init__.py b/ptp/components/masking/__init__.py new file mode 100644 index 0000000..a6cf1b2 --- /dev/null +++ b/ptp/components/masking/__init__.py @@ -0,0 +1,5 @@ +from .string_to_mask import StringToMask + +__all__ = [ + 'StringToMask', + ] diff --git a/ptp/components/masking/string_to_mask.py b/ptp/components/masking/string_to_mask.py new file mode 100644 index 0000000..f745570 --- /dev/null +++ b/ptp/components/masking/string_to_mask.py @@ -0,0 +1,118 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) tkornuta, IBM Corporation 2019 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__author__ = "Tomasz Kornuta" + +import torch + +from ptp.components.component import Component +from ptp.data_types.data_definition import DataDefinition + + +class StringToMask(Component): + """ + Class responsible for producing masks for strings using the provided word mappings. + Additionally, it returns the associated string indices. + """ + + def __init__(self, name, config): + """ + Initializes object. Loads key and word mappings. + + :param name: Loss name. + :type name: str + + :param config: Dictionary of parameters (read from the configuration ``.yaml`` file). + :type config: :py:class:`ptp.configuration.ConfigInterface` + + """ + # Call constructors of parent classes. + Component.__init__(self, name, StringToMask, config) + + # Get key mappings. + self.key_strings = self.stream_keys["strings"] + self.key_masks = self.stream_keys["masks"] + self.key_string_indices = self.stream_keys["string_indices"] + + # Retrieve word mappings from globals. + self.word_to_ix = self.globals["word_mappings"] + + # Get value from configuration. + self.out_of_vocabulary_value = self.config["out_of_vocabulary_value"] + + def input_data_definitions(self): + """ + Function returns a dictionary with definitions of input data that are required by the component. + + :return: dictionary containing input data definitions (each of type :py:class:`ptp.data_types.DataDefinition`). + """ + return { + self.key_strings: DataDefinition([-1, 1], [list, str], "Batch of strings, each being treated as a single 'vocabulary entry' (word) [BATCH_SIZE] x [STRING]") + } + + def output_data_definitions(self): + """ + Function returns a empty dictionary with definitions of output data produced the component. + + :return: Empty dictionary. + """ + return { + self.key_masks: DataDefinition([-1], [torch.Tensor], "Batch of masks [BATCH_SIZE]"), + self.key_string_indices: DataDefinition([-1], [torch.Tensor], "Batch of indices corresponging to inputs strings when using provided word mappings [BATCH_SIZE]") + } + + + def __call__(self, data_dict): + """ + Encodes "inputs" in the format of a single tensor. + Stores reshaped tensor in "outputs" field of in data_dict. + + :param data_dict: :py:class:`ptp.utils.DataDict` object containing (among others): + + - "inputs": expected input field containing tensor [BATCH_SIZE x ...] + + - "outputs": added output field containing tensor [BATCH_SIZE x ...] + """ + # Get inputs strings. + strings = data_dict[self.key_strings] + + masks = torch.zeros(len(strings), requires_grad=False).type(self.app_state.ByteTensor) + + outputs_list = [] + # Process samples 1 by 1. + for i,sample in enumerate(strings): + assert not isinstance(sample, (list,)), 'This encoder requires input sample to contain a single word' + # Process single token. + if sample in self.word_to_ix.keys(): + output_sample = self.word_to_ix[sample] + masks[i] = 1 + else: + # Word out of vocabulary. + output_sample = self.out_of_vocabulary_value + outputs_list.append(output_sample) + # Transform to tensor. + output_indices = torch.tensor(outputs_list, requires_grad=False).type(self.app_state.LongTensor) + + #print("strings ", strings) + #print("masks ", masks) + #print("indices ", output_indices) + + # Create the returned dict. + data_dict.extend({ + self.key_masks: masks, + self.key_string_indices: output_indices + }) + diff --git a/ptp/components/problems/image_to_class/mnist.py b/ptp/components/problems/image_to_class/mnist.py index c1a19da..43c2d98 100644 --- a/ptp/components/problems/image_to_class/mnist.py +++ b/ptp/components/problems/image_to_class/mnist.py @@ -57,6 +57,9 @@ def __init__(self, name, config): # Call base class constructors. super(MNIST, self).__init__(name, MNIST, config) + # Channel returning targets as words. + self.key_labels = self.stream_keys["labels"] + # Get absolute path. data_folder = os.path.expanduser(self.config['data_folder']) @@ -100,6 +103,8 @@ def __init__(self, name, config): # Class names. labels = 'Zero One Two Three Four Five Six Seven Eight Nine'.split(' ') word_to_ix = {labels[i]: i for i in range(10)} + # Reverse mapping - for labels. + self.ix_to_word = {value: key for (key, value) in word_to_ix.items()} # Export to globals. self.globals["label_word_mappings"] = word_to_ix @@ -120,7 +125,8 @@ def output_data_definitions(self): return { self.key_indices: DataDefinition([-1, 1], [list, int], "Batch of sample indices [BATCH_SIZE] x [1]"), self.key_inputs: DataDefinition([-1, 1, self.height, self.width], [torch.Tensor], "Batch of images [BATCH_SIZE x IMAGE_DEPTH x IMAGE_HEIGHT x IMAGE_WIDTH]"), - self.key_targets: DataDefinition([-1], [torch.Tensor], "Batch of targets, each being a single index [BATCH_SIZE]") + self.key_targets: DataDefinition([-1], [torch.Tensor], "Batch of targets, each being a single index [BATCH_SIZE]"), + self.key_labels: DataDefinition([-1, 1], [list, str], "Batch of targets, each being a single word [BATCH_SIZE] x [STRING]") } @@ -143,4 +149,5 @@ def __getitem__(self, index): data_dict = self.create_data_dict(index) data_dict[self.key_inputs] = img data_dict[self.key_targets] = target + data_dict[self.key_labels] = self.ix_to_word[target.item()] return data_dict diff --git a/ptp/components/publishers/precision_recall_statistics.py b/ptp/components/publishers/precision_recall_statistics.py index d927963..e09f287 100644 --- a/ptp/components/publishers/precision_recall_statistics.py +++ b/ptp/components/publishers/precision_recall_statistics.py @@ -53,8 +53,6 @@ def __init__(self, name, config): self.key_recall = self.statistics_keys["recall"] self.key_f1score = self.statistics_keys["f1score"] - # Get the number of possible outputs. - self.num_classes = self.globals["num_classes"] # Get (or create) vocabulary. if self.config["use_word_mappings"]: @@ -63,9 +61,14 @@ def __init__(self, name, config): # Assume they are ordered, starting from 0. for key in self.globals["word_mappings"].keys(): self.labels.append(key) + # Set number of classes by looking at labels. + self.num_classes = len(self.labels) else: + # Get the number of possible outputs. + self.num_classes = self.globals["num_classes"] self.labels = list(range(self.num_classes)) + # Check display options. self.show_confusion_matrix = self.config["show_confusion_matrix"] self.show_class_scores = self.config["show_class_scores"] From 34556790dfd795cd6663229b9df8569c5070303c Mon Sep 17 00:00:00 2001 From: tkornut Date: Wed, 10 Apr 2019 17:00:43 -0700 Subject: [PATCH 04/11] Added (optional) masking to components calculating statistics: P/R and accuracy --- .../default/components/losses/nll_loss.yml | 4 +- .../publishers/accuracy_statistics.yml | 6 ++ .../precision_recall_statistics.yml | 6 ++ .../mnist/mnist_classification_vf_2lenet5.yml | 12 +++- ptp/components/losses/nll_loss.py | 10 ++-- .../publishers/accuracy_statistics.py | 60 +++++++++++++------ .../publishers/precision_recall_statistics.py | 30 +++++++--- 7 files changed, 92 insertions(+), 36 deletions(-) diff --git a/configs/default/components/losses/nll_loss.yml b/configs/default/components/losses/nll_loss.yml index 858d86b..b49aad0 100644 --- a/configs/default/components/losses/nll_loss.yml +++ b/configs/default/components/losses/nll_loss.yml @@ -12,7 +12,7 @@ num_targets_dims: 1 # Options: NLLLoss | CrossEntropyLoss (NOT OPERATIONAL YET!) # loss_function: NLLLoss -# When set to True, performs masking of selected predictions (LOADED) +# When set to True, performs masking of selected samples from batch (LOADED) use_masking: False streams: @@ -26,7 +26,7 @@ streams: # Stream containing batch of predictions (INPUT) predictions: predictions - # Stream containing masks used for masking of selected predictions (INPUT) + # Stream containing masks used for masking of selected samples from batch (INPUT) masks: masks # Stream containing loss (OUTPUT) diff --git a/configs/default/components/publishers/accuracy_statistics.yml b/configs/default/components/publishers/accuracy_statistics.yml index bc4a9f2..21d2260 100644 --- a/configs/default/components/publishers/accuracy_statistics.yml +++ b/configs/default/components/publishers/accuracy_statistics.yml @@ -4,6 +4,9 @@ # 1. CONFIGURATION PARAMETERS that will be LOADED by the component. #################################################################### +# When set to True, performs masking of selected samples from batch (LOADED) +use_masking: False + streams: #################################################################### # 2. Keymappings associated with INPUT and OUTPUT streams. @@ -15,6 +18,9 @@ streams: # Stream containing batch of predictions (INPUT) predictions: predictions + # Stream containing masks used for masking of selected samples from batch (INPUT) + masks: masks + globals: #################################################################### # 3. Keymappings of variables that will be RETRIEVED from GLOBALS. diff --git a/configs/default/components/publishers/precision_recall_statistics.yml b/configs/default/components/publishers/precision_recall_statistics.yml index 17f5b1e..5dc4732 100644 --- a/configs/default/components/publishers/precision_recall_statistics.yml +++ b/configs/default/components/publishers/precision_recall_statistics.yml @@ -13,6 +13,9 @@ show_class_scores: False # When set to true, will use the provided word mappings as labels (LOADED) use_word_mappings: False +# When set to True, performs masking of selected samples from batch (LOADED) +use_masking: False + streams: #################################################################### # 2. Keymappings associated with INPUT and OUTPUT streams. @@ -24,6 +27,9 @@ streams: # Stream containing batch of predictions (INPUT) predictions: predictions + # Stream containing masks used for masking of selected samples from batch (INPUT) + masks: masks + globals: #################################################################### # 3. Keymappings of variables that will be RETRIEVED from GLOBALS. diff --git a/configs/mnist/mnist_classification_vf_2lenet5.yml b/configs/mnist/mnist_classification_vf_2lenet5.yml index 0767142..d2d9511 100644 --- a/configs/mnist/mnist_classification_vf_2lenet5.yml +++ b/configs/mnist/mnist_classification_vf_2lenet5.yml @@ -5,10 +5,10 @@ default_configs: mnist/default_mnist.yml training: problem: resize_image: [32, 32] - #batch_size: 32 + batch_size: 64 optimizer: #name: Adam - lr: 0.001 + lr: 0.01 # Validation parameters - overwrite defaults: validation: @@ -70,9 +70,11 @@ pipeline: flow1_accuracy: type: AccuracyStatistics priority: 1.32 + use_masking: True streams: predictions: flow1_predictions targets: flow1_targets + masks: flow1_masks statistics: accuracy: flow1_accuracy @@ -82,12 +84,14 @@ pipeline: use_word_mappings: True show_class_scores: True show_confusion_matrix: True + use_masking: True globals: word_mappings: word_to_ix1 num_classes: num_classes1 streams: targets: flow1_targets predictions: flow1_predictions + masks: flow1_masks statistics: precision: flow1_precision recall: flow1_recall @@ -130,9 +134,11 @@ pipeline: flow2_accuracy: type: AccuracyStatistics priority: 2.32 + use_masking: True streams: targets: flow2_targets predictions: flow2_predictions + masks: flow2_masks statistics: accuracy: flow2_accuracy @@ -142,12 +148,14 @@ pipeline: use_word_mappings: True show_class_scores: True show_confusion_matrix: True + use_masking: True globals: word_mappings: word_to_ix2 num_classes: num_classes2 streams: targets: flow2_targets predictions: flow2_predictions + masks: flow2_masks statistics: precision: flow2_precision recall: flow2_recall diff --git a/ptp/components/losses/nll_loss.py b/ptp/components/losses/nll_loss.py index 07281f8..a6c18b6 100644 --- a/ptp/components/losses/nll_loss.py +++ b/ptp/components/losses/nll_loss.py @@ -36,17 +36,17 @@ def __init__(self, name, config): # Call constructors of parent classes. Loss.__init__(self, name, NLLLoss, config) + # Get stream key mappnigs. self.key_masks = self.stream_keys["masks"] - # Set loss. - self.loss_function = nn.NLLLoss() + # Get masking flag. + self.use_masking = self.config["use_masking"] # Get number of targets dimensions. self.num_targets_dims = self.config["num_targets_dims"] - # Get masking flag. - self.use_masking = self.config["use_masking"] - + # Set loss. + self.loss_function = nn.NLLLoss() def input_data_definitions(self): diff --git a/ptp/components/publishers/accuracy_statistics.py b/ptp/components/publishers/accuracy_statistics.py index bd52584..ae72a68 100644 --- a/ptp/components/publishers/accuracy_statistics.py +++ b/ptp/components/publishers/accuracy_statistics.py @@ -44,10 +44,15 @@ def __init__(self, name, config): # Call constructors of parent classes. Component.__init__(self, name, AccuracyStatistics, config) - # Set key mappings. + # Get stream key mappings. self.key_targets = self.stream_keys["targets"] self.key_predictions = self.stream_keys["predictions"] + self.key_masks = self.stream_keys["masks"] + # Get masking flag. + self.use_masking = self.config["use_masking"] + + # Get statistics key mappings. self.key_accuracy = self.statistics_keys["accuracy"] @@ -57,10 +62,14 @@ def input_data_definitions(self): :return: dictionary containing input data definitions (each of type :py:class:`ptp.utils.DataDefinition`). """ - return { + input_defs = { self.key_targets: DataDefinition([-1], [torch.Tensor], "Batch of targets, each being a single index [BATCH_SIZE]"), self.key_predictions: DataDefinition([-1, -1], [torch.Tensor], "Batch of predictions, represented as tensor with probability distribution over classes [BATCH_SIZE x NUM_CLASSES]") } + if self.use_masking: + input_defs[self.key_masks] = DataDefinition([-1], [torch.Tensor], "Batch of masks [BATCH_SIZE]") + return input_defs + def output_data_definitions(self): """ @@ -88,24 +97,37 @@ def calculate_accuracy(self, data_dict): :return: Accuracy. """ + # Get targets. + targets = data_dict[self.key_targets].data.cpu().numpy() + # Get indices of the max log-probability. - #pred = data_dict[self.key_predictions].max(1, keepdim=True)[1] - preds = data_dict[self.key_predictions].max(1)[1] - #print("Max: {} ".format(data_dict[self.key_predictions].max(1)[1])) - - # Calculate the number of correct predictinos. - correct = preds.eq(data_dict[self.key_targets]).sum().item() - #print ("TARGETS = ",data_dict[self.key_targets]) - #print ("PREDICTIONS = ",data_dict[self.key_predictions]) - #print ("MAX PREDICTIONS = ", preds) - #print("CORRECTS = ", correct) - - #print(" Target: {}\n Prediction: {}\n Correct: {} ".format(data_dict[self.key_targets], preds, preds.eq(data_dict[self.key_targets]))) - - # Normalize. - batch_size = data_dict[self.key_predictions].shape[0] - accuracy = correct / batch_size - #print("ACCURACY = ", accuracy) + preds = data_dict[self.key_predictions].max(1)[1].data.cpu().numpy() + + # Calculate the correct predictinos. + correct = np.equal(preds, targets) + + #print(" Target: {}\n Prediction: {}\n Correct: {}\n".format(targets, preds, correct)) + + if self.use_masking: + # Get masks from inputs. + masks = data_dict[self.key_masks].data.cpu().numpy() + correct = correct * masks + batch_size = masks.sum() + else: + batch_size = preds.shape[0] + + #print(" Mask: {}\n Masked Correct: {}\n".format(masks, correct)) + + # Simply sum the correct values. + num_correct = correct.sum() + + #print(" num_correct: {}\n batch_size: {}\n".format(num_correct, batch_size)) + + # Normalize by batch size. + if batch_size > 0: + accuracy = num_correct / batch_size + else: + accuracy = 0 return accuracy diff --git a/ptp/components/publishers/precision_recall_statistics.py b/ptp/components/publishers/precision_recall_statistics.py index e09f287..89e6990 100644 --- a/ptp/components/publishers/precision_recall_statistics.py +++ b/ptp/components/publishers/precision_recall_statistics.py @@ -44,11 +44,15 @@ def __init__(self, name, config): # Call constructors of parent classes. Component.__init__(self, name, PrecisionRecallStatistics, config) - # Set key mappings. + # Get stream key mappings. self.key_targets = self.stream_keys["targets"] self.key_predictions = self.stream_keys["predictions"] + self.key_masks = self.stream_keys["masks"] - # Get statistic key mappings. + # Get masking flag. + self.use_masking = self.config["use_masking"] + + # Get statistics key mappings. self.key_precision = self.statistics_keys["precision"] self.key_recall = self.statistics_keys["recall"] self.key_f1score = self.statistics_keys["f1score"] @@ -79,10 +83,13 @@ def input_data_definitions(self): :return: dictionary containing input data definitions (each of type :py:class:`ptp.utils.DataDefinition`). """ - return { + input_defs = { self.key_targets: DataDefinition([-1], [torch.Tensor], "Batch of targets, each being a single index [BATCH_SIZE]"), self.key_predictions: DataDefinition([-1, -1], [torch.Tensor], "Batch of predictions, represented as tensor with probability distribution over classes [BATCH_SIZE x NUM_CLASSES]") } + if self.use_masking: + input_defs[self.key_masks] = DataDefinition([-1], [torch.Tensor], "Batch of masks [BATCH_SIZE]") + return input_defs def output_data_definitions(self): """ @@ -144,12 +151,19 @@ def calculate_statistics(self, data_dict): preds = data_dict[self.key_predictions].max(1)[1].data.cpu().numpy() #print("Predictions :", preds) + if self.use_masking: + # Get masks from inputs. + masks = data_dict[self.key_masks].data.cpu().numpy() + else: + # Create vector full of ones. + masks = np.ones(targets.shape[0]) + # Create the confusion matrix, use SciKit learn order: # Column - predicted class # Row - target (actual) class confusion_matrix = np.zeros([self.num_classes, self.num_classes], dtype=int) - for (target, pred) in zip(targets, preds): - confusion_matrix[target][pred] += 1 + for i, (target, pred) in enumerate(zip(targets, preds)): + confusion_matrix[target][pred] += 1 * masks[i] # Calculate true positive (TP), eqv. with hit. tp = np.zeros([self.num_classes], dtype=int) @@ -213,9 +227,9 @@ def collect_statistics(self, stat_col, data_dict): # Calculate weighted averages. support_sum = sum(support) - precision_avg = sum([pi*si / support_sum for (pi,si) in zip(precision,support)]) - recall_avg = sum([ri*si / support_sum for (ri,si) in zip(recall,support)]) - f1score_avg = sum([fi*si / support_sum for (fi,si) in zip(f1score,support)]) + precision_avg = sum([pi*si / support_sum if si > 0 else 0.0 for (pi,si) in zip(precision,support)]) + recall_avg = sum([ri*si / support_sum if si > 0 else 0.0 for (ri,si) in zip(recall,support)]) + f1score_avg = sum([fi*si / support_sum if si > 0 else 0.0 for (fi,si) in zip(f1score,support)]) # Export to statistics. stat_col[self.key_precision] = precision_avg From aae5cf18db982cd0eccdb226f270a4d0ab7c92f4 Mon Sep 17 00:00:00 2001 From: tkornut Date: Wed, 10 Apr 2019 19:10:28 -0700 Subject: [PATCH 05/11] First version of predictions joining working --- .../masking/join_masked_predictions.yml | 44 +++++ .../components/masking/string_to_mask.yml | 6 +- .../mnist/mnist_classification_vf_2lenet5.yml | 22 ++- ptp/components/masking/__init__.py | 2 + .../masking/join_masked_predictions.py | 179 ++++++++++++++++++ ptp/components/masking/string_to_mask.py | 2 +- ptp/components/transforms/list_to_tensor.py | 2 +- 7 files changed, 249 insertions(+), 8 deletions(-) create mode 100644 configs/default/components/masking/join_masked_predictions.yml create mode 100644 ptp/components/masking/join_masked_predictions.py diff --git a/configs/default/components/masking/join_masked_predictions.yml b/configs/default/components/masking/join_masked_predictions.yml new file mode 100644 index 0000000..407db9c --- /dev/null +++ b/configs/default/components/masking/join_masked_predictions.yml @@ -0,0 +1,44 @@ +# This file defines the default values for the Join Masked Predictions component. + +#################################################################### +# 1. CONFIGURATION PARAMETERS that will be LOADED by the component. +#################################################################### + +# List of input stream names, each containing batch of predictions (LOADED) +input_prediction_streams: '' + +# List of input stream names, each containing batch of masks (LOADED) +input_mask_streams: '' + +# List of word mapping names - those will be loaded from globals (LOADED) +input_word_mappings: '' + +streams: + #################################################################### + # 2. Keymappings associated with INPUT and OUTPUT streams. + #################################################################### + + # Stream containing batch of output strings (OUTPUT) + output_strings: output_strings + + # Stream containing batch of output indices (OUTPUT) + # WARNING: As performed operations are not differentiable, + # those indices cannot be used for e.g. calculation of loss!! + output_indices: output_indices + +globals: + #################################################################### + # 3. Keymappings of variables that will be RETRIEVED from GLOBALS. + #################################################################### + + # Vocabulary used to produce output strings (RETRIEVED) + output_word_mappings: output_word_mappings + + #################################################################### + # 4. Keymappings associated with GLOBAL variables that will be SET. + #################################################################### + + #################################################################### + # 5. Keymappings associated with statistics that will be ADDED. + #################################################################### + diff --git a/configs/default/components/masking/string_to_mask.yml b/configs/default/components/masking/string_to_mask.yml index b42b4d9..bcc363d 100644 --- a/configs/default/components/masking/string_to_mask.yml +++ b/configs/default/components/masking/string_to_mask.yml @@ -27,13 +27,13 @@ globals: # 3. Keymappings of variables that will be RETRIEVED from GLOBALS. #################################################################### + # Vocabulary used to produce masks and indices (RETRIEVED) + word_mappings: word_mappings + #################################################################### # 4. Keymappings associated with GLOBAL variables that will be SET. #################################################################### - # Vocabulary used to produce masks and indices (GET) - word_mappings: word_mappings - #################################################################### # 5. Keymappings associated with statistics that will be ADDED. #################################################################### diff --git a/configs/mnist/mnist_classification_vf_2lenet5.yml b/configs/mnist/mnist_classification_vf_2lenet5.yml index d2d9511..e23c391 100644 --- a/configs/mnist/mnist_classification_vf_2lenet5.yml +++ b/configs/mnist/mnist_classification_vf_2lenet5.yml @@ -31,9 +31,9 @@ pipeline: type: GlobalVariablePublisher priority: 0 keys: [num_classes1, num_classes2, word_to_ix1, word_to_ix2] - values: [4, 6, {"Zero": 0, "One": 1, "Two": 2, "Three": 3}, {"Four": 0, "Five": 1, "Six": 2, "Seven": 3, "Eight": 4, "Nine": 5}] + values: [3, 7, {"Zero": 0, "One": 1, "Two": 2}, {"Three": 0, "Four": 1, "Five": 2, "Six": 3, "Seven": 4, "Eight": 5, "Nine": 6}] - ######### Flow 1 ################# + ################# Flow 1 ################# flow1_string_to_mask: type: StringToMask priority: 1.1 @@ -97,7 +97,7 @@ pipeline: recall: flow1_recall f1score: flow1_f1score - ######### Flow 2 ################# + ################# Flow 2 ################# flow2_string_to_mask: type: StringToMask priority: 2.1 @@ -161,4 +161,20 @@ pipeline: recall: flow2_recall f1score: flow2_f1score + ################# JOIN ################# + joined_predictions: + type: JoinMaskedPredictions + priority: 3.1 + # Names of used input streams. + input_prediction_streams: [flow1_predictions, flow2_predictions] + input_mask_streams: [flow1_masks, flow2_masks] + input_word_mappings: [word_to_ix1, word_to_ix2] + globals: + output_word_mappings: label_word_mappings # from MNIST problem. + streams: + output_strings: merged_predictions + output_indices: merged_indices + + + #: pipeline diff --git a/ptp/components/masking/__init__.py b/ptp/components/masking/__init__.py index a6cf1b2..8f00421 100644 --- a/ptp/components/masking/__init__.py +++ b/ptp/components/masking/__init__.py @@ -1,5 +1,7 @@ +from .join_masked_predictions import JoinMaskedPredictions from .string_to_mask import StringToMask __all__ = [ + 'JoinMaskedPredictions', 'StringToMask', ] diff --git a/ptp/components/masking/join_masked_predictions.py b/ptp/components/masking/join_masked_predictions.py new file mode 100644 index 0000000..59e8702 --- /dev/null +++ b/ptp/components/masking/join_masked_predictions.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) tkornuta, IBM Corporation 2019 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__author__ = "Tomasz Kornuta" + +import torch +import numpy as np + +from ptp.components.component import Component +from ptp.data_types.data_definition import DataDefinition + + +class JoinMaskedPredictions(Component): + """ + Class responsible joining several prediction streams using the associated masks. + Additionally, it returns the associated string indices. + + .. warning: + As performed operations are not differentiable, the returned 'output_indices' cannot be used for e.g. calculation of loss!! + + """ + + def __init__(self, name, config): + """ + Initializes the object. Loads keys, word mappings and vocabularies. + + :param name: Name of the component read from the configuration file + :type name: str + + :param config: Dictionary of parameters (read from the configuration ``.yaml`` file). + :type config: :py:class:`ptp.configuration.ConfigInterface` + + """ + # Call constructors of parent classes. + Component.__init__(self, name, JoinMaskedPredictions, config) + + # Get input key mappings. + # Load list of prediction streams names (keys). + self.input_prediction_stream_keys = self.config["input_prediction_streams"] + if type(self.input_prediction_stream_keys) == str: + self.input_prediction_stream_keys = self.input_prediction_stream_keys.replace(" ", "").split(",") + #assert(self.input_prediction_stream_keys != ""), "ooo" + + # Load list of mask streams names (keys). + self.input_mask_stream_keys = self.config["input_mask_streams"] + if type(self.input_mask_stream_keys) == str: + self.input_mask_stream_keys = self.input_mask_stream_keys.replace(" ", "").split(",") + + # Load list of word mappings names (keys). + input_word_mappings_keys = self.config["input_word_mappings"] + if type(input_word_mappings_keys) == str: + input_word_mappings_keys = input_word_mappings_keys.replace(" ", "").split(",") + + # Retrieve input word mappings from globals. + self.input_ix_to_word = [] + for wmk in input_word_mappings_keys: + # Get word mappings. + word_to_ix = self.globals[wmk] + # Create inverse transformation. + ix_to_word = {value: key for (key, value) in word_to_ix.items()} + self.input_ix_to_word.append(ix_to_word) + + + # Get output key mappings. + self.key_output_indices = self.stream_keys["output_indices"] + self.key_output_strings = self.stream_keys["output_strings"] + + # Retrieve output word mappings from globals. + self.output_word_to_ix = self.globals["output_word_mappings"] + + + def input_data_definitions(self): + """ + Function returns a dictionary with definitions of input data that are required by the component. + + :return: dictionary containing input data definitions (each of type :py:class:`ptp.data_types.DataDefinition`). + """ + input_defs = {} + # Add input prediction streams. + for i, ipsk in enumerate(self.input_prediction_stream_keys): + # Use input prediction stream key along with the length of the associated word mappings (i.e. size of the vocabulary = NUM_CLASSES) + input_defs[ipsk] = DataDefinition([-1, len(self.input_ix_to_word[i])], [torch.Tensor], "Batch of predictions, represented as tensor with probability distribution over classes [BATCH_SIZE x NUM_CLASSES]") + # Add mask streams. + for imsk in self.input_mask_stream_keys: + # Every mask has the same definition, but different stream key. + input_defs[imsk] = DataDefinition([-1], [torch.Tensor], "Batch of masks [BATCH_SIZE]") + + return input_defs + + def output_data_definitions(self): + """ + Function returns a empty dictionary with definitions of output data produced the component. + + :return: Empty dictionary. + """ + return { + self.key_output_indices: DataDefinition([-1], [torch.Tensor], "Batch of merged (output) indices [BATCH_SIZE]"), + self.key_output_strings: DataDefinition([-1], [torch.Tensor], "Batch of merged strings, corresponging to indices when using the provided word mappings [BATCH_SIZE]") + } + + + def __call__(self, data_dict): + """ + Encodes "inputs" in the format of a single tensor. + Stores reshaped tensor in "outputs" field of in data_dict. + + :param data_dict: :py:class:`ptp.utils.DataDict` object containing (among others): + + - "inputs": expected input field containing tensor [BATCH_SIZE x ...] + + - "outputs": added output field containing tensor [BATCH_SIZE x ...] + """ + # Get inputs masks + masks = [] + for imsk in self.input_mask_stream_keys: + masks.append(data_dict[imsk].data.cpu().numpy()) + + # Sum all masks and make sure that they are complementary. + masks_sum = np.sum(masks, axis=0) + batch_size = masks_sum.shape[0] + sum_ones = sum(filter(lambda x: x == 1, masks_sum)) + if sum_ones != batch_size: + self.logger.error("Masks received from the {} streams are not complementary!".format(self.input_mask_stream_keys)) + exit(-1) + + # Create mapping indicating from which input take given sample. + weights = np.array(range(len(masks))) + masks = np.array(masks).transpose() + mapping = np.dot(masks, weights) + #print("Mapping = \n",mapping) + + # Get indices of the max log-probabilities. + preds = [] + for ipsk in self.input_prediction_stream_keys: + preds.append(data_dict[ipsk].max(1)[1].data.cpu().numpy()) + + # "Translate". + output_answers = [] + output_indices = [] + #output_predictions_lst = [] + # Iterate through samples. + for sample in range(batch_size): + # Get the right dictionary. + ix_to_word = self.input_ix_to_word[mapping[sample]] + #print(ix_to_word) + # Get the right sample from the right prediction stream + #sample_prediction = data_dict[self.input_prediction_stream_keys[mapping[sample]]][sample] + # Get the index. + index = preds[mapping[sample]][sample] + # Get the right word. + word = ix_to_word[index] + output_answers.append(word) + # Get original index using output dictionary. + output_indices.append(self.output_word_to_ix[word]) + + #print(output_answers) + #targets = data_dict["targets"].data.cpu().numpy() + #print("targets = \n",targets.tolist()) + #print("joined answers = \n",output_indices) + + # Extend the dict by returned output streams. + data_dict.extend({ + self.key_output_indices: output_indices, + self.key_output_strings: output_answers + }) + diff --git a/ptp/components/masking/string_to_mask.py b/ptp/components/masking/string_to_mask.py index f745570..4723ff6 100644 --- a/ptp/components/masking/string_to_mask.py +++ b/ptp/components/masking/string_to_mask.py @@ -110,7 +110,7 @@ def __call__(self, data_dict): #print("masks ", masks) #print("indices ", output_indices) - # Create the returned dict. + # Extend the dict by returned output streams. data_dict.extend({ self.key_masks: masks, self.key_string_indices: output_indices diff --git a/ptp/components/transforms/list_to_tensor.py b/ptp/components/transforms/list_to_tensor.py index 317bb3d..fbf3f21 100644 --- a/ptp/components/transforms/list_to_tensor.py +++ b/ptp/components/transforms/list_to_tensor.py @@ -32,7 +32,7 @@ def __init__(self, name, config): """ Initializes object. - :param name: Loss name. + :param name: Name of the component loaded from the configuration file. :type name: str :param config: Dictionary of parameters (read from the configuration ``.yaml`` file). From 51e209f827941b464ad121439eacab734b98f780 Mon Sep 17 00:00:00 2001 From: tkornut Date: Wed, 10 Apr 2019 19:49:37 -0700 Subject: [PATCH 06/11] Added joined precision recall --- .../precision_recall_statistics.yml | 5 +++ .../mnist/mnist_classification_vf_2lenet5.yml | 18 +++++++++++ .../masking/join_masked_predictions.py | 31 +++++++++++-------- .../publishers/precision_recall_statistics.py | 19 +++++++++--- 4 files changed, 56 insertions(+), 17 deletions(-) diff --git a/configs/default/components/publishers/precision_recall_statistics.yml b/configs/default/components/publishers/precision_recall_statistics.yml index 5dc4732..b6b1725 100644 --- a/configs/default/components/publishers/precision_recall_statistics.yml +++ b/configs/default/components/publishers/precision_recall_statistics.yml @@ -4,6 +4,11 @@ # 1. CONFIGURATION PARAMETERS that will be LOADED by the component. #################################################################### +# Flag indicatingwhether inputs are prediction distributions or indices (LOADED) +# Options: True (expects distribution for each preditions) +# False (expects indices) +use_prediction_distributions: True + # Flag indicating whether confusion matrix will be shown (LOADED) show_confusion_matrix: False diff --git a/configs/mnist/mnist_classification_vf_2lenet5.yml b/configs/mnist/mnist_classification_vf_2lenet5.yml index e23c391..dec1bca 100644 --- a/configs/mnist/mnist_classification_vf_2lenet5.yml +++ b/configs/mnist/mnist_classification_vf_2lenet5.yml @@ -175,6 +175,24 @@ pipeline: output_strings: merged_predictions output_indices: merged_indices + joined_precision_recall: + type: PrecisionRecallStatistics + priority: 3.2 + # Use prediction indices instead of distributions. + use_prediction_distributions: False + use_word_mappings: True + show_class_scores: True + show_confusion_matrix: True + globals: + word_mappings: label_word_mappings # straight from MNIST + #num_classes: num_classes + streams: + targets: targets # straight from MNIST + predictions: merged_indices + statistics: + precision: joined_precision + recall: joined_recall + f1score: joined_f1score #: pipeline diff --git a/ptp/components/masking/join_masked_predictions.py b/ptp/components/masking/join_masked_predictions.py index 59e8702..fa0eb3d 100644 --- a/ptp/components/masking/join_masked_predictions.py +++ b/ptp/components/masking/join_masked_predictions.py @@ -136,44 +136,49 @@ def __call__(self, data_dict): self.logger.error("Masks received from the {} streams are not complementary!".format(self.input_mask_stream_keys)) exit(-1) - # Create mapping indicating from which input take given sample. + # Create mapping indicating from which input prediction/mask/dictionary stream we will take data associated with given "sample". weights = np.array(range(len(masks))) masks = np.array(masks).transpose() mapping = np.dot(masks, weights) #print("Mapping = \n",mapping) - # Get indices of the max log-probabilities. - preds = [] - for ipsk in self.input_prediction_stream_keys: - preds.append(data_dict[ipsk].max(1)[1].data.cpu().numpy()) - # "Translate". output_answers = [] output_indices = [] - #output_predictions_lst = [] + output_predictions_lst = [] # Iterate through samples. for sample in range(batch_size): # Get the right dictionary. ix_to_word = self.input_ix_to_word[mapping[sample]] #print(ix_to_word) - # Get the right sample from the right prediction stream - #sample_prediction = data_dict[self.input_prediction_stream_keys[mapping[sample]]][sample] - # Get the index. - index = preds[mapping[sample]][sample] + + # Get the right sample from the right prediction stream. + sample_prediction = data_dict[self.input_prediction_stream_keys[mapping[sample]]][sample] + #print(sample_prediction) + output_predictions_lst.append(sample_prediction) + + # Get the index of max log-probabilities. + index = sample_prediction.max(0)[1].data.cpu().item() + #print(index) + # Get the right word. word = ix_to_word[index] output_answers.append(word) + # Get original index using output dictionary. output_indices.append(self.output_word_to_ix[word]) - #print(output_answers) + #print(output_predictions_lst) #targets = data_dict["targets"].data.cpu().numpy() #print("targets = \n",targets.tolist()) #print("joined answers = \n",output_indices) + # Change to tensor. + output_indices_tensor = torch.tensor(output_indices) + # Extend the dict by returned output streams. data_dict.extend({ - self.key_output_indices: output_indices, + self.key_output_indices: output_indices_tensor, self.key_output_strings: output_answers }) diff --git a/ptp/components/publishers/precision_recall_statistics.py b/ptp/components/publishers/precision_recall_statistics.py index 89e6990..6a6430f 100644 --- a/ptp/components/publishers/precision_recall_statistics.py +++ b/ptp/components/publishers/precision_recall_statistics.py @@ -49,6 +49,9 @@ def __init__(self, name, config): self.key_predictions = self.stream_keys["predictions"] self.key_masks = self.stream_keys["masks"] + # Get prediction distributions/indices flag. + self.use_prediction_distributions = self.config["use_prediction_distributions"] + # Get masking flag. self.use_masking = self.config["use_masking"] @@ -84,9 +87,14 @@ def input_data_definitions(self): :return: dictionary containing input data definitions (each of type :py:class:`ptp.utils.DataDefinition`). """ input_defs = { - self.key_targets: DataDefinition([-1], [torch.Tensor], "Batch of targets, each being a single index [BATCH_SIZE]"), - self.key_predictions: DataDefinition([-1, -1], [torch.Tensor], "Batch of predictions, represented as tensor with probability distribution over classes [BATCH_SIZE x NUM_CLASSES]") + self.key_targets: DataDefinition([-1], [torch.Tensor], "Batch of targets, each being a single index [BATCH_SIZE]") } + + if self.use_prediction_distributions: + input_defs[self.key_predictions] = DataDefinition([-1, -1], [torch.Tensor], "Batch of predictions, represented as tensor with probability distribution over classes [BATCH_SIZE x NUM_CLASSES]") + else: + input_defs[self.key_predictions] = DataDefinition([-1], [torch.Tensor], "Batch of predictions, represented as tensor with indices of predicted answers [BATCH_SIZE]") + if self.use_masking: input_defs[self.key_masks] = DataDefinition([-1], [torch.Tensor], "Batch of masks [BATCH_SIZE]") return input_defs @@ -147,8 +155,11 @@ def calculate_statistics(self, data_dict): targets = data_dict[self.key_targets].data.cpu().numpy() #print("Targets :", targets) - # Get indices of the max log-probability. - preds = data_dict[self.key_predictions].max(1)[1].data.cpu().numpy() + if self.use_prediction_distributions: + # Get indices of the max log-probability. + preds = data_dict[self.key_predictions].max(1)[1].data.cpu().numpy() + else: + preds = data_dict[self.key_predictions].data.cpu().numpy() #print("Predictions :", preds) if self.use_masking: From 7acca8025055b593573b38566c681891ebac911f Mon Sep 17 00:00:00 2001 From: tkornut Date: Wed, 10 Apr 2019 19:57:06 -0700 Subject: [PATCH 07/11] Added flag for using prediction distribution/indices to accuracy --- .../publishers/accuracy_statistics.yml | 5 +++++ .../precision_recall_statistics.yml | 4 ++-- .../mnist/mnist_classification_vf_2lenet5.yml | 14 +++++++++++++- .../publishers/accuracy_statistics.py | 19 +++++++++++++++---- 4 files changed, 35 insertions(+), 7 deletions(-) diff --git a/configs/default/components/publishers/accuracy_statistics.yml b/configs/default/components/publishers/accuracy_statistics.yml index 21d2260..0ef771c 100644 --- a/configs/default/components/publishers/accuracy_statistics.yml +++ b/configs/default/components/publishers/accuracy_statistics.yml @@ -4,6 +4,11 @@ # 1. CONFIGURATION PARAMETERS that will be LOADED by the component. #################################################################### +# Flag indicating whether prediction are represented as distributions or indices (LOADED) +# Options: True (expects distribution for each preditions) +# False (expects indices (max args)) +use_prediction_distributions: True + # When set to True, performs masking of selected samples from batch (LOADED) use_masking: False diff --git a/configs/default/components/publishers/precision_recall_statistics.yml b/configs/default/components/publishers/precision_recall_statistics.yml index b6b1725..4d1c916 100644 --- a/configs/default/components/publishers/precision_recall_statistics.yml +++ b/configs/default/components/publishers/precision_recall_statistics.yml @@ -4,9 +4,9 @@ # 1. CONFIGURATION PARAMETERS that will be LOADED by the component. #################################################################### -# Flag indicatingwhether inputs are prediction distributions or indices (LOADED) +# Flag indicating whether prediction are represented as distributions or indices (LOADED) # Options: True (expects distribution for each preditions) -# False (expects indices) +# False (expects indices (max args)) use_prediction_distributions: True # Flag indicating whether confusion matrix will be shown (LOADED) diff --git a/configs/mnist/mnist_classification_vf_2lenet5.yml b/configs/mnist/mnist_classification_vf_2lenet5.yml index dec1bca..8de192b 100644 --- a/configs/mnist/mnist_classification_vf_2lenet5.yml +++ b/configs/mnist/mnist_classification_vf_2lenet5.yml @@ -175,9 +175,21 @@ pipeline: output_strings: merged_predictions output_indices: merged_indices + # Statistics. + joined_accuracy: + type: AccuracyStatistics + priority: 3.21 + # Use prediction indices instead of distributions. + use_prediction_distributions: False + streams: + targets: targets + predictions: merged_indices + statistics: + accuracy: joined_accuracy + joined_precision_recall: type: PrecisionRecallStatistics - priority: 3.2 + priority: 3.22 # Use prediction indices instead of distributions. use_prediction_distributions: False use_word_mappings: True diff --git a/ptp/components/publishers/accuracy_statistics.py b/ptp/components/publishers/accuracy_statistics.py index ae72a68..b542b1e 100644 --- a/ptp/components/publishers/accuracy_statistics.py +++ b/ptp/components/publishers/accuracy_statistics.py @@ -49,6 +49,9 @@ def __init__(self, name, config): self.key_predictions = self.stream_keys["predictions"] self.key_masks = self.stream_keys["masks"] + # Get prediction distributions/indices flag. + self.use_prediction_distributions = self.config["use_prediction_distributions"] + # Get masking flag. self.use_masking = self.config["use_masking"] @@ -63,9 +66,14 @@ def input_data_definitions(self): :return: dictionary containing input data definitions (each of type :py:class:`ptp.utils.DataDefinition`). """ input_defs = { - self.key_targets: DataDefinition([-1], [torch.Tensor], "Batch of targets, each being a single index [BATCH_SIZE]"), - self.key_predictions: DataDefinition([-1, -1], [torch.Tensor], "Batch of predictions, represented as tensor with probability distribution over classes [BATCH_SIZE x NUM_CLASSES]") + self.key_targets: DataDefinition([-1], [torch.Tensor], "Batch of targets, each being a single index [BATCH_SIZE]") } + + if self.use_prediction_distributions: + input_defs[self.key_predictions] = DataDefinition([-1, -1], [torch.Tensor], "Batch of predictions, represented as tensor with probability distribution over classes [BATCH_SIZE x NUM_CLASSES]") + else: + input_defs[self.key_predictions] = DataDefinition([-1], [torch.Tensor], "Batch of predictions, represented as tensor with indices of predicted answers [BATCH_SIZE]") + if self.use_masking: input_defs[self.key_masks] = DataDefinition([-1], [torch.Tensor], "Batch of masks [BATCH_SIZE]") return input_defs @@ -100,8 +108,11 @@ def calculate_accuracy(self, data_dict): # Get targets. targets = data_dict[self.key_targets].data.cpu().numpy() - # Get indices of the max log-probability. - preds = data_dict[self.key_predictions].max(1)[1].data.cpu().numpy() + if self.use_prediction_distributions: + # Get indices of the max log-probability. + preds = data_dict[self.key_predictions].max(1)[1].data.cpu().numpy() + else: + preds = data_dict[self.key_predictions].data.cpu().numpy() # Calculate the correct predictinos. correct = np.equal(preds, targets) From 032949fa19f000cce0120f50490a5b9706b0867c Mon Sep 17 00:00:00 2001 From: tkornut Date: Wed, 10 Apr 2019 20:19:47 -0700 Subject: [PATCH 08/11] Added retain_graph=True for many backward passes - i.e. multiloss training with shared submodels/computational subgraphs --- ptp/application/pipeline_manager.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/ptp/application/pipeline_manager.py b/ptp/application/pipeline_manager.py index a91b766..06c9c80 100644 --- a/ptp/application/pipeline_manager.py +++ b/ptp/application/pipeline_manager.py @@ -539,9 +539,20 @@ def backward(self, data_dict): """ if (len(self.losses) == 0): raise ConfigurationError("Cannot train using backpropagation as there are no 'Loss' components") + # Calculate total number of backward passes. + total_passes = sum([len(loss.loss_keys()) for loss in self.losses]) + + # All but the last call to backward should have the retain_graph=True option. + pass_counter = 0 for loss in self.losses: for key in loss.loss_keys(): - data_dict[key].backward() + pass_counter += 1 + if pass_counter == total_passes: + # Last pass. + data_dict[key].backward() + else: + # "Other pass." + data_dict[key].backward(retain_graph=True) def get_loss(self, data_dict): From ca65824c18f5a1e04d5e8faa3132c36c02dcb957 Mon Sep 17 00:00:00 2001 From: tkornut Date: Wed, 10 Apr 2019 20:53:18 -0700 Subject: [PATCH 09/11] vf with shared convnet and two ff predictors, joined at the end --- .../mnist_classification_convnet_softmax.yml | 2 +- .../mnist/mnist_classification_vf_2lenet5.yml | 2 - ...ification_vf_shared_convnet_2softmaxes.yml | 235 ++++++++++++++++++ ptp/application/pipeline_manager.py | 5 + 4 files changed, 241 insertions(+), 3 deletions(-) create mode 100644 configs/mnist/mnist_classification_vf_shared_convnet_2softmaxes.yml diff --git a/configs/mnist/mnist_classification_convnet_softmax.yml b/configs/mnist/mnist_classification_convnet_softmax.yml index ac18985..67decc4 100644 --- a/configs/mnist/mnist_classification_convnet_softmax.yml +++ b/configs/mnist/mnist_classification_convnet_softmax.yml @@ -23,7 +23,7 @@ pipeline: # Image classifier. classifier: - type: SoftmaxClassifier + type: FeedForwardNetwork priority: 3 streams: inputs: reshaped_maps diff --git a/configs/mnist/mnist_classification_vf_2lenet5.yml b/configs/mnist/mnist_classification_vf_2lenet5.yml index 8de192b..fac486c 100644 --- a/configs/mnist/mnist_classification_vf_2lenet5.yml +++ b/configs/mnist/mnist_classification_vf_2lenet5.yml @@ -50,7 +50,6 @@ pipeline: priority: 1.2 globals: prediction_size: num_classes1 - word_mappings: word_to_ix1 streams: inputs: inputs predictions: flow1_predictions @@ -114,7 +113,6 @@ pipeline: priority: 2.2 globals: prediction_size: num_classes2 - word_mappings: word_to_ix2 streams: inputs: inputs predictions: flow2_predictions diff --git a/configs/mnist/mnist_classification_vf_shared_convnet_2softmaxes.yml b/configs/mnist/mnist_classification_vf_shared_convnet_2softmaxes.yml new file mode 100644 index 0000000..2c6ea0e --- /dev/null +++ b/configs/mnist/mnist_classification_vf_shared_convnet_2softmaxes.yml @@ -0,0 +1,235 @@ +# Load config defining MNIST problems for training, validation and testing. +default_configs: mnist/default_mnist.yml + +# Training parameters - overwrite defaults: +training: + problem: + #resize_image: [32, 32] + batch_size: 64 + optimizer: + #name: Adam + lr: 0.001 + terminal_conditions: + loss_stop: 0.08 + +# Validation parameters - overwrite defaults: +validation: + partial_validation_interval: 10 +# problem: +# resize_image: [32, 32] + +# Testing parameters - overwrite defaults: +#testing: +# problem: +# resize_image: [32, 32] + +# Definition of the pipeline. +pipeline: + name: mnist_variational_flow_shared_convnet_2softmaxes + # Disable components for "default" flow. + disable: nllloss, accuracy, precision_recall + + ################# SHARED ################# + + # Add global variables. + global_publisher: + type: GlobalVariablePublisher + priority: 0.1 + keys: [num_classes1, num_classes2, word_to_ix1, word_to_ix2] + values: [3, 7, {"Three": 0, "One": 1, "Five": 2}, {"Four": 0, "Two": 1, "Zero": 2, "Six": 3, "Seven": 4, "Eight": 5, "Nine": 6}] + #values: [3, 7, {"Zero": 0, "One": 1, "Two": 2}, {"Three": 0, "Four": 1, "Five": 2, "Six": 3, "Seven": 4, "Eight": 5, "Nine": 6}] + + # Shared model - encoder. + image_encoder: + type: ConvNetEncoder + priority: 0.2 + + # Reshape inputs + reshaper: + type: ReshapeTensor + input_dims: [-1, 16, 1, 1] + output_dims: [-1, 16] + priority: 0.3 + streams: + inputs: feature_maps + outputs: reshaped_maps + globals: + output_size: reshaped_maps_size + + ################# Flow 1 ################# + flow1_string_to_mask: + type: StringToMask + priority: 1.1 + globals: + word_mappings: word_to_ix1 + streams: + strings: labels + string_indices: flow1_targets + masks: flow1_masks + + # Classifier. + flow1_classifier: + type: FeedForwardNetwork + priority: 1.2 + streams: + globals: + input_size: reshaped_maps_size + prediction_size: num_classes1 + streams: + inputs: reshaped_maps + predictions: flow1_predictions + + # Masked loss. + flow1_nllloss: + type: NLLLoss + priority: 1.31 + use_masking: True + streams: + targets: flow1_targets + predictions: flow1_predictions + masks: flow1_masks + loss: flow1_loss + + # Statistics. + flow1_accuracy: + type: AccuracyStatistics + priority: 1.32 + use_masking: True + streams: + predictions: flow1_predictions + targets: flow1_targets + masks: flow1_masks + statistics: + accuracy: flow1_accuracy + + flow1_precision_recall: + type: PrecisionRecallStatistics + priority: 1.33 + use_word_mappings: True + show_class_scores: True + show_confusion_matrix: True + use_masking: True + globals: + word_mappings: word_to_ix1 + num_classes: num_classes1 + streams: + targets: flow1_targets + predictions: flow1_predictions + masks: flow1_masks + statistics: + precision: flow1_precision + recall: flow1_recall + f1score: flow1_f1score + + ################# Flow 2 ################# + flow2_string_to_mask: + type: StringToMask + priority: 2.1 + globals: + word_mappings: word_to_ix2 + streams: + strings: labels + string_indices: flow2_targets + masks: flow2_masks + + # Classifier. + flow2_classifier: + type: FeedForwardNetwork + priority: 2.2 + streams: + globals: + input_size: reshaped_maps_size + prediction_size: num_classes2 + streams: + inputs: reshaped_maps + predictions: flow2_predictions + + # Masked loss. + flow2_nllloss: + type: NLLLoss + priority: 2.31 + use_masking: True + streams: + targets: flow2_targets + predictions: flow2_predictions + masks: flow2_masks + loss: flow2_loss + + # Statistics. + flow2_accuracy: + type: AccuracyStatistics + priority: 2.32 + use_masking: True + streams: + targets: flow2_targets + predictions: flow2_predictions + masks: flow2_masks + statistics: + accuracy: flow2_accuracy + + flow2_precision_recall: + type: PrecisionRecallStatistics + priority: 2.33 + use_word_mappings: True + show_class_scores: True + show_confusion_matrix: True + use_masking: True + globals: + word_mappings: word_to_ix2 + num_classes: num_classes2 + streams: + targets: flow2_targets + predictions: flow2_predictions + masks: flow2_masks + statistics: + precision: flow2_precision + recall: flow2_recall + f1score: flow2_f1score + + ################# JOIN ################# + joined_predictions: + type: JoinMaskedPredictions + priority: 3.1 + # Names of used input streams. + input_prediction_streams: [flow1_predictions, flow2_predictions] + input_mask_streams: [flow1_masks, flow2_masks] + input_word_mappings: [word_to_ix1, word_to_ix2] + globals: + output_word_mappings: label_word_mappings # from MNIST problem. + streams: + output_strings: merged_predictions + output_indices: merged_indices + + # Statistics. + joined_accuracy: + type: AccuracyStatistics + priority: 3.21 + # Use prediction indices instead of distributions. + use_prediction_distributions: False + streams: + targets: targets + predictions: merged_indices + statistics: + accuracy: joined_accuracy + + joined_precision_recall: + type: PrecisionRecallStatistics + priority: 3.22 + # Use prediction indices instead of distributions. + use_prediction_distributions: False + use_word_mappings: True + show_class_scores: True + show_confusion_matrix: True + globals: + word_mappings: label_word_mappings # straight from MNIST + #num_classes: num_classes + streams: + targets: targets # straight from MNIST + predictions: merged_indices + statistics: + precision: joined_precision + recall: joined_recall + f1score: joined_f1score + + +#: pipeline diff --git a/ptp/application/pipeline_manager.py b/ptp/application/pipeline_manager.py index 06c9c80..c3ccd88 100644 --- a/ptp/application/pipeline_manager.py +++ b/ptp/application/pipeline_manager.py @@ -566,9 +566,14 @@ def get_loss(self, data_dict): if (len(self.losses) == 0): raise ConfigurationError("Cannot train using backpropagation as there are no 'Loss' components") loss_sum = 0 + num_losses = 0 for loss in self.losses: for key in loss.loss_keys(): loss_sum += data_dict[key].cpu().item() + num_losses +=1 + # Display additional information for multi-loss pipelines. + if num_losses > 1: + self.logger.info("Total loss: {}".format(loss_sum)) return loss_sum From 6ae3a74b393cf50b69a29665ee72f3053cf4dd94 Mon Sep 17 00:00:00 2001 From: tkornut Date: Thu, 11 Apr 2019 12:58:11 -0700 Subject: [PATCH 10/11] Fixes of P/R component, cleanups of VF LeNet5 pipelines, removed VFLenet5 model --- .../models/variational_flow_lenet5.yml | 46 ----- configs/mnist/default_mnist.yml | 4 +- .../mnist/mnist_classification_vf_2lenet5.yml | 4 - .../mnist/mnist_classification_vf_lenet5.yml | 81 ++++++-- ...mnist_classification_vf_lenet5_2losses.yml | 133 +++++++++++++ ptp/components/models/__init__.py | 2 - .../models/variational_flow_lenet5.py | 185 ------------------ .../publishers/precision_recall_statistics.py | 31 +-- 8 files changed, 215 insertions(+), 271 deletions(-) delete mode 100644 configs/default/components/models/variational_flow_lenet5.yml create mode 100644 configs/mnist/mnist_classification_vf_lenet5_2losses.yml delete mode 100644 ptp/components/models/variational_flow_lenet5.py diff --git a/configs/default/components/models/variational_flow_lenet5.yml b/configs/default/components/models/variational_flow_lenet5.yml deleted file mode 100644 index 11b31a5..0000000 --- a/configs/default/components/models/variational_flow_lenet5.yml +++ /dev/null @@ -1,46 +0,0 @@ -# This file defines the default values for the Variational Flow LeNet5 model. - -#################################################################### -# 1. CONFIGURATION PARAMETERS that will be LOADED by the component. -#################################################################### - -streams: - #################################################################### - # 2. Keymappings associated with INPUT and OUTPUT streams. - #################################################################### - - # Stream containing batch of images (INPUT) - inputs: inputs - - # Stream containing batch of targets (used for masks) (INPUT) - targets: targets - - # Streams containing predictions (OUTPUT) - flow1_predictions: flow1_predictions - flow2_predictions: flow2_predictions - - # Streams containing predictions (OUTPUT) - flow1_masks: flow1_masks - flow2_masks: flow2_masks - -globals: - #################################################################### - # 3. Keymappings of variables that will be RETRIEVED from GLOBALS. - #################################################################### - - # Size of the prediction (RETRIEVED) - flow1_prediction_size: flow1_prediction_size - flow2_prediction_size: flow2_prediction_size - - # Word mappings used for filtering. - flow1_word_mappings: flow1_word_mappings - flow2_word_mappings: flow2_word_mappings - - #################################################################### - # 4. Keymappings associated with GLOBAL variables that will be SET. - #################################################################### - - #################################################################### - # 5. Keymappings associated with statistics that will be ADDED. - #################################################################### - diff --git a/configs/mnist/default_mnist.yml b/configs/mnist/default_mnist.yml index c0fef34..095f69e 100644 --- a/configs/mnist/default_mnist.yml +++ b/configs/mnist/default_mnist.yml @@ -12,10 +12,10 @@ training: # optimizer parameters: optimizer: name: Adam - lr: 0.01 + lr: 0.0001 # settings parameters terminal_conditions: - loss_stop: 1.0e-2 + loss_stop: 0.05 episode_limit: 10000 epoch_limit: 10 diff --git a/configs/mnist/mnist_classification_vf_2lenet5.yml b/configs/mnist/mnist_classification_vf_2lenet5.yml index fac486c..6ddc4e3 100644 --- a/configs/mnist/mnist_classification_vf_2lenet5.yml +++ b/configs/mnist/mnist_classification_vf_2lenet5.yml @@ -5,10 +5,6 @@ default_configs: mnist/default_mnist.yml training: problem: resize_image: [32, 32] - batch_size: 64 - optimizer: - #name: Adam - lr: 0.01 # Validation parameters - overwrite defaults: validation: diff --git a/configs/mnist/mnist_classification_vf_lenet5.yml b/configs/mnist/mnist_classification_vf_lenet5.yml index 233fbfa..b1fc3c7 100644 --- a/configs/mnist/mnist_classification_vf_lenet5.yml +++ b/configs/mnist/mnist_classification_vf_lenet5.yml @@ -5,10 +5,6 @@ default_configs: mnist/default_mnist.yml training: problem: resize_image: [32, 32] - #batch_size: 32 - optimizer: - #name: Adam - lr: 0.001 # Validation parameters - overwrite defaults: validation: @@ -26,22 +22,34 @@ pipeline: # Disable components for "default" flow. disable: nllloss, accuracy, precision_recall + ################# SHARED ################# # Add global variables. global_publisher: type: GlobalVariablePublisher priority: 0.1 - keys: [num_classes1, num_classes2, word_to_ix1, word_to_ix2] - values: [3, 7, {"Zero": 0, "One": 1, "Two": 2, "Three": 3}, {"Four": 4, "Five": 5, "Six": 6, "Seven": 7, "Eight": 8, "Nine": 9}] + keys: [word_to_ix1, word_to_ix2] + values: [{"Zero": 0, "One": 1, "Two": 2, "Three": 3}, {"Four": 4, "Five": 5, "Six": 6, "Seven": 7, "Eight": 8, "Nine": 9}] # Image classifier. image_classifier: - type: VariationalFlowLeNet5 - priority: 1 + type: LeNet5 + priority: 1.2 globals: - flow1_prediction_size: num_classes1 - flow2_prediction_size: num_classes2 - flow1_word_mappings: word_to_ix1 - flow2_word_mappings: word_to_ix2 + prediction_size: num_classes + streams: + inputs: inputs + predictions: predictions + + ################# Flow 1 ################# + flow1_string_to_mask: + type: StringToMask + priority: 2.1 + globals: + word_mappings: word_to_ix1 + streams: + strings: labels + string_indices: flow1_targets + masks: flow1_masks # Masked loss. nllloss_flow1: @@ -49,24 +57,55 @@ pipeline: priority: 10.1 use_masking: True streams: - predictions: flow1_predictions + predictions: predictions + targets: flow1_targets masks: flow1_masks # Statistics. + flow1_precision_recall: + type: PrecisionRecallStatistics + priority: 100.3 + use_masking: True + use_word_mappings: True + show_class_scores: True + streams: + predictions: predictions + targets: flow1_targets + masks: flow1_masks + globals: + word_mappings: word_to_ix1 + statistics: + precision: flow1_precision + recall: flow1_recall + f1score: flow1_f1score - accuracy_flow1: - type: AccuracyStatistics - priority: 100.2 + ################# Flow 1 ################# + flow2_string_to_mask: + type: StringToMask + priority: 2.2 + globals: + word_mappings: word_to_ix2 streams: - predictions: flow1_predictions + strings: labels + string_indices: flow2_targets + masks: flow2_masks - precision_recall_flow1: + flow2_precision_recall: type: PrecisionRecallStatistics - priority: 100.3 + priority: 100.5 + use_masking: True use_word_mappings: True show_class_scores: True streams: - predictions: flow1_predictions + predictions: predictions + targets: flow2_targets + masks: flow2_masks globals: - word_mappings: label_word_mappings + word_mappings: word_to_ix2 + statistics: + precision: flow2_precision + recall: flow2_recall + f1score: flow2_f1score + + #: pipeline diff --git a/configs/mnist/mnist_classification_vf_lenet5_2losses.yml b/configs/mnist/mnist_classification_vf_lenet5_2losses.yml new file mode 100644 index 0000000..1ab4c90 --- /dev/null +++ b/configs/mnist/mnist_classification_vf_lenet5_2losses.yml @@ -0,0 +1,133 @@ +# Load config defining MNIST problems for training, validation and testing. +default_configs: mnist/default_mnist.yml + +# Training parameters - overwrite defaults: +training: + problem: + resize_image: [32, 32] + +# Validation parameters - overwrite defaults: +validation: + problem: + resize_image: [32, 32] + +# Testing parameters - overwrite defaults: +testing: + problem: + resize_image: [32, 32] + +# Definition of the pipeline. +pipeline: + name: mnist_variational_flow_lenet5 + # Disable components for "default" flow. + disable: nllloss, accuracy, precision_recall + + ################# SHARED ################# + + # Add global variables. + global_publisher: + type: GlobalVariablePublisher + priority: 0.1 + keys: [word_to_ix1, word_to_ix2] + values: [{"Zero": 0, "One": 1, "Two": 2, "Three": 3}, {"Four": 4, "Five": 5, "Six": 6, "Seven": 7, "Eight": 8, "Nine": 9}] + + # Image classifier. + image_classifier: + type: LeNet5 + priority: 1.2 + globals: + prediction_size: num_classes + streams: + inputs: inputs + predictions: predictions + + all_precision_recall: + type: PrecisionRecallStatistics + priority: 100.1 + use_word_mappings: True + show_class_scores: True + streams: + predictions: predictions + globals: + word_mappings: label_word_mappings + statistics: + precision: all_precision + recall: all_recall + f1score: all_f1score + + ################# Flow 1 ################# + flow1_string_to_mask: + type: StringToMask + priority: 2.1 + globals: + word_mappings: word_to_ix1 + streams: + strings: labels + string_indices: flow1_targets + masks: flow1_masks + + # Masked loss. + flow1_nllloss: + type: NLLLoss + priority: 10.1 + use_masking: True + streams: + predictions: predictions + masks: flow1_masks + loss: flow1_loss + + # Statistics. + flow1_precision_recall: + type: PrecisionRecallStatistics + priority: 100.3 + use_masking: True + use_word_mappings: True + show_class_scores: True + streams: + predictions: predictions + masks: flow1_masks + globals: + word_mappings: word_to_ix1 + statistics: + precision: flow1_precision + recall: flow1_recall + f1score: flow1_f1score + + ################# Flow 2 ################# + flow2_string_to_mask: + type: StringToMask + priority: 2.2 + globals: + word_mappings: word_to_ix2 + streams: + strings: labels + string_indices: flow2_targets + masks: flow2_masks + + # Masked loss. + flow2_nllloss: + type: NLLLoss + priority: 10.2 + use_masking: True + streams: + predictions: predictions + masks: flow2_masks + loss: flow2_loss + + flow2_precision_recall: + type: PrecisionRecallStatistics + priority: 100.5 + use_masking: True + use_word_mappings: True + show_class_scores: True + streams: + predictions: predictions + masks: flow2_masks + globals: + word_mappings: word_to_ix2 + statistics: + precision: flow2_precision + recall: flow2_recall + f1score: flow2_f1score + +#: pipeline diff --git a/ptp/components/models/__init__.py b/ptp/components/models/__init__.py index 8551125..32e95b0 100644 --- a/ptp/components/models/__init__.py +++ b/ptp/components/models/__init__.py @@ -6,7 +6,6 @@ from .model import Model from .recurrent_neural_network import RecurrentNeuralNetwork from .sentence_embeddings import SentenceEmbeddings -from .variational_flow_lenet5 import VariationalFlowLeNet5 __all__ = [ 'ConvNetEncoder', @@ -17,5 +16,4 @@ 'Model', 'RecurrentNeuralNetwork', 'SentenceEmbeddings', - 'VariationalFlowLeNet5', ] diff --git a/ptp/components/models/variational_flow_lenet5.py b/ptp/components/models/variational_flow_lenet5.py deleted file mode 100644 index a059812..0000000 --- a/ptp/components/models/variational_flow_lenet5.py +++ /dev/null @@ -1,185 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# -# Copyright (C) IBM Corporation 2018 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -__author__ = "Tomasz Kornuta & Vincent Marois" - - -import torch - -from ptp.components.models.model import Model -from ptp.data_types.data_definition import DataDefinition - - -class VariationalFlowLeNet5(Model): - """ - A proof of concept of variational flow with LeNet-5 model for MNIST digits classification. - - Uses masks (depending on targets here) to control the where each particular sample "flows through during backpropagation". - - For that purpose it has two flows, first for first subset of classes (0-2) and second for the remainder (3-9). - - """ - def __init__(self, name, config): - """ - Initializes the model, retrieves key mappings, creates two flows. - - :param name: Name of the model (taken from the configuration file). - - :param config: Parameters read from configuration file. - :type config: ``ptp.configuration.ConfigInterface`` - - """ - super(VariationalFlowLeNet5, self).__init__(name, VariationalFlowLeNet5, config) - - # Get key mappings. - self.key_inputs = self.stream_keys["inputs"] - self.key_targets = self.stream_keys["targets"] - - self.key_flow1_predictions = self.stream_keys["flow1_predictions"] - self.key_flow1_masks = self.stream_keys["flow1_masks"] - self.key_flow2_predictions = self.stream_keys["flow2_predictions"] - self.key_flow2_masks = self.stream_keys["floww_masks"] - - # Retrieve prediction sizes from globals. - self.flow1_prediction_size = self.globals["flow1_prediction_size"] - self.flow2_prediction_size = self.globals["flow2_prediction_size"] - - # Retrieve word mappings from globals. - self.flow1_word_mappings = self.globals["flow1_word_mappings"] - self.flow2_word_mappings = self.globals["flow2_word_mappings"] - - - # Create flow 1. - self.flow1_image_encoder = torch.nn.Sequential( - torch.nn.Conv2d(1, 6, kernel_size=(5, 5)), - torch.nn.ReLU(inplace=True), - torch.nn.MaxPool2d(kernel_size=(2, 2), stride=2), - - torch.nn.Conv2d(6, 16, kernel_size=(5, 5)), - torch.nn.ReLU(inplace=True), - torch.nn.MaxPool2d(kernel_size=(2, 2), stride=2), - - torch.nn.Conv2d(16, 120, kernel_size=(5, 5)), - torch.nn.ReLU(inplace=True) - ) - - self.flow1_classifier = torch.nn.Sequential( - torch.nn.Linear(120, 84), - torch.nn.ReLU(inplace=True), - torch.nn.Linear(84, 10), # FOR NOW - torch.nn.LogSoftmax(dim=1) - ) - - # Create flow 2. - self.flow2_image_encoder = torch.nn.Sequential( - torch.nn.Conv2d(1, 6, kernel_size=(5, 5)), - torch.nn.ReLU(inplace=True), - torch.nn.MaxPool2d(kernel_size=(2, 2), stride=2), - - torch.nn.Conv2d(6, 16, kernel_size=(5, 5)), - torch.nn.ReLU(inplace=True), - torch.nn.MaxPool2d(kernel_size=(2, 2), stride=2), - - torch.nn.Conv2d(16, 120, kernel_size=(5, 5)), - torch.nn.ReLU(inplace=True) - ) - - self.flow2_classifier = torch.nn.Sequential( - torch.nn.Linear(120, 84), - torch.nn.ReLU(inplace=True), - torch.nn.Linear(84, self.flow2_prediction_size), - torch.nn.LogSoftmax(dim=1) - ) - - - def input_data_definitions(self): - """ - Function returns a dictionary with definitions of input data that are required by the component. - - :return: dictionary containing input data definitions (each of type :py:class:`ptp.utils.DataDefinition`). - """ - return { - self.key_inputs: DataDefinition([-1, 1, 32, 32], [torch.Tensor], "Batch of images [BATCH_SIZE x IMAGE_DEPTH x IMAGE_HEIGHT x IMAGE WIDTH]"), - self.key_targets: DataDefinition([-1], [torch.Tensor], "Batch of targets [BATCH_SIZE]"), - } - - - def output_data_definitions(self): - """ - Function returns a dictionary with definitions of output data produced the component. - - :return: dictionary containing output data definitions (each of type :py:class:`ptp.utils.DataDefinition`). - """ - return { # 10 for now! - self.key_flow1_predictions: DataDefinition([-1, 10], [torch.Tensor], "Batch of flow 1predictions, each represented as probability distribution over classes [BATCH_SIZE x FLOW1_PREDICTION_SIZE]"), - self.key_flow1_masks: DataDefinition([-1], [torch.Tensor], "Batch of masks for flow 1 [BATCH_SIZE]"), - self.key_flow2_predictions: DataDefinition([-1, self.flow2_prediction_size], [torch.Tensor], "Batch of flow 2 predictions, each represented as probability distribution over classes [BATCH_SIZE x FLOW2_PREDICTION_SIZE]"), - self.key_flow2_masks: DataDefinition([-1], [torch.Tensor], "Batch of masks for flow 2 [BATCH_SIZE]"), - } - - - def forward(self, data_dict): - """ - Main forward pass of the model. - In fact performs two passes, using masks generated on the fly using targets. - - :param data_dict: DataDict({'images',**}), where: - - - images: [batch_size, num_channels, width, height] - - :type data_dict: ``miprometheus.utils.DataDict`` - - :return: Predictions [batch_size, num_classes] - - """ - targets = data_dict[self.key_targets] - #print("targets = \n", targets) - - # Produce masks - for flow 1. - flow1_masks = torch.zeros(targets.size(0), requires_grad=False).type(self.app_state.ByteTensor) - for _, val in self.flow1_word_mappings.items(): - flow1_masks = flow1_masks + (targets == val) - #print("flow1_masks = \n", flow1_masks) - - # Produce masks - for flow 2. - flow2_masks = torch.zeros(targets.size(0), requires_grad=False).type(self.app_state.ByteTensor) - for _, val in self.flow2_word_mappings.items(): - flow2_masks = flow2_masks + (targets == val) - #print("flow2_masks = \n", flow2_masks) - #exit(1) - - # Get images. - img = data_dict[self.key_inputs] - - # Pass inputs through flow 1. - x1 = self.flow1_image_encoder(img) - x1 = x1.view(-1, 120) - x1 = self.flow1_classifier(x1) - - # Pass inputs through flow 2. - x2 = self.flow2_image_encoder(img) - x2 = x2.view(-1, 120) - x2 = self.flow2_classifier(x2) - - - # Add predictions to datadict. - data_dict.extend({ - self.key_flow1_predictions: x1, - self.key_flow2_predictions: x2, - self.key_flow1_masks: flow1_masks, - self.key_flow2_masks: flow2_masks, - }) diff --git a/ptp/components/publishers/precision_recall_statistics.py b/ptp/components/publishers/precision_recall_statistics.py index 6a6430f..1500d36 100644 --- a/ptp/components/publishers/precision_recall_statistics.py +++ b/ptp/components/publishers/precision_recall_statistics.py @@ -65,16 +65,18 @@ def __init__(self, name, config): if self.config["use_word_mappings"]: # Get labels from word mappings. self.labels = [] + self.index_mappings = {} # Assume they are ordered, starting from 0. - for key in self.globals["word_mappings"].keys(): - self.labels.append(key) + for i,(word,index) in enumerate(self.globals["word_mappings"].items()): + self.labels.append(word) + self.index_mappings[index] = i # Set number of classes by looking at labels. self.num_classes = len(self.labels) else: # Get the number of possible outputs. self.num_classes = self.globals["num_classes"] self.labels = list(range(self.num_classes)) - + self.index_mappings = {i: i for i in range(self.num_classes)} # Check display options. self.show_confusion_matrix = self.config["show_confusion_matrix"] @@ -126,9 +128,9 @@ def __call__(self, data_dict): # Calculate weighted averages. support_sum = sum(support) - precision_avg = sum([pi*si / support_sum for (pi,si) in zip(precision,support)]) - recall_avg = sum([ri*si / support_sum for (ri,si) in zip(recall,support)]) - f1score_avg = sum([fi*si / support_sum for (fi,si) in zip(f1score,support)]) + precision_avg = sum([pi*si / support_sum if support_sum > 0 else 0.0 for (pi,si) in zip(precision,support)]) + recall_avg = sum([ri*si / support_sum if support_sum > 0 else 0.0 for (ri,si) in zip(recall,support)]) + f1score_avg = sum([fi*si / support_sum if support_sum > 0 else 0.0 for (fi,si) in zip(f1score,support)]) # Log class scores. if self.show_class_scores: @@ -171,10 +173,17 @@ def calculate_statistics(self, data_dict): # Create the confusion matrix, use SciKit learn order: # Column - predicted class + #print(self.index_mappings) # Row - target (actual) class confusion_matrix = np.zeros([self.num_classes, self.num_classes], dtype=int) for i, (target, pred) in enumerate(zip(targets, preds)): - confusion_matrix[target][pred] += 1 * masks[i] + #print("T: ",target) + #print("P: ",pred) + # If both indices are ok. + if target in self.index_mappings.keys() and pred in self.index_mappings.keys(): + #print(self.index_mappings[target]) + #print(self.index_mappings[pred]) + confusion_matrix[self.index_mappings[target]][self.index_mappings[pred]] += 1 * masks[i] # Calculate true positive (TP), eqv. with hit. tp = np.zeros([self.num_classes], dtype=int) @@ -203,7 +212,7 @@ def calculate_statistics(self, data_dict): recall = [float(tpi) / float(tpi+fni) if (tpi+fni) > 0 else 0.0 for (tpi,fni) in zip(tp,fn)] # Calcualte f1-score. - f1score = [ 2 * pi * ri / (pi+ri) if (pi+ri) > 0 else 0.0 for (pi,ri) in zip(precision,recall)] + f1score = [ 2 * pi * ri / float(pi+ri) if (pi+ri) > 0 else 0.0 for (pi,ri) in zip(precision,recall)] # Get support. support = np.sum(confusion_matrix, axis=1) @@ -238,9 +247,9 @@ def collect_statistics(self, stat_col, data_dict): # Calculate weighted averages. support_sum = sum(support) - precision_avg = sum([pi*si / support_sum if si > 0 else 0.0 for (pi,si) in zip(precision,support)]) - recall_avg = sum([ri*si / support_sum if si > 0 else 0.0 for (ri,si) in zip(recall,support)]) - f1score_avg = sum([fi*si / support_sum if si > 0 else 0.0 for (fi,si) in zip(f1score,support)]) + precision_avg = sum([pi*si / support_sum if support_sum > 0 else 0.0 for (pi,si) in zip(precision,support)]) + recall_avg = sum([ri*si / support_sum if support_sum > 0 else 0.0 for (ri,si) in zip(recall,support)]) + f1score_avg = sum([fi*si / support_sum if support_sum > 0 else 0.0 for (fi,si) in zip(f1score,support)]) # Export to statistics. stat_col[self.key_precision] = precision_avg From 1f46fbdcd0999e290bc909da0aad251d6733e73e Mon Sep 17 00:00:00 2001 From: tkornut Date: Thu, 11 Apr 2019 13:43:03 -0700 Subject: [PATCH 11/11] VF MNIST cleanups --- configs/mnist/default_mnist.yml | 6 +-- .../mnist/mnist_classification_vf_2lenet5.yml | 2 +- .../mnist/mnist_classification_vf_lenet5.yml | 18 ++++++- ...mnist_classification_vf_lenet5_2losses.yml | 2 +- ...ification_vf_shared_convnet_2softmaxes.yml | 51 ++++--------------- 5 files changed, 29 insertions(+), 50 deletions(-) diff --git a/configs/mnist/default_mnist.yml b/configs/mnist/default_mnist.yml index 095f69e..b289562 100644 --- a/configs/mnist/default_mnist.yml +++ b/configs/mnist/default_mnist.yml @@ -52,13 +52,9 @@ pipeline: type: BatchSizeStatistics priority: 100.0 - accuracy: - type: AccuracyStatistics - priority: 100.2 - precision_recall: type: PrecisionRecallStatistics - priority: 100.3 + priority: 100.1 use_word_mappings: True show_class_scores: True globals: diff --git a/configs/mnist/mnist_classification_vf_2lenet5.yml b/configs/mnist/mnist_classification_vf_2lenet5.yml index 6ddc4e3..47be2dc 100644 --- a/configs/mnist/mnist_classification_vf_2lenet5.yml +++ b/configs/mnist/mnist_classification_vf_2lenet5.yml @@ -20,7 +20,7 @@ testing: pipeline: name: mnist_variational_flow_2lenet5 # Disable components for "default" flow. - disable: nllloss, accuracy, precision_recall + disable: nllloss, precision_recall # Add global variables. global_publisher: diff --git a/configs/mnist/mnist_classification_vf_lenet5.yml b/configs/mnist/mnist_classification_vf_lenet5.yml index b1fc3c7..625f73f 100644 --- a/configs/mnist/mnist_classification_vf_lenet5.yml +++ b/configs/mnist/mnist_classification_vf_lenet5.yml @@ -20,7 +20,7 @@ testing: pipeline: name: mnist_variational_flow_lenet5 # Disable components for "default" flow. - disable: nllloss, accuracy, precision_recall + disable: nllloss, precision_recall ################# SHARED ################# # Add global variables. @@ -40,6 +40,20 @@ pipeline: inputs: inputs predictions: predictions + all_precision_recall: + type: PrecisionRecallStatistics + priority: 100.1 + use_word_mappings: True + show_class_scores: True + streams: + predictions: predictions + globals: + word_mappings: label_word_mappings + statistics: + precision: all_precision + recall: all_recall + f1score: all_f1score + ################# Flow 1 ################# flow1_string_to_mask: type: StringToMask @@ -79,7 +93,7 @@ pipeline: recall: flow1_recall f1score: flow1_f1score - ################# Flow 1 ################# + ################# Flow 2 ################# flow2_string_to_mask: type: StringToMask priority: 2.2 diff --git a/configs/mnist/mnist_classification_vf_lenet5_2losses.yml b/configs/mnist/mnist_classification_vf_lenet5_2losses.yml index 1ab4c90..f92dcae 100644 --- a/configs/mnist/mnist_classification_vf_lenet5_2losses.yml +++ b/configs/mnist/mnist_classification_vf_lenet5_2losses.yml @@ -20,7 +20,7 @@ testing: pipeline: name: mnist_variational_flow_lenet5 # Disable components for "default" flow. - disable: nllloss, accuracy, precision_recall + disable: nllloss, precision_recall ################# SHARED ################# diff --git a/configs/mnist/mnist_classification_vf_shared_convnet_2softmaxes.yml b/configs/mnist/mnist_classification_vf_shared_convnet_2softmaxes.yml index 2c6ea0e..25e58cc 100644 --- a/configs/mnist/mnist_classification_vf_shared_convnet_2softmaxes.yml +++ b/configs/mnist/mnist_classification_vf_shared_convnet_2softmaxes.yml @@ -6,15 +6,15 @@ training: problem: #resize_image: [32, 32] batch_size: 64 - optimizer: - #name: Adam - lr: 0.001 - terminal_conditions: - loss_stop: 0.08 + #optimizer: + # #name: Adam + # lr: 0.001 + #terminal_conditions: + # loss_stop: 0.08 # Validation parameters - overwrite defaults: -validation: - partial_validation_interval: 10 +#validation: +# partial_validation_interval: 10 # problem: # resize_image: [32, 32] @@ -27,7 +27,7 @@ validation: pipeline: name: mnist_variational_flow_shared_convnet_2softmaxes # Disable components for "default" flow. - disable: nllloss, accuracy, precision_recall + disable: nllloss, precision_recall ################# SHARED ################# @@ -77,6 +77,7 @@ pipeline: prediction_size: num_classes1 streams: inputs: reshaped_maps + targets: flow1_targets predictions: flow1_predictions # Masked loss. @@ -91,17 +92,6 @@ pipeline: loss: flow1_loss # Statistics. - flow1_accuracy: - type: AccuracyStatistics - priority: 1.32 - use_masking: True - streams: - predictions: flow1_predictions - targets: flow1_targets - masks: flow1_masks - statistics: - accuracy: flow1_accuracy - flow1_precision_recall: type: PrecisionRecallStatistics priority: 1.33 @@ -142,6 +132,7 @@ pipeline: prediction_size: num_classes2 streams: inputs: reshaped_maps + targets: flow2_targets predictions: flow2_predictions # Masked loss. @@ -156,17 +147,6 @@ pipeline: loss: flow2_loss # Statistics. - flow2_accuracy: - type: AccuracyStatistics - priority: 2.32 - use_masking: True - streams: - targets: flow2_targets - predictions: flow2_predictions - masks: flow2_masks - statistics: - accuracy: flow2_accuracy - flow2_precision_recall: type: PrecisionRecallStatistics priority: 2.33 @@ -201,17 +181,6 @@ pipeline: output_indices: merged_indices # Statistics. - joined_accuracy: - type: AccuracyStatistics - priority: 3.21 - # Use prediction indices instead of distributions. - use_prediction_distributions: False - streams: - targets: targets - predictions: merged_indices - statistics: - accuracy: joined_accuracy - joined_precision_recall: type: PrecisionRecallStatistics priority: 3.22