In [1]:
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from transformers import BartTokenizer, BartForConditionalGeneration


import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn

from tqdm import tqdm
import numpy as np
import os

# Define the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
tok = BartTokenizer.from_pretrained('facebook/bart-base')

In [3]:
from BRIO.data_utils import BrioDataset, collate_mp_brio
from functools import partial

In [4]:
train_set = BrioDataset(fdir='data/cnndm/diverse/train', model_type='facebook/bart-base', max_len=120, total_len=1024, max_num = 2)
val_set   = BrioDataset(fdir='data/cnndm/diverse/val', model_type='facebook/bart-base', max_len= 120, total_len=1024, max_num = 2, is_sorted=False)

In [5]:
collate_fn = partial(collate_mp_brio, pad_token_id=tok.pad_token_id, is_test=False)
collate_fn_val = partial(collate_mp_brio, pad_token_id=tok.pad_token_id, is_test=True)

In [6]:
dataloader = DataLoader(train_set, batch_size=4, shuffle=True,  num_workers=2, collate_fn=collate_fn)
val_loader = DataLoader(val_set  , batch_size=1, shuffle=False, num_workers=2, collate_fn=collate_fn_val)

In [7]:
from BRIO.model import BRIO, RankingLoss

In [8]:
model = BRIO(mname='facebook/bart-base', pad_token_id=tok.pad_token_id)

config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

In [9]:
model = model.to(device)
model.train()
model.scoring_mode()

In [10]:
from BRIO.label_smoothing_loss import label_smoothing_loss
mle_fn = label_smoothing_loss(ignore_index=tok.pad_token_id, epsilon=0.1)

s_optimizer = optim.Adam(model.parameters())
min_rank_loss = 100
min_mle_loss = 1e5
all_step_cnt = 0

def eval_fn(rouge1, rouge2, rougeLsum):
    return 1 - (rouge1 * rouge2 + rougeLsum) / 3

In [11]:
epochs = 3
scale = 1
accumulate_step = 10
max_lr = 2e-3
warmup_steps = 10000

In [12]:
for epoch in range(epochs):
    s_optimizer.zero_grad()
    avg_ranking_loss = 0
    avg_mle_loss = 0
    step_cnt = 0
    epoch_step = 0
    epoch_step = 0
    avg_loss = 0
    for (i, batch) in enumerate(dataloader):
        src_input_ids = batch['src_input_ids'].to(device)
        candidate_ids = batch['candidate_ids'].to(device)
        output = model(src_input_ids, candidate_ids, normalize=True, score_mode="log", length_penalty=2.0, adding = 0)
        
        similarity, gold_similarity = output['score'], output['summary_score']
        similarity = similarity*scale
        gold_similarity = gold_similarity*scale
        ranking_loss = RankingLoss(similarity, gold_similarity, margin=0.001, gold_margin=0, gold_weight=0)

        probs = output['probs'][:, :-1]
        gold = batch['candidate_ids'][:, 0, 1:].to(device)
        mle_loss = mle_fn(probs.transpose(1, 2), gold)

        loss = 0.1 * mle_loss + 100 * ranking_loss
        # record the loss
        avg_loss += loss.item() / accumulate_step
        avg_mle_loss += mle_loss.item() / accumulate_step
        avg_ranking_loss += ranking_loss.item() / accumulate_step

        # accumulate the loss
        loss = loss / accumulate_step
        loss.backward()
        if step_cnt == accumulate_step:
            step_cnt = 0
            epoch_step += 1
            all_step_cnt += 1

            lr = max_lr * min(all_step_cnt ** (-0.5), all_step_cnt * (warmup_steps ** (-1.5)))
            s_optimizer.step()
            s_optimizer.zero_grad()

        del similarity, gold_similarity, loss, mle_loss, ranking_loss, output, probs
        torch.cuda.empty_cache()