In [None]:
# default_exp mtl_model.base
%load_ext autoreload
%autoreload 2
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
from nbdev.showdoc import show_doc

# MTLBase

In [None]:
# export
from copy import copy
from typing import Dict, Tuple

import tensorflow as tf
from m3tl.utils import dispatch_features, get_phase


class MTLBase(tf.keras.Model):
    def __init__(self, params, name:str, *args, **kwargs):
        super(MTLBase, self).__init__(name, *args, **kwargs)
        self.params = params
        self.available_extract_target = copy(self.params.problem_list)
        self.available_extract_target.append('all')
        self.problem_list = self.params.problem_list

    def extract_feature(self, extract_problem: str, feature_dict: dict, hidden_feature_dict: dict):

        mode = get_phase()
        if extract_problem not in self.available_extract_target:
            raise ValueError('Tried to extract feature {0}, available extract problem: {1}'.format(
                extract_problem, self.available_extract_target))
        
        # if key contains problem, return directly
        if extract_problem in feature_dict and extract_problem in hidden_feature_dict:
            return feature_dict[extract_problem], hidden_feature_dict[extract_problem]

        # use dispatch function to extract record based on loss multiplier
        if 'all' in feature_dict and 'all' in hidden_feature_dict:
            return dispatch_features(
                features=feature_dict['all'], hidden_feature=hidden_feature_dict['all'], 
                problem=extract_problem, mode=mode)
        return dispatch_features(
                features=feature_dict, hidden_feature=hidden_feature_dict, 
                problem=extract_problem, mode=mode)

    def call(self, inputs: Tuple[Dict[str, tf.Tensor]]):
        raise NotImplementedError


In [None]:
show_doc(MTLBase.extract_feature)

<h4 id="MTLBase.extract_feature" class="doc_header"><code>MTLBase.extract_feature</code><a href="__main__.py#L15" class="source_link" style="float:right">[source]</a></h4>

> <code>MTLBase.extract_feature</code>(**`extract_problem`**:`str`, **`feature_dict`**:`dict`, **`hidden_feature_dict`**:`dict`)



Extract features(inputs) and hidden features(body model output tensors) from features and hidden_featues dicts.

In [None]:
from m3tl.test_base import TestBase
import numpy as np

tb = TestBase()

features, hidden_features = tb.get_one_batch_body_model_output()

mtl_base = MTLBase(params=tb.params, name='test_mtl_base')

for problem in tb.params.problem_list:
    loss_multiplier = mtl_base.extract_feature(problem, feature_dict=features, hidden_feature_dict=hidden_features)[0]['{}_loss_multiplier'.format(problem)].numpy()
    assert np.min(loss_multiplier) == 1




Adding new problem weibo_fake_ner, problem type: seq_tag
Adding new problem weibo_cws, problem type: seq_tag
Adding new problem weibo_fake_multi_cls, problem type: multi_cls
Adding new problem weibo_fake_cls, problem type: cls
Adding new problem weibo_masklm, problem type: masklm
Adding new problem weibo_pretrain, problem type: pretrain
Adding new problem weibo_fake_regression, problem type: regression
Adding new problem weibo_fake_vector_fit, problem type: vector_fit
Adding new problem weibo_premask_mlm, problem type: premask_mlm
INFO:tensorflow:sampling weights: 
INFO:tensorflow:weibo_fake_cls_weibo_fake_ner_weibo_fake_regression_weibo_fake_vector_fit: 0.2631578947368421
INFO:tensorflow:weibo_fake_multi_cls: 0.2631578947368421
INFO:tensorflow:weibo_masklm: 0.2236842105263158
INFO:tensorflow:weibo_premask_mlm: 0.25


404 Client Error: Not Found for url: https://huggingface.co/voidful/albert_chinese_tiny/resolve/main/tf_model.h5
Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFAlbertModel: ['predictions.LayerNorm.weight', 'predictions.decoder.weight', 'predictions.dense.weight', 'predictions.decoder.bias', 'predictions.dense.bias', 'predictions.bias', 'predictions.LayerNorm.bias']
- This IS expected if you are initializing TFAlbertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFAlbertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFAlbertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was tr

INFO:tensorflow:Modal Type id mapping: 
 {
    "class": 0,
    "image": 1,
    "text": 2
}
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module, class, method, function, traceback, frame, or code object was expected, got cython_function_or_method
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module, class, method, function, traceback, frame, or code object was expected, got cython_function_or_method


The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.


In [None]:
# export
class BasicMTL(MTLBase):
    def __init__(self, params, name: str, *args, **kwargs):
        super().__init__(params, name, *args, **kwargs)
    
    def call(self, inputs: Tuple[Dict[str, tf.Tensor]]):
        mode = get_phase()
        features, hidden_features = inputs
        features_per_problem, hidden_features_per_problem = {}, {}
        for problem in self.available_extract_target:
            features_per_problem[problem], hidden_features_per_problem[problem] = self.extract_feature(
                extract_problem=problem, feature_dict=features, hidden_feature_dict=hidden_features
            )
        return features_per_problem, hidden_features_per_problem