In [None]:
# default_exp test_base
%load_ext autoreload
%autoreload 2
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

# Test Base

Basic setup for testing

In [None]:
# export

import shutil
import tempfile

from bert_multitask_learning.predefined_problems import (
    get_weibo_cws_fn, get_weibo_fake_cls_fn, get_weibo_fake_multi_cls_fn,
    get_weibo_fake_ner_fn, get_weibo_masklm, get_weibo_pretrain_fn)
from bert_multitask_learning.params import BaseParams

class TestBase():
    def __init__(self):
        self.setUp()

    def setUp(self) -> None:
        self.tmpfiledir = tempfile.mkdtemp()
        self.tmpckptdir = tempfile.mkdtemp()
        self.prepare_params()

    def tearDown(self) -> None:
        shutil.rmtree(self.tmpfiledir)
        shutil.rmtree(self.tmpckptdir)

    def prepare_params(self):

        self.problem_type_dict = {
            'weibo_fake_ner': 'seq_tag',
            'weibo_cws': 'seq_tag',
            'weibo_fake_multi_cls': 'multi_cls',
            'weibo_fake_cls': 'cls',
            'weibo_masklm': 'masklm',
            'weibo_pretrain': 'pretrain'
        }

        self.processing_fn_dict = {
            'weibo_fake_ner': get_weibo_fake_ner_fn(file_path='/data/bert-multitask-learning/data/ner/weiboNER*'),
            'weibo_cws': get_weibo_cws_fn(file_path='/data/bert-multitask-learning/data/ner/weiboNER*'),
            'weibo_fake_cls': get_weibo_fake_cls_fn(file_path='/data/bert-multitask-learning/data/ner/weiboNER*'),
            'weibo_fake_multi_cls': get_weibo_fake_multi_cls_fn(file_path='/data/bert-multitask-learning/data/ner/weiboNER*'),
            'weibo_masklm': get_weibo_masklm(file_path='/data/bert-multitask-learning/data/ner/weiboNER*'),
            'weibo_pretrain': get_weibo_pretrain_fn(file_path='/data/bert-multitask-learning/data/ner/weiboNER*')
        }
        self.params = BaseParams()
        self.params.tmp_file_dir = self.tmpfiledir
        self.params.ckpt_dir = self.tmpckptdir
        self.params.transformer_model_name = 'voidful/albert_chinese_tiny'
        self.params.transformer_config_name = 'voidful/albert_chinese_tiny'
        self.params.transformer_tokenizer_name = 'voidful/albert_chinese_tiny'
        self.params.transformer_tokenizer_loading = 'BertTokenizer'
        self.params.transformer_config_loading = 'AlbertConfig'

        self.params.add_multiple_problems(
            problem_type_dict=self.problem_type_dict, processing_fn_dict=self.processing_fn_dict)
