In [1]:
# default_exp problem_types.cls
%load_ext autoreload
%autoreload 2
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [2]:
# hide
# test setup
import tensorflow as tf
import numpy as np
from m3tl.test_base import TestBase
from m3tl.input_fn import train_eval_input_fn
from m3tl.test_base import test_top_layer
test_base = TestBase()
params = test_base.params

hidden_dim = params.bert_config.hidden_size

train_dataset = train_eval_input_fn(params=params)
one_batch = next(train_dataset.as_numpy_iterator())


2021-06-24 20:21:12.618 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem weibo_fake_ner, problem type: seq_tag
2021-06-24 20:21:12.618 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem weibo_fake_multi_cls, problem type: multi_cls
2021-06-24 20:21:12.619 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem weibo_fake_cls, problem type: cls
2021-06-24 20:21:12.619 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem weibo_masklm, problem type: masklm
2021-06-24 20:21:12.620 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem weibo_fake_regression, problem type: regression
2021-06-24 20:21:12.620 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem weibo_fake_vector_fit, problem type: vector_fit
2021-06-24 20:21:12.621 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem weibo_pre

# Classification(cls)

Classification. By default this problem will use `[CLS]` token embedding.

Example: `m3tl.predefined_problems.get_weibo_fake_cls_fn`.

## Imports and utils


In [4]:
# export
from functools import partial
from typing import List

import numpy as np
import tensorflow as tf
from m3tl.base_params import BaseParams
from m3tl.problem_types.utils import (empty_tensor_handling_loss,
                                      nan_loss_handling)
from m3tl.special_tokens import PREDICT, TRAIN
from m3tl.utils import (LabelEncoder, get_label_encoder_save_path, get_phase,
                        need_make_label_encoder, variable_summaries)


## Top Layer

In [5]:
# export

class Classification(tf.keras.layers.Layer):
    """Classification Top Layer"""
    def __init__(self, params: BaseParams, problem_name: str) -> None:
        super(Classification, self).__init__(name=problem_name)
        self.params = params
        self.problem_name = problem_name
        self.num_classes = self.params.get_problem_info(problem=problem_name, info_name='num_classes')
        self.dense = tf.keras.layers.Dense(self.num_classes, activation=None)
        self.metric_fn = tf.keras.metrics.SparseCategoricalAccuracy(
            name='{}_acc'.format(self.problem_name))

        self.dropout = tf.keras.layers.Dropout(1-params.dropout_keep_prob)

    def call(self, inputs):
        mode = get_phase()
        training = (mode == TRAIN)
        feature, hidden_feature = inputs
        hidden_feature = hidden_feature['pooled']
        if mode != PREDICT:
            labels = feature['{}_label_ids'.format(self.problem_name)]
        else:
            labels = None
        hidden_feature = self.dropout(hidden_feature, training)
        logits = self.dense(hidden_feature)

        if self.params.detail_log:
            for weigth_variable in self.weights:
                variable_summaries(weigth_variable, self.problem_name)

        if mode != PREDICT:
            # labels = tf.squeeze(labels)
            # convert labels to one-hot to use label_smoothing
            one_hot_labels = tf.one_hot(
                labels, depth=self.num_classes)
            loss_fn = partial(tf.keras.losses.categorical_crossentropy,
                              from_logits=True, label_smoothing=self.params.label_smoothing)

            loss = empty_tensor_handling_loss(
                one_hot_labels, logits,
                loss_fn)
            loss = nan_loss_handling(loss)
            self.add_loss(loss)
            acc = self.metric_fn(labels, logits)
            self.add_metric(acc)
        return tf.nn.softmax(
            logits, name='%s_predict' % self.problem_name)

In [6]:
# hide
from m3tl.test_base import test_top_layer
test_top_layer(Classification, problem='weibo_fake_cls', params=params, sample_features=one_batch, hidden_dim=hidden_dim)

2021-06-24 20:21:23.812 | DEBUG    | m3tl.test_base:test_top_layer:248 - Testing Classification
2021-06-24 20:21:23.825 | DEBUG    | m3tl.test_base:test_top_layer:254 - testing batch size 0
2021-06-24 20:21:23.825 | INFO     | m3tl.utils:set_phase:478 - Setting phase to train
2021-06-24 20:21:23.880 | INFO     | m3tl.utils:set_phase:478 - Setting phase to eval
2021-06-24 20:21:23.886 | INFO     | m3tl.utils:set_phase:478 - Setting phase to infer
2021-06-24 20:21:23.891 | DEBUG    | m3tl.test_base:test_top_layer:254 - testing batch size 1
2021-06-24 20:21:23.892 | INFO     | m3tl.utils:set_phase:478 - Setting phase to train
2021-06-24 20:21:23.913 | INFO     | m3tl.utils:set_phase:478 - Setting phase to eval
2021-06-24 20:21:23.920 | INFO     | m3tl.utils:set_phase:478 - Setting phase to infer
2021-06-24 20:21:23.925 | DEBUG    | m3tl.test_base:test_top_layer:254 - testing batch size 2
2021-06-24 20:21:23.925 | INFO     | m3tl.utils:set_phase:478 - Setting phase to train
2021-06-24 20:2

## Get or make label encoder function


In [7]:
# export
def cls_get_or_make_label_encoder_fn(params: BaseParams, problem: str, mode: str, label_list: List[str], *args, **kwargs) -> LabelEncoder:

    le_path = get_label_encoder_save_path(params=params, problem=problem)
    label_encoder = LabelEncoder()
    if need_make_label_encoder(mode=mode, le_path=le_path, overwrite=kwargs['overwrite']):
        # fit and save label encoder
        label_encoder.fit(label_list)
        label_encoder.dump(le_path)
        params.set_problem_info(problem=problem, info_name='num_classes', info=len(label_encoder.encode_dict))
    else:
        label_encoder.load(le_path)

    return label_encoder

## Label handing function

In [8]:
# export
def cls_label_handling_fn(target, label_encoder=None, tokenizer=None, decoding_length=None, *args, **kwargs):
    label_id = label_encoder.transform([target]).tolist()[0]
    label_id = np.int32(label_id)
    return label_id, None

