## Import library

In [None]:
!pip install -q seqeval
!pip install -q evaluate
!pip install -q pytorch-crf

In [None]:
import os
import numpy as np
import evaluate
import pickle
import random
import tqdm
import re
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
import torch.nn as nn
from torchcrf import CRF
from transformers import AutoTokenizer
from transformers import DataCollatorForTokenClassification
from transformers import AutoModelForTokenClassification
from transformers import get_scheduler
from huggingface_hub import Repository, get_full_repo_name
from huggingface_hub import login
from accelerate import Accelerator
from tqdm.auto import tqdm
from datasets import *

## Create NER labels

In [None]:
entity = ['PATIENT'   , 'DOCTOR'       , 'USERNAME'  ,
          'PROFESSION',
          'ROOM'      , 'DEPARTMENT'   , 'HOSPITAL'  , 'ORGANIZATION', 'STREET' , 'CITY'    , 'STATE' , 'COUNTRY', 'ZIP'  , 'LOCATION-OTHER', 
          'AGE'       , 
          'DATE'      , 'TIME'         , 'DURATION'  , 'SET'         , 
          'PHONE'     , 'FAX'          , 'EMAIL'     , 'URL'         , 'IPADDR' , 
          'SSN'       , 'MEDICALRECORD', 'HEALTHPLAN', 'ACCOUNT'     , 'LICENSE', 'VECHICLE', 'DEVICE', 'BIOID'  , 'IDNUM']
label_names = ['OTHER']
entity_names = []
entity_count = [0] * len(entity)

for s in entity:
    label_names.append(f'B-{s}')
    label_names.append(f'I-{s}')
    entity_names.append(s)
    
id2label = {i: label for i, label in enumerate(label_names)}
label2id = {v: k for k, v in id2label.items()}
org_id2label = {i: label for i, label in enumerate(entity_names)}
org_label2id = {v: k for k, v in org_id2label.items()}

## Regular expression to find pattern need to be normalized

In [None]:
DATEs = '(\d{1,2}\/\d{2,5})|(\/\d{1,2}\/(\d{2}|\d{4}))|(\d{1,2}(\/|\.| |-|,)\d{1,2}(\/|\.| |-|,)\d{2,4})|(\d{3})|(\d{4})|(\d{8})|((\d{1,2}|)( |)(January|February|March|April|May|June|July|August|September|October|November|December) \d{4})|(Today|today|Now|now|Original|original|Previous|previous)|(Sunday|Monday|Tuesday|Wednesday|Thursday|Friday|Saturday)|((\d{2}|)(-|)(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)(-| )\d{2,4})|(\d{1,2}((st)|(nd)|(rd)|(th)) of (January|February|March|April|May|June|July|August|September|October|November|December) \d{4})'

TIMEs = '((\d{1,2}(\/|\.)\d{1,2}(\/|\.)\d{2,4}(  | |)|)(at|)( |)\d{1,2}(:|\.)\d{2}(AM|am|PM|pm|Hr|Hrs|hr|hrs|)( on the \d{1,2}((st)|(nd)|(rd)|(th)) of (January|February|March|April|May|June|July|August|September|October|November|December) \d{4}|))|(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2})|((at |)(\d{1,2}|)(:|\.|)\d{2}( |)(am|pm|Hr|Hrs|hr|hrs|)( on | )(the |)\d{1,2}(\/|\.)\d{2,4}(\/|\.)\d{1,2})|(((\d{1,2}((pm)|(am)))|(\d{4}(Hr|Hrs|hr|hrs|)))(( on )| )\d{1,2}(\/|\.)\d{1,2}(\/|\.)\d{2,4})'

DURATIONs = '((\d{1,2}|\d{1,2}-\d{1,2}|two|five)( |\/|)(day|days|week|weeks|wk|wks|month|months|year|years|yr|yrs))'

SETs = 'twice'

## Create dataset

In [None]:
# use argment "ans" to compare normalized string and ground truth

def Normalize(time_type, org, ans):
    nor = ''
    if (time_type == 'DATE'):
        if (re.match('\d{1,2}(\/|\.| |-|,)\d{1,2}(\/|\.| |-|,)\d{2,4}', org)):
            l = re.split('\/|\.| |-|,', org)
            if (len(l[2]) == 2):
                l[2] = '20' + l[2]
            elif (len(l[2]) == 3):
                l[2] = '2' + l[2]
            if (len(l[1]) == 1):
                l[1] = '0' + l[1]
            if (len(l[0]) == 1):
                l[0] = '0' + l[0]
            nor = l[2] + '-' + l[1] + '-' + l[0]
        elif (re.match('\/\d{1,2}\/(\d{2}|\d{4})', org)):
            l = re.split('\/', org)
            if (len(l[1]) == 1):
                l[1] = '0' + l[1]
            if (len(l[2]) == 2):
                l[2] = '20' + l[2]
            nor = l[2] + '-' + l[1]
        elif (re.match('\d{1,2}\/\d{2,5}', org)):
            l = re.split('\/', org)
            if (len(l[0]) == 1):
                l[0] = '0' + l[0]
            if (len(l[1]) == 2):
                nor = '20' + l[1] + '-' + l[0]
            elif (len(l[1]) == 3):
                nor = '20' + l[1][1:] + '-' + '0' + l[1][0] + '-' + l[0]
            elif (len(l[1]) == 4):
                nor = l[1] + '-' + l[0]
            elif (len(l[1]) == 5):
                nor = l[1][1:] + '-' + '0' + l[1][0] + '-' + l[0]
        elif (re.match('\d{8}', org)):
            nor = org[0:4] + '-' + org[4:6] + '-' + org[6:8]
        elif (re.match('\d{4}', org)):
            nor = org
        elif (re.match('\d{3}', org)):
            nor = '2' + org
        elif (re.match('(\d{2}|)(-|)(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)(-| )\d{2,4}', org)):
            org = org.replace('Jan', '01')
            org = org.replace('Feb', '02')
            org = org.replace('Mar', '03')
            org = org.replace('Apr', '04')
            org = org.replace('May', '05')
            org = org.replace('Jun', '06')
            org = org.replace('Jul', '07')
            org = org.replace('Aug', '08')
            org = org.replace('Sep', '09')
            org = org.replace('Oct', '10')
            org = org.replace('Nov', '11')
            org = org.replace('Dec', '12')
            l = re.split('-| ', org)
            if (len(l) == 2):
                if (len(l[1]) == 2):
                    l[1] = '20' + l[1]
                elif (len(l[1]) == 3):
                    l[1] = '2' + l[1]
                nor = l[1] + '-' + l[0]
            else:
                if (len(l[2]) == 2):
                    l[2] = '20' + l[2]
                elif (len(l[2]) == 3):
                    l[2] = '2' + l[2]
                nor = l[2] + '-' + l[1] + '-' + l[0]
        elif (re.match('\d{1,2}((st)|(nd)|(rd)|(th)) of (January|February|March|April|May|June|July|August|September|October|November|December) \d{4}', org)):
            org = org.replace('January', '01')
            org = org.replace('Feburary', '02')
            org = org.replace('March', '03')
            org = org.replace('April', '04')
            org = org.replace('May', '05')
            org = org.replace('June', '06')
            org = org.replace('July', '07')
            org = org.replace('August', '08')
            org = org.replace('September', '09')
            org = org.replace('October', '10')
            org = org.replace('November', '11')
            org = org.replace('December', '12')
            l = re.split(' ', org)
            nor = l[3] + '-' + l[2] + '-' + l[0][:-2]
        elif (re.match('(\d{1,2}|)( |)(January|February|March|April|May|June|July|August|September|October|November|December) \d{4}', org)):
            if (re.match('\d', org[0]) and re.match('\d', org[1]) == None):
                org = '0' + org
            org = org.replace('January', '01')
            org = org.replace('Feburary', '02')
            org = org.replace('March', '03')
            org = org.replace('April', '04')
            org = org.replace('May', '05')
            org = org.replace('June', '06')
            org = org.replace('July', '07')
            org = org.replace('August', '08')
            org = org.replace('September', '09')
            org = org.replace('October', '10')
            org = org.replace('November', '11')
            org = org.replace('December', '12')
            org = org.replace(' ', '')
            if (len(org) == 6):
                nor = org[2:] + '-' + org[0:2]
            else:    
                nor = org[4:] + '-' + org[2:4] + '-' + org[0:2]
    elif (time_type == 'TIME'):
        if (re.match('(\d{1,2}(\/|\.)\d{1,2}(\/|\.)\d{2,4}(  | |)|)(at|)( |)\d{1,2}(:|\.)\d{2}(AM|am|PM|pm|Hr|Hrs|hr|hrs|)( on the \d{1,2}((st)|(nd)|(rd)|(th)) of (January|February|March|April|May|June|July|August|September|October|November|December) \d{4}|)', org)):
            tmp = org
            pm = 0
            am = 0
            if (re.search('PM', org, flags=0) != None):
                pm = 1
            if (re.search('pm', org, flags=0) != None):
                pm = 1
            if (re.search('AM', org, flags=0) != None):
                am = 1
            if (re.search('am', org, flags=0) != None):
                am = 1
            get_date = 0
            date = re.search('\d{1,2}(\/|\.)\d{1,2}(\/|\.)\d{2,4}', org, flags=0)
            if (date != None):
                date = date.group(0)
                org = org.replace(date, '')
                date = re.split('\/|\.', date)
                if (len(date[0]) == 1):
                    date[0] = '0' + date[0]
                if (len(date[1]) == 1):
                    date[1] = '0' + date[1]
                if (len(date[2]) == 2):
                    date[2] = '20' + date[2]
                elif (len(date[2]) == 3):
                    date[2] = '2' + date[2]
                nor = date[2] + '-' + date[1] + '-' + date[0]
                get_date = 1
            yyyy = re.search('\d{4}', org, flags=0)
            if (yyyy != None and get_date == 0):
                yyyy = yyyy.group(0)
                org = org.replace(yyyy, '')
                nor = yyyy + '-'
            mm = re.search('January|February|March|April|May|June|July|August|September|October|November|December', org, flags=0)
            if (mm != None and get_date == 0):
                mm = mm.group(0)
                org = org.replace(mm, '')
                mm = mm.replace('January', '01')
                mm = mm.replace('Feburary', '02')
                mm = mm.replace('March', '03')
                mm = mm.replace('April', '04')
                mm = mm.replace('May', '05')
                mm = mm.replace('June', '06')
                mm = mm.replace('July', '07')
                mm = mm.replace('August', '08')
                mm = mm.replace('September', '09')
                mm = mm.replace('October', '10')
                mm = mm.replace('November', '11')
                mm = mm.replace('December', '12')
                nor = nor + mm + '-'
            dd = re.search('\d{1,2}((st)|(nd)|(rd)|(th))', org, flags=0)
            if (dd != None and get_date == 0):
                dd = dd.group(0)
                org = org.replace(dd, '')
                dd = dd.replace('st', '')
                dd = dd.replace('nd', '')
                dd = dd.replace('rd', '')
                dd = dd.replace('th', '')
                if (len(dd) == 1):
                    dd = '0' + dd
                nor = nor + dd
            get_time = 0
            time = re.search('\d{1,2}(:|\.)\d{1,2}', org, flags=0)
            if (time != None):
                time = time.group(0)
                org = org.replace(time, '')
                time = re.split('\.|:', time)
                if (pm == 1 and int(time[0]) < 12):
                    time[0] = str(int(time[0]) + 12)
                elif (am == 1 and int(time[0]) == 12):
                    time[0] = '00'
                if (len(time[0]) == 1):
                    time[0] = '0' + time[0]
                nor = nor + 'T' + time[0] + ':' + time[1]
                get_time = 1
            pm = 0
            am = 0
            if (re.search('pm', org, flags=0) != None):
                pm = 1
            if (re.search('am', org, flags=0) != None):
                am = 1
            time = re.search('\d{1,4}', org, flags=0)
            if (time != None and get_time == 0):
                time = time.group(0)
                org = org.replace(time, '')
                hh, mm = '00', '00'
                if (len(time) == 4):
                    hh = time[0:2]
                    mm = time[2:]
                elif (len(time) == 3):
                    hh = time[0]
                    mm = time[1:]
                elif (len(time) == 2):
                    hh = time
                elif (len(time) == 1):
                    hh = time
                if (pm == 1 and int(hh) < 12):
                    hh = str(int(hh) + 12)
                elif (am == 1 and int(hh) == 12):
                    hh = '00'
                nor = nor + 'T' + hh + ':' + mm    
            #if (nor != ans):    
                #print(f'1:nor={nor}, ans={ans}, org={tmp}')
        elif (re.match('\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}', org)):
            tmp = org
            nor = org.replace(' ', 'T')
            #if (nor != ans):    
                #print(f'2:nor={nor}, ans={ans}, org={tmp}')
        elif (re.match('(at |)(\d{1,2}|)(:|\.|)\d{2}( |)(am|pm|Hr|Hrs|hr|hrs|)( on | )(the |)\d{1,2}(\/|\.)\d{2,4}(\/|\.)\d{1,2}', org)):
            tmp = org
            pm = 0
            am = 0
            if (re.search('pm', org, flags=0) != None):
                pm = 1
            if (re.search('am', org, flags=0) != None):
                am = 1
            date = re.search('\d{1,2}(\/|\.)\d{1,2}(\/|\.)\d{2,4}', org, flags=0)
            if (date != None):
                date = date.group(0)
                org = org.replace(date, '')
                date = re.split('\/|\.', date)
                if (len(date[0]) == 1):
                    date[0] = '0' + date[0]
                if (len(date[1]) == 1):
                    date[1] = '0' + date[1]
                if (len(date[2]) == 2):
                    date[2] = '20' + date[2]
                elif (len(date[2]) == 3):
                    date[2] = '2' + date[2]
                nor = date[2] + '-' + date[1] + '-' + date[0] + 'T'
            org = org.replace(':', '')
            time = re.search('\d{1,4}', org, flags=0)
            if (time != None):
                time = time.group(0)
                org = org.replace(time, '')
                hh, mm = '00', '00'
                if (len(time) == 4):
                    hh = time[0:2]
                    mm = time[2:]
                elif (len(time) == 3):
                    hh = time[0]
                    mm = time[1:]
                elif (len(time) == 2):
                    hh = time
                elif (len(time) == 1):
                    hh = time
                if (pm == 1 and int(hh) < 12):
                    hh = str(int(hh) + 12)
                elif (am == 1 and int(hh) == 12):
                    hh = '00'
                nor = nor + hh + ':' + mm
            #if (nor != ans):    
                #print(f'3:nor={nor}, ans={ans}, org={tmp}')
        elif (re.match('((\d{1,2}((pm)|(am)))|(\d{4}(Hr|Hrs|hr|hrs|)))(( on )| )\d{1,2}(\/|\.)\d{1,2}(\/|\.)\d{2,4}', org)):
            tmp = org
            pm = 0
            am = 0
            if (re.search('pm', org, flags=0) != None):
                pm = 1
            if (re.search('am', org, flags=0) != None):
                am = 1
            date = re.search('\d{1,2}(\/|\.)\d{1,2}(\/|\.)\d{2,4}', org, flags=0)
            if (date != None):
                date = date.group(0)
                org = org.replace(date, '')
                date = re.split('\/|\.', date)
                if (len(date[0]) == 1):
                    date[0] = '0' + date[0]
                if (len(date[1]) == 1):
                    date[1] = '0' + date[1]
                if (len(date[2]) == 2):
                    date[2] = '20' + date[2]
                elif (len(date[2]) == 3):
                    date[2] = '2' + date[2]
                nor = date[2] + '-' + date[1] + '-' + date[0] + 'T'
            hrtime = re.search('\d{4}', org, flags=0)
            if (hrtime != None):
                hrtime = hrtime.group(0)
                org = org.replace(hrtime, '')
                nor = nor + hrtime[0:2] + ':' + hrtime[2:]
            time = re.search('\d{1,2}', org, flags=0)
            if (time != None):
                time = time.group(0)
                org = org.replace(time, '')
                hh = time
                if (pm == 1 and int(hh) < 12):
                    hh = str(int(hh) + 12)
                elif (am == 1 and int(hh) == 12):
                    hh = '00'
                if (len(hh) == 1):
                    hh = '0' + hh
                nor = nor + hh + ':' + '00'
            #if (nor != ans):    
                #print(f'4:nor={nor}, ans={ans}, org={tmp}')
    elif (time_type == 'DURATION'):
        tmp = org
        org = org.replace('one', '1')
        org = org.replace('two', '2')
        org = org.replace('three', '3')
        org = org.replace('four', '4')
        org = org.replace('five', '5')
        num = ''
        alp = ''
        for i in range(len(org)):
            if (org[i] == 'D' or org[i] == 'd' or\
                org[i] == 'W' or org[i] == 'w' or\
                org[i] == 'M' or org[i] == 'm' or\
                org[i] == 'Y' or org[i] == 'y'):
                alp = org[i]
                org = org[:i]
                break
        org = re.split('-| ', org)
        if (len(org) == 1 or org[1] == ''):
            nor = 'P' + org[0] + alp.upper()
        else:
            nor = 'P' + str((int(org[0]) + int(org[1])) / 2) + alp.upper()
        if (nor != ans):    
            print(f'dur:nor={nor}, ans={ans}, org={tmp}')
    elif (time_type == 'SET'):
        if (re.match('twice', org)):
            nor = 'R2'
    return nor

In [None]:
def Spilt2Words(name, f, fa, date1, date0, time1, time0, duration1, duration0, set1, set0):
    tok = []
    ner = []
    lidx = 0
    ridx = 0
    while True:
        # remove last '\n'
        ans_info = fa.readline()[:-1].split('\t')
        # remove normalized DATE/TIME
        if (ans_info[1] == 'DATE'):
            if (re.match(DATEs, ans_info[4])):
                date1 += 1
                #print(f'match DATE {ans_info[4:]}')
                nor = Normalize('DATE', ans_info[4], ans_info[5])
            else:
                date0 += 1
                #print(f'miss  DATE {ans_info}')
            ans_info = ans_info[:-1]
        elif (ans_info[1] == 'TIME'):
            if (re.match(TIMEs, ans_info[4])):
                time1 += 1
                #print(f'match TIME {ans_info[4:]}')
                nor = Normalize('TIME', ans_info[4], ans_info[5])
            else:
                time0 += 1
                #print(f'miss  TIME {ans_info[4:]}')
            ans_info = ans_info[:-1]
        elif (ans_info[1] == 'DURATION'):
            if (re.match(DURATIONs, ans_info[4])):
                duration1 += 1
                #print(f'match DURATION {ans_info[4:]}')
                nor = Normalize('DURATION', ans_info[4], ans_info[5])
            else:
                duration0 += 1
                #print(f'miss  DURATION {ans_info[4:]}')
            ans_info = ans_info[:-1]
        elif (ans_info[1] == 'SET'):
            if (re.match(SETs, ans_info[4])):
                set1 += 1
                #print(f'match SET {ans_info[4:]}')
                nor = Normalize('SET', ans_info[4], ans_info[5])
            else:
                set0 += 1
                #print(f'miss  SET {ans_info[4:]}')
            ans_info = ans_info[:-1]
            
        if (ans_info[1] != 'OTHER'): entity_count[org_label2id[ans_info[1]]] += 1
            
        ent_lidx, ent_ridx = int(ans_info[2]), int(ans_info[3])

        # find next ans_info
        while True:
            word = ''
            # find next word lidx
            while True:
                nxt_char = f.read(1)
                if (nxt_char == ' ' or nxt_char == '\n' or nxt_char == '\t'): 
                    lidx += 1
                else: 
                    word += nxt_char
                    break
            ridx = lidx
            # find next word ridx
            while True:
                char_pos = f.tell()
                nxt_char = f.read(1)
                if (nxt_char == ' ' or nxt_char == '\n' or nxt_char == '\t' or ridx + 1 == ent_ridx):
                    ridx += 1
                    f.seek(char_pos)
                    break
                else:
                    ridx += 1
                    word += nxt_char
                
            line_end = 0
            # remove '\n' in last word
            if (word[:-1] == '\n'): 
                line_end = 1
                word = word[:-1]
            # truncate beginning of the word if it is an entity word
            while (lidx < ent_lidx and ridx > ent_lidx and ridx <= ent_ridx):
                lidx += 1
                word = word[1:]
                
            tok.append(word)
            
            if (lidx < ent_lidx):
                ner.append(label2id['OTHER'])
            elif (lidx == ent_lidx):
                ner.append(label2id['B-' + ans_info[1]])
            elif (ridx <= ent_ridx):
                ner.append(label2id['I-' + ans_info[1]])
            
            lidx = ridx
            
            if (ridx == ent_ridx): # found the last word of entity, move to next answer info
                break
        
        info_pos = fa.tell()
        nxt_info = fa.readline()[:-1].split('\t')
        fa.seek(info_pos)
        # nxt_info is in next file
        if (nxt_info[0] != name): 
            break
        # nxt_info is in current file but has overlap in current info
        if (int(nxt_info[3]) <= ent_ridx):
            nxt_info = fa.readline()
            
    return tok, ner, date1, date0, time1, time0, duration1, duration0, set1, set0

In [None]:
def Segmentation(ds_id, ds_tok, ds_ner, id, tok, ner, l):
    while (len(ner) >= l):
        ridx = l
        k = random.randint(0, 1)
        while (ridx > 0 and ridx < len(ner) and id2label[ner[ridx]] != 'OTHER'):
            if (k): 
                ridx += 1
            else:
                ridx -= 1
        if (ridx == 0):
            ridx = len(ner)
        elif (ridx < len(ner)):
            ridx += 1
        find = 0
        for i in range(len(ner[:ridx])):
            if (ner[:ridx][i] != 0):
                find = 1
                break
        if (find == 1):
            ds_id.append(id)
            ds_tok.append(tok[:ridx])
            ds_ner.append(ner[:ridx])
        tok = tok[ridx:]
        ner = ner[ridx:]
    if (len(ner) > 0):
        find = 0
        for i in range(len(ner)):
            if (ner[i] != 0):
                find = 1
                break
        if (find == 1):
            ds_id.append(id)
            ds_tok.append(tok)
            ds_ner.append(ner)
    return

In [None]:
# Need to change path name with different directory structure

ds_dict = {'id':[], 'tokens':[], 'ner_tags':[]}
eval_dict = {'id':[], 'tokens':[], 'ner_tags':[]}

directories = [
    '/kaggle/input/nerdataset-phase1/First_Phase_ReleaseCorrection/First_Phase_Release(Correction)/First_Phase_Text_Dataset',
    '/kaggle/input/nerdataset-phase1/Second_Phase_Dataset/Second_Phase_Dataset/Second_Phase_Text_Dataset',
    '/kaggle/input/nerdataset-phase1/First_Phase_ReleaseCorrection/First_Phase_Release(Correction)/Validation_Release'
]
answers  = [
    '/kaggle/input/nerdataset-phase1/First_Phase_ReleaseCorrection/First_Phase_Release(Correction)/answer.txt',
    '/kaggle/input/nerdataset-phase1/Second_Phase_Dataset/Second_Phase_Dataset/answer.txt',
    '/kaggle/input/nerdataset-phase1/Validation_Dataset_Answer/answer.txt'
]

max_word_length = 80

date1 ,date0, time1, time0 , duration1, duration0, set1, set0 = 0, 0, 0, 0, 0, 0, 0, 0
for i in range(3):
    fnames = [f for f in os.listdir(directories[i])]
    fnames.sort()
    fa = open(answers[i], 'r')
    for fname in tqdm(fnames):
        #print(fname)
        f = open(f'{directories[i]}/{fname}', 'r')
        tok, ner, date1, date0, time1, time0, duration1, duration0, set1, set0 = Spilt2Words(fname[:-4], f, fa, date1, date0, time1, time0, duration1, duration0, set1, set0)
        if (max_word_length > 0):
                Segmentation(ds_dict['id'], ds_dict['tokens'], ds_dict['ner_tags'], fname[:-4], tok, ner, max_word_length)
        else:
            ds_dict['id'].append(fname[:-4])
            ds_dict['tokens'].append(tok)
            ds_dict['ner_tags'].append(ner)
        f.close()
        
print('Time information accuarcy:')
print(f'DATE: {date1 / (date1 + date0)}')
print(f'TIME: {time1 / (time1 + time0)}')
print(f'DURATION: {duration1 / (duration1 + duration0)}')
print(f'SET: {set1 / (set1 + set0)}')

## Spilt train & dev data

In [None]:
def CountSim(train, valid):
    tcnt = [0] * len(entity)
    vcnt = [0] * len(entity)
    for tdata in train:
        for t in tdata:
            if (t != 0 and id2label[t][0] != 'I'): tcnt[org_label2id[id2label[t][2:]]] += 1
    for vdata in valid:
        for v in vdata:
            if (v != 0 and id2label[v][0] != 'I'): vcnt[org_label2id[id2label[v][2:]]] += 1
    tsum = sum(tcnt)
    vsum = sum(vcnt)
    dist = 0
    for i in range(len(entity)):
        if (tsum > 0 and vsum > 0):
            tcnt[i] = tcnt[i]/tsum
            vcnt[i] = vcnt[i]/vsum
            dist += abs(tcnt[i] - vcnt[i]) * abs(tcnt[i] - vcnt[i])
    return tcnt, vcnt, dist

In [None]:
best_ds_train_valid = Dataset.from_dict(ds_dict).train_test_split(train_size=0.9)
best_tpor = [0] * len(entity)
best_vpor = [0] * len(entity)
best_dist = 1
upper_bound = 1.2e-5
try_step = 1000
while (best_dist > upper_bound):
    for i in tqdm(range(try_step)):
        cur_ds_train_valid = Dataset.from_dict(ds_dict).train_test_split(train_size=0.8)
        cur_tpor, cur_vpor, cur_dist = CountSim(cur_ds_train_valid['train']['ner_tags'], cur_ds_train_valid['test']['ner_tags'])
        if (cur_dist < best_dist):
            best_ds_train_valid = cur_ds_train_valid
            best_tpor = cur_tpor
            best_vpor = cur_vpor
            best_dist = cur_dist
            print(f'New smallest dist = {best_dist}')
            if (best_dist < upper_bound):
                break

x = np.arange(len(entity_names))
width = 0.4
plt.figure(figsize=(12.8, 4.8))
plt.bar(x, best_tpor, width, color='green', label='Train')
plt.bar(x + width, best_vpor, width, color='blue', label='Dev')
plt.xticks(x + width / 2, entity_names, rotation='vertical')
plt.ylabel('Porpotion')
plt.title('TrainDev distribution')
plt.legend()
plt.savefig('TrainDev distribution')
plt.show()

In [None]:
raw_ds = DatasetDict({'train': Dataset.from_dict(ds_dict),
                  'validation': best_ds_train_valid['test']})

In [None]:
# Count number of each entity in entire dataset
#plt.figure(figsize=(12.8, 4.8))
#plt.bar(entity_names,
#        entity_count, 
#        width=0.8, 
#        bottom=None, 
#        align='center', 
#        )
#plt.title('Entity Count')
#plt.xticks(rotation='vertical')
#plt.ylabel('Count')
#plt.savefig('Entity Count')
#plt.show()

## Tokenize data

In [None]:
model_name = 'lakshyakh93/deberta_finetuned_pii'
model_checkpoint = model_name
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
def align_labels_with_tokens(labels, word_ids):
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if word_id != current_word:
            # Start of a new word!
            current_word = word_id
            label = -100 if word_id is None else labels[word_id]
            new_labels.append(label)
        elif word_id is None:
            # Special token
            new_labels.append(-100)
        else:
            # Same word as previous token
            label = labels[word_id]
            # If the label is B-XXX we change it to I-XXX
            if label % 2 == 1:
                label += 1
            new_labels.append(label)

    return new_labels

In [None]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples['tokens'], truncation=True, is_split_into_words=True
    )
    all_labels = examples['ner_tags']
    new_labels = []
    for i, labels in enumerate(all_labels):
        word_ids = tokenized_inputs.word_ids(i)
        new_labels.append(align_labels_with_tokens(labels, word_ids))

    tokenized_inputs['labels'] = new_labels
    return tokenized_inputs

In [None]:
tokenized_datasets = raw_ds.map(
    tokenize_and_align_labels,
    batched=True,
    remove_columns=raw_ds['train'].column_names,
)

In [None]:
tokenized_datasets

In [None]:
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
#batch = data_collator([tokenized_datasets['train'][i] for i in range(2)])
#batch['labels']

In [None]:
#for i in range(2):
#    print(tokenized_datasets['train'][i]['labels'])

## Evaluate metric

In [None]:
metric = evaluate.load('seqeval')

In [None]:
#labels = raw_ds['train'][0]['ner_tags']
#labels = [label_names[i] for i in labels]
#labels
#
#predictions = labels.copy()
#predictions[2] = 'OTHER'
#metric.compute(predictions=[predictions], references=[labels])

In [None]:
def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        'precision': all_metrics['overall_precision'],
        'recall': all_metrics['overall_recall'],
        'f1': all_metrics['overall_f1'],
        'accuracy': all_metrics['overall_accuracy'],
    }

## Model & Training Config

In [None]:
# Can add linear/lstm layers

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class BERT_CRF(nn.Module):
    def __init__(self, bert, num_labels, hidden_dim):
        super(BERT_CRF, self).__init__()
        self.bert = bert
        self.num_labels = num_labels
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(input_size=self.num_labels, hidden_size=self.hidden_dim, num_layers=1, bidirectional=True)
        self.classifier = nn.Sequential(
            #nn.Dropout(0.06),
            #nn.Linear(1024, 1024),
            nn.Linear(self.num_labels, self.num_labels),
            nn.Dropout(0.1),
        )
        self.crf = CRF(self.num_labels, batch_first=True)

    def forward(self, input_ids, token_type_ids, attention_mask, labels):
        outputs=self.bert(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
        sequence_output = outputs[0]
        #h0 = torch.zeros(2, sequence_output.shape[1], self.hidden_dim).to(device)
        #c0 = torch.zeros(2, sequence_output.shape[1], self.hidden_dim).to(device)
        #sequence_output, (hn, cn) = self.lstm(sequence_output, (h0, c0))
        logits = self.classifier(sequence_output)
        for i in range(labels.shape[0]):
            for j in range(labels.shape[1]):
                if (labels[i][j] == -100):
                    labels[i][j] = 0
                    attention_mask[i][j] = True
        loss = -self.crf.forward(emissions=logits, tags=labels, mask=attention_mask.bool())
        pred = self.crf.decode(emissions=logits, mask=None)
        return {'loss':loss, 'pred':pred}

In [None]:
model = BERT_CRF(
    bert=AutoModelForTokenClassification.from_pretrained(
        model_checkpoint,
        id2label=id2label,
        label2id=label2id,
        ignore_mismatched_sizes=True
    ), 
    num_labels=len(label_names), 
    hidden_dim=128
)

train_dataloader = DataLoader(
    tokenized_datasets['train'],
    shuffle=True,
    collate_fn=data_collator,
    batch_size=8,
)

eval_dataloader = DataLoader(
    tokenized_datasets['validation'], collate_fn=data_collator, batch_size=8
)

optimizer = AdamW(model.parameters(), lr=2e-5)

accelerator = Accelerator()

model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

num_train_epochs = 30
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    'linear',
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

## Login to push model to hub

In [None]:
login(token=os.getenv("HUGGINGFACE_TOKEN"))

In [None]:
model_name = 'deberta-crf-finetuned-phi'
#os.mkdir(f'./{model_name}')
#repo_name = get_full_repo_name(model_name)
#repo = Repository(model_name, clone_from=repo_name)

## Training

In [None]:
def postprocess(predictions, labels):
    predictions = predictions.detach().cpu().clone().numpy()
    labels = labels.detach().cpu().clone().numpy()

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    return true_labels, true_predictions

In [None]:
progress_bar = tqdm(range(num_training_steps))
f1_score = []
max_f1 = -1
for epoch in range(num_train_epochs):
    # Training
    model.train()
    for batch in train_dataloader:
        outputs = model(**batch)

        loss = outputs['loss']
        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    # Evaluation
    model.eval()
    for batch in eval_dataloader:
        with torch.no_grad():
            outputs = model(**batch)

        predictions = outputs['pred']
        labels = batch['labels']

        # Necessary to pad predictions and labels for being gathered
        predictions = accelerator.pad_across_processes(torch.tensor(predictions), dim=1, pad_index=-100)
        labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)

        predictions_gathered = accelerator.gather(predictions)
        labels_gathered = accelerator.gather(labels)

        true_predictions, true_labels = postprocess(predictions_gathered, labels_gathered)
        metric.add_batch(predictions=true_predictions, references=true_labels)

    results = metric.compute()
    f1_score.append(results['overall_f1'])
    print(
        f'epoch {epoch}:',
        {
            key: results[f'overall_{key}']
            for key in ['precision', 'recall', 'f1', 'accuracy']
        },
    )

    if (results['overall_f1'] > max_f1):
        max_f1 = results['overall_f1']
        print(f'new max f1 score = {max_f1}\n')
        #Save and upload
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        torch.save(unwrapped_model, f'{model_name}/pytorch_model{epoch}.bin')
        if accelerator.is_main_process:
            tokenizer.save_pretrained(model_name)
            repo.push_to_hub(
                commit_message=f'Training in progress epoch {epoch}', blocking=False
            )

## Draw f1 score

In [None]:
model_name = model_name.replace('/', '_')
plt.plot(f1_score, label = "f1 score")
# naming the x axis
plt.xlabel('epoch')
# naming the y axis
plt.ylabel('f1 score')
# giving a title to my graph
title = f'{model_name} max word len = {max_word_length}'
plt.title(title)
# show a legend on the plot
plt.legend()
# store fig
plt.savefig(model_name)
# function to show the plot
plt.show()
# store score
with open(title, "wb") as fp:   #Pickling
    pickle.dump(f1_score, fp)

## Inference

In [None]:
repo_name = get_full_repo_name(model_name)
repo = Repository(model_name, clone_from=repo_name)

In [None]:
model = torch.load('/kaggle/working/deberta-crf-finetuned-phi/pytorch_model.bin')

In [None]:
p = open('time.txt', 'w')

for i in range(2, 3):
    fnames = [f for f in os.listdir("/kaggle/input/nerdataset-phase1/opendid_test/opendid_test")]
    fnames.sort()
    for fname in tqdm(fnames):
        base = 0
        print(fname)
        f = open(f'/kaggle/input/nerdataset-phase1/opendid_test/opendid_test/{fname}', 'r')
        NE = []
        text = f.readline()
        text_len = len(text)
        while (text != ''):
            #print(len(text))
            #print(len(text.split()))
            token = tokenizer(text.split(), is_split_into_words=True)
            word = token.tokens()
            output = model(input_ids=torch.LongTensor([token['input_ids']]).to(device), 
                           token_type_ids=torch.LongTensor([token['token_type_ids']]).to(device), 
                           attention_mask=torch.BoolTensor([token['attention_mask']]).to(device), 
                           labels=torch.LongTensor([token['token_type_ids']]).to(device)
                          )
            pred = output['pred'][0]
            l = len(word)
            s = ''
            find = 0
            for j in range(l):
                if (find == 1 and (pred[j] == 0 or id2label[pred[j]][0] == 'B')):
                    idx = text.find(s)
                    tmp = idx
                    text = (idx + len(s)) * '$' + text[idx + len(s):]
                    idx += base
                    if (idx == -1):
                        #print(s)
                        print(f'--------{text}--------')
                        #print('\n')
                        assert(0)
                    if (id2label[pred[j - 1]][2:] == 'TIME' and s[-2:] == 'on'):
                        s = s + ' '
                        tmp += len(s)
                        #print(f'tmp = {text[tmp]}')
                        while (len(text) > tmp and re.match('\d{1}|\/|\.', text[tmp])):
                            s = s + text[tmp]
                            tmp += 1
                        NE.append([fname[:-4], id2label[pred[j - 1]][2:], idx, idx + len(s), s])
                        print('on')
                        #print(s)
                    else:
                        NE.append([fname[:-4], id2label[pred[j - 1]][2:], idx, idx + len(s), s])
                    s = ''
                    find = 0
                if (pred[j] > 0):
                    if (word[j] == 'ï¿½'):
                        break
                    s += word[j]
                    find = 1
                    lidx = text.find(s)
                    ridx = lidx + len(s)
                    #print(f's={s}')
                    if (len(text) > ridx and id2label[pred[j + 1]][0] == 'I'):
                        if (text.find(s + ' ' + word[j + 1]) != -1):
                            s += ' '
                            continue
                        if (text.find(s + '  ' + word[j + 1]) != -1):
                            s += '  '
                            continue
                        if (text.find(s + '\n' + word[j + 1]) != -1 and text[ridx] == '\n'):
                            s += '\n'
                            continue
                    #print(f's={s}, ridx={ridx}, text[ridx]={text[ridx]}')
            
            #print(word)
            #print(pred)
            base += text_len
            text = f.readline()
            text_len = len(text)
        for j in range(len(NE)):
            if (NE[j][4][-1] == ' '):
                NE[j][3] -= 1
                NE[j][4] = NE[j][4][:-1]
            if (NE[j][4][-1] == '.'):
                NE[j][3] -= 1
                NE[j][4] = NE[j][4][:-1]
            if (len(NE[j][4]) == 0):
                continue
            #print(NE[j])
            if (NE[j][1] == 'DATE' or NE[j][1] == 'TIME' or NE[j][1] == 'DURATION' or NE[j][1] == 'SET'):
                nor = Normalize(NE[j][1], NE[j][4], '')
                if (NE[j][1] == 'DATE' and NE[j][4][-1] == ')'):
                    nor = Normalize(NE[j][1], NE[j][4][:-1], '')
                    p.write(f'{NE[j][0]}\t{NE[j][1]}\t{NE[j][2]}\t{NE[j][3]}\t{NE[j][4][:-1]}\t{nor}\n')
                else:
                    p.write(f'{NE[j][0]}\t{NE[j][1]}\t{NE[j][2]}\t{NE[j][3]}\t{NE[j][4]}\t{nor}\n')
                #print(f'{NE[j][0]}\t{NE[j][1]}\t{NE[j][2]}\t{NE[j][3]}\t{NE[j][4]}\t{nor}\n')
        f.close()

## Reference
- https://huggingface.co/docs/transformers/tasks/token_classification
- https://waynestalk.com/python-bar-charts/
- https://www.kaggle.com/discussions/general/65351
- https://pytorch-crf.readthedocs.io/en/stable/
- https://www.shenxiaohai.me/pytorch-tutorial-intermediate-04/
- https://regex101.com/