In [1]:
from statistics import mean, stdev
from transformers import BertTokenizer
from tqdm.notebook import tqdm
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt
# from datasets.reader import read_conll
from datasets.conll import ConllBertDataset

# Load data

In [2]:
manga_path = '../../data/NER/processed/comments/augmented_10/'

In [3]:
tokenizer = BertTokenizer.from_pretrained('DeepPavlov/rubert-base-cased-conversational')
train_ds = ConllBertDataset.from_file(manga_path + 'train.txt', tokenizer)
test_ds = ConllBertDataset.from_file(manga_path + 'test.txt', tokenizer)

  0%|          | 0/2014 [00:00<?, ?it/s]

  0%|          | 0/503 [00:00<?, ?it/s]

In [4]:
train_ds[-1:]

([[101, 2063, 801, 28293, 845, 1702, 10029, 292, 292, 102]],
 [['O', 'O', 'O', 'PER', 'PER', 'PER', 'O', 'O', 'O', 'O']])

In [5]:
type(train_ds[0])

tuple

# Create model

In [7]:
import torch
import torch.nn as nn

from models import BertLstm
from sklearn.metrics import roc_auc_score

In [8]:
BertLstm

models.naive.BertLstm

In [9]:
acc = lambda preds, labels: ((preds.argmax(dim=1) == labels).sum() / labels.size(0)).item()
sigmoid = nn.Sigmoid()

def auc(preds, labels):
    assert all(map(lambda preds: preds.size(1) == 2, preds))
    to_probs = lambda preds: sigmoid(preds)[:, 0].cpu().detach().numpy()
    to_labels = lambda labels: (1 - labels).cpu().detach().numpy()
    
    preds = list(map(to_probs, preds))
    labels = list(map(to_labels, labels))
    
    preds = [item for subl in preds for item in subl]
    labels = [item for subl in labels for item in subl]
    
    auc = roc_auc_score(labels, preds)
    return auc

In [70]:
import numpy as np
import sklearn
from sklearn.metrics import precision_recall_curve
T=0.5
def f_score(preds, labels):
    preds2 = []
    assert all(map(lambda preds: preds.size(1) == 2, preds))
    to_probs = lambda preds: sigmoid(preds)[:, 0].cpu().detach().numpy()
    to_labels = lambda labels: (1 - labels).cpu().detach().numpy()
    
    preds = list(map(to_probs, preds))
    labels = list(map(to_labels, labels))
    
    preds = [item for subl in preds for item in subl]
    labels = [item for subl in labels for item in subl]
    precision, recall, thresholds = precision_recall_curve(labels, preds)
    fscore = (2 * precision * recall) / (precision + recall)
    ix = np.argmax(fscore)
    print(thresholds[ix])
    return fscore[ix]

# Train

In [11]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
labels_n = max(train_ds.labels_n, test_ds.labels_n)

simple_model = BertLstm(labels_n).to(device)
opt = torch.optim.Adam(simple_model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

writer = SummaryWriter()#'runs/16_epoches_30_augs_uncased')

  return torch._C._cuda_getDeviceCount() > 0


In [72]:
loss_interval = 50

for e in range(10):
    print('Epoch', e+1) 
    losses = []
    probs_epoch = []
    labels_epoch = []
    opt.zero_grad()
    for i in tqdm(range(1, len(train_ds))):
        if i % loss_interval == 0:
            opt.step()
            opt.zero_grad()
        words, labels = train_ds[i]
        words = torch.LongTensor(words).unsqueeze(0).to(device)
        labels = torch.LongTensor(list(map(lambda label: train_ds.ne2ix[label], labels))).to(device)

        preds = simple_model(words)
        loss = criterion(preds, labels)
        loss.backward()
        step = e*len(train_ds)+i
        writer.add_scalar('Loss/Train', loss.item(), step)
        writer.add_scalar('Accuracy/Train', acc(preds, labels), step)
        probs_epoch.append(preds)
        labels_epoch.append(labels)
    print(f_score(probs_epoch, labels_epoch))
    writer.add_scalar('AUC_epoch/Train', auc(probs_epoch, labels_epoch), e) 
            
    losses = []
    probs_epoch = []
    labels_epoch = []
    with torch.no_grad():
        for i in tqdm(range(1, len(test_ds))):
            words, labels = train_ds[i]
            words = torch.LongTensor(words).unsqueeze(0).to(device)
            labels = torch.LongTensor(list(map(lambda label: test_ds.ne2ix[label], labels))).to(device)
            preds = simple_model(words)
            loss = criterion(preds, labels)
            step = e*len(test_ds)+i
            writer.add_ueuscalar('Loss/Test', loss.item(), step)
            writer.add_scalar('Accuracy/Test', acc(preds, labels), step)
            probs_epoch.append(preds)
            labels_epoch.append(labels)
    print(f_score(probs_epoch, labels_epoch))
    writer.add_scalar('AUC_epoch/Test', auc(probs_epoch, labels_epoch), e) 
    

Epoch 1


  0%|          | 0/2013 [00:00<?, ?it/s]

0.5798725
0.8515026517383618


  0%|          | 0/502 [00:00<?, ?it/s]

0.5720629
0.8476070528967256
Epoch 2


  0%|          | 0/2013 [00:00<?, ?it/s]

0.5030781
0.8795059600746804


  0%|          | 0/502 [00:00<?, ?it/s]

0.8053785
0.8819571865443425
Epoch 3


  0%|          | 0/2013 [00:00<?, ?it/s]

0.5569232
0.8709770948100899


  0%|          | 0/502 [00:00<?, ?it/s]

0.62582636
0.862476664592408
Epoch 4


  0%|          | 0/2013 [00:00<?, ?it/s]

0.562747
0.8584836486288313


  0%|          | 0/502 [00:00<?, ?it/s]

0.6049042
0.8737745098039216
Epoch 5


  0%|          | 0/2013 [00:00<?, ?it/s]

0.6234066
0.8675785797438882


  0%|          | 0/502 [00:00<?, ?it/s]

0.53196836
0.8391608391608391
Epoch 6


  0%|          | 0/2013 [00:00<?, ?it/s]

0.5467241
0.8643523920653441


  0%|          | 0/502 [00:00<?, ?it/s]

0.55217355
0.8674101610904584
Epoch 7


  0%|          | 0/2013 [00:00<?, ?it/s]

0.52305293
0.8689334495786107


  0%|          | 0/502 [00:00<?, ?it/s]

0.8854728
0.8642590286425903
Epoch 8


  0%|          | 0/2013 [00:00<?, ?it/s]

0.5323221
0.8758849877185378


  0%|          | 0/502 [00:00<?, ?it/s]

0.6398675
0.8865853658536587
Epoch 9


  0%|          | 0/2013 [00:00<?, ?it/s]

0.62996143
0.8772031204854087


  0%|          | 0/502 [00:00<?, ?it/s]

0.63887805
0.8679479231246124
Epoch 10


  0%|          | 0/2013 [00:00<?, ?it/s]

0.5154614
0.8799538838449344


  0%|          | 0/502 [00:00<?, ?it/s]

0.89143866
0.8740740740740741


In [64]:
simple_model.save('./weights/model.pt')