In [1]:
!pip install transformers==4.5.0 fugashi==1.1.0 ipadic==1.0.0 pytorch-lightning==1.2.7



In [6]:
import random
import glob
import json
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from transformers import BertJapaneseTokenizer, BertModel

# 日本語の事前学習モデル
MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'

In [3]:
!mkdir chap7
%cd ./chap7

mkdir: cannot create directory ‘chap7’: File exists
/content/chap7


In [7]:
class BertForSequenceClassificationMultiLabel(torch.nn.Module):
    def __init__(self, model_name, num_labels):
        super().__init__()
        #BertModelのロード
        self.bert = BertModel.from_pretrained(model_name)
        #　線形変換を初期化しておく
        self.linear = torch.nn.Linear(
            self.bert.config.hidden_size, num_labels
        )
    
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        labels=None
    ):
        # データを入力しBERTの最終層の出力を得る
        bert_output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        last_hidden_state = bert_output.last_hidden_state

        # [PAD]以外のトークンで隠れ状態の平均を取る
        averaged_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(1) \
             / attention_mask.sum(1, keepdim=True)

        # 線形変換
        scores = self.linear(averaged_hidden_state)

        # 出力の形式を整える
        output = {'logits':scores}

        # labelsが入力に含まれていたら、損失を計算し出力する
        if labels is not None:
            loss = torch.nn.BCEWithLogitsLoss()(scores, labels.float())
            output['loss'] = loss
        
        # 属性でアクセスできるようにする
        output = type('bert_output', (object,), output)

        return output

In [8]:
tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME)
bert_scml = BertForSequenceClassificationMultiLabel(
    MODEL_NAME,
    num_labels=2
)
bert_scml = bert_scml.cuda()

Downloading:   0%|          | 0.00/258k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/110 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/479 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/445M [00:00<?, ?B/s]

In [10]:
text_list = [
    '今日は仕事はうまくいったが、体調はあまり良くない。',
    '昨日は楽しかった。'
]

labels_list = [
    [1, 1],
    [0, 1]
]

# データの符号化
encoding = tokenizer(
    text_list,
    padding='longest',
    return_tensors='pt'
)
encoding = {k:v.cuda() for k, v in encoding.items()}
labels = torch.tensor(labels_list).cuda()

# BERTへのデータを入力し分類スコアを得る
with torch.no_grad():
    output = bert_scml(**encoding)
scores = output.logits

# スコアが正ならば、そのカテゴリーを選択する
labels_predicted = (scores > 0).int()

# 精度の計算
num_correct = (labels_predicted == labels).all(-1).sum().item()
accuracy = num_correct / labels.size(0)

In [11]:
# データの符号化
encoding = tokenizer(
    text_list,
    padding='longest',
    return_tensors='pt'
)
encoding['labels'] = torch.tensor(labels_list) # 入力にlabelsを含める
encoding = {k:v.cuda() for k, v in encoding.items()}

output = bert_scml(**encoding)
loss = output.loss # 損失

### データセット：chABSA-dataset

In [12]:
# データのダウンロード
!wget https://s3-ap-northeast-1.amazonaws.com/dev.tech-sketch.jp/chakki/public/chABSA-dataset.zip

# データの解凍
!unzip chABSA-dataset.zip

--2022-04-09 02:05:15--  https://s3-ap-northeast-1.amazonaws.com/dev.tech-sketch.jp/chakki/public/chABSA-dataset.zip
Resolving s3-ap-northeast-1.amazonaws.com (s3-ap-northeast-1.amazonaws.com)... 52.219.172.48
Connecting to s3-ap-northeast-1.amazonaws.com (s3-ap-northeast-1.amazonaws.com)|52.219.172.48|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 722777 (706K) [application/zip]
Saving to: ‘chABSA-dataset.zip’


2022-04-09 02:05:16 (1.03 MB/s) - ‘chABSA-dataset.zip’ saved [722777/722777]

Archive:  chABSA-dataset.zip
   creating: chABSA-dataset/
  inflating: chABSA-dataset/.DS_Store  
   creating: __MACOSX/
   creating: __MACOSX/chABSA-dataset/
  inflating: __MACOSX/chABSA-dataset/._.DS_Store  
 extracting: chABSA-dataset/.gitkeep  
  inflating: chABSA-dataset/e00008_ann.json  
  inflating: chABSA-dataset/e00017_ann.json  
  inflating: chABSA-dataset/e00024_ann.json  
  inflating: chABSA-dataset/e00026_ann.json  
  inflating: chABSA-dataset/e00030_ann.json  

In [13]:
data = json.load(open('chABSA-dataset/e00030_ann.json'))
print(data['sentences'][0])

{'sentence_id': 0, 'sentence': '当期におけるわが国経済は、景気は緩やかな回復基調が続き、設備投資の持ち直し等を背景に企業収益は改善しているものの、海外では、資源国等を中心に不透明な状況が続き、為替が急激に変動するなど、依然として先行きが見通せない状況で推移した', 'opinions': [{'target': 'わが国経済', 'category': 'NULL#general', 'polarity': 'neutral', 'from': 6, 'to': 11}, {'target': '景気', 'category': 'NULL#general', 'polarity': 'positive', 'from': 13, 'to': 15}, {'target': '設備投資', 'category': 'NULL#general', 'polarity': 'positive', 'from': 28, 'to': 32}, {'target': '企業収益', 'category': 'NULL#general', 'polarity': 'positive', 'from': 42, 'to': 46}, {'target': '資源国等', 'category': 'NULL#general', 'polarity': 'neutral', 'from': 62, 'to': 66}, {'target': '為替', 'category': 'NULL#general', 'polarity': 'negative', 'from': 80, 'to': 82}]}


In [14]:
category_id = {'negative':0, 'neutral':1, 'positive':2}

dataset = []
for file in glob.glob('chABSA-dataset/*.json'):
    data = json.load(open(file))
    # 各データから文章（text）を抜き出し、ラベル（labels）を作成
    for sentence in data['sentences']:
        text = sentence['sentence']
        labels = [0, 0, 0]
        for opinion in sentence['opinions']:
            labels[category_id[opinion['polarity']]] = 1
        sample = {'text':text, 'labels':labels}
        dataset.append(sample)

In [15]:
print(dataset[0])

{'text': '当連結会計年度におけるわが国の経済は、政府や日本銀行による各種施策効果もあり、引き続き穏やかな回復基調で推移してまいりました', 'labels': [0, 0, 1]}


### ファインチューニングと性能評価

In [17]:
# トークナイザのロード
tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME)

# 各データの形式を整える
max_length = 128
dataset_for_loader = []
for sample in dataset:
    text = sample['text']
    labels = sample['labels']
    encoding = tokenizer(
        text,
        max_length=max_length,
        padding='max_length',
        truncation=True
    )
    encoding['labels'] = labels
    encoding = {k:torch.tensor(v) for k, v in encoding.items()}
    dataset_for_loader.append(encoding)

# データセットの分割
random.shuffle(dataset_for_loader)
n = len(dataset_for_loader)
n_train = int(0.6*n)
n_val = int(0.2*n)
dataset_train = dataset_for_loader[:n_train] # 学習データ
dataset_val = dataset_for_loader[n_train:n_train+n_val] # 検証データ
dataset_test = dataset_for_loader[n_train+n_val:] # テストデータ

# データセットからデータローダを作成
dataloader_train = DataLoader(
    dataset_train,
    batch_size=32,
    shuffle=True
)
dataloader_val =DataLoader(dataset_val, batch_size=256)
dataloader_test =DataLoader(dataset_test, batch_size=256)

In [18]:
from pytorch_lightning import callbacks
class BertForSequenceClassificationMultiLabel_pl(pl.LightningModule):
    def __init__(self, model_name, num_labels, lr):
        super().__init__()
        self.save_hyperparameters()
        self.bert_scml = BertForSequenceClassificationMultiLabel(
            model_name,
            num_labels=num_labels
        )
    
    def training_step(self, batch, batch_idx):
        output = self.bert_scml(**batch)
        loss = output.loss
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        output = self.bert_scml(**batch)
        val_loss = output.loss
        self.log('val_loss', val_loss)
    
    def test_step(self, batch, batch_idx):
        labels = batch.pop('labels')
        output = self.bert_scml(**batch)
        scores = output.logits
        labels_predicted = (scores > 0).int()
        num_correct = (labels_predicted == labels).all(-1).sum().item()
        accuracy = num_correct / scores.size(0)
        self.log('accuracy', accuracy)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

checkpoint = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    save_weights_only=True,
    dirpath='model/'
)

trainer = pl.Trainer(
    gpus=1,
    max_epochs=5,
    callbacks=[checkpoint]
)

model = BertForSequenceClassificationMultiLabel_pl(
    MODEL_NAME,
    num_labels=3,
    lr=1e-5
)
trainer.fit(model, dataloader_train, dataloader_val)
test = trainer.test(test_dataloaders=dataloader_test)
print(f'Accuracy: {test[0]["accuracy"]:.2f}')

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                                    | Params
----------------------------------------------------------------------
0 | bert_scml | BertForSequenceClassificationMultiLabel | 110 M 
----------------------------------------------------------------------
110 M     Trainable params
0         Non-trainable params
110 M     Total params
442.479   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'accuracy': 0.9061224460601807}
--------------------------------------------------------------------------------
Accuracy: 0.91


In [20]:
# 入力する文章
text_list = [
    '今期は売り上げが順調に推移したが、株価は低迷の一途を辿っている。',
    '昨年から黒字が減少した。',
    '今日の飲み会は楽しかった。',
]

# モデルのロード
best_model_path = checkpoint.best_model_path
model = BertForSequenceClassificationMultiLabel_pl.load_from_checkpoint(best_model_path)
bert_scml = model.bert_scml.cuda()

# データの符号化
encoding = tokenizer(
    text_list,
    padding='longest',
    return_tensors='pt'
)
encoding = {k:v.cuda() for k, v in encoding.items()}

# BERTへデータを入力し分類スコアを得る
with torch.no_grad():
    output = bert_scml(**encoding)
scores = output.logits
labels_predicted = (scores > 0).int().cpu().numpy().tolist()

# 結果を表示
for text, label in zip(text_list, labels_predicted):
    print('--')
    print(f'入力：{text}')
    print(f'出力：{label}')

--
入力：今期は売り上げが順調に推移したが、株価は低迷の一途を辿っている。
出力：[1, 0, 0]
--
入力：昨年から黒字が減少した。
出力：[1, 0, 0]
--
入力：今日の飲み会は楽しかった。
出力：[0, 0, 0]
