In [None]:
# default_exp loss_strategy.base
%load_ext autoreload
%autoreload 2
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
from nbdev.showdoc import show_doc

# LossCombinationStrategyBase

In [None]:
# export
from collections import deque
from typing import Dict, List

import tensorflow as tf
from m3tl.utils import get_phase
from tensorflow.python.util.nest import (flatten,
                                         flatten_with_joined_string_paths)


class LossCombinationStrategyBase(tf.keras.Model):
    def __init__(self, params, name:str, *args, **kwargs):
        super(LossCombinationStrategyBase, self).__init__(name, *args, **kwargs)
        self.params = params
        self.problem_list = self.params.problem_list
        self.hist_loss_dict = deque(maxlen=100)
        self.hist_metric_dict = deque(maxlen=100)
    
    def extract_loss_metric_dict_from_history(self, 
                                            history: tf.keras.callbacks.History,
                                            structure: dict,
                                            prefix='val_') -> dict:
        history: Dict[str, float] = history.history

        # metrics from validation set starts with val
        if prefix:
            if prefix != 'val_':
                raise ValueError('prefix should either be "val_" or None')
            history = {k.replace(prefix, ''): v for k, v in history.items() if k.startswith(prefix)}

        

        # get structure path
        structure_path = [p for p, _ in flatten_with_joined_string_paths(structure)]
        # make flat history and pack
        flat_history = [history[p] for p in structure_path]
        history = tf.nest.pack_sequence_as(structure=structure, flat_sequence=flat_history)

        return history

    def get_all_losses(self, current_loss_dict: dict) -> List[tf.Tensor]:
        return flatten(current_loss_dict)

    def get_problem_loss(self, current_loss_dict:dict, problem: str) -> List[tf.Tensor]:
        flatten_loss_with_path = flatten_with_joined_string_paths(current_loss_dict)
        return [v for p, v in flatten_loss_with_path if problem in p]

    def call(self, 
            current_loss_dict: dict,
            current_metric_dict: dict,
            history: tf.keras.callbacks.History):
        raise NotImplementedError


In [None]:
# export
class SumLossCombination(LossCombinationStrategyBase):
    def __init__(self, params, name: str, *args, **kwargs):
        super().__init__(params, name, *args, **kwargs)
    
    def call(self, 
            current_loss_dict: dict,
            current_metric_dict: dict,
            history: tf.keras.callbacks.History):
        mode = get_phase()
        # total losses
        losses = self.get_all_losses(current_loss_dict)
        return losses

In [None]:
from m3tl.test_base import TestBase
from m3tl.special_tokens import TRAIN
from m3tl.utils import create_dict_from_nested_model

tb = TestBase()
tb.test_loss_combination_strategy(loss_combination_strategy_name='sum')

2021-06-12 22:06:32.702 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem weibo_fake_ner, problem type: seq_tag
2021-06-12 22:06:32.702 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem weibo_fake_multi_cls, problem type: multi_cls
2021-06-12 22:06:32.703 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem weibo_fake_cls, problem type: cls
2021-06-12 22:06:32.703 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem weibo_masklm, problem type: masklm
2021-06-12 22:06:32.703 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem weibo_fake_regression, problem type: regression
2021-06-12 22:06:32.704 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem weibo_fake_vector_fit, problem type: vector_fit
2021-06-12 22:06:32.704 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem weibo_pre

Epoch 1/2
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: invalid value for "node": expected "ast.AST", got "<class 'NoneType'>"; to visit lists of nodes, use "visit_block" instead
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module, class, method, function, traceback, frame, or code object was expected, got cython_function_or_method


The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.




The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.


Epoch 2/2


In [None]:
test_instance = LossCombinationStrategyBase(tb.params, 'test')
# validation losses
test_instance.extract_loss_metric_dict_from_history(history=tb.all_model.history, structure = create_dict_from_nested_model(tb.all_model))
# training losses
test_instance.extract_loss_metric_dict_from_history(history=tb.all_model.history, structure = create_dict_from_nested_model(tb.all_model), prefix='')

defaultdict(list,
            {'BertMultiTaskTop': defaultdict(list,
                         {'fake_contrastive_learning': defaultdict(list,
                                      {'simcse': defaultdict(list,
                                                   {'losses': [[2.494117498397827,
                                                      2.3415939807891846]]})}),
                          'weibo_fake_cls': defaultdict(list,
                                      {'losses': [[1.0908699035644531,
                                         0.6408037543296814]]}),
                          'weibo_fake_multi_cls': defaultdict(list,
                                      {'losses': [[0.4924350678920746,
                                         0.44675901532173157]]}),
                          'weibo_fake_ner': defaultdict(list,
                                      {'losses': [[1.4930245876312256,
                                         1.4878641366958618]]}),
                          '