In [None]:
""" Just a work bench"""
import os
import json
from typing import List
from pprint import pprint

from dotenv import load_dotenv

import numpy as np
import pandas as pd

from transformers import BertTokenizer, AutoTokenizer, AutoConfig, TFDistilBertModel, TFBertModel, TFTrainingArguments
import tensorflow as tf

from tc_data import TopCoder
from run_classification import build_dataset
from model_tcpm_distilbert import (
    TCPMDistilBertClassification,
    build_tcpm_model_distilbert_classification,
    build_tcpm_model_distilbert_regression
)

load_dotenv()
pd.set_option('display.max_rows', 800)

In [None]:
tc = TopCoder()

In [None]:
config = AutoConfig.from_pretrained(os.getenv('MODEL_NAME'))
tokenizer = AutoTokenizer.from_pretrained(os.getenv('MODEL_NAME'))

metadata = tc.get_meta_data_features(encoded_tech=True, softmax_tech=True, return_tensor=True)
encoded_txt = tc.get_bert_encoded_txt_features(tokenizer, return_tensor=True)
target = tc.get_target(return_tensor=True)

In [None]:
ds = tf.data.Dataset.from_tensor_slices((dict(**encoded_txt, meta_input=metadata), target))

In [None]:
tech_popularity, _ = tc.calculate_tech_popularity()
tech_popularity

In [None]:
meta_df = tc.get_meta_data_features(encoded_tech=True, softmax_tech=True, return_df=True)
meta_df

![regression model](regression_model.png)

In [None]:
training_history = pd.read_json('example_training_result.json', orient='index')
training_history.index.names = ['epochs']

In [None]:
training_history

In [None]:
test_result = pd.read_json('example_result.json', orient='index')
test_result

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

with sns.axes_style('whitegrid'):
    fig, axes = plt.subplots(1, 2, figsize=(8, 3), dpi=200)
    
    sns.lineplot(
        x=training_history.index,
        y=training_history.mae,
        ax=axes[0]
    )
    axes[0].set_title('Metrics - MAE')

    sns.lineplot(
        x=training_history.index,
        y=training_history.mse,
        ax=axes[1]
    )
    axes[1].set_title('Metrics - MSE')

    fig.tight_layout()