# Data Collator Experiments

In this notebook, we'll explore how to construct batches out of processed `Commit Chronicle` dataset during the training/validation setting for a encoder-decoder style architecture.

**Make sure to run `Commit Chronicle Dataset.ipynb` before using this notebook.**

The logic laid out in this notebook is implemented in `DataCollatorTrain`.

In [None]:
import sys

# Add the parent of the src directory for the imports to work
my_src_path = 'ML-24-25'
sys.path.append(my_src_path)

In [14]:
import rootutils
import torch.utils.data
from datasets import load_from_disk

from src.data.types import SingleExample

In [2]:
ROOT = rootutils.find_root(".", ".project-root")
OUTPUT_DIR = ROOT / "data/playground"

In [3]:
dataset_ = load_from_disk(OUTPUT_DIR / "02-processed")
dataset_.select(range(10))

Dataset({
    features: ['author', 'msg_input_ids', 'diff_input_ids', 'language', 'repo'],
    num_rows: 10
})

In [4]:
class HumbleDataset(torch.utils.data.Dataset):
    def __init__(self, dataset) -> None:
        self.dataset = dataset

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, index: int) -> SingleExample:
        row = self.dataset[index]
        return SingleExample(
            diff_input_ids=row["diff_input_ids"],
            msg_input_ids=row["msg_input_ids"],
            history_input_ids=[],  # ignored in this notebook. don't worry out it. trust me :)
            pos_in_file=-1,  # ignored in this notebook.
        )


data = HumbleDataset(dataset_)

In [5]:
from src.data.components.tokenization import add_special_tokens
from transformers import AutoTokenizer
from copy import deepcopy

msg_tokenizer_ = AutoTokenizer.from_pretrained("Salesforce/codet5-base")
msg_tokenizer_ = add_special_tokens(msg_tokenizer_, None)
diff_tokenizer_ = deepcopy(msg_tokenizer_)

# Encoder Input Processing

Here we assume input to the encoder is the git diff, `diff_input_ids` attribute of `SingleExample`. It can also be history of all git diffs, but we don't use it here.

In [6]:
def process_encoder_inputs(
    input_ids_batch: list[list[int]],
    encoder_context_max_len: int,
    bos_token_id: int,
    eos_token_id: int,
    pad_token_id: int,
):
    """
    This helper method processes either diffs or messages as encoder input.

    It truncates the inputs to the maximum allowed length.

    It also adds all required special tokens: format is [BOS] input [EOS].

    Finally, it is responsible for padding to maximum length in batch and conversion to torch.Tensor.

    Args:
        input_ids_batch: A list of tokenized examples from the current batch.
        encoder_context_max_len: The maximum length of the encoder context.
        bos_token_id: The value of the beginning of sequence (BOS) token.
        eos_token_id: The value of the end of sequence (EOS) token.
        pad_token_id: The value of the padding token (PAD) token.

    Returns:
        input_ids for encoder, attention_mask for encoder
    """

    # add BOS and EOS tokens to each example whilst making sure max length of resulting token list is encoder_context_max_len
    input_ids_batch = [
        [bos_token_id] + example[: encoder_context_max_len - 2] + [eos_token_id]
        for example in input_ids_batch
    ]
    inputs_tensors = [torch.tensor(ids, dtype=torch.int64) for ids in input_ids_batch]

    # pad tensors to max length in batch
    inputs_max_len = max(len(tensor) for tensor in input_ids_batch)
    inputs_tensors = [
        _pad_tensor(
            tensor,
            pad_len=inputs_max_len - tensor.numel(),
            value=pad_token_id,
            left=False,
        )
        for tensor in inputs_tensors
    ]

    masks_tensors = [torch.ones_like(ids) for ids in inputs_tensors]
    masks_tensors = [
        _pad_tensor(
            tensor,
            pad_len=inputs_max_len - tensor.numel(),
            value=0,
            left=False,
        )
        for tensor in masks_tensors
    ]
    return torch.stack(inputs_tensors), torch.stack(masks_tensors)


def _pad_tensor(
    input_tensor: torch.Tensor, pad_len: int, value: int, left: bool
) -> torch.Tensor:
    return torch.nn.functional.pad(
        input_tensor,
        pad=[pad_len, 0] if left else [0, pad_len],
        mode="constant",
        value=value,
    )

Let's try it out with a batch size of 2

In [7]:
examples_ = [data[0], data[1]]
git_diff_inputs_ = [example.diff_input_ids for example in examples_]
encoder_input_ids_, encoder_attention_mask_ = process_encoder_inputs(
    input_ids_batch=git_diff_inputs_,
    encoder_context_max_len=512,  # this is a hyperparameter
    bos_token_id=diff_tokenizer_.bos_token_id,
    eos_token_id=diff_tokenizer_.eos_token_id,
    pad_token_id=diff_tokenizer_.pad_token_id,
)
encoder_input_ids_.shape, encoder_attention_mask_.shape

(torch.Size([2, 512]), torch.Size([2, 512]))

That's it. The output data forms input to our encoder.

# Decoder Input

In [8]:
from typing import Literal, Optional


def _process_decoder_input(
    examples: list[SingleExample],
    msg_bos_token_id: int,
    msg_eos_token_id: int,
    msg_pad_token_id: int,
    decoder_context_max_len,
    shift_labels: bool,
    decoder_start_token_id: Optional[int] = None,
    # ignore these options
    encoder_input_type: Literal["diff", "history"] = "diff",
    with_history: bool = False,
):
    """
    Prepares decoder input for train/validation:
      * aggregates messages from history when configured accordingly
      * concatenates history with current message
      * constructs labels
      * pads, converts to tensors

    Args:
        examples: A list of inputs for current batch.

    Returns:
        Tuple of three tensors: input ids, attention masks, labels.
    """
    message_inputs: list[list[int]] = [example.msg_input_ids for example in examples]
    history_inputs: list[list[list[int]]] = [
        example.history_input_ids for example in examples
    ]

    all_msg_ids: list[torch.Tensor] = []
    all_msg_masks: list[torch.Tensor] = []
    all_msg_labels: list[torch.Tensor] = []

    for message_ids, history_ids in zip(message_inputs, history_inputs):
        message_ids = message_ids[: decoder_context_max_len - 2]

        cur_history_ids = []
        cur_history_labels = []

        # if encoder_input_type != "history" and with_history:
        #     cur_history_ids = _get_history(
        #         cur_len=len(message_ids) + 2,
        #         history_ids=history_ids,
        #     )
        #     cur_history_labels = [
        #         [-100 for _ in message] for message in cur_history_ids
        #     ]

        cur_ids = (
            [[msg_bos_token_id]]
            + cur_history_ids
            + [message_ids]
            + [[msg_eos_token_id]]
        )
        cur_labels = (
            [[msg_bos_token_id]]
            + cur_history_labels
            + [message_ids]
            + [[msg_eos_token_id]]
        )

        if shift_labels:
            cur_ids, cur_labels = _shift_for_encoder_decoder(
                cur_ids,
                cur_labels,
                msg_bos_token_id=msg_bos_token_id,
                decoder_start_token_id=decoder_start_token_id,
            )

        cur_ids_tensor = torch.tensor(
            [ex for sublist in cur_ids for ex in sublist], dtype=torch.int64
        )
        cur_labels_tensor = torch.tensor(
            [ex for sublist in cur_labels for ex in sublist], dtype=torch.int64
        )
        cur_mask_tensor = torch.ones_like(cur_ids_tensor)

        all_msg_ids.append(cur_ids_tensor)
        all_msg_masks.append(cur_mask_tensor)
        all_msg_labels.append(cur_labels_tensor)

    msg_max_len = max(len(tensor) for tensor in all_msg_ids)
    all_msg_ids = [
        _pad_tensor(
            tensor,
            pad_len=msg_max_len - tensor.numel(),
            value=msg_pad_token_id,
            left=False,
        )
        for tensor in all_msg_ids
    ]
    all_msg_masks = [
        _pad_tensor(
            tensor,
            pad_len=msg_max_len - tensor.numel(),
            value=0,
            left=False,
        )
        for tensor in all_msg_masks
    ]
    all_msg_labels = [
        _pad_tensor(
            tensor,
            pad_len=msg_max_len - tensor.numel(),
            value=-100,
            left=False,
        )
        for tensor in all_msg_labels
    ]

    return (
        torch.stack(all_msg_ids),
        torch.stack(all_msg_masks),
        torch.stack(all_msg_labels),
    )


def _shift_for_encoder_decoder(
    ids: list[list[int]],
    labels: list[list[int]],
    msg_bos_token_id: int,
    decoder_start_token_id: Optional[int] = None,
) -> tuple[list[list[int]], list[list[int]]]:
    """This method mimics transformers logic of ids and labels for EncoderDecoderModel
    (or T5ForConditionalGeneration).

    Starting from transformers v4.12, loss is now calculated in EncoderDecoderModel, not in decoder class.
    Also, decoder input ids are created automatically based on labels: labels are shifted and -100 is replaced
    with pad token. In our case, history ids are masked -100 in labels, but they are still
    meaningful ids. Therefore, we can't use the default approach.
    """
    if decoder_start_token_id is None:
        ids = [[msg_bos_token_id]] + ids[:-1]
    else:
        ids = [[decoder_start_token_id]] + ids[:-1]
    return ids, labels

Trying it out

In [9]:
decoder_input_ids_, decoder_attention_mask_, labels_ = _process_decoder_input(
    examples=examples_,
    msg_bos_token_id=msg_tokenizer_.bos_token_id,
    msg_eos_token_id=msg_tokenizer_.eos_token_id,
    msg_pad_token_id=msg_tokenizer_.pad_token_id,
    decoder_context_max_len=512,  # this is a hyperparam for the model
    shift_labels=True,
)
decoder_input_ids_.shape, decoder_attention_mask_.shape, labels_.shape

(torch.Size([2, 26]), torch.Size([2, 26]), torch.Size([2, 26]))

# Model Testing

In [10]:
from transformers import AutoModelForSeq2SeqLM


model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")

In [11]:
outputs = model(
    input_ids=encoder_input_ids_,
    attention_mask=encoder_attention_mask_,
    decoder_input_ids=decoder_input_ids_,
    decoder_attention_mask=decoder_attention_mask_,
    labels=labels_,
)
[attr for attr in dir(outputs) if not attr.startswith("_")]

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


['clear',
 'copy',
 'cross_attentions',
 'decoder_attentions',
 'decoder_hidden_states',
 'encoder_attentions',
 'encoder_hidden_states',
 'encoder_last_hidden_state',
 'fromkeys',
 'get',
 'items',
 'keys',
 'logits',
 'loss',
 'move_to_end',
 'past_key_values',
 'pop',
 'popitem',
 'setdefault',
 'to_tuple',
 'update',
 'values']

In [12]:
outputs.loss

tensor(13.9106, grad_fn=<NllLossBackward0>)

In [13]:
# let's overfit
from tqdm import tqdm

model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small").train()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-6, weight_decay=0.1)

for i in range(100):
    optimizer.zero_grad()
    outputs = model(
        input_ids=encoder_input_ids_,
        attention_mask=encoder_attention_mask_,
        decoder_input_ids=decoder_input_ids_,
        decoder_attention_mask=decoder_attention_mask_,
        labels=labels_,
    )

    loss = outputs.loss
    loss.backward()
    optimizer.step()
    print(f"Epoch: {i:03d} Loss:{loss:.4f}")

Epoch: 000 Loss:14.2112
Epoch: 001 Loss:13.3545
Epoch: 002 Loss:14.5035
Epoch: 003 Loss:13.6910
Epoch: 004 Loss:13.5256
Epoch: 005 Loss:13.7238
Epoch: 006 Loss:14.8889
Epoch: 007 Loss:12.7104
Epoch: 008 Loss:13.9504
Epoch: 009 Loss:13.7313
Epoch: 010 Loss:13.5998
Epoch: 011 Loss:12.8251
Epoch: 012 Loss:15.5966
Epoch: 013 Loss:13.9444
Epoch: 014 Loss:12.7470
Epoch: 015 Loss:13.0972
Epoch: 016 Loss:13.1960
Epoch: 017 Loss:13.6399
Epoch: 018 Loss:13.2541
Epoch: 019 Loss:14.0269
Epoch: 020 Loss:14.6821
Epoch: 021 Loss:13.2320
Epoch: 022 Loss:13.8468
Epoch: 023 Loss:13.9215
Epoch: 024 Loss:14.1133
Epoch: 025 Loss:13.8879
Epoch: 026 Loss:14.0524
Epoch: 027 Loss:13.4823
Epoch: 028 Loss:13.3946
Epoch: 029 Loss:13.6241
Epoch: 030 Loss:12.7660
Epoch: 031 Loss:13.5473
Epoch: 032 Loss:13.5531
Epoch: 033 Loss:15.0437
Epoch: 034 Loss:14.7943
Epoch: 035 Loss:13.0712
Epoch: 036 Loss:12.6115
Epoch: 037 Loss:13.6340
Epoch: 038 Loss:14.8388
Epoch: 039 Loss:12.6437
Epoch: 040 Loss:12.6444
Epoch: 041 Loss:

I was expecting the model loss to reduce smoothly but that didn't happen. Hmmmmm...