In [2]:
from transformers import BertTokenizer, BertForSequenceClassification
import torch
import pandas as pd
import sys
sys.path.append('../')
from utils import tokens2ids
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
import numpy as np
import json
import re
from tqdm import tqdm
tqdm.pandas()

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', 
                                                      num_labels=4,
                                                      output_attentions = False,
                                                      output_hidden_states = False)


  from pandas import Panel


In [16]:
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds.detach().numpy(), axis=1)
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat)


def trainer(model, optimizer, loader, test_loader, ds_size, device, max_iter=10):
    for epoch in range(10):
        n_correct = 0
        total_loss = 0
        for i, (inputs, labels, mask) in enumerate(tqdm(loader)):
            inputs = inputs.to(device)
            labels = labels.to(device)
            mask = mask.to(device)
            loss, logits = model(inputs, labels=labels, attention_mask=mask)
            model.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.data
            n_correct += flat_accuracy(logits, labels)
            print('n_correct', n_correct)

        print('epoch: ', epoch)
        print('[train]\tloss: %f accuracy: %f' % (loss, n_correct/ds_size))

        test_loss, test_acc = evaluate(model, test_loader)
        print('[test]\tloss: %f accuracy: %f' % (test_loss, test_acc))

    print('Finished Training')


class Mydatasets(torch.utils.data.Dataset):
    def __init__(self, data, labels):
        self.data = pad_sequence(data, batch_first=True)
        max_len = len(self.data[0])
        self.mask = torch.tensor([[1]*len(x)+[0]*(max_len-len(x)) for x in data])
        self.labels = torch.tensor(labels).long()
        self.datanum = len(data)

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label = self.labels[idx]
        mask = self.mask[idx]
        return out_data, out_label, mask


def normalize(doc):
    doc = re.sub(r"[',.]", '', doc)   # 記号を削除
    doc = re.sub(r" {2,}", ' ', doc)  # 2回以上続くスペースを削除
    doc = re.sub(r" *?$", '', doc)    # 行頭と行末のスペースを削除
    doc = re.sub(r"^ *?", '', doc)
    doc = doc.lower()                 # 小文字に統一
    return doc


def preprocessor(doc):
    doc = normalize(doc)
    tokens = tokenizer.tokenize(doc)
    return tokens

In [17]:
with open('../token2id_dic.json', 'r') as f:
    token2id_dic = json.loads(f.read())

dw = 300
dh = 50
L = 4
batch_size = 1024
columns = ('category', 'title')
vocab_size = len(token2id_dic)

train = pd.read_csv('../../data/NewsAggregatorDataset/train.txt',
                    names=columns, sep='\t')
test = pd.read_csv('../../data/NewsAggregatorDataset/test.txt',
                   names=columns, sep='\t')



train['tokens'] = train.title.apply(preprocessor)
test['tokens'] = test.title.apply(preprocessor)

X_train = train.tokens.apply(tokens2ids, token2id_dic=token2id_dic)
X_test = test.tokens.apply(tokens2ids, token2id_dic=token2id_dic)

label2int = {'b': 0, 't': 1, 'e': 2, 'm': 3}
Y_train = train.category.map(label2int)
Y_test = test.category.map(label2int)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 確率的勾配降下法

trainset = Mydatasets(X_train, Y_train)
testset = Mydatasets(X_test, Y_test)
loader = DataLoader(trainset, batch_size=batch_size)
test_loader = DataLoader(testset, batch_size=testset.__len__())

model = model.to(device)
ds_size = trainset.__len__()

print('train')
trainer(model, optimizer, loader, test_loader, ds_size, device, 10)


  0%|          | 0/11 [00:00<?, ?it/s]

train


  0%|          | 0/11 [00:01<?, ?it/s]


IndexError: index out of range in self

In [11]:
class Mydatasets(torch.utils.data.Dataset):
    def __init__(self, data, labels):
        self.data = pad_sequence(data, batch_first=True)
        max_len = len(self.data[0])

        self.mask = torch.tensor([([1]*len(x)+[0]*(max_len-len(x))) for x in data])
        self.labels = torch.tensor(labels).long()
        self.datanum = len(data)

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label = self.labels[idx]
        mask = self.mask[idx]
        return out_data, out_label, mask

                                  
trainset = Mydatasets(X_train, Y_train)
loader = DataLoader(trainset, batch_size=1)
for i, (inputs, labels, mask) in enumerate(tqdm(loader)):
            inputs = inputs.to(device)
            labels = labels.to(device)
            print(inputs.size())
            print(mask.size())
            print(labels.size())
            (outputs, logits) = model(inputs, labels=labels, attention_mask=mask)
            model.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.data
            n_correct += flat_accuracy(logits, labels)
            print('n_correct', n_correct)

  0%|          | 0/10672 [00:00<?, ?it/s]

torch.Size([1, 1017])
torch.Size([1, 1017])
torch.Size([1])





IndexError: index out of range in self

In [12]:
inputs

tensor([[  8, 420,   9,  ...,   0,   0,   0]])

In [48]:
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds.detach().numpy(), axis=1)
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat)

inputs = X_train[:19]
inputs = pad_sequence(inputs)
# labels = torch.tensor([1]).unsqueeze(0)  # Batch size 1
labels = torch.tensor(Y_train[:19])

outputs = model(inputs, labels=labels)

loss, logits = outputs[:2]
print(inputs)
print(labels)

flat_accuracy(logits, labels)

tensor([[   8,   59,   52, 1927,  642,    8,   65, 4517, 1780,  468,  511,  257,
          395,   41, 8321,  470,    0,  151, 5809],
        [ 420,    0,  650,  856,    9,  337, 3708, 1296,    0, 2366,  322,  688,
            0,  292,    0, 2145,    0, 2419,    0],
        [   9,  207,    9,   80,  642,    9,  160,    0, 1753,    9,    0,    0,
            0,  193,    9,  815,    0, 4524,  335],
        [1870,   12, 1872,  202,  181,    0,  712, 5798,  286,   20,    0,    0,
         4521,  260,    8, 1468, 5803,    0,    9],
        [   0, 5795,  621,  221,    0,    0,  466,  285, 2746,    0,    1, 1754,
         1261,   89,  140,  121,  857,   14, 2571],
        [   0, 2559, 1361,   26,   21, 3707, 1099, 2144,   16, 1174,  643,    0,
            0,    7,    9,    0,   40,    0, 1469],
        [1604,    0,   98,    0,  366,   13,  622, 1363,  539,  367,   16,  220,
         1037,   11,   11, 5802, 5804,  140, 1928],
        [2143,    0,  220,    0,  153,    0,   16,    0,  439,  307, 

0

In [53]:
# np.argmax(logits.detach().numpy(), axis=1)
train[['tokens', 'title']][:10]


Unnamed: 0,tokens,title
0,"[update, 7, -, frances, il, ##iad, challenges,...",UPDATE 7-France's Iliad challenges Sprint for ...
1,"[she, ##s, hit, a, bum, note, :, mile, ##y, cy...",She's hit a bum note: Miley Cyrus continues to...
2,"[global, markets, -, valuation, fears, drag, d...",GLOBAL MARKETS-Valuation fears drag down world...
3,"[transformers, turns, into, box, office, be, #...",'Transformers' turns into box office behemoth ...
4,"[bet, -, bet, awards, marred, by, party, death]",Bet - Bet Awards Marred By Party Death
5,"[update, 3, -, barclay, ##s, slapped, with, $,...",UPDATE 3-Barclays slapped with $44 mln fine ov...
6,"[chris, pine, gets, six, month, driving, ban, ...",Chris Pine gets six month driving ban after pl...
7,"[maureen, dow, ##d, eats, some, pot, candy, su...","Maureen Dowd Eats Some Pot Candy, Succumbs To ..."
8,"[euros, ##tar, passengers, face, delays, after...",Eurostar Passengers Face Delays After Power Su...
9,"[review, round, -, up, :, 22, jump, street]",Review Round-Up: 22 Jump Street
