In [1]:
%%html
<style type='text/css'>
.CodeMirror{
    font-size: 16px;
    font-family: Monaco;
}

div.output_area pre {
    font-size: 12px;
}
</style>

In [2]:
# import os
# os.chdir(os.getcwd()+"./..")

In [3]:
# import libaries
import torch
from torch import cuda
from torch.utils.data import Dataset,DataLoader

In [4]:
print(torch.__version__)

1.9.0a0+gitd69c22d


In [5]:
import os
import json
import random
import pandas as pd
import numpy as np
from sklearn.utils import shuffle
from sklearn import metrics
from collections import Counter

In [6]:
# !pip install transformers

In [7]:
import torch

In [8]:
from transformers import AutoTokenizer

# 1) Import Data

In [9]:
def read_qnli_data(file_name, data_dir):
    path = os.path.join(data_dir, file_name)
    with open(path, encoding='utf-8-sig') as f:
        text = f.readlines()

    header = text[0].strip().split("\t")
    lines = [line.strip().split("\t") for line in text[1:]]

    df = pd.DataFrame(lines, columns=header)
    return df


def get_qnli_pandas_dataframe(data_dir):
    qnli_train_df = read_qnli_data("train.tsv",data_dir)
    qnli_dev_df = read_qnli_data("dev.tsv",data_dir)
    qnli_train_df['label'] = np.where(
        qnli_train_df['label'] == 'entailment', 1, 0)
    qnli_dev_df['label'] = np.where(qnli_dev_df['label'] == 'entailment', 1, 0)

    qnli_dev_df['question'] = qnli_dev_df['question'].apply(lambda x: x.strip())
    qnli_dev_df['sentence'] = qnli_dev_df['sentence'].apply(lambda x: x.strip()) 

    qnli_train_df['question'] = qnli_train_df['question'].apply(lambda x: x.strip())
    qnli_train_df['sentence'] = qnli_train_df['sentence'].apply(lambda x: x.strip()) 

    return qnli_dev_df, qnli_train_df


def read_document_to_list(document_path):
    with open(document_path, encoding='utf-8-sig') as f:
        document = f.readlines()
        sentence_list = [line.strip()
                         for line in document if len(line.strip()) != 0]
        return sentence_list


def read_document_dict(document_dir):

    document_dict = {}

    for document_file_name in os.listdir(document_dir):
        if document_file_name.endswith(".txt"):
            document_name = document_file_name.replace(
                ".txt", "").replace("_", " ")
            document_path = os.path.join(document_dir, document_file_name)
            document_dict[document_name] = read_document_to_list(document_path)

    return document_dict


def read_json(file_path):
    with open(file_path) as f:
        json_f = json.load(f)
    data = json_f['data']
    return data


def get_random_index(List):
    return random.sample(range(len(List)), 1)[0]


def load_data(data_path, load_impossible_answer=False):

    data = read_json(data_path)

    data_dict = {}
    title_list = []
    context_list = []
    question_list = []
    id_list = []
    answer_text_list = []
    answer_start_list = []
    is_impossible_list = []

    for paragraphs in data:
        title = paragraphs['title']
        context_qas_list = paragraphs['paragraphs']

        for context_qas in context_qas_list:
            context = context_qas['context']
            qas_list = context_qas['qas']

            for qas in qas_list:
                title_list.append(title)
                context_list.append(context)

                is_impossible = qas['is_impossible']
                is_impossible_list.append(is_impossible)

                id_ = qas['id']
                id_list.append(id_)
                question = qas['question']
                question_list.append(question)

                if not is_impossible:
                    answer_list = qas['answers']
                    idx = get_random_index(answer_list)
                    answer_text = answer_list[idx]['text']
                    answer_start = answer_list[idx]['answer_start']

                    answer_text_list.append(answer_text)
                    answer_start_list.append(answer_start)
                else:
                    if load_impossible_answer:
                        answer_list = qas['plausible_answers']
                        idx = get_random_index(answer_list)
                        answer_text = answer_list[idx]['text']
                        answer_start = answer_list[idx]['answer_start']
                        answer_text_list.append(answer_text)
                        answer_start_list.append(answer_start)
                    else:
                        answer_text_list.append("")
                        answer_start_list.append(-1)

    data_dict['id'] = id_list
    data_dict['title'] = title_list
    data_dict['context'] = context_list
    data_dict['question'] = question_list
    data_dict['answer_text'] = answer_text_list
    data_dict['answer_start'] = answer_start_list
    data_dict['is_impossible'] = is_impossible_list

    return data_dict


def get_squad_v2_pandas_dataframe(squad_v2_dir,include_impossible=False, load_impossible_answer=False):
    # download from https://rajpurkar.github.io/SQuAD-explorer/
    train_data_path = os.path.join(squad_v2_dir, "train-v2.0.json")
    dev_data_path = os.path.join(squad_v2_dir, 'dev-v2.0.json')

    train_data_dict = load_data(train_data_path, load_impossible_answer)
    dev_data_dict = load_data(dev_data_path, load_impossible_answer)

    train_data_df = pd.DataFrame(train_data_dict)
    dev_data_df = pd.DataFrame(dev_data_dict)

    if not include_impossible:
        train_data_df = train_data_df[train_data_df['is_impossible'] == False]
        dev_data_df = dev_data_df[dev_data_df['is_impossible'] == False]

    train_data_df['question'] = train_data_df['question'].apply(lambda x: x.strip())
    train_data_df['context'] = train_data_df['context'].apply(lambda x: x.strip()) 

    dev_data_df['question'] = dev_data_df['question'].apply(lambda x: x.strip())
    dev_data_df['context'] = dev_data_df['context'].apply(lambda x: x.strip()) 

    return train_data_df, dev_data_df


In [10]:
device = "cuda" if cuda.is_available() else "cpu"

In [11]:
device

'cuda'

In [12]:
model_name ="albert-base-v2"

In [13]:
# qnli_dev_df, qnli_train_df = get_qnli_pandas_dataframe()

In [14]:
# qnli_train_df.head()

In [15]:
# qnli_train_df['question_sentence'] = qnli_train_df.apply(lambda x:  " ".join([x['question'],x['sentence']]),axis = 1)

# qnli_train_df['question_sentence_length'] = qnli_train_df['question_sentence'].apply(lambda x: len(x.split(" ")))

# max(qnli_train_df['question_sentence_length'])

In [16]:
os.listdir("/data/SQUAD2")

['__pycache__', 'dev-v2.0.json', 'evaluate.py', 'train-v2.0.json']

In [17]:
# from google.colab import drive
# drive.mount('/content/drive')
squad_v2_dir = "/data/SQUAD2" # data folder

In [18]:
train_df, dev_df = get_squad_v2_pandas_dataframe(squad_v2_dir,include_impossible=False,load_impossible_answer=False)

In [19]:
train_df['question_context'] = train_df.apply(lambda x:  " ".join([x['question'].strip(),x['context'].strip()]),axis = 1)

In [20]:
train_df['question_context_length'] = train_df['question_context'].apply(lambda x: len(x.split(" ")))

In [21]:
max(train_df['question_context_length'])

668

In [22]:
long_train_df = train_df[train_df['question_context_length']>=480]

In [23]:
len(long_train_df)

29

In [24]:
short_train_df = train_df[~train_df['id'].isin(long_train_df['id'])]

In [25]:
max(short_train_df['question_context_length'])

460

In [26]:
short_train_df = shuffle(short_train_df)

In [27]:
train_size  = 80000

In [28]:
short_train_df = short_train_df.head(train_size)

In [29]:
positive_long_text_dataset_df = pd.concat([long_train_df,short_train_df]) 

In [30]:
# positive_long_text_dataset_df = shuffle(positive_long_text_dataset_df)

In [31]:
# 5000*5000

In [32]:
sample_df = positive_long_text_dataset_df.head(2200)

In [33]:
sample_df['key'] = 1

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  """Entry point for launching an IPython kernel.


In [34]:
negative_long_text_dataset_df = sample_df.merge(sample_df,on = 'key')

In [35]:
negative_long_text_dataset_df.head()

Unnamed: 0,id_x,title_x,context_x,question_x,answer_text_x,answer_start_x,is_impossible_x,question_context_x,question_context_length_x,key,id_y,title_y,context_y,question_y,answer_text_y,answer_start_y,is_impossible_y,question_context_y,question_context_length_y
0,56cef51daab44d1400b88d11,Spectre_(2015_film),Critical appraisal of the film was mixed in th...,Which journalist considered Spectre the worst ...,Scott Mendelson,515,False,Which journalist considered Spectre the worst ...,511,1,56cef51daab44d1400b88d11,Spectre_(2015_film),Critical appraisal of the film was mixed in th...,Which journalist considered Spectre the worst ...,Scott Mendelson,515,False,Which journalist considered Spectre the worst ...,511
1,56cef51daab44d1400b88d11,Spectre_(2015_film),Critical appraisal of the film was mixed in th...,Which journalist considered Spectre the worst ...,Scott Mendelson,515,False,Which journalist considered Spectre the worst ...,511,1,56cef51daab44d1400b88d12,Spectre_(2015_film),Critical appraisal of the film was mixed in th...,Which writer for the San Francisco Chronicle a...,Mick LaSalle,1373,False,Which writer for the San Francisco Chronicle a...,512
2,56cef51daab44d1400b88d11,Spectre_(2015_film),Critical appraisal of the film was mixed in th...,Which journalist considered Spectre the worst ...,Scott Mendelson,515,False,Which journalist considered Spectre the worst ...,511,1,56cef51daab44d1400b88d13,Spectre_(2015_film),Critical appraisal of the film was mixed in th...,What score did the writer from the Chicago Tri...,75,2118,False,What score did the writer from the Chicago Tri...,511
3,56cef51daab44d1400b88d11,Spectre_(2015_film),Critical appraisal of the film was mixed in th...,Which journalist considered Spectre the worst ...,Scott Mendelson,515,False,Which journalist considered Spectre the worst ...,511,1,56cef51daab44d1400b88d14,Spectre_(2015_film),Critical appraisal of the film was mixed in th...,The reviewer from Variety compares Spectre to ...,Skyfall,2982,False,The reviewer from Variety compares Spectre to ...,510
4,56cef51daab44d1400b88d11,Spectre_(2015_film),Critical appraisal of the film was mixed in th...,Which journalist considered Spectre the worst ...,Scott Mendelson,515,False,Which journalist considered Spectre the worst ...,511,1,572a2c616aef051400155338,Digimon,"After a three-year hiatus, a fifth Digimon ser...",How long did Digimon stay off the air before r...,three-year hiatus,8,False,How long did Digimon stay off the air before r...,496


In [36]:
negative_long_text_dataset_df = negative_long_text_dataset_df[negative_long_text_dataset_df['title_x']!=negative_long_text_dataset_df['title_y']]

In [37]:
negative_long_text_dataset_df = shuffle(negative_long_text_dataset_df)

In [38]:
negative_long_text_dataset_df = negative_long_text_dataset_df.head(train_size)

In [39]:
cols  = "id_x	title_x	context_x	question_y	answer_text_x	answer_start_x	is_impossible_x	question_context_x	question_context_length_x".split("\t")

In [40]:
# negative_long_text_dataset_df.head()

In [41]:
negative_long_text_dataset_df = negative_long_text_dataset_df[cols]

In [42]:
negative_long_text_dataset_df.columns = positive_long_text_dataset_df.columns

In [43]:
negative_long_text_dataset_df.head()

Unnamed: 0,id,title,context,question,answer_text,answer_start,is_impossible,question_context,question_context_length
23815,5727e311ff5b5019007d9799,Gramophone_record,Original master discs are created by lathe-cut...,What type of cameras see infrared radiation?,limited to a few hundred vinyl pressings,2963,False,What is the limitation of the two step process...,542
30237,5727e311ff5b5019007d979c,Gramophone_record,Original master discs are created by lathe-cut...,How can additional points be earned?,quality of the vinyl is high,3070,False,What can increase the output of a stamper mold...,542
2898388,572fdc43947a6a140053cd76,Greeks,"In Homer's Iliad, the names Danaans (or Danaoi...",Her first appearance performing since giving b...,A country Danaja with a city Mukana (propaply:...,1601,False,What country is spoken of in the inscriptions ...,312
4528532,5727eb304b864d190016401a,London,Following his victory in the Battle of Hasting...,What does Nanjing mean?,southeastern corner,287,False,In what area of London was the Tower of London...,99
378066,56e7990237bdd419002c41eb,University_of_Kansas,"The school's sports teams, wearing crimson and...",What is the literal meaning of tariqah?,Rim Rock Farm,489,False,Where does KU's cross country team run? The sc...,127


In [44]:
negative_long_text_dataset_df['label'] = 0
positive_long_text_dataset_df['label'] = 1

In [45]:
all_data_dataset = pd.concat([negative_long_text_dataset_df,positive_long_text_dataset_df])

In [46]:
all_data_dataset = shuffle(all_data_dataset)

In [47]:
size = len(all_data_dataset)//2

In [48]:
train_df = all_data_dataset.head(size)

In [49]:
dev_df = all_data_dataset.tail(size)

In [50]:
from collections import Counter

In [51]:
Counter(train_df['label'])

Counter({0: 39810, 1: 40204})

# 2) Tokenization Features Engineering

In [52]:
max_length = 512

In [53]:
doc_overlap_length = 32

In [54]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [55]:
question = train_df['question'].iloc[0]
context = train_df['context'].iloc[0]

train_df['question_context_length'].iloc[0]

143

In [56]:
def get_token(question,context):
    
    inputs = tokenizer(
            text = question,
            text_pair = context,
            add_special_tokens = True,
            max_length = 512,
            padding = "max_length",
            return_token_type_ids = True,
            truncation = "only_second",
#             return_tensors=   'pt'
            )
    
    return inputs

In [57]:
def get_raw_token(question,context):
    
    inputs = tokenizer(
        text = question,
        text_pair = context,
        add_special_tokens = True,
        max_length = None,
        padding = False,
        return_token_type_ids = True,
        truncation = False,
        return_offsets_mapping = True
        )
    
    return inputs

In [58]:
raw_token  = get_raw_token(question,context)

In [59]:
raw_token.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping'])

In [60]:
def duplicate_token(question,context):
    
    inputs = get_token(question,context)
    
    return inputs,inputs

In [61]:
token_inputs_1,token_inputs_2 = duplicate_token(question,context)

In [62]:
tokenizer.decode(token_inputs_1['input_ids'])

'[CLS] english freemasonry almost came to a halt in what year?[SEP] yale seniors at graduation smash clay pipes underfoot to symbolize passage from their "bright college years," though in recent history the pipes have been replaced with "bubble pipes". ("bright college years," the university\'s alma mater, was penned in 1881 by henry durand, class of 1881, to the tune of die wacht am rhein.) yale\'s student tour guides tell visitors that students consider it good luck to rub the toe of the statue of theodore dwight woolsey on old campus. actual students rarely do so. in the second half of the 20th century bladderball, a campus-wide game played with a large inflatable ball, became a popular tradition but was banned by administration due to safety concerns. in spite of administration opposition, students revived the game in 2009, 2011, and 2014, but its future remains uncertain.[SEP]<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><

In [63]:
tokenizer.decode(token_inputs_2['input_ids'])

'[CLS] english freemasonry almost came to a halt in what year?[SEP] yale seniors at graduation smash clay pipes underfoot to symbolize passage from their "bright college years," though in recent history the pipes have been replaced with "bubble pipes". ("bright college years," the university\'s alma mater, was penned in 1881 by henry durand, class of 1881, to the tune of die wacht am rhein.) yale\'s student tour guides tell visitors that students consider it good luck to rub the toe of the statue of theodore dwight woolsey on old campus. actual students rarely do so. in the second half of the 20th century bladderball, a campus-wide game played with a large inflatable ball, became a popular tradition but was banned by administration due to safety concerns. in spite of administration opposition, students revived the game in 2009, 2011, and 2014, but its future remains uncertain.[SEP]<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><

In [64]:
index = train_df[train_df['question_context_length']>512].head(3).index[0]

In [65]:
question = train_df['question'].loc[index]
context = train_df['context'].loc[index]

In [66]:
train_df['question_context_length'].loc[index]

568

In [67]:
raw_token  = get_raw_token(question,context)

Token indices sequence length is longer than the specified maximum sequence length for this model (717 > 512). Running this sequence through the model will result in indexing errors


In [68]:
raw_token.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping'])

In [69]:
def split_long_token(question,context,raw_token,doc_overlap_length = 32):
    
    first_context_end_pos = raw_token['offset_mapping'][511][1]-1# because of specical token
    context_1 = context[:first_context_end_pos]
    
    sencond_char_start_pos = raw_token['offset_mapping'][511-doc_overlap_length][0]-1
    
    context_2 = context[sencond_char_start_pos:]
    
    
    inputs_1 = get_token(question,context_1)
    
    inputs_2 = get_token(question,context_2)

        
    return inputs_1,inputs_2

In [70]:
token_inputs_1,token_inputs_2 = split_long_token(question,context,raw_token,doc_overlap_length = 32)

In [71]:
tokenizer.decode(token_inputs_1['input_ids'])

'[CLS] where was wrestling later showcased in north america?[SEP] the mandolin has been used extensively in the traditional music of england and scotland for generations. simon mayor is a prominent british player who has produced six solo albums, instructional books and dvds, as well as recordings with his mandolin quartet the mandolinquents. the instrument has also found its way into british rock music. the mandolin was played by mike oldfield (and introduced by vivian stanshall) on oldfield\'s album tubular bells, as well as on a number of his subsequent albums (particularly prominently on hergest ridge (1974) and ommadawn (1975)). it was used extensively by the british folk-rock band lindisfarne, who featured two members on the instrument, ray jackson and simon cowe, and whose "fog on the tyne" was the biggest selling uk album of 1971-1972. the instrument was also used extensively in the uk folk revival of the 1960s and 1970s with bands such as fairport convention and steeleye span 

In [72]:
tokenizer.decode(token_inputs_2['input_ids'])

'[CLS] where was wrestling later showcased in north america?[SEP] mandolin solo played by johnny marr. more recently, the glasgow-based band sons and daughters featured the mandolin, played by ailidh lennon, on tracks such as fight, start to end, and medicine. british folk-punk icons the levellers also regularly use the mandolin in their songs. current bands are also beginning to use the mandolin and its unique sound - such as south london\'s indigo moss who use it throughout their recordings and live gigs. the mandolin has also featured in the playing of matthew bellamy in the rock band muse. it also forms the basis of paul mccartney\'s 2007 hit "dance tonight." that was not the first time a beatle played a mandolin, however; that distinction goes to george harrison on gone troppo, the title cut from the 1982 album of the same name. the mandolin is taught in lanarkshire by the lanarkshire guitar and mandolin association to over 100 people. also more recently hard rock supergroup them 

In [73]:
def prepare_feature(example):
    
    context = example['context']
    question = example['question']
    
    # get raw token
    raw_token  = get_raw_token(question,context)
    
    if len(raw_token['input_ids'])<=512:
        
        token_inputs_1,token_inputs_2 = duplicate_token(question,context)
        
    else:
        token_inputs_1,token_inputs_2 = split_long_token(question,context,raw_token,doc_overlap_length = 32)
        
    return token_inputs_1,token_inputs_2

In [74]:
# train_token_pairs_list = train_df.apply(lambda x:prepare_feature(x),axis = 1)

# dev_token_pairs_list = dev_df.apply(lambda x:prepare_feature(x),axis = 1)

# train_token_pairs_1 = [pair[0] for pair in train_token_pairs_list]
# train_token_pairs_1_df  = pd.DataFrame(train_token_pairs_1)

# train_token_pairs_2 = [pair[1] for pair in train_token_pairs_list]
# train_token_pairs_2_df  = pd.DataFrame(train_token_pairs_2)

# dev_token_pairs_1 = [pair[0] for pair in dev_token_pairs_list]
# dev_token_pairs_1_df  = pd.DataFrame(dev_token_pairs_1)

# dev_token_pairs_2 = [pair[1] for pair in dev_token_pairs_list]
# dev_token_pairs_2_df  = pd.DataFrame(dev_token_pairs_2)

# train_labels  = train_df['label'].to_list()

# dev_labels  = dev_df['label'].to_list()

In [75]:
class LongTextPairDataSet(Dataset):
    
    def __init__(self,df_pair_1,df_pair_2, label_list,device = "cpu"):
        self.len = len(df_pair_1)
        self.df_pair_1 = df_pair_1
        self.df_pair_2 = df_pair_2
        self.label_list = label_list
        self.device = device
    
    def __getitem__(self,index):
        df_1 = self.df_pair_1.iloc[index]
        df_2 = self.df_pair_2.iloc[index]
        labels = self.label_list[index]
        
        if isinstance(df_1,pd.core.series.Series):
            pair_dict_1 = df_1.to_dict()
            pair_dict_2 = df_2.to_dict()
        else:
            pair_dict_1 = df_1.to_dict(orient = "list")
            pair_dict_2 = df_2.to_dict(orient = "list")
        
        inputs_1 = {k:torch.tensor(v).to(self.device) for k,v in pair_dict_1.items()}
        
        inputs_2 = {k:torch.tensor(v).to(self.device) for k,v in pair_dict_2.items()}

        return {"token_inputs_1":inputs_1,"token_inputs_2":inputs_2,"labels":torch.tensor(labels).to(self.device)}

    def __len__(self):
        return self.len

In [76]:
import pickle
def save_object(obj,save_path):
    with open(save_path,mode='wb') as f:
        pickle.dump(obj,f,protocol = pickle.HIGHEST_PROTOCOL)
        
        
def open_object(file_name):
    with open(file_name,mode = 'rb') as f:
        return pickle.load(f)

In [77]:
# train_long_text_pair_dataset = LongTextPairDataSet(train_token_pairs_1_df,train_token_pairs_2_df,train_labels,'cuda')

# save_object(train_long_text_pair_dataset,"./data/train_long_text_pair_dataset.pkl")

In [78]:
train_long_text_pair_dataset = open_object("./data/train_long_text_pair_dataset.pkl")

In [79]:
# train_long_text_pair_dataset_test[:10]

In [80]:
# dev_long_text_pair_dataset = LongTextPairDataSet(dev_token_pairs_1_df,dev_token_pairs_2_df,dev_labels,'cuda')

# save_object(dev_long_text_pair_dataset,"./data/dev_long_text_pair_dataset.pkl")

In [81]:
dev_long_text_pair_dataset = open_object("./data/dev_long_text_pair_dataset.pkl")

In [82]:
batch_size = 8

In [83]:
len(train_long_text_pair_dataset)//batch_size

1251

In [84]:
train_loader = DataLoader(train_long_text_pair_dataset,batch_size)
dev_loader = DataLoader(dev_long_text_pair_dataset,batch_size)

In [85]:
len(train_long_text_pair_dataset)//batch_size

1251

# 3) Fine Tune Twin-Albert for long text pair classification

In [86]:
from transformers import AlbertModel

In [87]:
from transformers.modeling_outputs import SequenceClassifierOutput

In [88]:
class TwinAlBerts(torch.nn.Module):
    def __init__(self,model_config):
        
        super(TwinAlBerts,self).__init__()
        
        self.albert_layer_1 = AlbertModel.from_pretrained(model_config.model_name)
        self.albert_layer_2 = AlbertModel.from_pretrained(model_config.model_name)
        
        self.pre_classifier = torch.nn.Linear(768*2,768)
        
        self.dropout = torch.nn.Dropout(0.3)
        
        self.classifer = torch.nn.Linear(768,model_config.num_class)

        self.loss_fct = torch.nn.CrossEntropyLoss()

    def forward(self,token_inputs_1,token_inputs_2,labels=None):

        albert_outputs_1 = self.albert_layer_1(**token_inputs_1)
        albert_outputs_2 = self.albert_layer_2(**token_inputs_2)
        
        pooler_output_1 = albert_outputs_1.pooler_output

        pooler_output_2 = albert_outputs_2.pooler_output
        
        
        concat_pooler = torch.cat([pooler_output_1,pooler_output_2],axis = 1)
        
        concat_pooler = self.pre_classifier(concat_pooler)
        
        concat_pooler = self.dropout(concat_pooler)
        
        logits = self.classifer(concat_pooler)
        
        loss = None
        if labels is not None:
            loss = self.loss_fct(logits,labels)
                    
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=None,
            attentions=None,
        )

In [89]:
class model_config():
    model_name = 'albert-base-v2'
    num_class = 2

In [90]:
model = TwinAlBerts(model_config)

Some weights of the model checkpoint at albert-base-v2 were not used when initializing AlbertModel: ['predictions.dense.weight', 'predictions.dense.bias', 'predictions.decoder.weight', 'predictions.decoder.bias', 'predictions.LayerNorm.bias', 'predictions.bias', 'predictions.LayerNorm.weight']
- This IS expected if you are initializing AlbertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at albert-base-v2 were not used when initializing AlbertModel: ['predictions.dense.weight', 'predictions.dense.bias', 'predictions.decoder.weight', 'predictions.decoder.bias', 'predictions.LayerNorm.bias', 'pred

In [91]:
_ =  model.to(device)

In [92]:
len(train_long_text_pair_dataset)//batch_size

1251

In [110]:
learning_rate = 2e-05
epoches = 2
train_evaluate_step = 500
dev_evaluate_step = 500

In [111]:
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(params = model.parameters(),lr = learning_rate )

In [112]:
"""test"""

'test'

In [113]:
for inputs in train_loader:
    with torch.no_grad():
        outputs = model(**inputs)
    break

In [114]:
inputs['labels'].detach().cpu().numpy()

array([0, 1, 0, 1, 1, 0, 0, 1])

In [115]:
outputs.logits

tensor([[ 2.7387, -2.0829],
        [-2.6721,  2.1786],
        [ 2.6080, -2.5522],
        [-2.7392,  3.1546],
        [-2.8947,  3.0523],
        [ 2.5939, -1.8248],
        [ 2.2283, -2.2588],
        [-3.1693,  3.7005]], device='cuda:0')

In [116]:
outputs.loss.item()

0.00638848589733243

In [117]:
torch.argmax(outputs.logits,1).detach().cpu().numpy()

array([0, 1, 0, 1, 1, 0, 0, 1])

In [118]:
torch.max(outputs.logits,1)[1].cpu()

tensor([0, 1, 0, 1, 1, 0, 0, 1])

In [119]:
model_save_dir = "./twin-albert-base/" # data folder

In [120]:
os.listdir(model_save_dir)

['checkpoint-800',
 'checkpoint-2400',
 'checkpoint-400',
 'checkpoint-1600',
 'checkpoint-1200',
 'checkpoint-2000']

In [121]:
from tqdm import tqdm

In [122]:
def save_model(model,tokenizer, model_save_dir,step,train_state):
    model_save_dir = os.path.join(model_save_dir,f"checkpoint-{step}")
    model_name = "pytorch_model.bin"
    train_state_name = "train_state.txt"
    try:
        os.mkdir(model_save_dir)
    except FileExistsError:
        pass
    
    model_path = os.path.join(model_save_dir,model_name)
    train_state_path = os.path.join(model_save_dir,train_state_name)

    torch.save(model,model_path)
    tokenizer.save_pretrained(model_save_dir)
    
    if train_state is not None:
        with open(train_state_path,mode = 'w',encoding = 'utf-8-sig') as f:
            f.write(train_state)
    
    

In [123]:
def evaluate_full_metrics(model,dataset_loader):

    model.eval()

    loss_list = []
    labels_list = []
    pred_list = []
    prob_list = []

    pbar = tqdm(total = len(dataset_loader),desc = "Model Evaluate",position=0, leave=True)


    for inputs in dataset_loader:

        with torch.no_grad():
            outputs = model(**inputs)
        
            loss = outputs.loss.item()
            loss_list.append(loss)
            
            labels  = inputs['labels'].detach().cpu().numpy()
            labels_list.extend(labels)

            pred = torch.argmax(outputs.logits,1).detach().cpu().numpy()
            pred_list.extend(pred)
            
            probs = torch.sigmoid(outputs.logits)
            prob = probs[:,1]
            prob = prob.detach().cpu().numpy()
            prob_list.extend(prob)

            pbar.update(1)

    pbar.close()


    accuracy = metrics.accuracy_score(labels_list,pred_list)
    recall = metrics.recall_score(labels_list,pred_list)
    precision = metrics.precision_score(labels_list,pred_list)
    f1 = metrics.f1_score(labels_list,pred_list)
    fpr,tpr, threshold = metrics.roc_curve(labels_list,prob_list,pos_label=1)

    auc = metrics.auc(fpr,tpr)
    loss = np.mean(loss_list)
    
    result = {"accuracy":accuracy, "recall":recall, "precision":precision, "recall":recall, "f1":f1,'auc':auc,'loss':loss} 

    return result

In [124]:
def train():
    total_batch = 0
    model.train()
    
    total_pbar = tqdm(total = len(train_loader)*epoches,desc = "Model Training",position=0, leave=True)
    
    for epoch in range(epoches):

        for inputs in train_loader:
            outputs = model(**inputs)
            optimizer.zero_grad()
            loss = outputs.loss
            loss.backward()
            optimizer.step()

            if (total_batch+1) % dev_evaluate_step ==0:
                metrics = evaluate_full_metrics(model,dev_loader)
                train_state = str(metrics)
                print(train_state) 
                save_model(model,tokenizer,model_save_dir,total_batch+1,train_state)
                model.train()
            
            total_batch +=1
            total_pbar.update(1)
        
    total_pbar.close()

In [125]:
train()

Model Evaluate: 100%|██████████| 1252/1252 [02:47<00:00,  7.47it/s]


{'accuracy': 0.9494707409626523, 'recall': 0.9030193961207759, 'precision': 0.9953713907868635, 'f1': 0.9469490459215768, 'auc': 0.988254174419455, 'loss': 0.16381210512511385}


Model Evaluate: 100%|██████████| 1252/1252 [02:47<00:00,  7.46it/s]  


{'accuracy': 0.981925304573597, 'recall': 0.993001399720056, 'precision': 0.9714397496087637, 'f1': 0.9821022446356176, 'auc': 0.9980118079715397, 'loss': 0.0739072480662235}


Model Evaluate: 100%|██████████| 1252/1252 [02:47<00:00,  7.46it/s]   


{'accuracy': 0.9833233473137607, 'recall': 0.9938012397520496, 'precision': 0.9733646690168429, 'f1': 0.9834767982586325, 'auc': 0.9989856407334133, 'loss': 0.05458383240324918}


Model Evaluate: 100%|██████████| 1252/1252 [02:47<00:00,  7.46it/s]   


{'accuracy': 0.9914120231675654, 'recall': 0.9902019596080783, 'precision': 0.9925836841050311, 'f1': 0.9913913913913914, 'auc': 0.9995031514343452, 'loss': 0.02775414964522067}


Model Evaluate: 100%|██████████| 1252/1252 [02:47<00:00,  7.46it/s]  


{'accuracy': 0.9899141202316757, 'recall': 0.9932013597280543, 'precision': 0.9866905045689313, 'f1': 0.9899352267065271, 'auc': 0.9992751499570423, 'loss': 0.034879945902131164}


Model Training: 100%|██████████| 2504/2504 [28:31<00:00,  1.46it/s]
