#### Modified from: https://towardsdatascience.com/discover-the-sentiment-of-reddit-subgroup-using-roberta-model-10ab9a8271b8

## Train on SST data

In [1]:
# remove existing models from last training
!rm -rf models*

In [2]:
import pandas as pd
# Recommended tensorflow version is <= 2.1.0, otherwise F1 score function breaks
import tensorflow as tf
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
import tensorflow_datasets as tfds
from transformers import TFRobertaForSequenceClassification
from transformers import RobertaTokenizer, RobertaConfig, AutoTokenizer
import os


# Load training and validation data
train_tweets = pd.read_csv('sst_train.csv')
val_tweets = pd.read_csv('sst_test.csv')

training_sentences, testing_sentences = train_tweets[['text', 'target']], val_tweets[['text', 'target']]

# Use BERTweet model and tokenizer
model = TFRobertaForSequenceClassification.from_pretrained("vinai/bertweet-base", num_labels=5) # SST data has 5 classes
roberta_tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-base", use_fast=False)

# set max length for an input text
max_length = 128

batch_size = 64

def convert_example_to_feature(review):
    # combine step for tokenization, WordPiece vector mapping and will
    # add also special tokens and truncate reviews longer than our max length
    return roberta_tokenizer.encode_plus(review,
                                 add_special_tokens=True,  # add [CLS], [SEP]
                                 max_length=max_length,  # max length of the text that can go to RoBERTa
                                 pad_to_max_length=True,  # add [PAD] tokens at the end of sentence
                                 return_attention_mask=True,  # add attention mask to not focus on pad tokens
                                 )

# map to the expected input to TFRobertaForSequenceClassification, see here
def map_example_to_dict(input_ids, attention_masks, label):
    return {
      "input_ids": input_ids,
      "attention_mask": attention_masks,
           }, label

def encode_examples(ds, limit=-1):
    # Prepare Input list
    input_ids_list = []
    attention_mask_list = []
    label_list = []

    if (limit > 0):
        ds = ds.take(limit)

    for review, label in tfds.as_numpy(ds):
        bert_input = convert_example_to_feature(review.decode())
        input_ids_list.append(bert_input['input_ids'])
        attention_mask_list.append(bert_input['attention_mask'])
        label_list.append([label])

    return tf.data.Dataset.from_tensor_slices((input_ids_list,
                                               attention_mask_list,
                                               label_list)).map(map_example_to_dict)

training_sentences_modified = tf.data.Dataset.from_tensor_slices((training_sentences['text'],
                                                                  training_sentences['target']))

testing_sentences_modified = tf.data.Dataset.from_tensor_slices((testing_sentences['text'],
                                                                 testing_sentences['target']))

ds_train_encoded = encode_examples(training_sentences_modified).shuffle(10000).batch(batch_size)
ds_test_encoded = encode_examples(testing_sentences_modified).batch(batch_size)



learning_rate = 7e-5
number_of_epochs = 6

class ModelMetrics(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.count_n = 1

    def on_epoch_end(self, batch, logs={}):
        # save model
        os.mkdir('models' + str(self.count_n))
        self.model.save_pretrained('models' + str(self.count_n))

        self.count_n += 1

metrics = ModelMetrics()

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=1e-08)

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
model.fit(ds_train_encoded, epochs=number_of_epochs,
          validation_data=ds_test_encoded, callbacks=[metrics])

Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Some layers from the model checkpoint at vinai/bertweet-base were not used when initializing TFRobertaForSequenceClassification: ['lm_head']
- This IS expected if you are initializing TFRobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFRobertaForSequenceClassification from the checkpoint of a model tha

Epoch 1/8
Epoch 2/8
Epoch 3/8
Epoch 4/8
Epoch 5/8
Epoch 6/8
Epoch 7/8
  21/4978 [..............................] - ETA: 1:09:47 - loss: 0.2344 - accuracy: 0.9077

KeyboardInterrupt: 

## Test on Covid/Election data

In [3]:
import preprocessor as p

def predict(testcsv_name, model, output_name):
    test_tweets = pd.read_csv(testcsv_name).dropna()
    for i,v in enumerate(test_tweets['text']):
        test_tweets.loc[i,'text'] = p.clean(v) # preprocessing the tweets
    test_tweets['target'] = 0

    submission_sentences_modified = tf.data.Dataset.from_tensor_slices((test_tweets['text'],
                                                                      test_tweets['target']))
    ds_submission_encoded = encode_examples(submission_sentences_modified).batch(batch_size)

    submission_pre = tf.nn.softmax(model.predict(ds_submission_encoded))
    submission_pre_argmax = tf.math.argmax(submission_pre[0], axis=1)
    test_tweets['target'] = submission_pre_argmax
    test_tweets.to_csv(output_name, index=False)
    return test_tweets

In [4]:
from os.path import isfile, join
from os import listdir

data_path = 'data'
csvfiles = [join(data_path, f) for f in listdir(data_path) if isfile(join(data_path, f))]

for f in csvfiles:
    predict(f, model, 'predict-'+f)