In [1]:
!pip install transformers==4.5.0 fugashi==1.1.0 ipadic==1.0.0 pytorch-lightning==1.2.7

Collecting transformers==4.5.0
[?25l  Downloading https://files.pythonhosted.org/packages/81/91/61d69d58a1af1bd81d9ca9d62c90a6de3ab80d77f27c5df65d9a2c1f5626/transformers-4.5.0-py3-none-any.whl (2.1MB)
[K     |████████████████████████████████| 2.2MB 28.0MB/s 
[?25hCollecting fugashi==1.1.0
[?25l  Downloading https://files.pythonhosted.org/packages/55/9c/009da34dd111e84f54eef833c84afb5c744a0306af8546014a958e1967a0/fugashi-1.1.0-cp37-cp37m-manylinux1_x86_64.whl (486kB)
[K     |████████████████████████████████| 491kB 44.9MB/s 
[?25hCollecting ipadic==1.0.0
[?25l  Downloading https://files.pythonhosted.org/packages/e7/4e/c459f94d62a0bef89f866857bc51b9105aff236b83928618315b41a26b7b/ipadic-1.0.0.tar.gz (13.4MB)
[K     |████████████████████████████████| 13.4MB 191kB/s 
[?25hCollecting pytorch-lightning==1.2.7
[?25l  Downloading https://files.pythonhosted.org/packages/e6/13/fb401b8f9d9c5e2aa08769d230bb401bf11dee0bc93e069d7337a4201ec8/pytorch_lightning-1.2.7-py3-none-any.whl (830kB)
[

In [2]:
import random
import glob
import json
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader
from transformers import BertJapaneseTokenizer, BertModel
import pytorch_lightning as pl

MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'

In [3]:

class BertForSequenceClassificationMultiLabel(torch.nn.Module):

  def __init__(self, model_name, num_labels):
    super().__init__()
    self.bert = BertModel.from_pretrained(model_name)
    self.linear = torch.nn.Linear(
        self.bert.config.hidden_size, num_labels
    )

  def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None):
    
    #最終層
    bert_output = self.bert(
        input_ids=input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids
    )
    last_hidden_state = bert_output.last_hidden_state

    #平均
    averaged_hidden_state = \
      (last_hidden_state*attention_mask.unsqueeze(-1)).sum(1) \
      / attention_mask.sum(1, keepdim=True)

    #線形変換
    scores = self.linear(averaged_hidden_state)

    #出力形式
    output = {'logits':scores}

    if labels is not None:
      loss = torch.nn.BCEWithLogitsLoss()(scores, labels.float())
      output['loss'] = loss

    output = type('bert_output', (object,), output) 
    
    return output

In [4]:
tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME)
bert_scml = BertForSequenceClassificationMultiLabel(
    MODEL_NAME, num_labels=2
) 
bert_scml = bert_scml.cuda()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=257706.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=110.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=479.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=445021143.0, style=ProgressStyle(descri…




In [5]:
text_list = [
    '今日の仕事はうまくいったが、体調があまり良くない。',
    '昨日は楽しかった。'
]

labels_list = [
    [1, 1],
    [0, 1]
]

encoding = tokenizer(
    text_list,
    padding='longest',
    return_tensors='pt'
)
encoding = { k: v.cuda() for k,v in encoding.items() }
labels = torch.tensor(labels_list).cuda()

with torch.no_grad():
  output = bert_scml(**encoding)
scores = output.logits

labels_predicted = ( scores > 0 ).int()

num_correct = ( labels_predicted == labels ).all(-1).sum().item()
accuracy = num_correct/labels.size(0)

In [6]:
encoding = tokenizer(
    text_list,
    padding='longest',
    return_tensors='pt'
)

encoding['labels'] = torch.tensor(labels_list)
encoding = { k: v.cuda() for k, v in encoding.items() }

output = bert_scml(**encoding)
loss = output.loss

# Dataset : chABSA-dataset

In [7]:
!wget https://s3-ap-northeast-1.amazonaws.com/dev.tech-sketch.jp/chakki/public/chABSA-dataset.zip
# データの解凍
!unzip chABSA-dataset.zip 

--2021-07-13 04:03:16--  https://s3-ap-northeast-1.amazonaws.com/dev.tech-sketch.jp/chakki/public/chABSA-dataset.zip
Resolving s3-ap-northeast-1.amazonaws.com (s3-ap-northeast-1.amazonaws.com)... 52.219.12.14
Connecting to s3-ap-northeast-1.amazonaws.com (s3-ap-northeast-1.amazonaws.com)|52.219.12.14|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 722777 (706K) [application/zip]
Saving to: ‘chABSA-dataset.zip’


2021-07-13 04:03:17 (918 KB/s) - ‘chABSA-dataset.zip’ saved [722777/722777]

Archive:  chABSA-dataset.zip
   creating: chABSA-dataset/
  inflating: chABSA-dataset/.DS_Store  
   creating: __MACOSX/
   creating: __MACOSX/chABSA-dataset/
  inflating: __MACOSX/chABSA-dataset/._.DS_Store  
 extracting: chABSA-dataset/.gitkeep  
  inflating: chABSA-dataset/e00008_ann.json  
  inflating: chABSA-dataset/e00017_ann.json  
  inflating: chABSA-dataset/e00024_ann.json  
  inflating: chABSA-dataset/e00026_ann.json  
  inflating: chABSA-dataset/e00030_ann.json  
  

In [8]:
data = json.load(open('chABSA-dataset/e00030_ann.json'))
data['sentences'][0]

{'opinions': [{'category': 'NULL#general',
   'from': 6,
   'polarity': 'neutral',
   'target': 'わが国経済',
   'to': 11},
  {'category': 'NULL#general',
   'from': 13,
   'polarity': 'positive',
   'target': '景気',
   'to': 15},
  {'category': 'NULL#general',
   'from': 28,
   'polarity': 'positive',
   'target': '設備投資',
   'to': 32},
  {'category': 'NULL#general',
   'from': 42,
   'polarity': 'positive',
   'target': '企業収益',
   'to': 46},
  {'category': 'NULL#general',
   'from': 62,
   'polarity': 'neutral',
   'target': '資源国等',
   'to': 66},
  {'category': 'NULL#general',
   'from': 80,
   'polarity': 'negative',
   'target': '為替',
   'to': 82}],
 'sentence': '当期におけるわが国経済は、景気は緩やかな回復基調が続き、設備投資の持ち直し等を背景に企業収益は改善しているものの、海外では、資源国等を中心に不透明な状況が続き、為替が急激に変動するなど、依然として先行きが見通せない状況で推移した',
 'sentence_id': 0}

In [13]:

category_id = {'negative':0, 'neutral':1 , 'positive':2}

dataset = []
for file in glob.glob('chABSA-dataset/*.json'):
  data = json.load(open(file))
  for sentence in data['sentences']:
    text = sentence['sentence']
    labels = [0, 0, 0]
    for opnion in sentence['opinions']:
      labels[category_id[opnion['polarity']]] == 1
    sample = {'text':text, 'labels':labels}
    dataset.append(sample)

In [15]:
dataset[0]

{'labels': [0, 0, 0], 'text': '当期のわが国経済は、輸出の回復などを背景に企業収益は増加し、緩やかな景気の回復を見せました'}

In [18]:
#トークン定義
tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME)

#データ形式を整える
max_length = 128
dataset_for_loader = []
for sample in dataset:
  text = sample['text']
  labels = sample['labels']
  encoding = tokenizer(
      text, 
      max_length=max_length,
      padding='max_length',
      truncation=True
  )
  encoding['labels'] = labels
  encoding = {k: torch.tensor(v) for k, v in encoding.items() }
  dataset_for_loader.append(encoding)

#データセットの分割
random.shuffle(dataset_for_loader)
n = len(dataset_for_loader)
n_train = int(0.6*n)
n_val = int(0.2*n)
dataset_train = dataset_for_loader[:n_train]
dataset_val = dataset_for_loader[n_train:n_train+n_val]
dataset_test = dataset_for_loader[n_train+n_val:]

#データローダ作成
dataloader_train = DataLoader(
    dataset_train, batch_size=32, shuffle=True
)
dataloader_val = DataLoader(
    dataset_val, batch_size=256
)
dataloader_test = DataLoader(
    dataset_test, batch_size=256
)

In [22]:
class BertForSequenceClassificationMultiLabel_pl(pl.LightningModule):

  def __init__(self, model_name, num_labels, lr):
    super().__init__()
    self.save_hyperparameters()
    self.bert_scml = BertForSequenceClassificationMultiLabel(
        model_name, num_labels=num_labels
    )

  def training_step(self, batch, batch_idx):
    output = self.bert_scml(**batch)
    loss = output.loss
    self.log('train_loss', loss)
    return loss

  def validation_step(self, batch, batch_idx):
    output = self.bert_scml(**batch)
    val_loss = output.loss
    self.log('val_loss', val_loss)

  def test_step(self, batch, batch_idx):
    labels = batch.pop('labels')
    output = self.bert_scml(**batch)
    scores = output.logits
    labels_predicted = (scores > 0).int()
    num_correct = ( labels_predicted == labels ).all(-1).sum().item()
    accuracy = num_correct / scores.size(0)
    self.log('accuracy', accuracy)

  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

checkpoint = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    save_weights_only=True,
    dirpath='model/'
)

#学習方法指定
trainer = pl.Trainer(
    gpus=1,
    max_epochs=5,
    callbacks=[checkpoint]
)

model = BertForSequenceClassificationMultiLabel_pl(
    MODEL_NAME, num_labels=3, lr=1e-5
)

trainer.fit(model, dataloader_train, dataloader_val)
test = trainer.test(test_dataloaders=dataloader_test)
print(f'Accuracy: {test[0]["accuracy"]:.2f}')


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                                    | Params
----------------------------------------------------------------------
0 | bert_scml | BertForSequenceClassificationMultiLabel | 110 M 
----------------------------------------------------------------------
110 M     Trainable params
0         Non-trainable params
110 M     Total params
442.479   Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'accuracy': 1.0}
--------------------------------------------------------------------------------
Accuracy: 1.00


In [23]:
text_list = [
    "今期は売り上げが順調に推移したが、株価は低迷の一途を辿っている。",
    "昨年から黒字が減少した。",
    "今日の飲み会は楽しかった。"
]

best_model_path = checkpoint.best_model_path
model = BertForSequenceClassificationMultiLabel_pl.load_from_checkpoint(best_model_path)
bert_scml = model.bert_scml.cuda()

# データの符号化
encoding = tokenizer(
    text_list, 
    padding = 'longest',
    return_tensors='pt'
)
encoding = { k: v.cuda() for k, v in encoding.items() }

with torch.no_grad():
    output = bert_scml(**encoding)
scores = output.logits
labels_predicted = ( scores > 0 ).int().cpu().numpy().tolist()

for text, label in zip(text_list, labels_predicted):
    print('--')
    print(f'入力：{text}')
    print(f'出力：{label}')

--
入力：今期は売り上げが順調に推移したが、株価は低迷の一途を辿っている。
出力：[0, 0, 0]
--
入力：昨年から黒字が減少した。
出力：[0, 0, 0]
--
入力：今日の飲み会は楽しかった。
出力：[0, 0, 0]
