In [None]:
!pip install transformers
!pip install emoji
!pip install datasets

# %env WANDB_PROJECT=twitter-roberta-base-dec2021_rbam_fine_tuned
# %env WANDB_WATCH=all

import pandas as pd
import numpy as np
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from datasets import Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score

from google.colab import drive, files
drive.mount('/content/drive')


class_weights = torch.tensor([0.6191, 0.7282, 0.6526], device='cuda:0')

class WeightedLossTrainer(Trainer):

  def compute_loss(self, model, inputs, return_outputs=False):
    # Feed input into model and extract logits
    outputs = model(**inputs)
    logits = outputs.get("logits")
    # Extract Labels
    labels = inputs.get("labels")
    # Define loss function with class weights
    loss_func = nn.CrossEntropyLoss(weight=class_weights)
    # Compute loss
    loss = loss_func(logits, labels)
    return (loss, outputs) if return_outputs else loss


class BertRBAM:

  DATASET_PATH = 'drive/MyDrive/Colab Notebooks/prj/dataset.csv'

  MODEL = "cardiffnlp/twitter-roberta-base-dec2021"
  tokenizer = AutoTokenizer.from_pretrained(MODEL, padding="max_length", truncation=True, max_length=512)

  id2label = {0: "attack", 1: "neutral", 2: "support"}
  label2id = {"attack": 0, "neutral": 1, "support": 2}

  hyperparameters = {'num_train_epochs': 2, 'learning_rate': 2e-5, 'weight_decay': 0.01, 'batch_size': 8}
  OUTPUT_DIR ="twitter-roberta-base-dec2021_rbam_fine_tuned"


  def __init__(self):
    self.df = pd.read_csv(BertRBAM.DATASET_PATH)
    self.preprocess_dataset()

    self.X_train, self.y_train, self.X_valid, self.y_valid, self.X_test, self.y_test = self.get_split_dataset()

    self.df_train, self.df_valid, self.df_test, self.tokenized_dataset_train, self.tokenized_dataset_test, self.tokenized_dataset_valid = self.tokenize_dataset()


    self.class_weights = (1 - (self.df_train["labels"].value_counts().sort_index() / len(self.df_train))).values
    self.class_weights = torch.from_numpy(self.class_weights).float().to("cuda")
    global class_weights
    class_weights = self.class_weights

    self.steps_per_epoch = len(self.df_train) // BertRBAM.hyperparameters['batch_size']

    self.trainer = None
    self.evaluation_metrics = None

  def preprocess(self, text):
    preprocessed_text = []
    original_text_words = text.split(" ")
    for word in original_text_words:
        word = 'http' if text.startswith('http') else word  # Preprocess links
        word = '@user' if text.startswith('@') and len(word) > 1 else word  # Preprocess user handles
        preprocessed_text.append(word)
    return " ".join(preprocessed_text)

  def get_split_dataset(self):
    train_size = 0.8
    test_size = 0.5
    X = self.df[['text_a', 'text_b']]
    y = self.df[['labels']]

    X_train, X_rem, y_train, y_rem = train_test_split(X, y, train_size=train_size, random_state=42)

    X_valid, X_test, y_valid, y_test = train_test_split(X_rem,y_rem, test_size=test_size, random_state=42)

    return X_train, y_train, X_valid, y_valid, X_test, y_test

  def preprocess_dataset(self):
    self.df['text_a'] = self.df['text_a'].apply(lambda x : self.preprocess(x))
    self.df['text_b'] = self.df['text_b'].apply(lambda x : self.preprocess(x))


  def tokenize(self, example):
    text_a = example['text_a']
    text_b = example['text_b']

    return BertRBAM.tokenizer(text_a, text_b, padding="max_length", truncation=True, max_length=512)


  def tokenize_dataset(self):
    df_train = pd.concat([self.X_train, self.y_train], axis=1)
    df_test = pd.concat([self.X_test, self.y_test], axis=1)
    df_valid = pd.concat([self.X_valid, self.y_valid], axis=1)

    dataset_train = Dataset.from_pandas(df_train)
    dataset_test = Dataset.from_pandas(df_test)
    dataset_valid = Dataset.from_pandas(df_valid)

    tokenized_dataset_train = dataset_train.map(self.tokenize, batched=True)
    tokenized_dataset_test = dataset_test.map(self.tokenize, batched=True)
    tokenized_dataset_valid = dataset_valid.map(self.tokenize, batched=True)

    tokenized_dataset_train = tokenized_dataset_train.remove_columns(['text_a', 'text_b', '__index_level_0__', ])
    tokenized_dataset_test = tokenized_dataset_test.remove_columns(['text_a', 'text_b', '__index_level_0__', ])
    tokenized_dataset_valid = tokenized_dataset_valid.remove_columns(['text_a', 'text_b', '__index_level_0__', ])

    tokenized_dataset_train = tokenized_dataset_train.with_format('torch')
    tokenized_dataset_test = tokenized_dataset_test.with_format('torch')
    tokenized_dataset_valid = tokenized_dataset_valid.with_format('torch')

    return df_train, df_valid, df_test, tokenized_dataset_train, tokenized_dataset_test, tokenized_dataset_valid


  def compute_metrics(self, pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)

    accuracy = accuracy_score(y_true=labels, y_pred=preds)
    recall = recall_score(y_true=labels, y_pred=preds, average="weighted")
    precision = precision_score(y_true=labels, y_pred=preds, average="weighted")
    f1 = f1_score(labels, preds, average="weighted")

    # class_names = ["attack", "neutral", "support"]

    # wandb.log({"conf_mat" : wandb.plot.confusion_matrix(probs=None,
    #                           preds=preds, y_true=labels,
    #                           class_names=class_names)})

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

  def model_init(self):
    model = AutoModelForSequenceClassification.from_pretrained(BertRBAM.MODEL, num_labels=3, id2label=BertRBAM.id2label, label2id=BertRBAM.label2id)
    return model

  def train_model(self):
    training_args = TrainingArguments(
                                  output_dir=BertRBAM.OUTPUT_DIR,
                                  num_train_epochs=BertRBAM.hyperparameters['num_train_epochs'],
                                  learning_rate=BertRBAM.hyperparameters['learning_rate'],
                                  weight_decay=BertRBAM.hyperparameters['weight_decay'],
                                  per_device_train_batch_size=BertRBAM.hyperparameters['batch_size'],
                                  per_device_eval_batch_size=BertRBAM.hyperparameters['batch_size'],
                                  save_strategy="epoch",
                                  evaluation_strategy="epoch",
                                  load_best_model_at_end=True,
                                  metric_for_best_model="eval_f1",
                                  remove_unused_columns = False,
                                  logging_steps=self.steps_per_epoch,
                                  log_level="warning",
                                  # report_to="wandb",
                                  # push_to_hub=True
                                  )

    self.trainer = WeightedLossTrainer(model_init=self.model_init,
                              args=training_args,
                              compute_metrics=self.compute_metrics,
                              train_dataset=self.tokenized_dataset_train,
                              eval_dataset=self.tokenized_dataset_valid)

    self.trainer.train()

    self.evaluation_metrics = self.trainer.evaluate(eval_dataset=self.tokenized_dataset_test)

    print(self.evaluation_metrics)


bert_rbam_model = BertRBAM()
eval_metrics = bert_rbam_model.train_model()