In [None]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl

from transformers import BertJapaneseTokenizer
from transformers import BertModel

import pandas as pd
import tarfile
from glob import glob
import linecache
from tqdm import tqdm
import urllib.request

In [None]:
download_path = "livedoor_news_corpus.tar.gz"
extract_path = "livedoor/"

if not os.path.isfile("livedoor_news_corpus.tar.gz"):
    urllib.request.urlretrieve("https://www.rondhuit.com/download/ldcc-20140209.tar.gz",download_path)

with tarfile.open(download_path, "r:gz") as t:
    t.extractall(extract_path)


In [None]:
## https://qiita.com/m__k/items/841950a57a0d7ff05506

categories = [name for name in os.listdir(extract_path + "text") if os.path.isdir(extract_path + "text/" + name)]

datasets = pd.DataFrame(columns=["title", "category"])
for cat in categories:
    path = extract_path + "text/" + cat + "/*.txt"
    files = glob(path)
    for text_name in files:
        title = linecache.getline(text_name, 3)
        s = pd.Series([title, cat], index=datasets.columns)
        datasets = datasets.append(s, ignore_index=True)

datasets = datasets.sample(frac=1).reset_index(drop=True)
datasets.head()

In [None]:
## https://qiita.com/m__k/items/e312ddcf9a3d0ea64d72

categories = list(set(datasets['category']))
id2cat = dict(zip(list(range(len(categories))), categories))
cat2id = dict(zip(categories, list(range(len(categories)))))

datasets['category_id'] = datasets['category'].map(cat2id)
datasets = datasets[['title', 'category_id']]

datasets.head()

In [None]:
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')


In [None]:
tokenizer.encode("私は元気です。",padding=True,)

In [None]:
class LivedoorDatasets(torch.utils.data.Dataset):
    def __init__(self, transform = None):
        self.transform = transform

        self.data = list(datasets["title"])
        self.label = list(datasets["category_id"])
        
        if not len(self.data) == len(self.label):
            raise ValueError("Invalid dataset")
        self.datanum = len(self.data)

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label = self.label[idx]

        if self.transform:
            out_data = self.transform(["input_ids"][0])

        return out_data, out_label   

In [None]:
livedoor_datasets = LivedoorDatasets()

In [None]:
print(len(list(datasets["category_id"])))
print(len(list(datasets["title"])))

In [None]:
livedoor_datasets[0]

In [None]:
class TokenizerCollate:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def collate_fn(self, batch):
        input = [item[0] for item in batch]
        input = self.tokenizer(
            input,
            padding=True,
            max_length=512,
            truncation=True,
            return_tensors="pt")
        targets = torch.tensor([item[1] for item in batch])
        return input["input_ids"], targets
    
    def __call__(self, batch):
        return self.collate_fn(batch)

train_loader = DataLoader(livedoor_datasets, batch_size=16, collate_fn=TokenizerCollate(tokenizer=tokenizer))

In [None]:
class Model(pl.LightningModule):
    def __init__(self):
        super(Model, self).__init__()
        self.bert = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
        self.output = nn.Linear(768, 9)
        
        self.train_acc = pl.metrics.Accuracy()
        self.val_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()

    def forward(self, x):
        y = self.bert(x).last_hidden_state
        ## cls token相当部分のhidden_stateのみ抜粋
        y = y[:,0,:]
        y = y.view(-1, 768)
        y = self.output(y)
        return y

    def training_step(self, batch, batch_nb):
        x, t = batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        self.log("loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_nb): 
        x, t = batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        preds = torch.argmax(y, dim=1)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', self.val_acc(y,t), prog_bar=True)
        return loss

    def test_step(self, batch, batch_nb):
        return self.validation_step(batch, batch_nb)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

In [None]:
model = Model()

In [None]:
for param in model.bert.parameters():
    param.requires_grad = False

In [None]:
trainer = pl.Trainer(gpus=1,max_epochs=1,)

In [None]:
trainer.fit(model, train_loader) 

In [None]:
model.summarize

In [None]:
for i in train_loader:
    print(i[0].shape)
    break