In [1]:
import torch
import pandas as pd
from transformers import BertModel, BertTokenizer
from tqdm import tqdm
import pickle
import random
from torch.utils.data import Dataset, DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
BERT_BASE_UNCASED = 'bert-base-cased'

In [3]:
# model 部分

class BERT(torch.nn.Module):
    def __init__(self, TYPE = BERT_BASE_UNCASED) -> None:
        super(BERT, self).__init__()
        self.model = BertModel.from_pretrained(BERT_BASE_UNCASED)

    def forward(self, dict):
        out = self.model(**dict)
        cls_output = out.pooler_output
        seq_output = out.last_hidden_state
        return cls_output, seq_output

class WordSelector(torch.nn.Module):
    # 字詞選擇器
    def __init__(self, d_model = 768) -> None:
        super(WordSelector, self).__init__()
        self.linear = torch.nn.Linear(d_model, 2)

    def forward(self, last_state):
        out = self.linear(last_state)
        return out

class SentPoistionTeller(torch.nn.Module):
    # 立場辨識模組
    def __init__(self, dim_q = 768, dim_r = 768) -> None:
        super(SentPoistionTeller, self).__init__()
        self.linear = torch.nn.Linear(dim_q + dim_r, 2)

    def forward(self, q_cls, r_cls):
        h = torch.concat([q_cls, r_cls], dim = 1)
        out = self.linear(h)
        return out

class BertExtModel(torch.nn.Module):

    def __init__(self) -> None:
        super(BertExtModel, self).__init__()
        self.r_bert = BERT(TYPE = BERT_BASE_UNCASED)
        self.sent_position_teller = SentPoistionTeller()
        self.word_selector = WordSelector()

    def forward(self, q_dict, r_dict):
        q_cls, q_last_seq = self.r_bert(q_dict)
        r_cls, r_last_seq = self.r_bert(r_dict)
        s = self.sent_position_teller(q_cls, r_cls)
        q_out_seq = self.word_selector(q_last_seq)
        r_out_seq = self.word_selector(r_last_seq)
        return q_out_seq, r_out_seq, s

In [4]:
dataFrame = pd.read_csv('./PreprocessFullData.csv')

In [5]:
dataFrame.sample(5) # 前處理部分

Unnamed: 0,id,q,r,s,Q,R,com_q,com_r
4317,1132,My argument that a `` human individual 's life...,Apologies that was n't my intention My meaning...,DISAGREE,human individual 's life biologically begins a...,your use of it as an unsupported 'fact forced ...,"[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ..."
611,172,Give Ryuuichi a break will you She is a 17 yea...,Yes I was 17 years old before but if I asked a...,DISAGREE,She is a 17 year old girl that is contemplating,Now that I 'm 26 I can take people,"[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
30701,8045,I would love to see more discussion of views t...,Yes it is interesting how we draw an arbitary ...,AGREE,I would love to see more discussion of views t...,it is interesting how we draw an arbitary `` l...,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
12425,3269,OK try the US Constitution We are a representa...,which is pretty much the same thing The majori...,DISAGREE,We are a representative democratic republic no...,The representatives are elected by a majority ...,"[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]","[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, ..."
16715,4431,Pro-lifers are like slave traders What a pro-l...,Thats not the point The Pro-lifers advocate fo...,DISAGREE,Pro-lifers are like slave traders What a pro-l...,Thats not the point The Pro-lifers advocate fo...,"[1, 1, 1, 1, 1, 1, 1, 1, 1]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, ..."


In [6]:
tokenizer = BertTokenizer.from_pretrained(BERT_BASE_UNCASED)

In [7]:
class NLPDataset(Dataset):

    def __init__(self, dataFrame, tokenizer, padding = 'max_length', max_length = 512) -> None:
        self.tokenizer = tokenizer
        self.dataframe = dataFrame
        self.padding = padding
        self.max_length = max_length
    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        # q_token = tokenizer.encode_plus(self.dataframe['q'].iloc[index].split(), padding = self.padding, max_length = self.max_length)
        # r_token = tokenizer.encode_plus(self.dataframe['r'].iloc[index].split(), padding = self.padding, max_length = self.max_length)
        q_token = tokenizer.encode_plus(self.dataframe['q'].iloc[index].split(), padding = self.padding, max_length = self.max_length, truncation=True)
        r_token = tokenizer.encode_plus(self.dataframe['r'].iloc[index].split(), padding = self.padding, max_length = self.max_length, truncation=True)
        s = 1 if self.dataframe['s'].iloc[index] == "AGREE" else 0
        com_q = torch.tensor([1] + eval(self.dataframe['com_q'].iloc[index])[:self.max_length - 2] + [1])
        com_r = torch.tensor([1] + eval(self.dataframe['com_r'].iloc[index])[:self.max_length - 2] + [1])
        com_q = torch.nn.functional.pad(com_q, pad=(0, self.max_length - com_q.shape[0]))
        com_r = torch.nn.functional.pad(com_r, pad=(0, self.max_length - com_r.shape[0]))
        return (
            torch.tensor(q_token['input_ids']), torch.tensor(q_token['token_type_ids']), torch.tensor(q_token['attention_mask']),
            torch.tensor(r_token['input_ids']), torch.tensor(r_token['token_type_ids']), torch.tensor(r_token['attention_mask']),
            torch.tensor(s),com_q, com_r
        )
dataset = NLPDataset(dataFrame=dataFrame, tokenizer=tokenizer)
trainLoader = DataLoader(dataset, batch_size=4, shuffle = True)

In [8]:
model = BertExtModel().cuda()
loss_fn = torch.nn.CrossEntropyLoss()
model_opt = torch.optim.AdamW(model.parameters(), 5e-5)
lr_sc = torch.optim.lr_scheduler.LinearLR(model_opt, start_factor=0.5, total_iters = 19)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
for epoch in range(5):
    total_loss = 0.
    currentLR = lr_sc.get_last_lr()[0]
    train_process = tqdm(trainLoader)
    for batch, data in enumerate(train_process, start = 1):
        model_opt.zero_grad()
        q_dict = {"input_ids" : data[0].cuda(), "token_type_ids": data[1].cuda(), "attention_mask": data[2].cuda()}
        r_dict = {"input_ids" : data[3].cuda(), "token_type_ids": data[4].cuda(), "attention_mask": data[5].cuda()}
        s_label, q_label, r_label = data[6], data[7], data[8]
        q_pred, r_pred, s_pred = model(q_dict, r_dict)
        s_loss = loss_fn(s_pred, s_label.cuda())
        q_loss = loss_fn(q_pred.contiguous().reshape(q_pred.shape[0] * q_pred.shape[1], -1), q_label.cuda().contiguous().reshape(-1))
        r_loss = loss_fn(r_pred.contiguous().reshape(r_pred.shape[0] * r_pred.shape[1], -1), r_label.cuda().contiguous().reshape(-1))
        t_loss = (q_loss + r_loss + s_loss) / 3.
        t_loss.backward()
        model_opt.step()
        total_loss += t_loss.item()
        train_process.set_postfix({"AVG_LOSS" : total_loss/ batch, "CURRENT_LR" : currentLR})
    lr_sc.step()

100%|██████████| 9587/9587 [35:39<00:00,  4.48it/s, AVG_LOSS=0.107, CURRENT_LR=2.5e-5]
100%|██████████| 9587/9587 [35:38<00:00,  4.48it/s, AVG_LOSS=0.0465, CURRENT_LR=2.63e-5]
100%|██████████| 9587/9587 [35:38<00:00,  4.48it/s, AVG_LOSS=0.045, CURRENT_LR=2.76e-5] 
100%|██████████| 9587/9587 [35:40<00:00,  4.48it/s, AVG_LOSS=0.043, CURRENT_LR=2.89e-5] 
100%|██████████| 9587/9587 [35:40<00:00,  4.48it/s, AVG_LOSS=0.043, CURRENT_LR=3.03e-5] 


In [10]:
torch.save(model.state_dict(), './bertExtModelFullDataWordGrained.pt')