# Momentum Calibration for Text Generation
---

https://arxiv.org/abs/2212.04257

In [1]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=1

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=1


In [2]:
# %%capture
# !pip install fairscale transformers[fairscale] optimum datasets evaluate rouge_score

In [3]:
import torch
from torch.nn import functional as F

In [4]:
from datasets import load_dataset

cnn_train = load_dataset("cnn_dailymail", '3.0.0', split="train")
cnn_test = load_dataset("cnn_dailymail", '3.0.0', split="test")

Found cached dataset cnn_dailymail (/home/markintosh/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Found cached dataset cnn_dailymail (/home/markintosh/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)


In [5]:
cnn_test[0].keys()

dict_keys(['article', 'highlights', 'id'])

In [6]:
import transformers
from transformers import AutoTokenizer

checkpoint = "facebook/bart-large-cnn"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [7]:
prefix = ""


def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["article"]]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True)

    labels = tokenizer(text_target=examples["highlights"], max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [8]:
# tokenized_train = cnn_train.map(preprocess_function, batched=True)
tokenized_test = cnn_test.map(preprocess_function, batched=True)

Loading cached processed dataset at /home/markintosh/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de/cache-955cea1f919269e0.arrow


In [9]:
import evaluate
from evaluate import evaluator

rouge = evaluate.load("rouge")

In [10]:
from transformers import pipeline

summarizer = pipeline("summarization", model=checkpoint, device=0)

In [11]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

In [12]:
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
# from optimum.bettertransformer import BetterTransformer

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
# model = BetterTransformer.transform(model, keep_original_model=True, device_ids=[0, 1, 2])

In [37]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

In [14]:
training_args = Seq2SeqTrainingArguments(
    output_dir="BART",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=10,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=4,
    predict_with_generate=True,
    fp16=True,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=cnn_train,
    eval_dataset=cnn_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# trainer.evaluate(tokenized_test.select(range(100)))

### TODO
- [x] copy model and update model weight = 1
- [x] implement beam search = 2
- [ ] calculate s(x, y') = 3 s
- [ ] calculate loss = 4
- [ ] training loop = 5

In [15]:
# create two models
import copy

M = model
G = copy.deepcopy(M)

In [16]:
summarizer = pipeline("summarization", model=G, tokenizer=tokenizer, device=0)

In [17]:
def gen_samples(text, n_samples=8):
    res = summarizer(text, num_return_sequences=n_samples, diversity_penalty=5.5, 
                     num_beam_groups=4, num_beams=n_samples)
    return [r['summary_text'] for r in res]

In [15]:
gen_samples(tokenized_test[0]['article'])

["Palestinian Authority becomes 123rd member of the International Criminal Court. The move gives the court jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki said it was a move toward greater justice.",
 "Palestinian Authority becomes 123rd member of the International Criminal Court. The move gives the court jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki says it is a move toward greater justice.",
 "Palestinian Authority becomes 123rd member of the International Criminal Court. The move gives the court jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki said it was a move towa

In [16]:
# calculate the loss
text = tokenized_test[0]['article']

summarizer(text, diversity_penalty=5.5,
           num_beam_groups=2, num_beams=2, output_scores=True)

[{'summary_text': "The Palestinian Authority becomes the 123rd member of the International Criminal Court. The move gives the court jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki says it is a move toward greater justice."}]

In [57]:
class MocaTrainer(Seq2SeqTrainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        # inputs=(input_ids, atten_mask, labels: batch_size * ids_length)
        batch_size = input_ids.shape[0]
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        
        # outputs=(loss: generator, logits: batch_size * seq_len * vocab_size)
        print(outputs.logits.shape)
        
        # Save past state if it exists (maybe not necessary)
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
        
        # calculate loss
        labels = F.one_hot(labels)  # (batch_size * seq_len * vocab_size)
        mask = labels >= 0
        labels[~mask] = 0
        max_label_len = labels.shape[0]
        logits = outputs.logits[:, :max_label_len] # (batch_size * seq_len * vocab_size)
        probs = ((logits * labels).sum(axis=2) * mask).sum(1)   # probability of each token
        seq_lens = mask.sum(axis=1)
        # y_t = torch.ones((batch_size, max_label_len))
        s_theta = probs / seq_lens
        
        loss = s_theta * 
        return (loss, outputs) if return_outputs else loss

In [18]:
b = torch.as_tensor([
  [[2, 2, 2],
   [3, 3, 3]],
  [[4, 4, 4],
   [5, 5, 5]],
])
a = torch.as_tensor([[0, -100], [1, 2]])
mask = (a >= 0)
a[~mask] = 0
a = F.one_hot(a)
print(a)
print(a.shape, b.shape)

tensor([[[1, 0, 0],
         [1, 0, 0]],

        [[0, 1, 0],
         [0, 0, 1]]])
torch.Size([2, 2, 3]) torch.Size([2, 2, 3])


In [25]:
((a * b).sum(axis=2) * mask).sum(axis=1)

tensor([2, 9])

In [44]:
tokenizer.

0

In [35]:
tokenized_test[0]['highlights'], tokenized_test[0]['labels']

('Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June .\nIsrael and the United States opposed the move, which could open the door to war crimes investigations against Israelis .',
 [0,
  31339,
  4128,
  2029,
  5,
  14305,
  10542,
  81,
  1697,
  3474,
  2021,
  11,
  5791,
  13560,
  187,
  94,
  502,
  479,
  50118,
  20517,
  8,
  5,
  315,
  532,
  4340,
  5,
  517,
  6,
  61,
  115,
  490,
  5,
  1883,
  7,
  997,
  3474,
  4941,
  136,
  19544,
  479,
  2])

In [58]:
training_args = Seq2SeqTrainingArguments(
    output_dir="BART",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=10,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=2,
    predict_with_generate=True,
    fp16=True,
)

trainer = MocaTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_test,
    eval_dataset=tokenized_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# trainer.evaluate(tokenized_test.select(range(100)))

In [59]:
trainer.train()

dict_keys(['input_ids', 'attention_mask', 'labels'])
tensor([[    0,   495, 11278,  2552,  9747, 14932,   139,     6,    10,  7819,
          1044,     6,    21,    66,    15,  8958,  1115, 26310,   128, 15691,
            12, 15691,   108,  5592, 20703,  3939,  3126,   861,   552,  5507,
           260,  8878,  4104,  3954,    11,   384, 17421,   479, 50118,  2515,
          7185,     8,  1064,  1764,  1730,   874,    10,  4683,    11,     5,
          5592,   216,    25,    20,   221, 10620,   479,     2,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [    0,   133,  3621, 12104,  1521,    16,  2319,     7,   701,   843,
          4416,  2729, 11888,  2890,   119, 33137,  3897,   119,    43,     7,
          1119,     8,    55,     7,   907,   479, 50118,   133,  1089,  6004,
     

OutOfMemoryError: Caught OutOfMemoryError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/markintosh/.conda/envs/nlp/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/home/markintosh/.conda/envs/nlp/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/markintosh/.conda/envs/nlp/lib/python3.9/site-packages/transformers/models/bart/modeling_bart.py", line 1373, in forward
    outputs = self.model(
  File "/home/markintosh/.conda/envs/nlp/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/markintosh/.conda/envs/nlp/lib/python3.9/site-packages/transformers/models/bart/modeling_bart.py", line 1255, in forward
    decoder_outputs = self.decoder(
  File "/home/markintosh/.conda/envs/nlp/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/markintosh/.conda/envs/nlp/lib/python3.9/site-packages/transformers/models/bart/modeling_bart.py", line 1113, in forward
    layer_outputs = decoder_layer(
  File "/home/markintosh/.conda/envs/nlp/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/markintosh/.conda/envs/nlp/lib/python3.9/site-packages/transformers/models/bart/modeling_bart.py", line 445, in forward
    hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
  File "/home/markintosh/.conda/envs/nlp/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/markintosh/.conda/envs/nlp/lib/python3.9/site-packages/transformers/models/bart/modeling_bart.py", line 249, in forward
    attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 122.00 MiB (GPU 0; 31.75 GiB total capacity; 30.25 GiB already allocated; 77.69 MiB free; 30.74 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF


In [40]:
import inspect

In [35]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    tokenized_test, shuffle=True, batch_size=1, collate_fn=data_collator,
)

In [38]:
# train iteration
for batch in train_dataloader:
    print(batch)
    break

ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`article` in this case) have excessive nesting (inputs type `list` where type `int` is expected).