In [1]:
import gc
import json
import pickle
from importlib import reload

import torch
import numpy as np
import pandas as pd 
from tqdm.auto import tqdm
from sklearn.metrics import balanced_accuracy_score, f1_score, confusion_matrix

from nn_module import BiLSTM_CRF

In [2]:
def data_label_split(data, label, train_size=0.8):
    randidx = np.arange(len(data))
    data_train, data_test = train_test_split(data, randidx, train_size)
    label_train, label_test = train_test_split(label, randidx, train_size)

    return data_train, data_test, label_train, label_test

def train_test_split(data, randidx, train_size):
    N = len(data)
    return [data[i] for i in randidx[:int(train_size*N)]], [data[i] for i in randidx[int(train_size*N):]]

def shuffle_data_label_lists(data, label):
    randidx = np.arange(len(data))
    np.random.shuffle(randidx)
    return [data[i] for i in randidx], [label[i] for i in randidx]

def embedding(data, token2emb, label2num):
    tokens_lst = []
    labels_lst = []

    for doc in tqdm(data):
        tokens = map(lambda x: x.replace('\t', '').replace('\n', '').replace(' ', ''), doc['tokens'])
        emb_tokens = torch.zeros(len(doc['tokens']), 300)
        for i, token in enumerate(tokens):
            if token in token2emb:
                emb_tokens[i] = torch.FloatTensor(token2emb[token])
            else:
                for it in token:
                    emb_tokens[i] += torch.FloatTensor(token2emb[it] if it in token2emb else [0]*300)

        tokens_lst.append(torch.FloatTensor(emb_tokens))
        labels_lst.append(torch.LongTensor([label2num[it] if it in label2num else 0 for it in doc['labels']]))
        
    return tokens_lst, labels_lst

def batch_split(X, Y, batch_size=1000):
    x_batched = []
    y_batched = []

    n = len(X)
    pointer = 0
    while pointer + batch_size < n:
        x_batched.append(X[pointer:pointer+batch_size])
        y_batched.append(Y[pointer:pointer+batch_size])
        pointer += batch_size 
    
    x_batched.append(X[pointer:])
    y_batched.append(Y[pointer:])

    return x_batched, y_batched

def fix_label_disbalance(tokens, labels):
    label_idxs = {}
    for i in range(len(labels)):
        unique_labels = np.unique(labels[i])
        for lab in unique_labels:
            if lab in label_idxs:
                label_idxs[lab].append(i)
            else:
                label_idxs[lab] = [i]

    idxs = []
    count_o = len(label_idxs[0])
    del label_idxs[0]

    for it in label_idxs:
        scale = count_o // len(label_idxs[it])
        idxs += label_idxs[it]*scale

    np.random.shuffle(idxs)
    tokens = [tokens[i] for i in idxs]
    labels = [labels[i] for i in idxs]

    return tokens, labels

In [3]:
# create table for name and address
label2num = {
    'O': 0,
    'B-NAME_STUDENT': 1, 
    'I-NAME_STUDENT': 2, 
    'B-STREET_ADDRESS': 3, 
    'I-STREET_ADDRESS': 4, 
    'B-USERNAME': 5,
    'I-USERNAME': 6, 
    'B-ID_NUM': 7, 
    'I-ID_NUM': 8, 
    'B-URL_PERSONAL': 9,
    'I-URL_PERSONAL': 10,
    'B-EMAIL': 11,
    'I-EMAIL': 12,
    'B-PHONE_NUM': 13,
    'I-PHONE_NUM': 14,
}

num2label = {
    0: 'O',
    1: 'B-NAME_STUDENT',  
    2: 'I-NAME_STUDENT',  
    3: 'B-STREET_ADDRESS',  
    4: 'I-STREET_ADDRESS',  
    5: 'B-USERNAME', 
    6: 'I-USERNAME',  
    7: 'B-ID_NUM',  
    8: 'I-ID_NUM', 
    9: 'B-URL_PERSONAL',
    10: 'I-URL_PERSONAL',
    11: 'B-EMAIL',
    12: 'I-EMAIL',
    13: 'B-PHONE_NUM',
    14: 'I-PHONE_NUM',
}

In [4]:
# load fastext
# token2emb = {}
# with open('wiki-news-300d-1M.vec', 'r', encoding='utf-8') as f:
#     next(f)
#     for it in tqdm(f):
#         row = it.split(' ')
#         token2emb[row[0]] = list(map(float, row[1:]))

# # encoding tokens and labels
# with open('data/mixtral-8x7b-v1.json', 'r', encoding='utf-8') as f:
#     data_1 = json.load(f)
# with open('data/train.json', 'r', encoding='utf-8') as f:
#     data_2 = json.load(f)
# data = data_1 + data_2

# data_tokens = []
# data_labels = []
# for doc in tqdm(data):
#     row_tokens = []
#     row_labels = []

#     for i, (token, label) in enumerate(zip(doc['tokens'], doc['labels'])):
#         emb_tokens = torch.zeros(300)
        
#         if token in token2emb:
#             emb_tokens = torch.FloatTensor(token2emb[token])
#         else:
#             for it in token:
#                 emb_tokens += torch.FloatTensor(token2emb[it] if it in token2emb else [0]*300)

#         row_tokens.append(emb_tokens.unsqueeze(0))
#         row_labels.append(label2num[label])

#     data_tokens.append(torch.cat(row_tokens))
#     data_labels.append(torch.LongTensor(row_labels))

# data = []
# token2emb = None
# gc.collect()

# N = len(data_tokens)
# np.random.seed(123)
# data_idx = np.arange(N)
# np.random.shuffle(data_idx)
# data_tokens = [data_tokens[i] for i in data_idx]
# data_labels = [data_labels[i] for i in data_idx]

# with open('data/train_pool.pkl', 'wb') as f:
#     pickle.dump([
#         data_tokens,
#         data_labels,
#     ], f)

In [4]:
with open('data/train_pool.pkl', 'rb') as f:
    data_tokens, data_labels = pickle.load(f)

N = len(data_tokens)
train_size = 0.85
n = int(N*train_size)
train_tokens, valid_tokens = data_tokens[:n], data_tokens[n:]
train_labels, valid_labels = data_labels[:n], data_labels[n:]

data_tokens = None
data_labels = None
gc.collect()

train_tokens, train_labels = fix_label_disbalance(train_tokens, train_labels)
valid_tokens, valid_labels = fix_label_disbalance(valid_tokens, valid_labels)

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'

import nn_module    
reload(nn_module)
from nn_module import BiLSTM_CRF

# fit lstm
model = BiLSTM_CRF(
    embedding_dim=300,
    hidden_size=32,
    nclasses=15,
    device=device,
).to(device)

# checkpoint = torch.load('saved_models/bi_lstm.pt')
# model.load_state_dict(checkpoint['model_state_dict'])
# model.optim.load_state_dict(checkpoint['optimizer_state_dict'])

model.fit(
    train_X=train_tokens,
    train_Y=train_labels,
    valid_X=valid_tokens,
    valid_Y=valid_labels,
    nepochs=10,
    lr=1e-3,
    device=device
)

  0%|          | 0/10 [00:00<?, ?it/s]

0it [00:00, ?it/s]

In [8]:
with torch.no_grad():
    TP, TN, FP, FN = 0., 0., 0., 0.
    train_predict = []
    for batch_X, batch_Y in tqdm(zip(train_tokens, train_labels)):
        predict = torch.argmax(model.forward(batch_X.to(device)), dim=1).cpu()
        TP += ((predict == batch_Y) & (predict != 0)).sum()
        TN += ((predict == batch_Y) & (predict == 0)).sum()
        FP += ((predict != batch_Y) & (predict != 0)).sum()
        FN += ((predict != batch_Y) & (predict == 0)).sum()
    p_metric_train = TP / (TP + FP)
    r_metric_train = TP / (TP + FN)

    TP, TN, FP, FN = 0, 0, 0, 0
    test_predict = []
    for batch_X, batch_Y in tqdm(zip(valid_tokens, valid_labels)):
        predict = torch.argmax(model.forward(batch_X.to(device)), dim=1).cpu()
        TP += ((predict == batch_Y) & (predict != 0)).sum()
        TN += ((predict == batch_Y) & (predict == 0)).sum()
        FP += ((predict != batch_Y) & (predict != 0)).sum()
        FN += ((predict != batch_Y) & (predict == 0)).sum()
    p_metric_valid = TP / (TP + FP)
    r_metric_valid = TP / (TP + FN)
    
    print(f'Train precision {p_metric_train:.3f}, Train recall: {r_metric_train:.3f}, Valid precision: {p_metric_valid:.3f}, Valid recall: {r_metric_valid:.3f}')

0it [00:00, ?it/s]

0it [00:00, ?it/s]

Train precision 1.000, Train recall: 1.000, Valid precision: 0.997, Valid recall: 0.995


In [7]:
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': model.optim.state_dict(),
    }, 'saved_models/bi_lstm.pt')

In [9]:
token2emb = []
train_tokens = []
train_labels = []
test_tokens = []
test_labels = []

import gc
gc.collect()

5715

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = BiLSTM_CRF(
    embedding_dim=300,
    hidden_size=32,
    nclasses=15,
    device=device,
)

checkpoint = torch.load('saved_models/bi_lstm.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.optim.load_state_dict(checkpoint['optimizer_state_dict'])
model.eval()
model.to(device)
print()




In [5]:
token2emb = {}
with open('wiki-news-300d-1M.vec', 'r', encoding='utf-8') as f:
    next(f)
    for it in tqdm(f):
        row = it.split(' ')
        token2emb[row[0]] = list(map(float, row[1:]))

0it [00:00, ?it/s]

In [6]:
def check_name(x):
    flag = False
    flag = x[0].isupper()
    for it in x[1:]:
        flag *= it.islower()
    return flag

def get_predict_table(data, token2emb, num2label):
    predict_table = [[], [], [], [], []] # document, token_idx, token, predict, real
    for i in tqdm(range(len(data))):
        with torch.no_grad():
            batch_X = []

            for token in data[i]['tokens']:
                emb_tokens = torch.zeros(300)
                if token in token2emb:
                    emb_tokens = torch.FloatTensor(token2emb[token])
                else:
                    for it in token:
                        emb_tokens += torch.FloatTensor(token2emb[it] if it in token2emb else [0]*300)
                batch_X.append(emb_tokens.unsqueeze(0))

            batch_X = torch.cat(batch_X)

            predict =  torch.argmax(model.forward(batch_X.to(device)), dim=1).cpu()
            predict = [num2label[it.item()] for it in predict]
            predict = ['O' if len(set(x)) < 2 else y for x, y in zip(data[i]['tokens'], predict)]
            predict_table[0] += [data[i]['document']]*len(predict)
            predict_table[1] += list(range(len(predict)))
            predict_table[2] += data[i]['tokens']
            predict_table[3] += predict
            predict_table[4] += data[i]['labels'] 
        
    predict_table = [[predict_table[0][i], predict_table[1][i], predict_table[2][i], predict_table[3][i], predict_table[4][i]] for i in range(len(predict_table[0]))]
    predict_table = pd.DataFrame(predict_table, columns=['document', 'token_i', 'token', 'predict', 'label'])
    predict_table = predict_table.loc[(predict_table.label != 'O') & (predict_table.token != "\n") & (predict_table.token != "\n\n") & (predict_table.token != "\t")].reset_index(drop=True)

    return predict_table.sort_values('document').reset_index(drop=True)

In [7]:
import gc
with open('data/train.json', 'r', encoding='utf-8') as f:
    data = json.load(f)

predict_table = get_predict_table(data, token2emb, num2label)
gc.collect()

  0%|          | 0/6807 [00:00<?, ?it/s]

0

In [8]:
label = predict_table['label'].str.slice(2).values 
predict = predict_table['predict'].str.slice(2).values 
print(f'F1 score: {f1_score(label, predict, average="micro")}')

F1 score: 0.9459261965655827


In [12]:
label = predict_table['label'].values 
predict = predict_table['predict'].values 
print(f'F1 score: {f1_score(label, predict, average="micro")}')

F1 score: 0.9444647424187066


In [11]:
for i in range(1, predict_table.shape[0]):
    if (predict_table.at[i-1, 'predict'] == predict_table.at[i, 'predict']) & (predict_table.at[i-1, 'document'] == predict_table.at[i, 'document']) & (predict_table.at[i, 'token_i'] - predict_table.at[i-1, 'token_i'] == 1):
        predict_table.at[i, 'prefix'] = 'I-'
predict_table['predict'] = predict_table['prefix'] + predict_table['predict']

print(f'F1 score: {f1_score(predict_table.label, predict_table.predict, average="micro")}')

F1 score: 0.7822433321154548


In [12]:
predict_table

Unnamed: 0,document,token_i,token,predict,label,prefix
0,7,9,Nathalie,B-NAME_STUDENT,B-NAME_STUDENT,B-
1,7,10,Sylla,I-NAME_STUDENT,I-NAME_STUDENT,I-
2,7,482,Nathalie,B-NAME_STUDENT,B-NAME_STUDENT,B-
3,7,483,Sylla,I-NAME_STUDENT,I-NAME_STUDENT,I-
4,7,741,Nathalie,B-NAME_STUDENT,B-NAME_STUDENT,B-
...,...,...,...,...,...,...
2732,15717,365,IV-8322,B-ID_NUM,B-ID_NUM,B-
2733,15717,964,IV-8322,B-ID_NUM,B-ID_NUM,B-
2734,19280,55,30407059,B-ID_NUM,I-ID_NUM,B-
2735,19280,54,Z.S.,B-ID_NUM,B-ID_NUM,B-


In [13]:
predict_table.loc[predict_table.predict != predict_table.label].groupby('predict', as_index=False).label.count()

Unnamed: 0,predict,label
0,B-ID_NUM,1
1,B-NAME_STUDENT,366
2,B-O,211
3,B-STREET_ADDRESS,10
4,B-USERNAME,1
5,I-O,7


In [12]:
predict_table

Unnamed: 0,document,token_i,token,predict,label,prefix,upper_start
0,7,9,Nathalie,B-NAME_STUDENT,B-NAME_STUDENT,B-,1
1,7,10,Sylla,I-NAME_STUDENT,I-NAME_STUDENT,I-,1
2,7,482,Nathalie,B-NAME_STUDENT,B-NAME_STUDENT,B-,1
3,7,483,Sylla,I-NAME_STUDENT,I-NAME_STUDENT,I-,1
4,7,741,Nathalie,B-NAME_STUDENT,B-NAME_STUDENT,B-,1
...,...,...,...,...,...,...,...
2639,22147,457,Francesca,B-NAME_STUDENT,O,B-,1
2640,22147,1408,Melanie,B-NAME_STUDENT,O,B-,1
2641,22181,738,Portman,B-NAME_STUDENT,O,B-,1
2642,22181,736,Natalie,B-NAME_STUDENT,O,B-,1
