In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
!pip install transformers==4.5.0 fugashi==1.1.0 ipadic==1.0.0 pytorch-lightning==1.2.7

Collecting pytorch-lightning==1.2.7
  Using cached pytorch_lightning-1.2.7-py3-none-any.whl (830 kB)
Installing collected packages: pytorch-lightning
  Attempting uninstall: pytorch-lightning
    Found existing installation: pytorch-lightning 1.2.10
    Uninstalling pytorch-lightning-1.2.10:
      Successfully uninstalled pytorch-lightning-1.2.10
Successfully installed pytorch-lightning-1.2.7


In [3]:
import random 
import glob
import json
from tqdm import tqdm
import unicodedata

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from transformers import BertTokenizer, BertModel

from lists_file import dumpJoblib, loadJoblib
from multi_classification import BertForSequenceClassificationMultiLabel

MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'

In [4]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

In [5]:
X_train = list(loadJoblib('unique_textdata.joblib').values())
y_train = loadJoblib('label_file.joblib')

In [6]:
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
bert_scml = BertForSequenceClassificationMultiLabel(
    MODEL_NAME, num_labels=6
) 
bert_scml = bert_scml.to(device=device)

In [7]:
list = []
label = []
for i , j in zip(X_train, y_train):
    list.append(i)
    label.append(j)

In [8]:
labels = []
for emotion in label:
    if emotion == 63:
        labels.append([1, 0, 0, 0, 0, 0])
    elif emotion == 64:
        labels.append([0, 1, 0, 0, 0, 0])
    elif emotion == 65:
        labels.append([0, 0, 1, 0, 0, 0])
    elif emotion == 66:
        labels.append([0, 0, 0, 1, 0, 0])
    elif emotion == 67:
        labels.append([0, 0, 0, 0, 1, 0])
    else:
        labels.append([0, 0, 0, 0, 0, 1])

In [9]:
max_length = 152
dataset_for_loader = []

for i in range(len(list)):
    text = list[i]
    encoding = tokenizer(
      text, 
      max_length=max_length,
      padding='max_length',
      truncation=True
    )
    encoding['labels'] = labels[i]
    encoding = {k: torch.tensor(v).to(device=device) for k, v in encoding.items() }
    dataset_for_loader.append(encoding)
    

random.shuffle(dataset_for_loader)
n = len(dataset_for_loader)
n_train = int(0.6*n)
n_val = int(0.2*n)
dataset_train = dataset_for_loader[:n_train]
dataset_val = dataset_for_loader[n_train:n_train+n_val]
dataset_test = dataset_for_loader[n_train+n_val:]

In [10]:
dataloader_train = DataLoader(
    dataset_train, batch_size=16, shuffle=True
)
dataloader_val = DataLoader(
    dataset_val, batch_size=16
)
dataloader_test = DataLoader(
    dataset_test, batch_size=16
)

dataloader_val

<torch.utils.data.dataloader.DataLoader at 0x7eff1035e190>

In [11]:
from pytorch_lightning_multi import BertForSequenceClassificationMultiLabel_pl

checkpoint = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    save_weights_only=True,
    dirpath='model/'
)

#学習方法指定
trainer = pl.Trainer(
    gpus=1,
    max_epochs=100,
    callbacks=[checkpoint]
)

model = BertForSequenceClassificationMultiLabel_pl(
    MODEL_NAME, num_labels=6, lr=1e-5
)

trainer.fit(model, dataloader_train, dataloader_val)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                                    | Params
----------------------------------------------------------------------
0 | bert_scml | BertForSequenceClassificationMultiLabel | 110 M 
----------------------------------------------------------------------
110 M     Trainable params
0         Non-trainable params
110 M     Total params
442.488   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

1

In [12]:
best_model_path = checkpoint.best_model_path
print(f'validation data loss:{checkpoint.best_model_score}')

validation data loss:0.15882304310798645


In [13]:
%load_ext tensorboard
%tensorboard --logdir ./

In [14]:
test = trainer.test(test_dataloaders=dataloader_test, ckpt_path=best_model_path)
print(f'Accuracy: {test[0]["accuracy"]:.4f}')

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'accuracy': 0.7836218476295471}
--------------------------------------------------------------------------------
Accuracy: 0.78


In [15]:
model = BertForSequenceClassificationMultiLabel_pl.load_from_checkpoint(best_model_path)

# #transformers対応
# model.save_pretrained('./model_transformers')

In [16]:
# save_pretrained

# model = BertForSequenceClassificationMultiLabel_pl.load_from_checkpoint('model/epoch=1-step=2815.ckpt')

# model.bert_scml.save_pretrained('./model_transformers')

# model.bert_scml.__dict__

In [17]:
best_model_path

'/home/takakiyuto/Desktop/RESEARCH-COVID19-datasets/model/epoch=1-step=2815-v1.ckpt'

In [18]:
len(X_train)

37543