In [124]:
import torch
from transformers import BertJapaneseTokenizer
from transformers import BertModel
import os
from torch import nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
from tqdm import tqdm
import torchtext

import urllib.request
import pandas as pd
import tarfile
from glob import glob
import linecache

In [125]:
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese')

Downloading:   0%|          | 0.00/258k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/104 [00:00<?, ?B/s]

In [49]:
download_path = "livedoor_news_corpus.tar.gz"
extract_path = "livedoor/"

if not os.path.isfile("livedoor_news_corpus.tar.gz"):
    urllib.request.urlretrieve("https://www.rondhuit.com/download/ldcc-20140209.tar.gz",download_path)

with tarfile.open(download_path, "r:gz") as t:
    t.extractall(extract_path)


In [50]:
## https://qiita.com/m__k/items/841950a57a0d7ff05506

categories = [name for name in os.listdir(extract_path + "text") if os.path.isdir(extract_path + "text/" + name)]

datasets = pd.DataFrame(columns=["title", "category"])
for cat in categories:
    path = extract_path + "text/" + cat + "/*.txt"
    files = glob(path)
    for text_name in files:
        title = linecache.getline(text_name, 3)
        s = pd.Series([title, cat], index=datasets.columns)
        datasets = datasets.append(s, ignore_index=True)

datasets = datasets.sample(frac=1).reset_index(drop=True)
datasets.head()

Unnamed: 0,title,category
0,プレイ可能なGoogleロゴ第三弾！　Googleロゴがスポーツ関連画像に変化第十弾\n,it-life-hack
1,「めんどくせぇ」共演者も呆れる松岡修造、錦織圭の話題には「カチンときちゃう」\n,sports-watch
2,インテリアにこだわる人へオススメ！　ソニーの美しすぎるうつぶせリモコン\n,kaden-channel
3,【終了しました】高原リゾートホテルの“ぐっすり眠れる”特別イベントにご招待\n,peachy
4,撮るだけじゃない！ モニター利用もできるビデオムービー登場\n,peachy


In [51]:
## https://qiita.com/m__k/items/e312ddcf9a3d0ea64d72

categories = list(set(datasets['category']))
id2cat = dict(zip(list(range(len(categories))), categories))
cat2id = dict(zip(categories, list(range(len(categories)))))

datasets['category_id'] = datasets['category'].map(cat2id)
datasets = datasets[['title', 'category_id']]

datasets.head()

Unnamed: 0,title,category_id
0,プレイ可能なGoogleロゴ第三弾！　Googleロゴがスポーツ関連画像に変化第十弾\n,7
1,「めんどくせぇ」共演者も呆れる松岡修造、錦織圭の話題には「カチンときちゃう」\n,2
2,インテリアにこだわる人へオススメ！　ソニーの美しすぎるうつぶせリモコン\n,1
3,【終了しました】高原リゾートホテルの“ぐっすり眠れる”特別イベントにご招待\n,6
4,撮るだけじゃない！ モニター利用もできるビデオムービー登場\n,6


In [146]:
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')


[2, 1325, 9, 12453, 2992, 8, 3]

In [147]:
tokenizer.encode("私は元気です。",padding=True,)

[2, 1325, 9, 12453, 2992, 8, 3]

In [142]:
class LivedoorDatasets(torch.utils.data.Dataset):
    def __init__(self, transform = None):
        self.transform = transform

        self.data = list(datasets["title"])
        self.label = list(datasets["category_id"])
        
        if not len(self.data) == len(self.label):
            raise ValueError("Invalid dataset")
        self.datanum = len(self.data)

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label = self.label[idx]

        if self.transform:
            out_data = self.transform(out_data,padding=True,max_length=512,truncation=True)["input_ids"]

        return out_data, out_label   

In [143]:
livedoor_datasets = LivedoorDatasets(transform=tokenizer)

In [144]:
print(len(list(datasets["category_id"])))
print(len(list(datasets["title"])))

7376
7376


In [145]:
livedoor_datasets[0]

([2,
  1881,
  519,
  18,
  10994,
  5826,
  97,
  240,
  1406,
  679,
  10994,
  5826,
  14,
  1784,
  1634,
  4845,
  7,
  1709,
  97,
  714,
  1406,
  3],
 7)

In [113]:
train_loader = DataLoader(livedoor_datasets, batch_size=64)

In [114]:
class Model(pl.LightningModule):
    def __init__(self):
        super(Model, self).__init__()
        self.bert = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
        self.output = nn.Linear(768, 9)
        
        self.train_acc = pl.metrics.Accuracy()
        self.val_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()

    def forward(self, x):
        y, _, __ = self.bert(x)
        y = y[:,0,:]
        y = y.view(-1, 768)
        y = self.output(y)
        y = F.softmax(y,dim=1)
        return y

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        self.log("loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_nb): 
        x, t = batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        preds = torch.argmax(y, dim=1)

        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', self.val_acc(y,t), prog_bar=True)
        return loss

    def test_step(self, batch, batch_nb):
        return self.validation_step(batch, batch_nb)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

In [115]:
model = Model()

In [116]:
for param in model.bert.parameters():
    param.requires_grad = False

In [117]:
trainer = pl.Trainer(gpus=1,max_epochs=1,)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores


In [118]:
trainer.fit(model, train_loader) 


  | Name      | Type      | Params
----------------------------------------
0 | bert      | BertModel | 110 M 
1 | output    | Linear    | 6.9 K 
2 | train_acc | Accuracy  | 0     
3 | val_acc   | Accuracy  | 0     
4 | test_acc  | Accuracy  | 0     
----------------------------------------
6.9 K     Trainable params
110 M     Non-trainable params
110 M     Total params
442.497   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

RuntimeError: each element in list of batch should be of equal size

In [119]:
for i in train_loader:
    print(i)
    break

RuntimeError: each element in list of batch should be of equal size