In [1]:
import torch
import pandas as pd
from transformers import BertTokenizer, BertModel
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import math
BERT_BASE_CASED = 'bert-base-cased'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df = pd.read_csv('./uniqueBertSumTraining.csv')

In [3]:
class NLPDataset(Dataset):

    def __init__(self, df : pd.DataFrame) -> None:
        self.tokenizer : BertTokenizer = BertTokenizer.from_pretrained(BERT_BASE_CASED)
        self.q_datas = []
        self.q_clss_list = []
        self.r_datas = []
        self.r_clss_list = []
        self.s_labels = []
        self.q_sent_labels = []
        self.r_sent_labels = []
        for i in tqdm(df.index, desc="Constructing Dataset..."):
            row = df.iloc[i]
            q_data = self.tokenizer.encode_plus(row['q_word_token'].split(sep=';'), return_token_type_ids=False, add_special_tokens=False)
            q_clss = self.get_cls_indices(q_data['input_ids'])
            q_data['token_type_ids'] = eval(f'[{",".join(row["q_inter_seg"].split(sep = ";"))}]')
            q_sent_label = eval(f'[{",".join(row["comp_sent_q"].split(sep = ";"))}]')
            r_data = self.tokenizer.encode_plus(row['r_word_token'].split(sep=';'), return_token_type_ids=False, add_special_tokens=False)
            r_clss = self.get_cls_indices(r_data['input_ids'])
            r_data['token_type_ids'] = eval(f'[{",".join(row["r_inter_seg"].split(sep = ";"))}]')
            r_sent_label = eval(f'[{",".join(row["comp_sent_r"].split(sep = ";"))}]')
            s_label = 1 if row['s'] == "AGREE" else 0
            self.q_datas.append(q_data)
            self.q_clss_list.append(q_clss)
            self.r_datas.append(r_data)
            self.r_clss_list.append(r_clss)
            self.s_labels.append(s_label)
            self.q_sent_labels.append(q_sent_label)
            self.r_sent_labels.append(r_sent_label)

    def get_cls_indices(self, target) -> list:
        cls_indices = []
        for i, data in enumerate(target):
            if (data == 101):
                cls_indices.append(i)
        return cls_indices

    def __len__(self):
        return len(self.s_labels)

    def __getitem__(self, index):
        return (
            self.q_datas[index],
            self.q_clss_list[index],
            self.r_datas[index],
            self.r_clss_list[index],
            self.s_labels[index],
            self.q_sent_labels[index],
            self.r_sent_labels[index],
            )

In [4]:
dataset = NLPDataset(df = df)

Constructing Dataset...: 100%|██████████| 7855/7855 [00:02<00:00, 3168.48it/s]


In [5]:
class SentenceSelector(torch.nn.Module):

    def __init__(self) -> None:
        super(SentenceSelector, self).__init__()
        self.linear = torch.nn.Linear(768, 1)

    def forward(self, clss_hiddens):
        return self.linear(clss_hiddens)


class PositionalEncoding(torch.nn.Module):

    def __init__(self, dim, dropout = 0.1, max_len = 5000) -> None:
        super(PositionalEncoding, self).__init__()
        self.pe = torch.zeros(max_len, dim)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
                              -(math.log(10000.0) / dim)))
        self.pe[:, 0::2] = torch.sin(position.float() * div_term)
        self.pe[:, 1::2] = torch.cos(position.float() * div_term)
        self.pe = self.pe.unsqueeze(0)
        self.register_buffer("PositionalEmb", self.pe)
        self.dropoutLayer = torch.nn.Dropout(p= dropout)
        self.dim = dim
        self.pe = self.pe.cuda()
    
    def forward(self, x):
        emb = x + self.pe[:, :x.size(1)]
        out = self.dropoutLayer(emb)
        return out

class TransformerBlock(torch.nn.Module):

    def __init__(self, embed_dim, num_heads, ff_dim, rate = 0.1) -> None:
        super(TransformerBlock, self).__init__()
        self.attn = torch.nn.MultiheadAttention(embed_dim = embed_dim, num_heads = num_heads, batch_first = True)
        self.ffn = torch.nn.Sequential(
            torch.nn.Linear(embed_dim, ff_dim),
            torch.nn.GELU(),
            torch.nn.Dropout(p = rate),
            torch.nn.Linear(ff_dim, embed_dim),
            torch.nn.Dropout(p = rate),
        )
        self.layernorm1 = torch.nn.LayerNorm(embed_dim, eps=1e-6)
        self.layernorm2 = torch.nn.LayerNorm(embed_dim, eps=1e-6)
        self.dropout = torch.nn.Dropout(rate)

    def forward(self, inputs):
        attn_output, attn_score = self.attn(inputs, inputs, inputs)
        attn_output = self.dropout(attn_output + inputs)
        out1 = self.layernorm1(attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout(ffn_output + out1)
        return self.layernorm2(ffn_output), attn_score

class BertSumExtModel(torch.nn.Module):

    def __init__(self, numOfExtLayer = 2, freeze = True) -> None:
        super(BertSumExtModel, self).__init__()
        self.bertModel = BertModel.from_pretrained(BERT_BASE_CASED, return_dict = False)
        self.pe_layer = PositionalEncoding(dim = 768)
        self.ext_layer = torch.nn.ModuleList([TransformerBlock(768, 6, 768, rate = 0.1) for _ in range(numOfExtLayer)])
        self.layernorm = torch.nn.LayerNorm(768 , eps=1e-6)
        self.sentenceSelector = SentenceSelector()
        self.freeze = freeze

    def forward(self, q_data, q_clss, r_data, r_clss):
        if self.freeze:
            with torch.no_grad():
                q_hidden = self.bertModel(**q_data)[0]
                r_hidden = self.bertModel(**r_data)[0]
        else: 
            q_hidden = self.bertModel(**q_data)[0]
            r_hidden = self.bertModel(**r_data)[0]
        q_clss_hidden = q_hidden[0, q_clss]
        r_clss_hidden = r_hidden[0, r_clss]
        q_clss_hidden = self.pe_layer(q_clss_hidden)
        r_clss_hidden = self.pe_layer(r_clss_hidden)
        for transformerBlock in self.ext_layer:
            q_clss_hidden, _ = transformerBlock(q_clss_hidden)
            r_clss_hidden, _ = transformerBlock(r_clss_hidden)
        q_clss_hidden = self.layernorm(q_clss_hidden)
        r_clss_hidden = self.layernorm(r_clss_hidden)
        q_out = self.sentenceSelector(q_clss_hidden)
        r_out = self.sentenceSelector(r_clss_hidden)
        return q_out, r_out

In [6]:
model = BertSumExtModel(freeze= True).cuda()
loss_fn = torch.nn.BCEWithLogitsLoss()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
model_opt = torch.optim.AdamW(model.parameters(), lr = 3.68e-5)

In [8]:
lr_sc = torch.optim.lr_scheduler.LinearLR(model_opt, start_factor=0.5, total_iters = 5)

In [9]:
for epoch in range(10):
    total_loss = 0.
    currentLR = lr_sc.get_last_lr()[0]
    train_process = tqdm(dataset)
    for batch, (q_data, q_clss, r_data, r_clss, _, q_sent_label, r_sent_label) in enumerate(train_process, start = 1):
        q_data = {k: torch.tensor([v]).cuda() for k, v in q_data.items()}
        r_data = {k: torch.tensor([v]).cuda() for k, v in r_data.items()}
        q_clss = torch.tensor([q_clss]).cuda()
        r_clss = torch.tensor([r_clss]).cuda()
        model_opt.zero_grad()
        q_pred, r_pred = model(q_data, q_clss, r_data, r_clss)
        loss_q = loss_fn(q_pred.contiguous().reshape(1, -1), torch.tensor([q_sent_label], dtype=torch.float32).cuda())
        loss_r = loss_fn(r_pred.contiguous().reshape(1, -1), torch.tensor([r_sent_label], dtype=torch.float32).cuda())
        t_loss = loss_q + loss_r
        t_loss.backward()
        model_opt.step()
        total_loss += t_loss.item()
        train_process.set_postfix({ "CURRENT_LR" : currentLR, "AVG_LOSS" : total_loss/ batch})
    lr_sc.step()

100%|██████████| 7855/7855 [08:59<00:00, 14.55it/s, CURRENT_LR=1.84e-5, AVG_LOSS=0.903]
100%|██████████| 7855/7855 [08:59<00:00, 14.57it/s, CURRENT_LR=2.21e-5, AVG_LOSS=0.886]
100%|██████████| 7855/7855 [08:57<00:00, 14.62it/s, CURRENT_LR=2.58e-5, AVG_LOSS=0.879]
100%|██████████| 7855/7855 [08:58<00:00, 14.60it/s, CURRENT_LR=2.94e-5, AVG_LOSS=0.883]
100%|██████████| 7855/7855 [08:58<00:00, 14.59it/s, CURRENT_LR=3.31e-5, AVG_LOSS=0.878]
100%|██████████| 7855/7855 [08:57<00:00, 14.61it/s, CURRENT_LR=3.68e-5, AVG_LOSS=0.877]
100%|██████████| 7855/7855 [08:57<00:00, 14.60it/s, CURRENT_LR=3.68e-5, AVG_LOSS=0.875]
100%|██████████| 7855/7855 [08:58<00:00, 14.60it/s, CURRENT_LR=3.68e-5, AVG_LOSS=0.883]
100%|██████████| 7855/7855 [08:58<00:00, 14.57it/s, CURRENT_LR=3.68e-5, AVG_LOSS=0.877]
100%|██████████| 7855/7855 [08:58<00:00, 14.59it/s, CURRENT_LR=3.68e-5, AVG_LOSS=0.874]


In [10]:
torch.save(model.state_dict(), './BertSum_With_Encoder.pt')