In [1]:
import torch 
from torch import nn 
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from tqdm.auto import tqdm 

class BiLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, nclasses, device) -> None:
        super().__init__()

        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim).to(device)
        self.lstm_model = nn.LSTM(embedding_dim, hidden_size//2, bidirectional=True).to(device)
        self.ffwd_lay = nn.Linear(hidden_size, nclasses).to(device)
        self.softmax = nn.Softmax(dim=2).to(device)

        self.optim = torch.optim.Adam(self.parameters(), lr=1e-2)
        self.criterion = nn.CrossEntropyLoss()


    def forward(self, batch_X, seq_lens, device):
        out = self.embedding(batch_X.to(device)) # L x vocab_size -> L x embedding_dim
        out = pack_padded_sequence(out, seq_lens, batch_first=True, enforce_sorted=False)
        out, _ = self.lstm_model(out) # L x hidden_size
        out, seq_lens = pad_packed_sequence(out, batch_first=True)
        out = self.ffwd_lay(out)
        out = self.softmax(out)

        return out


    def fit(self, train_X, train_Y, seq_lens, nepochs, lr, device):
        self.train()
        self.to(device)

        for g in self.optim.param_groups:
            g['lr'] = lr    
        
        for ep in tqdm(range(nepochs)):
            eploss = 0
            
            for batch_X, batch_Y, batch_seq_len in tqdm(zip(train_X, train_Y, seq_lens)):
                predict = self.forward(batch_X, batch_seq_len, device)
                predict = torch.cat([predict[i, :batch_seq_len[i]] for i in range(len(predict))])
                real = torch.cat([batch_Y[i][:batch_seq_len[i]] for i in range(len(batch_Y))])
  
                self.optim.zero_grad()
                loss = self.criterion(predict, real.to(device))
                loss.backward()
                self.optim.step()

                eploss += loss.item()
            
            printbool = ep % (nepochs//10) == 0 if nepochs > 10 else True
            if printbool:
                print(f'Train loss: {eploss/len(train_X):.3f}')

In [2]:
import json
from importlib import reload

import torch
import numpy as np
import pandas as pd 
from tqdm.auto import tqdm
from sklearn.metrics import balanced_accuracy_score, f1_score


def data_label_split(data, label, train_size=0.8):
    randidx = np.arange(len(data))
    data_train, data_test = train_test_split(data, randidx, train_size)
    label_train, label_test = train_test_split(label, randidx, train_size)

    return data_train, data_test, label_train, label_test

def train_test_split(data, randidx, train_size):
    N = len(data)
    return [data[i] for i in randidx[:int(train_size*N)]], [data[i] for i in randidx[int(train_size*N):]]

def shuffle_data_label_lists(data, label):
    randidx = np.arange(len(data))
    np.random.shuffle(randidx)
    return [data[i] for i in randidx], [label[i] for i in randidx]

def batch_split(X, Y, seq_len, batch_size=1000):
    x_batched = []
    y_batched = []
    seq_len_batched = []

    n = len(X)
    pointer = 0
    while pointer + batch_size < n:
        x_batched.append(X[pointer:pointer+batch_size])
        y_batched.append(Y[pointer:pointer+batch_size])
        seq_len_batched.append(seq_len[pointer:pointer+batch_size])
        pointer += batch_size 
    
    x_batched.append(X[pointer:])
    y_batched.append(Y[pointer:])
    seq_len_batched.append(seq_len[pointer:])

    return x_batched, y_batched, seq_len_batched

In [3]:
# encoding tokens and labels
with open('data/mixtral-8x7b-v1.json', 'r', encoding='utf-8') as f:
    data_1 = json.load(f)
with open('data/train.json', 'r', encoding='utf-8') as f:
    data_2 = json.load(f)

data = data_1 + data_2

unique_tokens, unique_labels = set(), set()
for doc_i, doc in enumerate(tqdm(data)):
    unique_tokens |= set(np.unique(doc['tokens']))
    unique_labels |= set(np.unique(doc['labels']))

token2num = dict(zip(unique_tokens, range(1, len(unique_tokens)+1)))
label2num = {
    'O': 0,
    'B-URL_PERSONAL': 1, 
    'I-URL_PERSONAL': 1, 
    'B-ID_NUM': 2, 
    'I-ID_NUM': 2, 
    'B-EMAIL': 3, 
    'I-EMAIL': 3,
    'B-NAME_STUDENT': 4, 
    'I-NAME_STUDENT': 4, 
    'B-PHONE_NUM': 5, 
    'I-PHONE_NUM': 5, 
    'B-USERNAME': 6,
    'I-USERNAME': 6, 
    'B-STREET_ADDRESS': 7, 
    'I-STREET_ADDRESS': 7, 
}
num2token = {}
for it in token2num:
    num2token[token2num[it]] = it

  0%|          | 0/9162 [00:00<?, ?it/s]

In [4]:
# load data and split by sentences
sentences = []
cur_sentence = []
sentences_labels = []
cur_sentences_labels = []

for doc_i, doc in enumerate(tqdm(data)):
    for token, label in zip(data[doc_i]['tokens'], data[doc_i]['labels']):
        cur_sentence.append(token2num[token])
        cur_sentences_labels.append(label2num[label])

        if (token == '.') | (token.endswith('\n')) | (token == '?') | (token == '!'):  
            if len(cur_sentence) > 2: 
                sentences.append(torch.LongTensor(cur_sentence))
                sentences_labels.append(torch.LongTensor(cur_sentences_labels))

            cur_sentences_labels = []
            cur_sentence = []
    
    if len(cur_sentence) > 2:
        sentences.append(torch.LongTensor(cur_sentence))
        sentences_labels.append(torch.LongTensor(cur_sentences_labels))

    cur_sentences_labels = []
    cur_sentence = []
    
    
# create train and test df
class_split_sentences = {
    'O': [[],[]],
    'B-NAME_STUDENT': [[],[]],    
    'B-STREET_ADDRESS': [[],[]],
    'B-URL_PERSONAL': [[],[]],
    'B-ID_NUM': [[],[]],
    'B-EMAIL': [[],[]],
    'B-PHONE_NUM': [[],[]],
    'B-USERNAME': [[],[]],
}

classes_link = {}
for it in class_split_sentences:
    classes_link[label2num[it]] = it
unique_classes = classes_link.keys()

for i, it in enumerate(sentences):
    for cl in unique_classes:
        if cl in sentences_labels[i]:
            class_split_sentences[classes_link[cl]][0].append(sentences[i])
            class_split_sentences[classes_link[cl]][1].append(sentences_labels[i])
                   
# train test split
sentences_train = []
sentences_labels_train = []

sentences_test = []
sentences_labels_test = []

max_cl_size = max([len(class_split_sentences[it][0]) for it in class_split_sentences])
print(f'max_cl_size: {max_cl_size}')
for it in class_split_sentences:
    cl_sen_train, cl_sen_test, cl_sen_labels_train, cl_sen_labels_test = data_label_split(class_split_sentences[it][0], class_split_sentences[it][1])
    imbalance_coef = max_cl_size // len(class_split_sentences[it][0])
    
    sentences_train += cl_sen_train*imbalance_coef
    sentences_labels_train += cl_sen_labels_train*imbalance_coef

    sentences_test += cl_sen_test*imbalance_coef
    sentences_labels_test += cl_sen_labels_test*imbalance_coef

  0%|          | 0/9162 [00:00<?, ?it/s]

max_cl_size: 316863


In [5]:
sentences_train, sentences_labels_train = shuffle_data_label_lists(sentences_train, sentences_labels_train)
sentences_test, sentences_labels_test = shuffle_data_label_lists(sentences_test, sentences_labels_test)

seq_len_train = list(map(len, sentences_train))
max_len = max(seq_len_train)
sentences_train = torch.cat([F.pad(sentences_train[i], (0, max_len-seq_len_train[i])).reshape(1,-1) for i in range(len(sentences_train))])

seq_len_test = list(map(len, sentences_test))
max_len = max(seq_len_test)
sentences_test = torch.cat([F.pad(sentences_test[i], (0, max_len-seq_len_test[i])).reshape(1,-1) for i in range(len(sentences_test))])

sentences_train, sentences_labels_train, seq_len_train = batch_split(sentences_train, sentences_labels_train, seq_len_train)
sentences_test, sentences_labels_test, seq_len_test = batch_split(sentences_test, sentences_labels_test, seq_len_test)

In [6]:
# old load data and split by sentences
# sentences = []
# cur_sentence = []
# sentences_labels = []
# cur_sentences_labels = []

# for doc_i, doc in enumerate(tqdm(data)):
#     for token, label in zip(data[doc_i]['tokens'], data[doc_i]['labels']):
#         cur_sentence.append(token)
#         cur_sentences_labels.append(label)

#         if (token == '.') | (token.endswith('\n')) | (token == '?') | (token == '!'):   
#             if len(cur_sentence) > 2:
#                 sentences.append(cur_sentence)
#                 sentences_labels.append(cur_sentences_labels)

#             cur_sentences_labels = []
#             cur_sentence = []
    
#     if ('B-NAME_STUDENT' in cur_sentences_labels) | ('I-NAME_STUDENT' in cur_sentences_labels):
#         sentences.append(cur_sentence)
#         sentences_labels.append(cur_sentences_labels)

#     cur_sentences_labels = []
#     cur_sentence = []

# # load data and split by sentences
# sentences = []
# cur_sentence = []
# sentences_labels = []
# cur_sentences_labels = []

# for doc_i, doc in enumerate(tqdm(data)):
#     for token, label in zip(data[doc_i]['tokens'], data[doc_i]['labels']):
#         cur_sentence.append(token2num[token])
#         cur_sentences_labels.append(label2num[label])

#         if (token == '.') | (token.endswith('\n')) | (token == '?') | (token == '!'):  
#             if len(cur_sentence) > 2: 
#                 sentences.append(torch.LongTensor(cur_sentence))
#                 sentences_labels.append(torch.LongTensor(cur_sentences_labels))

#             cur_sentences_labels = []
#             cur_sentence = []
    
#     if len(cur_sentence) > 2:
#         sentences.append(torch.LongTensor(cur_sentence))
#         sentences_labels.append(torch.LongTensor(cur_sentences_labels))

#     cur_sentences_labels = []
#     cur_sentence = []
    
# # create train and test df 
# name_sentences_labels = []
# name_sentences = []

# username_sentences_labels = []
# username_sentences = []

# o_sentences_labels = []
# o_sentences = []

# for i, it in enumerate(sentences):
#     if 1 in sentences_labels[i]:
#         name_sentences_labels.append(sentences_labels[i])
#         name_sentences.append(sentences[i])
#     if 2 in sentences_labels[i]:
#         username_sentences_labels.append(sentences_labels[i])
#         username_sentences.append(sentences[i])
#     else:
#         o_sentences_labels.append(sentences_labels[i])
#         o_sentences.append(sentences[i])
      
        
# name_sentences_train, name_sentences_test, name_sentences_labels_train, name_sentences_labels_test = data_label_split(name_sentences, name_sentences_labels)
# # username_sentences_train, username_sentences_test, username_sentences_labels_train, username_sentences_labels_test = data_label_split(username_sentences, username_sentences_labels)
# o_sentences_train, o_sentences_test, o_sentences_labels_train, o_sentences_labels_test = data_label_split(o_sentences, o_sentences_labels)

# sentences_train = o_sentences_train + name_sentences_train*20
# sentences_labels_train = o_sentences_labels_train + name_sentences_labels_train*20

# sentences_test = o_sentences_test + name_sentences_test*20
# sentences_labels_test = o_sentences_labels_test + name_sentences_labels_test*20

# # sentences_train = o_sentences_train + name_sentences_train*20 + username_sentences_train*20
# # sentences_labels_train = o_sentences_labels_train + name_sentences_labels_train*20 + username_sentences_labels_train*20

# # sentences_test = o_sentences_test + name_sentences_test*20 + username_sentences_test*20
# # sentences_labels_test = o_sentences_labels_test + name_sentences_labels_test*20 + username_sentences_labels_test*20

# sentences_train, sentences_labels_train = shuffle_data_label_lists(sentences_train, sentences_labels_train)
# sentences_test, sentences_labels_test = shuffle_data_label_lists(sentences_test, sentences_labels_test)

# seq_len_train = list(map(len, sentences_train))
# max_len = max(seq_len_train)
# sentences_train = torch.cat([F.pad(sentences_train[i], (0, max_len-seq_len_train[i])).reshape(1,-1) for i in range(len(sentences_train))])

# seq_len_test = list(map(len, sentences_test))
# max_len = max(seq_len_test)
# sentences_test = torch.cat([F.pad(sentences_test[i], (0, max_len-seq_len_test[i])).reshape(1,-1) for i in range(len(sentences_test))])

# sentences_train, sentences_labels_train, seq_len_train = batch_split(sentences_train, sentences_labels_train, seq_len_train)
# sentences_test, sentences_labels_test, seq_len_test = batch_split(sentences_test, sentences_labels_test, seq_len_test)

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
# fit lstm
model = BiLSTM(
    vocab_size=len(token2num)+1,
    embedding_dim=32,
    hidden_size=32,
    nclasses=8,
    device=device
)

In [8]:
# fit bi-lstm
model.fit(
    sentences_train,
    sentences_labels_train,
    seq_len_train,
    nepochs=5,
    lr=1e-3,
    device=device
)

  0%|          | 0/5 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train loss: 1.443


0it [00:00, ?it/s]

In [61]:
train_predict = []
train_real = []
with torch.no_grad():
    for batch_X, batch_Y, batch_seq_len in tqdm(zip(sentences_train, sentences_labels_train, seq_len_train)):
        predict = model.forward(batch_X, batch_seq_len, device)
        predict = torch.argmax(torch.cat([predict[i, :batch_seq_len[i]] for i in range(len(predict))]), dim=1).cpu()
        real = torch.cat([batch_Y[i][:batch_seq_len[i]] for i in range(len(batch_Y))])

        train_predict.append(predict)
        train_real.append(real)

train_predict = torch.cat(train_predict)
train_real = torch.cat(train_real)
print(f'BA: {balanced_accuracy_score(train_real, train_predict):.3f}')
print()

0it [00:00, ?it/s]

BA: 0.994



In [63]:
test_predict = []
test_real = []
with torch.no_grad():
    for batch_X, batch_Y, batch_seq_len in tqdm(zip(sentences_test, sentences_labels_test, seq_len_test)):
        predict = model.forward(batch_X, batch_seq_len, device)
        predict = torch.argmax(torch.cat([predict[i, :batch_seq_len[i]] for i in range(len(predict))]), dim=1).cpu()
        real = torch.cat([batch_Y[i][:batch_seq_len[i]] for i in range(len(batch_Y))])

        test_predict.append(predict)
        test_real.append(real)

test_predict = torch.cat(test_predict)
test_real = torch.cat(test_real)
print(f'BA: {balanced_accuracy_score(test_real, test_predict):.3f}')

0it [00:00, ?it/s]

BA: 0.636


Only address test = 0.978
Address + name  = 0.884