In [None]:
!pip install transformers datasets
!pip install pip install pytorch-lightning

In [82]:
import re

from transformers import BertForSequenceClassification, BertConfig, BertTokenizer, AdamW,get_linear_schedule_with_warmup
from tqdm.notebook import tqdm
import pytorch_lightning as pl
import torch 
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, confusion_matrix, f1_score
import datasets


In [83]:
def transform_text(text):
  text = text.lower()
  text = re.sub('[^A-Za-z0-9 ]+', '', text)
  text = re.sub('\s+(a|is|be|will|the|was|were|have|has|are|been|s|ll)\s+', '', text)
  return text

def create_documents_list(l):
  temp_vocab = [i.split(' ') for i in l]
  documents_list = [[j for j in i if len(j)>0] for i in temp_vocab]
  return documents_list

In [84]:
dataset = datasets.load_dataset("tweet_eval", "irony")

df_train = dataset["train"].to_pandas()
df_val = dataset["validation"].to_pandas()
df_test = dataset["test"].to_pandas()

df_train["clean_text"] = df_train.text.apply(lambda x: transform_text(x))
df_test["clean_text"] = df_test.text.apply(lambda x: transform_text(x))
df_val["clean_text"] = df_val.text.apply(lambda x: transform_text(x))

Reusing dataset tweet_eval (/root/.cache/huggingface/datasets/tweet_eval/irony/1.1.0/12aee5282b8784f3e95459466db4cdf45c6bf49719c25cdb0743d71ed0410343)


In [85]:
config = BertConfig.from_pretrained('bert-base-uncased')
config

BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.8.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [86]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
config = BertConfig.from_pretrained('bert-base-uncased')
config.num_labels = 2

model = BertForSequenceClassification(config)
model.to('cuda')

class tweet_dataset(Dataset):
  def __init__(self, df, tokenizer, train=True, max_len=256):
    self.tokenizer = tokenizer
    self.train_flag = train
    self.ids = df.index.to_list()
    self.text = df.clean_text.values
    self.max_len = max_len

    if train is False:
        self.label = None
    else:
        self.label = df.label.values

  def __getitem__(self, i):
    id = self.ids[i]
    text = self.text[i]
    tokenized_text = self.tokenizer(text, return_tensors="pt", padding='max_length', max_length=self.max_len)

    if self.train_flag is True:
        return {"tokenized_text": tokenized_text, 'label': self.label[i], "text": text, "id": id}
    else:
        return {"tokenized_text": tokenized_text, "text": text, "id": id}

  def __len__(self):
      return len(self.ids)



In [98]:
train_dataset = tweet_dataset(df_train, tokenizer)
val_dataset = tweet_dataset(df_val, tokenizer, train=False)
test_dataset = tweet_dataset(df_test, tokenizer, train=False)

def inference(model, data):
  data_loader_test = DataLoader(data, shuffle=False, batch_size=32, num_workers=0)
  temp_result_list = []
  # inference the model
  for i, x_batch in tqdm(enumerate(data_loader_test)):
    x_batch  = x_batch['tokenized_text']['input_ids'][:,0,:]
    x_batch = x_batch.to('cuda')
    with torch.no_grad():
      pred = model(x_batch)
      batch_pred = torch.nn.functional.softmax(pred[0]).detach().cpu().numpy()
    temp_result_list.append(batch_pred)

  result_array = np.concatenate(temp_result_list)
  return result_array

In [100]:
result_array = inference(model, test_dataset)
f1_score(df_test.label, result_array[:,0]>0.5)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

  





0.6699255121042831