In [None]:
# default_exp problem_types.contrastive_learning
import os
%load_ext autoreload
%autoreload 2
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


In [None]:
# test setup
import numpy as np
import tensorflow as tf
from m3tl.input_fn import train_eval_input_fn
from m3tl.test_base import TestBase, test_top_layer

test_base = TestBase()
params = test_base.params

hidden_dim = params.bert_config.hidden_size

train_dataset = train_eval_input_fn(params=params)
one_batch = next(train_dataset.as_numpy_iterator())


2021-06-12 21:42:24.148 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem weibo_fake_ner, problem type: seq_tag
2021-06-12 21:42:24.149 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem weibo_fake_multi_cls, problem type: multi_cls
2021-06-12 21:42:24.149 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem weibo_fake_cls, problem type: cls
2021-06-12 21:42:24.150 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem weibo_masklm, problem type: masklm
2021-06-12 21:42:24.150 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem weibo_fake_regression, problem type: regression
2021-06-12 21:42:24.151 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem weibo_fake_vector_fit, problem type: vector_fit
2021-06-12 21:42:24.151 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem weibo_pre

In [None]:
one_batch['array_input_ids'].shape


(32, 1, 10)

# Contrastive Learning(contrastive_learning)

Contrastive learning is usually used along side in-batch data augmentation. To implement data augmentation strategy, one should implement their own embedding layer. See `embedding_layer` sub-module for more details.


## Imports and utils


In [None]:
# export
from typing import List

import numpy as np
import tensorflow as tf
from loguru import logger
from m3tl.base_params import BaseParams
from m3tl.problem_types.utils import (empty_tensor_handling_loss,
                                      nan_loss_handling, pad_to_shape)
from m3tl.special_tokens import PREDICT
from m3tl.utils import (LabelEncoder, get_label_encoder_save_path, get_phase,
                        need_make_label_encoder)


## Top Layer

In [None]:
# SimSCE
# export
class SimCSE(tf.keras.Model):
    def __init__(self, params: BaseParams, problem_name: str) -> None:
        super(SimCSE, self).__init__(name='simcse')
        self.params = params
        self.problem_name = problem_name
        self.dropout = tf.keras.layers.Dropout(self.params.dropout)
        self.pooler = self.params.get('simcse_pooler', 'pooled')
        self.metric_fn = tf.keras.metrics.CategoricalAccuracy(name='{}_acc'.format(problem_name))
        availabel_pooler = ['pooled', 'mean_pool']
        assert self.pooler in availabel_pooler, \
            'available params.simcse_pooler: {}, got: {}'.format(
                availabel_pooler, self.pooler)
        if self.params.embedding_layer['name'] != 'duplicate_data_augmentation_embedding':
            raise ValueError(
                'SimCSE requires duplicate_data_augmentation_embedding. Fix it with `params.assign_embedding_layer(\'duplicate_data_augmentation_embedding\')`')

    def call(self, inputs):

        features, hidden_features = inputs
        phase = get_phase()

        if phase != PREDICT:
            # created pool embedding
            if self.pooler == 'pooled':
                all_pooled_embedding = hidden_features['pooled']
            else:
                all_pooled_embedding = tf.reduce_mean(
                    hidden_features['seq'], axis=1)

            # shape (batch_size, hidden_dim)
            pooled_rep1_embedding, pooled_rep2_embedding = tf.split(
                all_pooled_embedding, 2)

            # calculate similarity
            pooled_rep1_embedding = tf.math.l2_normalize(
                pooled_rep1_embedding, axis=1)
            pooled_rep2_embedding = tf.math.l2_normalize(
                pooled_rep2_embedding, axis=1)
            # shape (batch_size, batch_size)
            similarity = tf.matmul(pooled_rep1_embedding,
                                   pooled_rep2_embedding, transpose_b=True)
            labels = tf.eye(tf.shape(similarity)[0])

            # shape (batch_size*batch_size)
            similarity = tf.reshape(similarity, shape=(-1, 1))
            labels = tf.reshape(labels, shape=(-1, 1))

            # make compatible with binary crossentropy
            similarity = tf.concat([1-similarity, similarity], axis=1)
            labels = tf.concat([1-labels, labels], axis=1)
            loss = tf.keras.losses.binary_crossentropy(labels, similarity)
            loss = tf.reduce_mean(loss)
            self.add_loss(loss)
            acc = self.metric_fn(labels, similarity)
            self.add_metric(acc)
        return inputs[1]['pooled']


In [None]:
# export
def get_contrastive_learning_model(params: BaseParams, problem_name: str, model_name: str) -> tf.keras.Model:
    if model_name == 'simcse':
        return SimCSE(params=params, problem_name=problem_name)

    logger.warning(
        '{} not match any contrastive learning model, using SimCSE'.format(model_name))
    return SimCSE(params=params, problem_name=problem_name)


In [None]:
# export

class ContrastiveLearning(tf.keras.Model):
    def __init__(self, params: BaseParams, problem_name: str) -> None:
        super(ContrastiveLearning, self).__init__(name=problem_name)
        self.params = params
        self.problem_name = problem_name
        self.contrastive_learning_model_name = self.params.contrastive_learning_model_name
        self.contrastive_learning_model = get_contrastive_learning_model(
            params=self.params, problem_name=problem_name, model_name=self.contrastive_learning_model_name)

    def call(self, inputs):
        return self.contrastive_learning_model(inputs)


In [None]:
test_top_layer(ContrastiveLearning, problem='fake_contrastive_learning',
               params=params, sample_features=one_batch, hidden_dim=hidden_dim, test_batch_size_list=[0,2])


2021-06-12 21:42:33.553 | DEBUG    | m3tl.test_base:test_top_layer:247 - Testing ContrastiveLearning
2021-06-12 21:42:33.585 | DEBUG    | m3tl.test_base:test_top_layer:253 - testing batch size 0
2021-06-12 21:42:33.593 | DEBUG    | m3tl.test_base:test_top_layer:253 - testing batch size 2


## Get or make label encoder function


In [None]:
# export
def contrastive_learning_get_or_make_label_encoder_fn(params: BaseParams, problem: str, mode: str, label_list: List[str], *args, **kwargs) -> LabelEncoder:

    le_path = get_label_encoder_save_path(params=params, problem=problem)
    label_encoder = LabelEncoder()
    if need_make_label_encoder(mode=mode, le_path=le_path, overwrite=kwargs['overwrite']):
        # fit and save label encoder
        label_encoder.fit(label_list)
        label_encoder.dump(le_path)
        params.set_problem_info(
            problem=problem, info_name='num_classes', info=len(label_encoder.encode_dict))
    else:
        label_encoder.load(le_path)

    return label_encoder


## Label handing function

In [None]:
# export
def contrastive_learning_label_handling_fn(target: str, label_encoder=None, tokenizer=None, decoding_length=None, *args, **kwargs) -> dict:

    label_id = label_encoder.transform([target]).tolist()[0]
    label_id = np.int32(label_id)
    return label_id, None
