In [1]:
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertForSequenceClassification
from sklearn.model_selection import train_test_split
from scipy.special import softmax

import pandas as pd
import numpy as np

from tabulate import tabulate
from tqdm import trange
from tqdm import tqdm
import random
from pandarallel import pandarallel

import pytorch_warmup as warmup

pandarallel.initialize(progress_bar=True)

pd.options.display.max_colwidth = 500

INFO: Pandarallel will run on 2 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


In [2]:
import torch
torch.cuda.is_available()

True

In [3]:
# To load
scaler = torch.cuda.amp.GradScaler()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

MODEL_NAME = "/home/amanda-putra/20221220_REGSOSEK/MODEL/ditto_indobert-base-p2-tokonow-matching-v1_1"

tokenizer = BertTokenizer.from_pretrained(
    MODEL_NAME,
    do_lower_case = True
    )

special_tokens_dict = {'additional_special_tokens': ['[COL]','[VAL]']}
tokenizer.add_special_tokens(special_tokens_dict)

# Load the BertForSequenceClassification model
model = BertForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels = 2,
    output_attentions = False,
    output_hidden_states = False,
)
model.resize_token_embeddings(len(tokenizer))

# Recommended learning rates (Adam): 5e-5, 3e-5, 2e-5. See: https://arxiv.org/pdf/1810.04805.pdf
optimizer = torch.optim.AdamW(model.parameters(), 
                              lr = 5e-5,
                              eps = 1e-08
                              )

# Run on GPU
model.cuda()
# model.load_state_dict(torch.load('/home/amanda-putra/product_match/MODEL/IndoBERT_v1.1_c/indobert-base-p1-tokonow-matching-v1_1.pth'))
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30523, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [8]:
from unidecode import unidecode
import re

def set_seed(seed):
    """ Set all seeds to make results reproducible (deterministic mode).
        When seed is a false-y value or not supplied, disables deterministic mode. """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def clean_unit(text):
    text = unidecode(text)
    text = str(text).lower()

    text = text.replace("'", "")
    text = text.replace("`", "")
    
    # remove all special characters
    text = re.sub(r"[^a-zA-Z0-9]+", " ", text)

    # add whitespace after each number
    text = re.sub(r"([0-9]+(\.[0-9]+)?)", r" \1 ", text).strip()
    text = " ".join(text.split())
    
    return text


def text_preprocess(text):
    try:
        text = clean_unit(text)

    except:
        text = "unknown"

    return text

def preprocessing_data(df):
    df['id'] = df['id_ruta_pes'].astype(str)+"#"+df['id_art_pes'].astype(int).astype(str)+"_"+df['id_ruta_sp'].astype(int).astype(str)+"#"+df['id_art_sp'].astype(int).astype(str)
    # df['jk_pes'] = df['jk_pes'].astype(str).apply(lambda x: 'pria' if x=='1' else 'wanita' if x=='2' else '')
    # df['jk_sp'] = df['jk_sp'].astype(int).astype(str).apply(lambda x: 'pria' if x=='1' else 'wanita' if x=='2' else '')
    
    df['umur_pes'] = df['umur_pes'].astype(int).astype(str)
    df['umur_sp'] = df['umur_sp'].astype(int).astype(str)
    df['label'] = df['is_match'].astype(int)%2
    text_cols = ['nama_pes','nama_krt_pes','nama_sp','nama_krt_sp']
    for col in text_cols:
        df[col] = df[[col]].parallel_apply(lambda x: text_preprocess(x[col]),axis=1)
    
    df['p1_ent'] = "[COL] nama [VAL] "+df['nama_pes']+" [COL] kk [VAL] "+df['nama_krt_pes']+" [COL] umur [VAL] "+df['umur_pes']+" [COL] gender [VAL] "+df['jk_pes']
    df['p2_ent'] = "[COL] nama [VAL] "+df['nama_sp']+" [COL] kk [VAL] "+df['nama_krt_sp']+" [COL] umur [VAL] "+df['umur_sp']+" [COL] gender [VAL] "+df['jk_sp']
    
    return df

set_seed(2022)

In [9]:
df = pd.read_feather("/home/amanda-putra/20221220_REGSOSEK/df v0.1 - Pairs NIK.feather")
df = df[df['nik_pes']!=10000000000000000]
df['is_match'] = 1
df = df.sample(frac=1,random_state=2022).reset_index(drop=True)
df = preprocessing_data(df)
df.head()

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=7154), Label(value='0 / 7154'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=7154), Label(value='0 / 7154'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=7154), Label(value='0 / 7154'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=7154), Label(value='0 / 7154'))), …

Unnamed: 0,blok_pes,id_ruta_pes,id_art_pes,no_kk_pes,nik_pes,nama_pes,nama_krt_pes,jk_pes,umur_pes,blok_sp,...,nik_sp,nama_sp,nama_krt_sp,jk_sp,umur_sp,is_match,id,label,p1_ent,p2_ent
0,4,0CB68580-BF80-4544-8977-9C20BAA7D6E4,5,9907077000000000.0,9907076705990000,enjel maisani br s,aripin sembiring,wanita,10,4,...,9907076705990000,ayu melinda ginting,terkena ginting,wanita,15,1,0CB68580-BF80-4544-8977-9C20BAA7D6E4#5_1027027#4,1,[COL] nama [VAL] enjel maisani br s [COL] kk [VAL] aripin sembiring [COL] umur [VAL] 10 [COL] gender [VAL] wanita,[COL] nama [VAL] ayu melinda ginting [COL] kk [VAL] terkena ginting [COL] umur [VAL] 15 [COL] gender [VAL] wanita
1,2,D14CF1A5-F2B0-43C7-ABE2-21BABF7AEB64,0,9907040000000000.0,9907040212990000,dana,dana,pria,32,2,...,9907040212990000,ray irpanta,firman tarigan,pria,15,1,D14CF1A5-F2B0-43C7-ABE2-21BABF7AEB64#0_1060060#3,1,[COL] nama [VAL] dana [COL] kk [VAL] dana [COL] umur [VAL] 32 [COL] gender [VAL] pria,[COL] nama [VAL] ray irpanta [COL] kk [VAL] firman tarigan [COL] umur [VAL] 15 [COL] gender [VAL] pria
2,6,22E07DC4-CEFB-4602-8E11-22C83A5F591D,4,9907276000000000.0,9907276004990000,nabila aprilia putri,agus sugianto,wanita,7,6,...,9907276004990000,nabila aprilia putri,agus sugianto,wanita,7,1,22E07DC4-CEFB-4602-8E11-22C83A5F591D#4_1002002#5,1,[COL] nama [VAL] nabila aprilia putri [COL] kk [VAL] agus sugianto [COL] umur [VAL] 7 [COL] gender [VAL] wanita,[COL] nama [VAL] nabila aprilia putri [COL] kk [VAL] agus sugianto [COL] umur [VAL] 7 [COL] gender [VAL] wanita
3,25,5981AE3A-89C3-4847-8780-1AA49E039CA2,6,9908204000000000.0,9908204105990000,jumriani,usman,wanita,10,25,...,9908204105990000,jumriani,usman,wanita,10,1,5981AE3A-89C3-4847-8780-1AA49E039CA2#6_1002002#7,1,[COL] nama [VAL] jumriani [COL] kk [VAL] usman [COL] umur [VAL] 10 [COL] gender [VAL] wanita,[COL] nama [VAL] jumriani [COL] kk [VAL] usman [COL] umur [VAL] 10 [COL] gender [VAL] wanita
4,6,69DB2EA1-39B1-4A75-8573-DB531B9AE45B,1,9907275000000000.0,9907275408990000,eka gustika,santoso,wanita,26,6,...,9907275408990000,lisma sri ramadhani,purwanto,wanita,43,1,69DB2EA1-39B1-4A75-8573-DB531B9AE45B#1_1045055#2,1,[COL] nama [VAL] eka gustika [COL] kk [VAL] santoso [COL] umur [VAL] 26 [COL] gender [VAL] wanita,[COL] nama [VAL] lisma sri ramadhani [COL] kk [VAL] purwanto [COL] umur [VAL] 43 [COL] gender [VAL] wanita


In [10]:
df[['id','p1_ent','p2_ent','label']].head()

Unnamed: 0,id,p1_ent,p2_ent,label
0,0CB68580-BF80-4544-8977-9C20BAA7D6E4#5_1027027#4,[COL] nama [VAL] enjel maisani br s [COL] kk [VAL] aripin sembiring [COL] umur [VAL] 10 [COL] gender [VAL] wanita,[COL] nama [VAL] ayu melinda ginting [COL] kk [VAL] terkena ginting [COL] umur [VAL] 15 [COL] gender [VAL] wanita,1
1,D14CF1A5-F2B0-43C7-ABE2-21BABF7AEB64#0_1060060#3,[COL] nama [VAL] dana [COL] kk [VAL] dana [COL] umur [VAL] 32 [COL] gender [VAL] pria,[COL] nama [VAL] ray irpanta [COL] kk [VAL] firman tarigan [COL] umur [VAL] 15 [COL] gender [VAL] pria,1
2,22E07DC4-CEFB-4602-8E11-22C83A5F591D#4_1002002#5,[COL] nama [VAL] nabila aprilia putri [COL] kk [VAL] agus sugianto [COL] umur [VAL] 7 [COL] gender [VAL] wanita,[COL] nama [VAL] nabila aprilia putri [COL] kk [VAL] agus sugianto [COL] umur [VAL] 7 [COL] gender [VAL] wanita,1
3,5981AE3A-89C3-4847-8780-1AA49E039CA2#6_1002002#7,[COL] nama [VAL] jumriani [COL] kk [VAL] usman [COL] umur [VAL] 10 [COL] gender [VAL] wanita,[COL] nama [VAL] jumriani [COL] kk [VAL] usman [COL] umur [VAL] 10 [COL] gender [VAL] wanita,1
4,69DB2EA1-39B1-4A75-8573-DB531B9AE45B#1_1045055#2,[COL] nama [VAL] eka gustika [COL] kk [VAL] santoso [COL] umur [VAL] 26 [COL] gender [VAL] wanita,[COL] nama [VAL] lisma sri ramadhani [COL] kk [VAL] purwanto [COL] umur [VAL] 43 [COL] gender [VAL] wanita,1


In [11]:
df['label'].value_counts()

1    14308
Name: label, dtype: int64

In [12]:
def prepare(s1,s2):
    # res = f"{tokenizer.cls_token} {s1} {tokenizer.sep_token} {s2} {tokenizer.sep_token}"
    res = f"{s1} {tokenizer.sep_token} {s2}"
    return res

In [13]:
df['pairs'] = df.apply(lambda x: prepare(x['p1_ent'],x['p2_ent']),axis=1)
df.head()

Unnamed: 0,blok_pes,id_ruta_pes,id_art_pes,no_kk_pes,nik_pes,nama_pes,nama_krt_pes,jk_pes,umur_pes,blok_sp,...,nama_sp,nama_krt_sp,jk_sp,umur_sp,is_match,id,label,p1_ent,p2_ent,pairs
0,4,0CB68580-BF80-4544-8977-9C20BAA7D6E4,5,9907077000000000.0,9907076705990000,enjel maisani br s,aripin sembiring,wanita,10,4,...,ayu melinda ginting,terkena ginting,wanita,15,1,0CB68580-BF80-4544-8977-9C20BAA7D6E4#5_1027027#4,1,[COL] nama [VAL] enjel maisani br s [COL] kk [VAL] aripin sembiring [COL] umur [VAL] 10 [COL] gender [VAL] wanita,[COL] nama [VAL] ayu melinda ginting [COL] kk [VAL] terkena ginting [COL] umur [VAL] 15 [COL] gender [VAL] wanita,[COL] nama [VAL] enjel maisani br s [COL] kk [VAL] aripin sembiring [COL] umur [VAL] 10 [COL] gender [VAL] wanita [SEP] [COL] nama [VAL] ayu melinda ginting [COL] kk [VAL] terkena ginting [COL] umur [VAL] 15 [COL] gender [VAL] wanita
1,2,D14CF1A5-F2B0-43C7-ABE2-21BABF7AEB64,0,9907040000000000.0,9907040212990000,dana,dana,pria,32,2,...,ray irpanta,firman tarigan,pria,15,1,D14CF1A5-F2B0-43C7-ABE2-21BABF7AEB64#0_1060060#3,1,[COL] nama [VAL] dana [COL] kk [VAL] dana [COL] umur [VAL] 32 [COL] gender [VAL] pria,[COL] nama [VAL] ray irpanta [COL] kk [VAL] firman tarigan [COL] umur [VAL] 15 [COL] gender [VAL] pria,[COL] nama [VAL] dana [COL] kk [VAL] dana [COL] umur [VAL] 32 [COL] gender [VAL] pria [SEP] [COL] nama [VAL] ray irpanta [COL] kk [VAL] firman tarigan [COL] umur [VAL] 15 [COL] gender [VAL] pria
2,6,22E07DC4-CEFB-4602-8E11-22C83A5F591D,4,9907276000000000.0,9907276004990000,nabila aprilia putri,agus sugianto,wanita,7,6,...,nabila aprilia putri,agus sugianto,wanita,7,1,22E07DC4-CEFB-4602-8E11-22C83A5F591D#4_1002002#5,1,[COL] nama [VAL] nabila aprilia putri [COL] kk [VAL] agus sugianto [COL] umur [VAL] 7 [COL] gender [VAL] wanita,[COL] nama [VAL] nabila aprilia putri [COL] kk [VAL] agus sugianto [COL] umur [VAL] 7 [COL] gender [VAL] wanita,[COL] nama [VAL] nabila aprilia putri [COL] kk [VAL] agus sugianto [COL] umur [VAL] 7 [COL] gender [VAL] wanita [SEP] [COL] nama [VAL] nabila aprilia putri [COL] kk [VAL] agus sugianto [COL] umur [VAL] 7 [COL] gender [VAL] wanita
3,25,5981AE3A-89C3-4847-8780-1AA49E039CA2,6,9908204000000000.0,9908204105990000,jumriani,usman,wanita,10,25,...,jumriani,usman,wanita,10,1,5981AE3A-89C3-4847-8780-1AA49E039CA2#6_1002002#7,1,[COL] nama [VAL] jumriani [COL] kk [VAL] usman [COL] umur [VAL] 10 [COL] gender [VAL] wanita,[COL] nama [VAL] jumriani [COL] kk [VAL] usman [COL] umur [VAL] 10 [COL] gender [VAL] wanita,[COL] nama [VAL] jumriani [COL] kk [VAL] usman [COL] umur [VAL] 10 [COL] gender [VAL] wanita [SEP] [COL] nama [VAL] jumriani [COL] kk [VAL] usman [COL] umur [VAL] 10 [COL] gender [VAL] wanita
4,6,69DB2EA1-39B1-4A75-8573-DB531B9AE45B,1,9907275000000000.0,9907275408990000,eka gustika,santoso,wanita,26,6,...,lisma sri ramadhani,purwanto,wanita,43,1,69DB2EA1-39B1-4A75-8573-DB531B9AE45B#1_1045055#2,1,[COL] nama [VAL] eka gustika [COL] kk [VAL] santoso [COL] umur [VAL] 26 [COL] gender [VAL] wanita,[COL] nama [VAL] lisma sri ramadhani [COL] kk [VAL] purwanto [COL] umur [VAL] 43 [COL] gender [VAL] wanita,[COL] nama [VAL] eka gustika [COL] kk [VAL] santoso [COL] umur [VAL] 26 [COL] gender [VAL] wanita [SEP] [COL] nama [VAL] lisma sri ramadhani [COL] kk [VAL] purwanto [COL] umur [VAL] 43 [COL] gender [VAL] wanita


In [14]:
df['pairs_len'] = df['pairs'].apply(lambda x: len(x.split(" ")))
df[['id','label','pairs','pairs_len']].sort_values('pairs_len')

Unnamed: 0,id,label,pairs,pairs_len
6525,C9B2F376-1CFC-426A-AF8C-3BE94EC32647#0_2006032#1,1,[COL] nama [VAL] paozan [COL] kk [VAL] paozan [COL] umur [VAL] 36 [COL] gender [VAL] pria [SEP] [COL] nama [VAL] subandi [COL] kk [VAL] subandi [COL] umur [VAL] 35 [COL] gender [VAL] pria,33
3366,D6921A5B-2856-4238-BD15-EB74D6ECDA75#0_1057070#1,1,[COL] nama [VAL] aldyan [COL] kk [VAL] aldyan [COL] umur [VAL] 38 [COL] gender [VAL] pria [SEP] [COL] nama [VAL] aldyan [COL] kk [VAL] aldyan [COL] umur [VAL] 38 [COL] gender [VAL] pria,33
3367,AF26AB0F-1ADE-49B0-B3CA-2D054E41404F#1_2044050#4,1,[COL] nama [VAL] remma [COL] kk [VAL] rasi [COL] umur [VAL] 48 [COL] gender [VAL] wanita [SEP] [COL] nama [VAL] suri [COL] kk [VAL] pasinring [COL] umur [VAL] 84 [COL] gender [VAL] wanita,33
3368,D1BF79AF-3C13-436A-A54C-3785F6D62C43#0_3025032#1,1,[COL] nama [VAL] sanusi [COL] kk [VAL] sanusi [COL] umur [VAL] 77 [COL] gender [VAL] pria [SEP] [COL] nama [VAL] suhadi [COL] kk [VAL] suhadi [COL] umur [VAL] 44 [COL] gender [VAL] pria,33
8728,C17038EF-3711-4991-B1A9-A3807F11B55D#0_6355375#1,1,[COL] nama [VAL] sahban [COL] kk [VAL] sahban [COL] umur [VAL] 33 [COL] gender [VAL] pria [SEP] [COL] nama [VAL] sahban [COL] kk [VAL] sahban [COL] umur [VAL] 34 [COL] gender [VAL] pria,33
...,...,...,...,...
4266,9011B5B5-457D-4C38-B9D2-41B02BB15693#1_12069077#2,1,[COL] nama [VAL] a a trisna permana s [COL] kk [VAL] a a ngurah awidya n [COL] umur [VAL] 35 [COL] gender [VAL] wanita [SEP] [COL] nama [VAL] a a trisna permana s [COL] kk [VAL] a a ngurah awidya n [COL] umur [VAL] 36 [COL] gender [VAL] wanita,49
11624,4DF19CF2-3E78-47A9-BA65-EE66DF18B6E4#0_12047055#1,1,[COL] nama [VAL] a a gde ngurah d [COL] kk [VAL] a a gde ngurah d [COL] umur [VAL] 42 [COL] gender [VAL] pria [SEP] [COL] nama [VAL] anak agung gde p d [COL] kk [VAL] anak agung gde p d [COL] umur [VAL] 42 [COL] gender [VAL] pria,49
4252,4DF19CF2-3E78-47A9-BA65-EE66DF18B6E4#3_12047055#4,1,[COL] nama [VAL] a a gde n p [COL] kk [VAL] a a gde ngurah d [COL] umur [VAL] 11 [COL] gender [VAL] pria [SEP] [COL] nama [VAL] anak agung gde n p [COL] kk [VAL] anak agung gde p d [COL] umur [VAL] 12 [COL] gender [VAL] pria,49
10335,9011B5B5-457D-4C38-B9D2-41B02BB15693#0_12069077#1,1,[COL] nama [VAL] a a ngurah awidya n [COL] kk [VAL] a a ngurah awidya n [COL] umur [VAL] 40 [COL] gender [VAL] pria [SEP] [COL] nama [VAL] a a ngurah awidya n [COL] kk [VAL] a a ngurah awidya n [COL] umur [VAL] 40 [COL] gender [VAL] pria,49


In [15]:
text_test = df.pairs.values
labels_test =  df['label'].values

In [16]:
def print_rand_sentence():
  '''Displays the tokens and respective IDs of a random text sample'''
  index = random.randint(0, len(text_test)-1)
  table = np.array([tokenizer.tokenize(text_test[index]), 
                    tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text_test[index]))]).T
  print(tabulate(table,
                 headers = ['Tokens', 'Token IDs'],
                 tablefmt = 'fancy_grid'))

print_rand_sentence()

╒══════════╤═════════════╕
│ Tokens   │   Token IDs │
╞══════════╪═════════════╡
│ [COL]    │       30521 │
├──────────┼─────────────┤
│ nama     │         712 │
├──────────┼─────────────┤
│ [VAL]    │       30522 │
├──────────┼─────────────┤
│ novi     │       22624 │
├──────────┼─────────────┤
│ ##a      │       30354 │
├──────────┼─────────────┤
│ br       │         801 │
├──────────┼─────────────┤
│ sila     │       15258 │
├──────────┼─────────────┤
│ ##lah    │         212 │
├──────────┼─────────────┤
│ ##i      │       30356 │
├──────────┼─────────────┤
│ [COL]    │       30521 │
├──────────┼─────────────┤
│ kk       │        7819 │
├──────────┼─────────────┤
│ [VAL]    │       30522 │
├──────────┼─────────────┤
│ sang     │        1749 │
├──────────┼─────────────┤
│ ##ap     │          36 │
├──────────┼─────────────┤
│ sila     │       15258 │
├──────────┼─────────────┤
│ ##lah    │         212 │
├──────────┼─────────────┤
│ ##i      │       30356 │
├──────────┼─────────────┤
│

In [17]:
def create_data(text,label,tokenizer):
    token_id = []
    attention_masks = []

    def preprocessing(input_text, tokenizer):
        '''
        Returns <class transformers.tokenization_utils_base.BatchEncoding> with the following fields:
            - input_ids: list of token ids
            - token_type_ids: list of token type ids
            - attention_mask: list of indices (0,1) specifying which tokens should considered by the model (return_attention_mask = True).
        '''
        return tokenizer.encode_plus(
                                input_text,
                                add_special_tokens = True,
                                max_length = 100,
                                padding='max_length',
                                return_attention_mask = True,
                                return_tensors = 'pt'
                        )


    for sample in text:
        encoding_dict = preprocessing(sample, tokenizer)
        token_id.append(encoding_dict['input_ids']) 
        attention_masks.append(encoding_dict['attention_mask'])


    token_id = torch.cat(token_id, dim = 0)
    attention_masks = torch.cat(attention_masks, dim = 0)
    label = torch.tensor(label)

    return token_id,attention_masks,label

In [18]:
test_token_id, test_attention_masks, test_labels = create_data(text_test,labels_test,tokenizer)

In [19]:
test_set = TensorDataset(test_token_id, 
                        test_attention_masks, 
                        test_labels)
                        
test_dataloader = DataLoader(
            test_set,
            sampler = SequentialSampler(test_set),
            batch_size = 256
        )

In [20]:
preds_all = []
label_all = []
proba_all = []


for batch in tqdm(test_dataloader):
  batch = tuple(t.to(device) for t in batch)
  b_input_ids, b_input_mask, b_labels = batch
  with torch.no_grad():
    # Forward pass
    eval_output = model(b_input_ids, 
                        token_type_ids = None, 
                        attention_mask = b_input_mask)
    
    logits = eval_output.logits.detach().cpu().numpy()
  # Calculate validation metrics
  label_ids = b_labels.to('cpu').numpy()
  
  proba = softmax(logits, axis=1)
  proba_all.extend(proba[:,1])
  
  preds = np.argmax(logits, axis = 1).flatten()
  preds = [1 if l>=0.99 else 0 for l in proba[:,1]]
  preds_all.extend(preds)
  
  label_ids = b_labels.to('cpu').numpy()
  label_all.extend(label_ids)

100%|██████████| 56/56 [01:22<00:00,  1.47s/it]


In [27]:
print("Inference time: {:.3}s".format(82/14308))

Inference time: 0.00573s


In [28]:
label_all[:5],preds_all[:5],proba_all[:5]

([1, 1, 1, 1, 1],
 [0, 0, 1, 1, 0],
 [0.0048365244, 0.0028911585, 0.99904126, 0.9991014, 0.0021986251])

In [29]:
from sklearn.metrics import classification_report, confusion_matrix

def benchmark_model(df):
    ## Evaluate in TEST data
    print(confusion_matrix(df['label'], df['pred']))
    print(classification_report(df['label'], df['pred'],digits=4))
    report = classification_report(df['label'], df['pred'],digits=4,output_dict = True)
    return report

In [30]:
cek = df.copy()
cek['pred'] = preds_all
cek['proba'] = proba_all
cek['conf'] = (np.abs(2*cek['proba']-1))
cek = cek.sort_values('conf')
target_cols = ['id','blok_pes','nik_pes','nama_pes','nama_krt_pes','jk_pes','umur_pes','blok_sp','nik_sp','nama_sp','nama_krt_sp','nama_krt_sp','jk_sp','umur_sp','label','pred','proba','conf']
cek[target_cols].head(10)
# cek[['id','pairs','label','pred','proba','conf']].head(10)

Unnamed: 0,id,blok_pes,nik_pes,nama_pes,nama_krt_pes,jk_pes,umur_pes,blok_sp,nik_sp,nama_sp,nama_krt_sp,nama_krt_sp.1,jk_sp,umur_sp,label,pred,proba,conf
6746,02B6E44D-AA38-4670-A634-443BFA6526CB#0_1029029#1,5,9907080808990000,rony rianda suhendi,rony rianda suhendi,pria,40,5,9907080808990000,subehana,subehana,subehana,pria,60,1,0,0.498441,0.003119
14112,7D2CAFEB-D066-4BEE-B833-77C9A302570F#0_4078093#1,24,9908190107990120,muh haris,muh haris,pria,71,24,9908190107990120,sudding,sudding,sudding,pria,71,1,0,0.498314,0.003372
628,17F0DCC2-BA9B-4719-82AC-ECCF4D3F82CF#0_1066066#1,10,9903203112990090,selamet riadi,selamet riadi,pria,34,10,9903203112990090,sekandar,sekandar,sekandar,pria,37,1,0,0.49809,0.003819
7940,13D747E6-FA50-4B71-81EF-D75BCB62486A#4_1010012#4,33,9971055712990000,dessy pattimahu,johanis pattimahu,wanita,13,33,9971055712990000,cheyril defretes,david defretes,david defretes,wanita,13,1,0,0.505868,0.011737
11017,658FED0F-10D6-45EF-AB6A-864368CA1C64#0_1001006#1,26,9908220107990070,cobba,cobba,pria,51,26,9908220107990070,muh ilyas,muh ilyas,muh ilyas,pria,51,1,0,0.492282,0.015436
1017,7253EC38-FDCC-4D7C-9A59-EE17201F0A9E#1_2006007#4,23,9908274107990030,andi rikayanti,faisal,wanita,28,23,9908274107990030,tutianti,taba,taba,wanita,27,1,0,0.509205,0.018409
890,36CC80D5-F0D2-47D2-893B-835912C6C84C#1_15071071#2,29,9901145603990000,fadila syukur,wa ona,wanita,22,29,9901145603990000,wa rani,jumail sapsuha,jumail sapsuha,wanita,26,1,0,0.513195,0.02639
5040,EF14E536-C243-4985-A3BF-38479693B0D7#2_2023040#3,30,9901150909990000,ismail pelu,ruslan pelu,pria,8,30,9901150909990000,all faujan wael,ibrahim wael,ibrahim wael,pria,6,1,0,0.486525,0.026951
257,031EA3FE-1537-4226-9A94-769A1F1EB794#2_2015025#2,30,9901151506990000,junaidi pelu,amirudin pelu,pria,29,30,9901151506990000,abdulla hitu,napsia pelu,napsia pelu,pria,27,1,0,0.518143,0.036287
12294,F4B29A34-37B3-4484-B3F0-4095AB7079E3#3_2039044#7,21,9908030107990010,isnul,nurjannah,pria,22,21,9908030107990010,ridwan nur,salmawati,salmawati,pria,25,1,0,0.481815,0.036371


In [34]:
6251/14308

0.4368884540117417

In [32]:
report = benchmark_model(cek)
pd.DataFrame(report).transpose()

[[   0    0]
 [6251 8057]]
              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     1.0000    0.5631    0.7205     14308

    accuracy                         0.5631     14308
   macro avg     0.5000    0.2816    0.3603     14308
weighted avg     1.0000    0.5631    0.7205     14308



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Unnamed: 0,precision,recall,f1-score,support
0,0.0,0.0,0.0,0.0
1,1.0,0.563112,0.720501,14308.0
accuracy,0.563112,0.563112,0.563112,0.563112
macro avg,0.5,0.281556,0.36025,14308.0
weighted avg,1.0,0.563112,0.720501,14308.0


In [31]:
cek[cek['id']=='FA29F4CA-1F60-4C8F-A74E-E5040E3941F4#1_2037041#2']

Unnamed: 0,blok_pes,id_ruta_pes,id_art_pes,no_kk_pes,nik_pes,nama_pes,nama_krt_pes,jk_pes,umur_pes,blok_sp,...,is_match,id,label,p1_ent,p2_ent,pairs,pairs_len,pred,proba,conf
1653,23,FA29F4CA-1F60-4C8F-A74E-E5040E3941F4,1,9908274000000000.0,9908274107990010,a muliati,syamsuddin,wanita,40,23,...,1,FA29F4CA-1F60-4C8F-A74E-E5040E3941F4#1_2037041#2,1,[COL] nama [VAL] a muliati [COL] kk [VAL] syamsuddin [COL] umur [VAL] 40 [COL] gender [VAL] wanita,[COL] nama [VAL] a nurlinda [COL] kk [VAL] a firdaus [COL] umur [VAL] 40 [COL] gender [VAL] wanita,[COL] nama [VAL] a muliati [COL] kk [VAL] syamsuddin [COL] umur [VAL] 40 [COL] gender [VAL] wanita [SEP] [COL] nama [VAL] a nurlinda [COL] kk [VAL] a firdaus [COL] umur [VAL] 40 [COL] gender [VAL] wanita,36,1,0.997669,0.995339


In [35]:
cek[target_cols].to_csv('df v0.1 - Pairs NIK check.csv',index=False)