In [1]:
import torch
from transformers import BertModel, BertTokenizer
import pandas as pd
import numpy as np
from nltk import word_tokenize
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
BERT_BASE_UNCASED = 'bert-base-cased'
PUNCATUATION = '''!\"#$%&\'()*+, -./:;<=>?@[\]^_`{|}~'''
MAX_LEN = 512

In [3]:
# model 部分
class BERT(torch.nn.Module):
    def __init__(self, TYPE = BERT_BASE_UNCASED) -> None:
        super(BERT, self).__init__()
        self.model = BertModel.from_pretrained(BERT_BASE_UNCASED)

    def forward(self, dict):
        out = self.model(**dict)
        cls_output = out.pooler_output
        seq_output = out.last_hidden_state
        return cls_output, seq_output

class WordSelector(torch.nn.Module):
    # 字詞選擇器
    def __init__(self, d_model = 768) -> None:
        super(WordSelector, self).__init__()
        self.linear = torch.nn.Linear(d_model, 2)

    def forward(self, last_state):
        out = self.linear(last_state)
        return out

class SentPoistionTeller(torch.nn.Module):
    # 立場辨識模組
    def __init__(self, dim_q = 768, dim_r = 768) -> None:
        super(SentPoistionTeller, self).__init__()
        self.linear = torch.nn.Linear(dim_q + dim_r, 2)

    def forward(self, q_cls, r_cls):
        h = torch.concat([q_cls, r_cls], dim = 1)
        out = self.linear(h)
        return out

class BertExtModel(torch.nn.Module):

    def __init__(self) -> None:
        super(BertExtModel, self).__init__()
        self.r_bert = BERT(TYPE = BERT_BASE_UNCASED)
        self.sent_position_teller = SentPoistionTeller()
        self.word_selector = WordSelector()

    def forward(self, q_dict, r_dict):
        q_cls, q_last_seq = self.r_bert(q_dict)
        r_cls, r_last_seq = self.r_bert(r_dict)
        s = self.sent_position_teller(q_cls, r_cls)
        q_out_seq = self.word_selector(q_last_seq)
        r_out_seq = self.word_selector(r_last_seq)
        return q_out_seq, r_out_seq, s

In [4]:
def preProcess(sent: str):
    sent = sent.strip('\"')
    sent = word_tokenize(sent)
    for i in PUNCATUATION:
        if i in sent:
            sent = list(filter(i.__ne__, sent))
    if len(sent) == 0:
        sent.append(' ')
    return ' '.join(sent)

In [5]:
dataframe = pd.read_csv('./Batch_answers - test_data(no_label).csv')
dataframe['q'] = dataframe['q'].apply(preProcess)
dataframe['r'] = dataframe['r'].apply(preProcess)

In [6]:
dataframe

Unnamed: 0,id,q,r,s
0,6199,-- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -...,If so why do we still have apes and why are th...,DISAGREE
1,5807,There 's a lot of discussion there on that iss...,Of course The makers of Expelled were within t...,DISAGREE
2,8487,`` It 's not helping The guns these people hav...,Oh I would wager about like Mexico about 80 fe...,DISAGREE
3,1760,Shooting 3 seriously injured in Auburn shootin...,Pickup strikes group of four youths Houston am...,AGREE
4,6228,This is the argument concerning 'choice that t...,I believe there is a point at which we society...,DISAGREE
...,...,...,...,...
2011,9499,You are betraying your belief system,Yep I 'm assuming that by `` belief system `` ...,AGREE
2012,4611,You are in a loud minority railing against the...,Being in the minority or in the majority is ir...,DISAGREE
2013,9328,You bet your XXX that 'd make me happy,Well first I probably would n't bet my XXX but...,DISAGREE
2014,5225,you say `` f the Constitution ``,and gun nuts say f the children when we have t...,DISAGREE


In [7]:
dataframe.sample(4)

Unnamed: 0,id,q,r,s
1374,1507,Go do a Google search and find out How many ba...,None Embryos are n't babies,DISAGREE
1185,9475,Where the hell did you get that from Do you de...,Well you have n't proven that you do n't suppo...,DISAGREE
5,3537,My point is people are prepared to admit that ...,The bible is not a science textbook I do not t...,DISAGREE
363,7061,The day when you have to walk into that aborti...,Right on A real man supports those he loves th...,AGREE


In [8]:
tokenizer = BertTokenizer.from_pretrained(BERT_BASE_UNCASED)

In [9]:
model = BertExtModel()
model.load_state_dict(torch.load('./bertExtModelFullDataWordGrained.pt'))
model = model.cuda()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
id_list = []
Q_list = []
R_list = []

In [11]:
model.eval()

BertExtModel(
  (r_bert): BERT(
    (model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(28996, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (Laye

In [12]:
for i in tqdm(dataframe.index):
    data = dataframe.iloc[i]
    id_str = data['id']
    text_q = data['q'].split()[:MAX_LEN]
    text_r = data['r'].split()[:MAX_LEN]
    q_dict = {k:v.cuda() for k, v in tokenizer.encode_plus(text_q, return_tensors='pt', truncation=True).items()}
    r_dict = {k:v.cuda() for k, v in tokenizer.encode_plus(text_r, return_tensors='pt', truncation=True).items()}
    q_pred, r_pred, s_pred = model(q_dict, r_dict)
    Q_str = tokenizer.decode(token_ids=q_dict['input_ids'].squeeze(0)[q_pred.argmax(dim=2).squeeze(0).bool()], skip_special_tokens=True)
    R_str = tokenizer.decode(token_ids=r_dict['input_ids'].squeeze(0)[r_pred.argmax(dim=2).squeeze(0).bool()], skip_special_tokens=True)
    id_list.append(id_str)
    Q_list.append(Q_str)
    R_list.append(R_str)

100%|██████████| 2016/2016 [00:40<00:00, 49.32it/s]


In [13]:
submission = pd.DataFrame({"id":id_list, "q": Q_list, "r":R_list})

In [14]:
submission.to_csv('./submission.csv', index=False, columns=['id', 'q','r'])