In [None]:
""" Just a work bench"""
import os
import json
from typing import List

import numpy as np
import pandas as pd

from transformers import BertTokenizer, AutoTokenizer, AutoConfig, TFDistilBertModel, TFBertModel
import tensorflow as tf

from tc_data import TopCoder
from fine_tune_bert import build_dataset
from model_tcpm_distilbert import TCPMDistilBertClassification

pd.set_option('display.max_rows', 500)

In [None]:
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
ds, ds_size, num_labels = build_dataset(tokenizer)


In [None]:
config = AutoConfig.from_pretrained('distilbert-base-uncased', num_labels=num_labels)
model = TCPMDistilBertClassification.from_pretrained('distilbert-base-uncased', config=config)

In [None]:
tf_distilbert_model = TFDistilBertModel.from_pretrained('distilbert-base-uncased', config=config)
# meta_input_layer = tf.keras.layers.InputLayer(input_shape=(4,), name='meta_input_layer')
fully_connected = tf.keras.layers.Dense(config.dim, activation='relu', name='fully_connected')
drop_out = tf.keras.layers.Dropout(config.seq_classif_dropout)
classification = tf.keras.layers.Dense(num_labels, name='classification')

In [None]:
distil_bert_input = {k: tf.keras.layers.Input(shape=(512,), dtype=tf.int32, name=k) for k in ('input_ids', 'attention_mask')}
distil_bert_output = tf_distilbert_model(distil_bert_input)
hidden_state = distil_bert_output[0]
pooled_output = hidden_state[:, 0] # (bs, dim)

In [None]:
meta_input = tf.keras.layers.Input(shape=(4,), dtype=tf.float32, name='meta_input')
# meta_output = meta_input_layer(meta_input)

In [None]:
concat_layer = tf.keras.layers.concatenate([pooled_output, meta_input])
x = fully_connected(concat_layer)
x = drop_out(x)
output = classification(x)

In [None]:
jpy_model = tf.keras.Model(inputs=[distil_bert_input, meta_input], outputs=output)

In [None]:
tf.keras.utils.plot_model(jpy_model, show_shapes=True)

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
jpy_model.compile(optimizer=optimizer, loss=loss)

In [None]:
d = [i for i in ds.take(1).batch(1)][0]

In [None]:
d_input = {**d[0][0], 'meta_input':d[0][1]}

In [None]:
jpy_model.predict(d_input)

In [None]:
model(d[0])

In [None]:
# tc = TopCoder()

In [None]:
# cbi_df = tc.get_filtered_challenge_info()
# meta_df = cbi_df.reindex(['number_of_platforms', 'number_of_technologies', 'project_id', 'challenge_duration'], axis=1)