In [1]:
# import model as M
# import evaluate as E
import config as CFG
import os
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, BatchEncoding, BertModel
from tqdm import tqdm

In [2]:
model_checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [3]:
DATASET_ROOT= 'exp/data/'
RESULT_ROOT = "exp/result_bert_base_uncased_summarly"
METHOD = 'summar_ly'

DATASETS = ["cnn_dailymail"]

In [4]:
def convert_single_example(example):
    doc_sents = ["[CLS] " + sent for sent in example["doc"]]
    doc_text = " ".join(doc_sents)

    samples = []

    for _sum in example["sums"]:
        sum_sents = ["[CLS] " + sent["text"] for sent in _sum["sample"]]
        labels = [[sent["fact"], sent["ling"]] for sent in _sum["sample"]]
        sum_text = " ".join(sum_sents)

        inputs = tokenizer(doc_text, sum_text, 
        padding='max_length', truncation="longest_first", return_tensors="pt")
        cov_label = [-100] * tokenizer.model_max_length
        sum_label = [[-100, -100]] * tokenizer.model_max_length
        
        # Align coverage label
        cnt = 0
        sep = -1
        for i in range(1, tokenizer.model_max_length):
            if inputs["input_ids"][0][i] == tokenizer.sep_token_id:
                sep = i
                break
            elif inputs["input_ids"][0][i] == tokenizer.cls_token_id:
                cov_label[i] = int(_sum["coverage"][cnt])
                cnt += 1
        cnt = 0
        # Align fact/ling label
        for i in range(sep, tokenizer.model_max_length):
            if inputs["input_ids"][0][i] == tokenizer.cls_token_id:
                sum_label[i] = labels[cnt]
                cnt += 1

        final_label = np.hstack([np.array(cov_label).reshape(-1, 1), np.array(sum_label)])
        
        inputs = {k: v.squeeze() for k, v in inputs.items()}
        samples.append((inputs, torch.Tensor(final_label)))
        
    return samples

In [5]:
class TokenizedDataset(Dataset):
    def __init__(self, datapath, limit = None):
        self.data = []
        with open(datapath, "r", encoding="utf-8") as f:
            for line in tqdm(f):
                example = json.loads(line)
                samples = convert_single_example(example)
                size = len(samples)
                for i in range(1, size):
                    self.data.append((samples[0], samples[i]))
                
                if not limit is None:
                    limit -= 1
                    if limit == 0: break

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx][0], self.data[idx][1]

In [6]:
class Scorer(nn.Module):
    def __init__(self):
        super(Scorer, self).__init__()
        self.model = AutoModel.from_pretrained(CFG.BERT_MODEL)
        self.score_head = nn.Linear(self.model.config.hidden_size, 1)
        self.classify_head = nn.Linear(self.model.config.hidden_size, 3)
    
    def forward(self, inputs):
        outputs = self.model(**inputs)
        score = self.score_head(outputs.pooler_output)
        logits = self.classify_head(outputs.last_hidden_state)

        return score, logits

In [7]:
def train_model(model, train_set, max_iter=CFG.MAX_ITERATION):
    train_dataloader = DataLoader(train_set, batch_size=CFG.BATCH_SIZE, shuffle=True)

    optimizer = torch.optim.Adam(model.parameters(), lr = CFG.LR)
    score_loss_fn = nn.CrossEntropyLoss()
    class_loss_fn = nn.BCEWithLogitsLoss(reduction='none')

    running_loss = 0.0
    with tqdm(total=max_iter) as pbar:
        for pos, neg in train_dataloader:
            inputs_pos, labels_pos = pos
            inputs_neg, labels_neg = neg
            labels_pos = labels_pos.to(CFG.DEVICE)
            labels_neg = labels_neg.to(CFG.DEVICE)
            inputs_pos = BatchEncoding(inputs_pos).to(CFG.DEVICE)
            inputs_neg = BatchEncoding(inputs_neg).to(CFG.DEVICE)
            
            score_pos, logits_pos = model(inputs_pos)
            score_neg, logits_neg = model(inputs_neg)
            mask_pos = labels_pos != -100
            mask_neg = labels_neg != -100
            
            # Preference Loss
            score_cat = torch.cat((score_pos, score_neg), -1)
            score_labels = torch.tensor([0]*labels_pos.shape[0], dtype=torch.long).to(CFG.DEVICE)
            loss_score = score_loss_fn(score_cat, score_labels)
            # Classification Loss
            loss_class_pos = class_loss_fn(logits_pos, labels_pos) * mask_pos
            loss_class_neg = class_loss_fn(logits_neg, labels_neg) * mask_neg
            loss = loss_score + torch.sum(loss_class_pos) + torch.sum(loss_class_neg)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            num_iter += 1
            pbar.update(1)
            if num_iter % 1000 == 999:
                pbar.write("Iteration {}, Loss {}".format(num_iter+1, running_loss))
                running_loss = 0.0

In [8]:
def train():
    for DATASET in DATASETS:
        train_set = TokenizedDataset(os.path.join(CFG.DATASET_ROOT, DATASET, METHOD, 'train.jsonl'))

        model = Scorer()
        model.to(CFG.DEVICE)

        train_model(model, train_set)

        CKPT_PATH = os.path.join(CFG.RESULT_ROOT, DATASET, CFG.METHOD, "model.pth")
        if not os.path.exists(os.path.dirname(CKPT_PATH)):
            os.makedirs(os.path.dirname(CKPT_PATH))

        torch.save(model.state_dict(), CKPT_PATH)

        del model
        torch.cuda.empty_cache()

In [None]:
train()

767it [01:07, 10.98it/s]