In [None]:
# ====================================================
# Directory settings
# ====================================================
import os

VER = 3
OUTPUT_DIR = f'./AI_CUP_{VER}'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

In [None]:
# ====================================================
# Library
# ====================================================
import os
import gc
import re
import ast
import sys
import copy
import json
import time
import math
import string
import pickle
import random
import joblib
import itertools
import warnings
warnings.filterwarnings("ignore")

import scipy as sp
import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
from tqdm.auto import tqdm
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold

# os.system('pip install iterative-stratification==0.1.7')
# from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset, Features, Value


os.system('pip install -q transformers')
os.system('pip install -q tokenizers')
import tokenizers
import transformers
print(f"tokenizers.__version__: {tokenizers.__version__}")
print(f"transformers.__version__: {transformers.__version__}")
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
%env TOKENIZERS_PARALLELISM=true

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
from torch.cuda.amp import autocast, GradScaler
from sklearn import metrics
from src.machine_learning_util import set_seed, set_device, init_logger, AverageMeter, to_pickle, unpickle, asMinutes, timeSince

In [None]:
class CFG:
    EXP_ID = '024'
    apex = True
    model ='microsoft/deberta-v3-large' # 'microsoft/deberta-large' # 'microsoft/deberta-v3-base' #'microsoft/deberta-v3-large' 
    seed = 2022 # 42 # 71
    n_splits = 4
    max_len = 512 # 1429 # 1024 # 512
    dropout = 0
    target_cols = "label"
    target_size = None
    n_accumulate=1
    print_freq = 100
    eval_freq = 8500 # 780 * 2 # 390 # 170
    min_lr=1e-6
    scheduler = 'cosine'
    batch_size = 6 # 2 # 4
    num_workers = 0 #3
    lr = 5e-6 # 3e-6
    weigth_decay = 0.01
    epochs = 2
    n_fold = 4
    trn_fold = [i for i in range(n_fold)]
    train = True
    num_warmup_steps = 0
    num_cycles=0.5
    debug = False
    freezing = True
    gradient_checkpoint = True
    reinit_layers = 4 # 3
    tokenizer = AutoTokenizer.from_pretrained(model)
    max_norm = 1

In [None]:
# ====================================================
# Utils
# ====================================================

def get_logger(filename=OUTPUT_DIR+'_train'):
    from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=f"{filename}.log")
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = get_logger()


def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def find_all(a_string, sub):
    result = []
    k = 0
    while k < len(a_string):
        k = a_string.find(sub, k)
        if k == -1:
            return result
        else:
            start = k
            end = k + len(sub)

            result.append(f"{start} {end}")
            # k += 1 #change to k += len(sub) to not search overlapping results
            k += len(sub)

    return result
    
seed_everything(CFG.seed)
# seed_everything(seed=60)

# Data Loading

In [None]:
data_path = "./NER_Dataset"

train_anno_file_path = f"{data_path}/First_Phase_Text_Dataset_answer/answer.txt"

train_anno_file_path_2 = f"{data_path}/Second_Phase_Text_Dataset_answer/answer.txt"

val_anno_file_path = f"{data_path}/Validation_Release_answer/answer.txt"
train_report_file_path = f"{data_path}/First_Phase_Text_Dataset/"

train_report_file_path_2 = f"{data_path}/Second_Phase_Text_Dataset/"

val_report_file_path = f"{data_path}/Validation_Release"




def read_file(anno_file_path):
    file = open(anno_file_path, 'r', encoding="UTF-8-sig")
    file_Lines = file.readlines()
    return file_Lines

def process_annotation_file(lines):
    entity_dict = {}
    for line in lines:
        items = line.strip('\n').split('\t')
        if len(items) == 5:
            item_dict = {'phi' : items[1],
                         'st_idx' : int(items[2]),
                         'ed_idx' : int(items[3]),
                         'entity' : items[4]}
        elif len(items) == 6:
            item_dict = {'phi' : items[1],
                         'st_idx' : int(items[2]),
                         'ed_idx' : int(items[3]),
                         'entity' : items[4],
                         'normalize_time' : items[5]}
        if items[0] not in entity_dict:
            entity_dict[items[0]] = [item_dict]
        else :
            entity_dict[items[0]].append(item_dict)
    return  entity_dict

def process_medical_report(txt_name, report_file_path, annos_dict,special_tokens_dict): # annos_dict : 標註答案
    
    file_name = txt_name+'.txt'
    sents = read_file(os.path.join(report_file_path, file_name))

    article = "".join(sents) # 病人報告


    bounary, item_idx, temp_seq, seq_pairs = 0, 0, "", []

    for w_idx, word in enumerate(article):
        if w_idx == annos_dict[txt_name][item_idx]['st_idx']: # w_idx == start
            phi_key = annos_dict[txt_name][item_idx]['phi'] # phi類別
            phi_value = annos_dict[txt_name][item_idx]['entity'] # entity

            if 'normalize_time' in annos_dict[txt_name][item_idx]:
                temp_seq += f"{phi_key}:{phi_value}=>{annos_dict[txt_name][item_idx]['normalize_time']}\n"
            else:
                temp_seq += f"{phi_key}:{phi_value}\n"
            if item_idx == len(annos_dict[txt_name]) - 1: # 還有答案就繼續
                continue
            item_idx += 1 # 下一個正確答案
        if word == '\n':
            new_line_idx = w_idx + 1
            if temp_seq == "": # 都沒有正確答案
                temp_seq = "PHI:NULL"
            # seq_pair = special_tokens_dict['bos_token'] + article[bounary:new_line_idx] + special_tokens_dict['sep_token'] + temp_seq + special_tokens_dict['eos_token']
            seq_pair = article[bounary:new_line_idx] +'[SEP]' +temp_seq
            bounary = new_line_idx
            seq_pairs.append(seq_pair)
            temp_seq = "" # 初始化
    return seq_pairs

def generate_annotated_medical_report(anno_file_path, report_file_path):
    anno_lines = read_file(anno_file_path)
    annos_dict = process_annotation_file(anno_lines)
    report_name_list = [f.split('.')[0] for f in os.listdir(report_file_path) if os.path.isfile(os.path.join(report_file_path, f))]
    
    special_tokens_dict = {"bos_token":"<|endoftext|>","sep_token":"\n####\n\n","eos_token":"<|END|>"}
    
    pocessed_files={}

    # test = []
    for txt_name in report_name_list:
        pocessed_file = process_medical_report(txt_name, report_file_path,annos_dict,special_tokens_dict)
        pocessed_files[txt_name] = pocessed_file

        #test.append(pocessed_file)
    
    return pocessed_files #test
    
    
    

In [None]:
train_report = generate_annotated_medical_report(train_anno_file_path,train_report_file_path)

train_report_2 = generate_annotated_medical_report(train_anno_file_path_2,train_report_file_path_2)

val_report = generate_annotated_medical_report(val_anno_file_path,val_report_file_path)

In [None]:
train_list = []
val_list = []
for key,value in train_report.items():
    train_list.extend(value)

for key,value in train_report_2.items():
    train_list.extend(value)

for key,value in val_report.items():
    train_list.extend(value)

for key,value in val_report.items():
    val_list.extend(value)

In [None]:
train_df = pd.DataFrame(train_list,columns=["content_label"])
val_df = pd.DataFrame(val_list,columns=["content_label"])
train_df = train_df.drop(train_df[train_df["content_label"] == "\n[SEP]PHI:NULL"].index)
val_df = val_df.drop(val_df[val_df["content_label"] == "\n[SEP]PHI:NULL"].index)
train_df = train_df.reset_index(drop = True)
val_df = val_df.reset_index(drop = True)

In [None]:
def split_text(df):
    text, label = df["content_label"].split("\n[SEP]")
    return text

def split_label(df):
    text, label = df["content_label"].split("\n[SEP]")
    return label


train_df["text"] = train_df.apply(split_text,axis = 1)
train_df["label"] = train_df.apply(split_label,axis = 1)
val_df["text"] = val_df.apply(split_text,axis = 1)
val_df["label"] = val_df.apply(split_label,axis = 1)

train_df = train_df.drop(columns=['content_label'])

val_df = val_df.drop(columns=['content_label'])


In [None]:
train_df.head()

In [None]:
val_df.head()

# Load Chatgpt generated file

In [None]:
gpt_label_text_list = []
gpt_text_list = []

gpt_data_path = "./NER_Dataset/CHATGPT/VER-5"

country_label_text_list = read_file(f'{gpt_data_path}/COUNTRY_VER-5/country.txt')
organization_text_list = read_file(f'{gpt_data_path}/ORGANIZATION_VER-5/Corporation_INC_company/ORGANIZATION_text.txt')
organization_label_text_list = read_file(f'{gpt_data_path}/ORGANIZATION_VER-5/Corporation_INC_company/ORGANIZATION_label.txt')

organization_label_text_list_2 = read_file(f'{gpt_data_path}/ORGANIZATION_VER-5/ORGANIZATION_label_text.txt')

LOCATION_OTHER_label_text_list_1 = read_file(f'{gpt_data_path}/LOCATION-OTHER_VER-5/Correct-Format-LOCATION-OTHER_FORMAT-1.txt')
# LOCATION_OTHER_label_text_list_2 = read_file('C:/Users/Lab000/Desktop/2023_AI_CUP秋季/data/CHATGPT/VER-3/LOCATION-OTHER_VER-3/LOCATION-OTHER_FORMAT-2.txt')

PHONE_label_text_list_1 = read_file(f'{gpt_data_path}/PHONE_VER-5/PHONE_text-label_FORMAT-1.txt')
PHONE_label_text_list_2 = read_file(f'{gpt_data_path}/PHONE_VER-5/PHONE_text-label_FORMAT-2.txt')


for country_label_text in country_label_text_list:
    label_text = country_label_text.split(': ')[0]
    text = country_label_text.split(': ')[1].strip('\n')
    label_text = "COUNTRY:"+label_text+"\n"
    gpt_label_text_list.append(label_text)
    gpt_text_list.append(text)

for pos, organization_text in enumerate(organization_text_list):
    text = organization_text.strip('\n')
    label_text = 'ORGANIZATION:'+organization_label_text_list[pos]
    gpt_label_text_list.append(label_text)
    gpt_text_list.append(text)

for organization_text in organization_label_text_list_2:
    label_text = organization_text.split(': ')[0]
    text = organization_text.split(': ')[1].strip('\n')
    label_text = "ORGANIZATION:"+label_text+"\n"
    gpt_label_text_list.append(label_text)
    gpt_text_list.append(text)


for pos, LOCATION_OTHER_text in enumerate(LOCATION_OTHER_label_text_list_1):
    text = LOCATION_OTHER_text.strip('\n')
    label_text = 'LOCATION-OTHER:'+text+"\n"
    gpt_label_text_list.append(label_text)
    gpt_text_list.append(text)

for PHONE_label_text in PHONE_label_text_list_1:
    label_text = PHONE_label_text.split(': ')[0]
    text = PHONE_label_text.split(': ')[1].strip('\n')
    label_text = "PHONE:"+label_text+"\n"
    gpt_label_text_list.append(label_text)
    gpt_text_list.append(text)

for PHONE_label_text in PHONE_label_text_list_2:
    label_text = PHONE_label_text.split(': ')[0]
    text = PHONE_label_text.split(': ')[1].strip('\n')
    label_text = "PHONE:"+label_text+"\n"
    gpt_label_text_list.append(label_text)
    gpt_text_list.append(text)


In [None]:
GPT_train_df= pd.DataFrame(data={'text':gpt_text_list,'label':gpt_label_text_list})
GPT_train_df.head()

In [None]:
len(GPT_train_df)

In [None]:
train_df = pd.concat([train_df, GPT_train_df], ignore_index=True)

In [None]:
label_name_list = ["PATIENT","DOCTOR","USERNAME","PROFESSION","ROOM","DEPARTMENT","HOSPITAL",
              "ORGANIZATION","STREET","CITY","STATE","COUNTRY","ZIP","LOCATION-OTHER",
              "AGE","DATE","TIME","DURATION","SET","PHONE","FAX","EMAIL","URL","IPADDR",
              "SSN","MEDICALRECORD","HEALTHPLAN","ACCOUNT","LICENSE","VECHICLE","DEVICE",
              "BIOID","IDNUM","PHI"]

id_to_label = dict(enumerate(label_name_list))
label_to_id = {v: k for k, v in id_to_label.items()}

CFG.target_size = len(label_name_list)

# content max len

In [None]:
# ====================================================
# Define max_len
# ====================================================
lengths = []
tk0 = tqdm(train_df['text'].fillna("").values, total=len(train_df))
for text in tk0:
    length = len(CFG.tokenizer(text, add_special_tokens=False)['input_ids'])
    lengths.append(length)
print(max(lengths) + 2)
CFG.max_len = max(lengths) + 2 # cls & sep

# LOGGER.info(f"input column max_len: {CFG.max_len}")

# 找尋子字串的方法

In [None]:
result = find_all("SWAN HILL DISTRICT HEALTH [SWAN HILL","SWAN")
result = ";".join(result)
print(result)

# Label & Annotation

In [None]:
all_text = []
all_annotations = []
all_label_list = []

all_label_text = []
all_annotation_length = []
for ind in train_df.index:
    annotation_length = 0 # annotation個數
    sub_text_list = []
    annotation = []
    label_list = []
    text = train_df['text'][ind] 
    label_text = train_df['label'][ind]
    if "\n" in label_text:
        label_text_list = label_text.split("\n") # 分開多個label
        label_text_list.remove("")

        for label_text in label_text_list:
            if(label_text.split(":")[0] == "TIME"):
                label_id_name = "TIME"
                temp_text_label, _ = label_text.split("=>")

                _ , sub_text=temp_text_label.split("TIME:")

            elif(label_text.split(":")[0] == "DATE"):
            # if len(label_text.split(":")) !=2:
                label_id_name = "DATE"

                temp_text_label, _ = label_text.split("=>")

                
                _ , sub_text=temp_text_label.split("DATE:")
            elif(label_text.split(":")[0] == "URL"):

                label_id_name = "URL"

                _ , sub_text=label_text.split("URL:")
            elif(label_text.split(":")[0] == "DURATION"):
                
                label_id_name = "DURATION"

                temp_text_label, _ = label_text.split("=>")

                
                _ , sub_text=temp_text_label.split("DURATION:")

            elif(label_text.split(":")[0] == "SET"):
                
                label_id_name = "SET"

                temp_text_label, _ = label_text.split("=>")

                
                _ , sub_text=temp_text_label.split("SET:")            
                
            else :
                label_id_name, sub_text = label_text.split(":") # label子字串是哪種PHI, text中的label子字串

            label_id = label_to_id[label_id_name]

            result = find_all(text, sub_text)

            if(len(result)==0):
                print(f"Index:{ind},PHI:{label_id_name}, text:{text},sub_text:{sub_text}, do not match")

            else:

                annotation_length += len(result) # 計算annotation個數

                for sub_result in result:
                    sub_text_list.append(sub_text)
                    annotation.append(sub_result)
                    label_list.append(label_id)

            
    else :
        # print("PHI NULL")
        sub_text_list.append("NULL")
        annotation.append("NULL")
        label_list.append(label_to_id["PHI"]) # if sub_text not match, label connect PHI

    all_text.append(text)
    all_label_text.append(sub_text_list)
    all_annotations.append(annotation)
    all_annotation_length.append(annotation_length)
    all_label_list.append(label_list)
        

In [None]:
all_val_text = []
all_val_annotations = []
all_val_label_list = []
all_val_label_text = []
all_val_annotation_length = []
for ind in val_df.index:
    annotation_length = 0 # annotation個數
    sub_text_list = []
    annotation = []
    label_list = []
    text = val_df['text'][ind] 
    label_text = val_df['label'][ind]
    if "\n" in label_text:
        label_text_list = label_text.split("\n") # 分開多個label
        label_text_list.remove("")

        for label_text in label_text_list:
            if(label_text.split(":")[0] == "TIME"):
                label_id_name = "TIME"
                temp_text_label, _ = label_text.split("=>")

                _ , sub_text=temp_text_label.split("TIME:")

            elif(label_text.split(":")[0] == "DATE"):
            # if len(label_text.split(":")) !=2:
                label_id_name = "DATE"

                temp_text_label, _ = label_text.split("=>")

                
                _ , sub_text=temp_text_label.split("DATE:")
            elif(label_text.split(":")[0] == "URL"):

                label_id_name = "URL"

                _ , sub_text=label_text.split("URL:")
            elif(label_text.split(":")[0] == "DURATION"):
                
                label_id_name = "DURATION"

                temp_text_label, _ = label_text.split("=>")

                
                _ , sub_text=temp_text_label.split("DURATION:")

            elif(label_text.split(":")[0] == "SET"):
                
                label_id_name = "SET"

                temp_text_label, _ = label_text.split("=>")

                
                _ , sub_text=temp_text_label.split("SET:")            
                
            else :
                label_id_name, sub_text = label_text.split(":") # label子字串是哪種PHI, text中的label子字串

            label_id = label_to_id[label_id_name]

            result = find_all(text, sub_text)

            if(len(result)==0):
                print(f"Index:{ind},PHI:{label_id_name}, text:{text},sub_text:{sub_text}, do not match")

            else:

                annotation_length += len(result) # 計算annotation個數

                for sub_result in result:
                    sub_text_list.append(sub_text)
                    annotation.append(sub_result)
                    label_list.append(label_id)
            
    else :
        # print("PHI NULL")
        sub_text_list.append("NULL")
        annotation.append("NULL")
        label_list.append(label_to_id["PHI"]) # if sub_text not match, label connect PHI

    all_val_text.append(text)
    all_val_label_text.append(sub_text_list)
    all_val_annotations.append(annotation)
    all_val_annotation_length.append(annotation_length)
    all_val_label_list.append(label_list)
        

# 創建Train Dataframe

In [None]:
new_train_df = pd.DataFrame(list(zip(all_text, all_label_text,all_label_list, all_annotations, all_annotation_length)), 
columns =['text', 'label_text','label', 'annotation', 'annotation_length'])

In [None]:
print("原本train dataframe筆數:",len(train_df))
print("新train dataframe筆數:",len(new_train_df))

# 計算訓練集中各PHI數量

In [None]:
all_label_count = {k : 0 for k, v in id_to_label.items()}
label_id = label_to_id["FAX"] # DATE、DEPARTMENT少2筆，DOCTOR少4筆，ORGANIZATION少1筆
iteration = 0
count = 0
for ind in new_train_df.index:
    # if label_id in new_train_df.loc[ind].label:
        
    label_id_list = new_train_df.loc[ind].label
    label_id_count_dict = {i:label_id_list.count(i) for i in label_id_list}

    for label_id in label_id_list:
        all_label_count[label_id]+=label_id_count_dict[label_id]

all_label_count = {id_to_label[k] : v for k, v in all_label_count.items()}


# 創建Val Dataframe

In [None]:
new_val_df = pd.DataFrame(list(zip(all_val_text, all_val_label_text,all_val_label_list, all_val_annotations, all_val_annotation_length)), 
columns =['text', 'label_text','label', 'annotation', 'annotation_length'])

In [None]:
new_val_df.head()

In [None]:
print("原本val dataframe筆數:",len(val_df))
print("新val dataframe筆數:",len(new_val_df))

# Train Dataset

In [None]:
# ====================================================
# Dataset
# ====================================================
def prepare_input(cfg, text):
    #print(text)
    inputs = cfg.tokenizer(text, 
                           add_special_tokens=True,
                           max_length=CFG.max_len,
                           padding="max_length",
                           return_offsets_mapping=False)
    
    for k, v in inputs.items():
        inputs[k] = torch.tensor(v, dtype=torch.long)
    return inputs


def create_label(cfg, text, label_list, location_list, annotation_length):

    encoded = cfg.tokenizer(text,
                            add_special_tokens=True,
                            max_length=CFG.max_len,
                            padding="max_length",
                            return_offsets_mapping=True)
    offset_mapping = encoded['offset_mapping']


    ignore_idxes = np.where(np.array(encoded.sequence_ids()) != 0)[0] # ignore special & padding token



    label = np.zeros(len(offset_mapping))# EXAMPLE 1 : [32], EXAMPLE 2 : [1, 1, 1, 1]

    # 將矩陣中的0替換為PHI:NULL 標籤的數字

    PHI_NULL_ID  = label_to_id["PHI"] 
    label[label == 0] = PHI_NULL_ID

    label[ignore_idxes] = -1

    if annotation_length != 0:
    # if location_list[0] != "NULL":
        for step, location in enumerate(location_list):# EXAMPLE 1 : [13 23], EXAMPLE 2 : [37 46, 57 63, 76 84, 89 97], EXAMPLE 3 : [9 17, 18 26, 9 17, 18 26]
            
            # for loc in [s.split() for s in location.split(' ')]:
            
                start_idx = -1
                end_idx = -1

    

                # start, end = int(loc[0]), int(loc[1]) # example start : 696, end : 724
                start, end = location.split(' ')
                start = int(start)
                end = int(end)
                for idx in range(len(offset_mapping)): # 走訪offset_mapping
                    if (start_idx == -1) & (start < offset_mapping[idx][0]):
                        start_idx = idx - 1 # example 180 = 181 - 1

                        

                    if (end_idx == -1) & (end <= offset_mapping[idx][1]):
                        end_idx = idx + 1 # example 187 = 186 + 1

                        

                if start_idx == -1:
                    start_idx = end_idx
                if (start_idx != -1) & (end_idx != -1):
                    label_id = label_list[step]
                    label[start_idx:end_idx] = label_id

        
    return torch.tensor(label, dtype=torch.float)


class TrainDataset(Dataset):
    def __init__(self, cfg, df):
        self.cfg = cfg
        
        self.text = df['text'].values
        self.label_text = df['label_text'].values
        self.label = df['label'].values
        self.locations = df['annotation'].values

        self.annotation_lengths = df['annotation_length'].values

    def __len__(self):
        return len(self.text)

    def __getitem__(self, item):
        inputs = prepare_input(self.cfg, 
                               self.text[item])
        
        label = create_label(self.cfg, 
                             self.text[item], 
                             self.label[item], 
                             self.locations[item],
                             self.annotation_lengths[item])
        
        return {
            'input_ids':inputs['input_ids'],
            'attention_mask':inputs['attention_mask'],
            'label':label,
            }
        # return inputs, label

In [None]:
def freeze(module):
    """
    Freezes module's parameters.
    """

    for parameter in module.parameters():
        parameter.requires_grad = False

def get_scheduler(cfg, optimizer, num_train_steps):
    if cfg.scheduler == 'linear':
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps
        )
    elif cfg.scheduler == 'cosine':
        scheduler = get_cosine_schedule_with_warmup(
            optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps, num_cycles=cfg.num_cycles
        )
    return scheduler


In [None]:
class NER_Model(nn.Module):
    def __init__(self, model_name):
        super(NER_Model, self).__init__()

        self.cfg = CFG
        self.config = AutoConfig.from_pretrained(model_name)
        self.config.hidden_dropout_prob = 0
        self.config.attention_probs_dropout_prob = 0

        self.model = AutoModel.from_pretrained(model_name, config=self.config)

        self.output = nn.Sequential(
            nn.LayerNorm(self.config.hidden_size),
            nn.Linear(self.config.hidden_size, self.cfg.target_size)
        )

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, ids, mask, token_type_ids=None, targets=None, input_token_starts = None):
        if token_type_ids:
            transformer_out = self.model(ids, mask, token_type_ids)
        else:
            transformer_out = self.model(ids, mask)
        
        sequence_output = transformer_out[0] # shape : (batch,length,dimension)

        # 去除[CLS]标签等位置，获得与label对齐的pre_label表示
        # token_sequence_output = [layer[starts.nonzero().squeeze(1)]
        #                           for layer, starts in zip(sequence_output, input_token_starts)]
        
        # 将sequence_output的pred_label维度padding到最大长度
        # padded_sequence_output = pad_sequence(token_sequence_output, batch_first=True)
        
        logits = self.output(sequence_output)

        return logits

In [None]:
def criterion(logits, labels):

    if labels is not None:
        
        loss_mask = labels.gt(-1) # 大於 -1 => True
        loss_fct = nn.CrossEntropyLoss()
        # Only keep active parts of the loss
        if loss_mask is not None:
        # 只留下label存在的位置计算loss
            active_loss = loss_mask.view(-1) == 1
            active_logits = logits.view(-1, CFG.target_size)[active_loss]
            active_labels = labels.view(-1)[active_loss]
            loss = loss_fct(active_logits, active_labels)
        else:
            loss = loss_fct(logits.view(-1, CFG.target_size), labels.view(-1))

    return loss

In [None]:
def train_one_epoch(model, optimizer, scheduler, dataloader, device, epoch, best_score, best_valid_loss, valid_loader):
    model.train()
    scaler = GradScaler(enabled=CFG.apex)

    dataset_size = 0
    running_loss = 0

    start = end = time.time()

    for step, data in enumerate(dataloader):

        
        
        ids = data['input_ids'].to(device, dtype=torch.long)
        mask = data['attention_mask'].to(device, dtype=torch.long)
        labels = data['label'].to(device, dtype=torch.long)

        # label_start = data['label_start'].to(device, dtype=torch.float)

        batch_size = ids.size(0)

        with autocast(enabled=CFG.apex):
            logits = model(ids, mask)
            loss = criterion(logits, labels)

        #accumulate
        loss = loss / CFG.n_accumulate
        scaler.scale(loss).backward()
        if (step +1) % CFG.n_accumulate == 0:#n_accumulate=1
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_norm)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            if scheduler is not None:
                scheduler.step()
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size

        epoch_loss = running_loss / dataset_size

        end = time.time()

        if step % CFG.print_freq == 0 or step == (len(dataloader)-1):
            print('Epoch: [{0}][{1}/{2}] '
                  'Loss: [{3}]'
                  'Elapsed {remain:s} '
                  .format(epoch+1, step, len(dataloader), epoch_loss,
                          remain=timeSince(start, float(step+1)/len(dataloader))))

        if (step > 0) & (step % CFG.eval_freq == 0) :

            valid_epoch_loss = valid_one_epoch(model, valid_loader, device, epoch)

            # score = get_score(pred, valid_labels)

            LOGGER.info(f'Epoch {epoch+1} Step {step} - avg_train_loss: {epoch_loss:.4f}  avg_val_loss: {valid_epoch_loss:.4f}')
            # LOGGER.info(f'Epoch {epoch+1} Step {step} - Score: {score:.4f}')

            if valid_epoch_loss < best_valid_loss:
                best_valid_loss = valid_epoch_loss
                LOGGER.info(f'Epoch {epoch+1} Step {step} - Save Best Loss: {best_valid_loss:.4f} Model')
                torch.save({'model': model.state_dict()},
                            # 'predictions': pred},
                            OUTPUT_DIR+f"/model/{CFG.model.replace('/', '-')}_best.pth")

            # model.train()


    gc.collect()

    return epoch_loss, best_valid_loss # valid_epoch_loss, pred, best_score

@torch.no_grad()
def valid_one_epoch(model, dataloader, device, epoch):
    model.eval()

    dataset_size = 0
    running_loss = 0

    start = end = time.time()
    pred = []

    for step, data in enumerate(dataloader):
        ids = data['input_ids'].to(device, dtype=torch.long)
        mask = data['attention_mask'].to(device, dtype=torch.long)
        labels = data['label'].to(device, dtype=torch.long)

        # label_start = data['label_start'].to(device, dtype=torch.float)

        batch_size = ids.size(0)
        outputs = model(ids, mask)
        loss = criterion(outputs, labels)
        # pred.append(outputs.to('cpu').numpy())

        running_loss += (loss.item()* batch_size)
        dataset_size += batch_size

        epoch_loss = running_loss / dataset_size

        end = time.time()

        if step % CFG.print_freq == 0 or step == (len(dataloader)-1):
            print('EVAL: [{0}/{1}] '
                  'Elapsed {remain:s} '
                  .format(step, len(dataloader),
                          remain=timeSince(start, float(step+1)/len(dataloader))))

    # pred = np.concatenate(pred)
    return epoch_loss #, pred

In [None]:
def train_loop(fold):
    # LOGGER.info(f'-------------fold:{fold} training-------------')

    # train_data = train[train.fold != fold].reset_index(drop=True)
    # valid_data = train[train.fold == fold].reset_index(drop=True)
    # valid_labels = valid_data[CFG.targets].values

    trainDataset = TrainDataset(CFG, new_train_df)
    validDataset = TrainDataset(CFG, new_val_df)
    
    

    train_loader = DataLoader(trainDataset,
                              batch_size = CFG.batch_size,
                              shuffle=True,
                            #   collate_fn = collate_fn,
                              num_workers = CFG.num_workers,
                              pin_memory = True,
                              drop_last=True)

    
    valid_loader = DataLoader(validDataset,
                              # batch_size = CFG.batch_size * 2,
                              batch_size = CFG.batch_size,
                              shuffle=False,
                              # collate_fn = collate_fn,
                              num_workers = CFG.num_workers,
                              pin_memory = True,
                              drop_last=False)

    model = NER_Model(CFG.model)
    torch.save(model.config, OUTPUT_DIR+'/model/config.pth')
    model.to(device)
    optimizer = AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weigth_decay)
    num_train_steps = int(len(train_df) / CFG.batch_size * CFG.epochs)
    scheduler = get_scheduler(CFG, optimizer, num_train_steps)

    # loop
    best_score = 100

    best_valid_loss = 100 # 之前epoch最低的loss

    for epoch in range(CFG.epochs):


        start_time = time.time()

        train_epoch_loss, valid_epoch_loss = train_one_epoch(model, optimizer, scheduler, train_loader, device, epoch, best_score, best_valid_loss, valid_loader)
        # valid_epoch_loss : 當前epoch最低的loss
        
        elapsed = time.time() - start_time

        # LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {train_epoch_loss:.4f}  time: {elapsed:.0f}s')
        LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {train_epoch_loss:.4f}  avg_val_loss: {valid_epoch_loss:.4f}  time: {elapsed:.0f}s')
        # LOGGER.info(f'Epoch {epoch+1} - Score: {score:.4f}')

        if valid_epoch_loss < best_valid_loss:
            best_valid_loss = valid_epoch_loss
            LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
            torch.save({'model': model.state_dict()},
                        OUTPUT_DIR+f"/model/{CFG.model.replace('/', '-')}_best.pth")


    torch.cuda.empty_cache()
    gc.collect()


In [None]:
train_loop(0)