In [1]:
import torch
import pandas as pd
from transformers import BertModel, BertTokenizer
from tqdm import tqdm
import nltk
BERT_BASE_CASED = 'bert-base-cased'
PUNCATUATION = '''!\"#$%&\'()*+, -./:;<=>?@[\]^_`{|}~'''
MAX_LEN = 512

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def preProcess(sent: str):
    sent = sent.strip('\"')
    sent = nltk.word_tokenize(sent)
    for i in PUNCATUATION:
        if i in sent:
            sent = list(filter(i.__ne__, sent))
    if len(sent) == 0:
        sent.append(' ')
    return ' '.join(sent)

In [3]:
def get_cls_indices(target) -> list:
    cls_indices = []
    for i, data in enumerate(target):
        if (data == 101):
            cls_indices.append(i)
    return cls_indices

In [4]:
def get_data(series, tokenizer):
    q_sent = nltk.sent_tokenize(series['q'])
    q_seg = []
    for j in range(len(q_sent)):
        q_seg.append(f'[CLS] {preProcess(q_sent[j])} [SEP]')
    q = ' '.join(q_seg)
    flag = 0
    q_inter_seg = []
    for word in q.split():
        if word == "[CLS]":
            flag += 1
        q_inter_seg.append(flag % 2)
    
    r_sent = nltk.sent_tokenize(series['r'])
    r_seg = []
    for j in range(len(r_sent)):
        r_seg.append(f'[CLS] {preProcess(r_sent[j])} [SEP]')
    r = ' '.join(r_seg)
    flag = 0
    r_inter_seg = []
    for word in r.split():
        if word == "[CLS]":
            flag += 1
        r_inter_seg.append(flag % 2)
    
    q_data = tokenizer.encode_plus(q.split()[:MAX_LEN], return_token_type_ids=False, add_special_tokens=False, truncation = True)
    q_data['token_type_ids'] = q_inter_seg[:MAX_LEN]

    r_data = tokenizer.encode_plus(r.split()[:MAX_LEN], return_token_type_ids=False, add_special_tokens=False, truncation = True)
    r_data['token_type_ids'] = r_inter_seg[:MAX_LEN]
    return q_data, r_data, q_sent, r_sent

In [5]:
class SentenceSelector(torch.nn.Module):

    def __init__(self) -> None:
        super(SentenceSelector, self).__init__()
        self.linear = torch.nn.Linear(768, 1)

    def forward(self, clss_hiddens):
        return self.linear(clss_hiddens)

class SideTeller(torch.nn.Module):

    def __init__(self, q_dim, r_dim) -> None:
        super(SideTeller, self).__init__()
        self.linear = torch.nn.Linear(q_dim + r_dim, 1)

    def forward(self, q_clss_hiddens, r_clss_hiddens):
        q_cls_hidden = torch.sum(q_clss_hiddens, dim=1)
        r_cls_hidden = torch.sum(r_clss_hiddens, dim=1)
        qr_cls_hiddens = torch.concat([q_cls_hidden, r_cls_hidden], dim=1)
        return self.linear(qr_cls_hiddens)
        
class BertSumExtModel(torch.nn.Module):

    def __init__(self) -> None:
        super(BertSumExtModel, self).__init__()
        self.bertModel = BertModel.from_pretrained(BERT_BASE_CASED, return_dict = False)
        self.sideTeller = SideTeller(q_dim = 768, r_dim = 768)
        self.sentenceSelector = SentenceSelector()

    def forward(self, q_data, q_clss, r_data, r_clss):
        q_hidden = self.bertModel(**q_data)[0]
        r_hidden = self.bertModel(**r_data)[0]
        q_clss_hidden = q_hidden[0, q_clss]
        r_clss_hidden = r_hidden[0, r_clss]
        s_out = self.sideTeller(q_clss_hidden, r_clss_hidden)
        q_out = self.sentenceSelector(q_clss_hidden)
        r_out = self.sentenceSelector(r_clss_hidden)
        return q_out, r_out, s_out

In [6]:
df = pd.read_csv('./Batch_answers - test_data(no_label).csv')

In [7]:
df.iloc[0]['q']

'"-- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- ( Time . )"'

In [8]:
tokenizer : BertTokenizer = BertTokenizer.from_pretrained(BERT_BASE_CASED)
model = BertSumExtModel().cuda()
model.load_state_dict(torch.load('./BertSum_WO_Encoder_with_S(1).pt'))

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>

In [9]:
id_list = []
Q_list = []
R_list = []

In [10]:
model.eval()
with torch.no_grad():
    for i in tqdm(df.index):
        series = df.iloc[i]
        q_data, r_data, q_sent, r_sent = get_data(series=series, tokenizer=tokenizer)
        q_clss = get_cls_indices(q_data['input_ids'])
        r_clss = get_cls_indices(r_data['input_ids'])
        q_data = {k: torch.tensor([v]).cuda() for k, v in q_data.items()}
        r_data = {k: torch.tensor([v]).cuda() for k, v in r_data.items()}
        q_clss = torch.tensor([q_clss]).cuda()
        r_clss = torch.tensor([r_clss]).cuda()
        q_pred, r_pred, _ = model(q_data, q_clss, r_data, r_clss)
        q_pred = torch.sigmoid(q_pred)
        r_pred = torch.sigmoid(r_pred)
        q_index = (q_pred > 0.5).reshape(-1).nonzero().T.squeeze(0).tolist()
        r_index = (r_pred > 0.5).reshape(-1).nonzero().T.squeeze(0).tolist()
        if not q_index:
            q_index.append(q_pred.argmax().item())
        if not r_index:
            r_index.append(r_pred.argmax().item())
        id_list.append(series['id'])
        Q_list.append(' '.join([q_sent[index] for index in q_index]))
        R_list.append(' '.join([r_sent[index] for index in r_index]))


100%|██████████| 2016/2016 [00:37<00:00, 53.84it/s]


In [11]:
submission = pd.DataFrame({"id":id_list, "q": Q_list, "r":R_list})
submission.to_csv('./submission.csv', index=False, columns=['id', 'q','r'])

In [19]:
submission.sample(5)

Unnamed: 0,id,q,r
499,258,"""I could say the same about heteros . I just s...",Ca n't you see that if you were given the powe...
1994,1401,"""What are you referring to here ?""","""Well , this guy is so bad ad arguing and so g..."
1753,3213,"""She thinks the university & # 8217 ; s leader...","""Oh yeah right , like sociopaths who are bent ..."
889,7084,"""He has nothing else . Bleating about the star...","""You have ignored a third option , my denier o..."
1126,761,"""We take reality and describe it in our minds ...","""This is my understanding as well . Interestin..."
