In [1]:
import os
import sys
import pandas as pd
import numpy as np
import time

from tqdm import tqdm

In [2]:
import torch
from torch.utils.data import DataLoader

from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.metrics import classification_report
from sklearn.pipeline import Pipeline
import xgboost as xgb

sys.path.append("..")

from utils import DATA_DIR  # noqa

In [3]:
# BertのモデルとTokenizer(前処理用)をimport
from transformers import BertTokenizer, BertModel

In [4]:
start = time.time()

In [5]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [6]:
tweet_df = pd.read_csv(os.path.join(DATA_DIR, "cleaned_airline_tweets.csv"))


train, test = train_test_split(tweet_df, test_size=0.2,
                               random_state=0, stratify=tweet_df["sentiment"])
train, test = train.reset_index(drop=True), test.reset_index(drop=True)

In [7]:
bert_model = BertModel.from_pretrained("bert-base-uncased").to(device)
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [8]:
train_dataloader = DataLoader(
    train["text"],
    batch_size=512,
    shuffle=False,
    collate_fn=lambda batch: bert_tokenizer(
        text=batch, padding="longest", truncation=True,
        return_tensors="pt"
    )
)

test_dataloader = DataLoader(
    test["text"],
    batch_size=512,
    shuffle=False,
    collate_fn=lambda batch: bert_tokenizer(
        text=batch, padding="longest", truncation=True,
        return_tensors="pt"
    )
)

In [9]:
train_emb_list = []

with torch.no_grad():
    for batch in tqdm(train_dataloader):
        outputs = bert_model(input_ids=batch["input_ids"].to(device),
                             attention_mask=batch["attention_mask"].to(device),
                             token_type_ids=batch["token_type_ids"].to(device))
        embeddings = outputs.pooler_output
        train_emb_list.append(embeddings.cpu().numpy())

train_emb = np.vstack(train_emb_list)


test_emb_list = []
with torch.no_grad():
    for batch in tqdm(test_dataloader):
        outputs = bert_model(input_ids=batch["input_ids"].to(device),
                             attention_mask=batch["attention_mask"].to(device),
                             token_type_ids=batch["token_type_ids"].to(device))
        embeddings = outputs.pooler_output
        test_emb_list.append(embeddings.cpu().numpy())

test_emb = np.vstack(test_emb_list)

  return forward_call(*args, **kwargs)
100%|██████████| 7/7 [00:13<00:00,  1.91s/it]
100%|██████████| 2/2 [00:03<00:00,  1.54s/it]


In [11]:
xgb_params = {
    'n_estimators': [500, 1000],
    'max_depth': [3, 6],
    'learning_rate': [0.01, 0.1, 0.2],
}
xgb_grid = GridSearchCV(
    xgb.XGBClassifier(random_state=42, verbosity=0), xgb_params, cv=3, scoring="f1_macro", n_jobs=-1, verbose=1
)
train["sentiment"] = train["sentiment"].replace({"negative": 0, "neutral": 1, "positive": 2})
test["sentiment"] = test["sentiment"].replace({"negative": 0, "neutral": 1, "positive": 2})

xgb_grid.fit(train_emb, train["sentiment"])
print(classification_report(test["sentiment"], xgb_grid.predict(test_emb)))

Fitting 3 folds for each of 12 candidates, totalling 36 fits
              precision    recall  f1-score   support

           0       0.80      0.81      0.81       243
           1       0.78      0.76      0.77       260
           2       0.83      0.84      0.83       269

    accuracy                           0.80       772
   macro avg       0.80      0.80      0.80       772
weighted avg       0.80      0.80      0.80       772



In [12]:
xgb_grid.best_params_

{'learning_rate': 0.1, 'max_depth': 3, 'n_estimators': 500}

In [13]:
xgb_model = xgb.XGBClassifier(learning_rate=0.1, max_depth=3, n_estimators=500, random_state=42)
xgb_model.fit(train_emb, train["sentiment"])
print(classification_report(test["sentiment"], xgb_model.predict(test_emb)))

              precision    recall  f1-score   support

           0       0.80      0.81      0.81       243
           1       0.78      0.76      0.77       260
           2       0.83      0.84      0.83       269

    accuracy                           0.80       772
   macro avg       0.80      0.80      0.80       772
weighted avg       0.80      0.80      0.80       772



In [None]:
print(time.time() - start)

23.644593000411987
