In [11]:
import os
import io
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl

from transformers import BertJapaneseTokenizer
from transformers import BertModel

import pandas as pd
import tarfile
from glob import glob
import linecache
from tqdm import tqdm
import urllib.request

In [20]:
download_path = "ldcc-20140209.tar.gz"
extract_path = "livedoor/"

def track_progress(members):
   for member in members:
      # this will be the current file being extracted
      yield member

if not os.path.isfile(download_path):
    urllib.request.urlretrieve("https://www.rondhuit.com/download/ldcc-20140209.tar.gz",download_path)


## https://stackoverflow.com/questions/3667865/python-tarfile-progress-output
def get_file_progress_file_object_class(on_progress):
    class FileProgressFileObject(tarfile.ExFileObject):
        def read(self, size, *args):
            on_progress(self.name, self.position, self.size)
            return tarfile.ExFileObject.read(self, size, *args)
    return FileProgressFileObject

class TestFileProgressFileObject(tarfile.ExFileObject):
    def read(self, size, *args):
        on_progress(self.name, self.position, self.size)
        return tarfile.ExFileObject.read(self, size, *args)

class ProgressFileObject(io.FileIO):
    def __init__(self, path, *args, **kwargs):
        self._total_size = os.path.getsize(path)
        io.FileIO.__init__(self, path, *args, **kwargs)

    def read(self, size):
        print("Overall process: %d of %d" %(self.tell(), self._total_size))
        return io.FileIO.read(self, size)

def on_progress(filename, position, total_size):
    print("%s: %d of %s" %(filename, position, total_size))

tarfile.TarFile.fileobject = get_file_progress_file_object_class(on_progress)
tar = tarfile.open(fileobj=ProgressFileObject(download_path))
tar.extractall(extract_path)
tar.close()


Overall process: 0 of 8855190
Overall process: 2 of 8855190
Overall process: 10 of 8855190
Overall process: 8202 of 8855190
Overall process: 9822 of 8855190
Overall process: 12075 of 8855190
Overall process: 14755 of 8855190
Overall process: 17736 of 8855190
Overall process: 20047 of 8855190
Overall process: 22602 of 8855190
Overall process: 24866 of 8855190
Overall process: 27150 of 8855190
Overall process: 29386 of 8855190
Overall process: 31578 of 8855190
Overall process: 33814 of 8855190
Overall process: 36322 of 8855190
Overall process: 38863 of 8855190
Overall process: 41277 of 8855190
Overall process: 43376 of 8855190
Overall process: 45992 of 8855190
Overall process: 48098 of 8855190
Overall process: 50424 of 8855190
Overall process: 52671 of 8855190
Overall process: 54976 of 8855190
Overall process: 57346 of 8855190
Overall process: 59584 of 8855190
Overall process: 62011 of 8855190
Overall process: 63941 of 8855190
Overall process: 66429 of 8855190
Overall process: 68647 of 8

In [3]:
## https://qiita.com/m__k/items/841950a57a0d7ff05506

categories = [name for name in os.listdir(extract_path + "text") if os.path.isdir(extract_path + "text/" + name)]

datasets = pd.DataFrame(columns=["title", "category"])
for cat in categories:
    path = extract_path + "text/" + cat + "/*.txt"
    files = glob(path)
    for text_name in files:
        title = linecache.getline(text_name, 3)
        s = pd.Series([title, cat], index=datasets.columns)
        datasets = datasets.append(s, ignore_index=True)

datasets = datasets.sample(frac=1).reset_index(drop=True)
datasets.head()

Unnamed: 0,title,category
0,NTTドコモ、MEDIAS PP N-01Dにて特定のブラウザを利用したときに不具合でソフト...,smax
1,大切な投稿を見逃さない　Facebookに親友の投稿だけを表示する【知っ得！虎の巻】\n,it-life-hack
2,大画面スマホ時代はデジタル・ドクショが一気に快適に！大幅リニューアルで使える便利すぎる機能満...,smax
3,おはこんこん、ふぉっくす紺子です！アキバで流れる動画が決まりました【紺子にゅうす】\n,it-life-hack
4,雑誌をPDF化してiPadで読む裏技！スキャナー活用のノウハウを伝授【新スタイル活用術】\n,it-life-hack


In [4]:
## https://qiita.com/m__k/items/e312ddcf9a3d0ea64d72

categories = list(set(datasets['category']))
id2cat = dict(zip(list(range(len(categories))), categories))
cat2id = dict(zip(categories, list(range(len(categories)))))

datasets['category_id'] = datasets['category'].map(cat2id)
datasets = datasets[['title', 'category_id']]

datasets.head()

Unnamed: 0,title,category_id
0,NTTドコモ、MEDIAS PP N-01Dにて特定のブラウザを利用したときに不具合でソフト...,0
1,大切な投稿を見逃さない　Facebookに親友の投稿だけを表示する【知っ得！虎の巻】\n,2
2,大画面スマホ時代はデジタル・ドクショが一気に快適に！大幅リニューアルで使える便利すぎる機能満...,0
3,おはこんこん、ふぉっくす紺子です！アキバで流れる動画が決まりました【紺子にゅうす】\n,2
4,雑誌をPDF化してiPadで読む裏技！スキャナー活用のノウハウを伝授【新スタイル活用術】\n,2


In [5]:
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')


In [6]:
tokenizer.encode("私は元気です。",padding=True,)

[2, 1325, 9, 12453, 2992, 8, 3]

In [7]:
class LivedoorDatasets(torch.utils.data.Dataset):
    def __init__(self, transform = None):
        self.transform = transform

        self.data = list(datasets["title"])
        self.label = list(datasets["category_id"])
        
        if not len(self.data) == len(self.label):
            raise ValueError("Invalid dataset")
        self.datanum = len(self.data)

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label = self.label[idx]

        if self.transform:
            out_data = self.transform(["input_ids"][0])

        return out_data, out_label   

In [8]:
livedoor_datasets = LivedoorDatasets()

In [9]:
print(len(list(datasets["category_id"])))
print(len(list(datasets["title"])))

7376
7376


In [10]:
livedoor_datasets[0]

('NTTドコモ、MEDIAS PP N-01Dにて特定のブラウザを利用したときに不具合でソフトウェア更新を提供開始\n', 0)

In [11]:
class TokenizerCollate:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def collate_fn(self, batch):
        input = [item[0] for item in batch]
        input = self.tokenizer(
            input,
            padding=True,
            max_length=512,
            truncation=True,
            return_tensors="pt")
        targets = torch.tensor([item[1] for item in batch])
        return input, targets
    
    def __call__(self, batch):
        return self.collate_fn(batch)

train_loader = DataLoader(livedoor_datasets, batch_size=16, collate_fn=TokenizerCollate(tokenizer=tokenizer))

In [15]:
class Model(pl.LightningModule):
    def __init__(self):
        super(Model, self).__init__()
        self.bert = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
        self.output = nn.Linear(768, 9)
        
        self.train_acc = pl.metrics.Accuracy()
        self.val_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()

    def forward(self, x):
        y = self.bert(**x).last_hidden_state
        ## cls token相当部分のhidden_stateのみ抜粋
        y = y[:,0,:]
        y = y.view(-1, 768)
        y = self.output(y)
        return y

    def training_step(self, batch, batch_nb):
        x, t = batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        self.log("loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_nb): 
        x, t = batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        preds = torch.argmax(y, dim=1)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', self.val_acc(y,t), prog_bar=True)
        return loss

    def test_step(self, batch, batch_nb):
        return self.validation_step(batch, batch_nb)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

In [16]:
model = Model()

In [17]:
for param in model.bert.parameters():
    param.requires_grad = False

In [18]:
trainer = pl.Trainer(gpus=1,max_epochs=1,)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores


In [19]:
trainer.fit(model, train_loader) 


  | Name      | Type      | Params
----------------------------------------
0 | bert      | BertModel | 110 M 
1 | output    | Linear    | 6.9 K 
2 | train_acc | Accuracy  | 0     
3 | val_acc   | Accuracy  | 0     
4 | test_acc  | Accuracy  | 0     
----------------------------------------
6.9 K     Trainable params
110 M     Non-trainable params
110 M     Total params
442.497   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

1

In [None]:
model.summarize

In [None]:
for i in train_loader:
    print(i[0].shape)
    break