# Bart large configuration

In [1]:
### MODULES ###

import sys,os
import tqdm
import csv
from datetime import datetime 
import numpy as np
import pandas as pd
import json

from typing import Optional


import math

import matplotlib.pyplot as plt


from datasets import load_dataset, Dataset

import torch
from torch import cuda
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import torch.nn.functional as F


# Load the ROUGE metric
import evaluate

from transformers import AutoTokenizer, BartForConditionalGeneration

In [2]:
import math
import logging
from typing import List, Optional, Tuple, Union

In [3]:

NUM_PROCS = os.cpu_count() 

print("NUM_PROCS = " ,NUM_PROCS)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)



NUM_PROCS =  12
cuda


In [4]:

SEED = 42
NUM_LOADER = 4 #config['config_machine']["NUM_LOADER"] #depends of the number of thread 


# Set random seeds and deterministic pytorch for reproducibility
torch.manual_seed(SEED) # pytorch random seed
np.random.seed(SEED) # numpy random seed
torch.backends.cudnn.deterministic = True

# Load dataset CNN daily

In [5]:
# Load CNN/DailyMail dataset
dataset = load_dataset("cnn_dailymail", "3.0.0")

## Comment this part for the real training time :

percentage = 0.05

for split in dataset: 
    dataset[split] = dataset[split].shuffle(seed=SEED).select(range(int(len(dataset[split]) * percentage)))

# Check the dataset structure
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 14355
    })
    validation: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 668
    })
    test: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 574
    })
})


# load the model and tokenizer 

In [6]:
### Load model ###
MODEL_HUB = 'facebook/bart-large'
# Load Model and Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_HUB, clean_up_tokenization_spaces=True)
# model = BartForConditionalGeneration.from_pretrained(MODEL_HUB, forced_bos_token_id=0)
print(tokenizer.model_max_length)
print(type(tokenizer))
# print(type(model))
print(tokenizer)

1024
<class 'transformers.models.bart.tokenization_bart_fast.BartTokenizerFast'>
BartTokenizerFast(name_or_path='facebook/bart-large', vocab_size=50265, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	50264: AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True, special=True),
}


In [7]:
# def len_distrib(batch):

#     len_articles = []
#     len_highlights = []
    
#     for article, highlight in zip(batch["article"], batch["highlights"]):
#         len_articles.append(len(tokenizer(article, truncation=False)["input_ids"]))
#         len_highlights.append(len(tokenizer(highlight, truncation=False)["input_ids"]))


#     source = tokenizer(batch["article"],truncation=True, max_length=tokenizer.model_max_length,padding='max_length')
#     resume = tokenizer(batch["highlights"],truncation=True, max_length=tokenizer.model_max_length,padding='max_length')

#     return {
#         'input_ids': source['input_ids'], 
#         'input_mask': source['attention_mask'],
#         'input_len': len_articles,
#         'target_ids': resume['input_ids'], 
#         'target_mask': resume['attention_mask'],
#         'target_len': len_highlights
#         }

def len_distrib(batch):

    len_articles = []
    len_highlights = []
    
    for article, highlight in zip(batch["article"], batch["highlights"]):
        len_articles.append(len(tokenizer(article, truncation=False)["input_ids"])-1) #Add -1 to skip the <bos> token 
        len_highlights.append(len(tokenizer(highlight, truncation=False)["input_ids"])-1) #Add -1 to skip the <bos> token 


    source = tokenizer(batch["article"],truncation=True, max_length=tokenizer.model_max_length)
    resume = tokenizer(batch["highlights"],truncation=True, max_length=tokenizer.model_max_length)

    return {
        'input_ids': source['input_ids'], 
        'input_mask': source['attention_mask'],
        'input_len': len_articles,
        'target_ids': resume['input_ids'], 
        'target_mask': resume['attention_mask'],
        'target_len': len_highlights
        }


In [8]:
dataset = dataset.map(len_distrib,num_proc=NUM_PROCS,batched=True,batch_size=64)


In [9]:

# Define the custom collate function
def collate_fn(batch):
    """
    Custom collate function that add padding for each batch.
    """

    # Pad the tokenized content
    input_ids = [torch.tensor(item['input_ids'], dtype=torch.long) for item in batch]
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    
    attention_mask = [torch.tensor(item['input_mask'], dtype=torch.long) for item in batch]
    attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)

    decoder_input_ids  = [torch.tensor(item['target_ids'][:-1], dtype=torch.long) for item in batch]
    decoder_input_ids = pad_sequence(decoder_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)     
    
    decoder_attention_mask = [torch.tensor(item['target_mask'][:-1], dtype=torch.long) for item in batch]
    decoder_attention_mask = pad_sequence(decoder_attention_mask, batch_first=True, padding_value=0)
    
    input_len = torch.tensor([item['input_len'] for item in batch], dtype=torch.long)

    target_len = torch.tensor([item['target_len'] for item in batch], dtype=torch.long)

    # Labels should be the same as decoder_input_ids (BART-style training)
    labels = [torch.tensor(item['target_ids'][1:], dtype=torch.long) for item in batch]
    labels = pad_sequence(labels, batch_first=True, padding_value=tokenizer.pad_token_id)  
    labels[labels == tokenizer.pad_token_id] = -100  # Ignore padding in loss computation

    return {
        'input_ids':input_ids,
        'attention_mask':attention_mask,
        'decoder_input_ids':decoder_input_ids,
        'decoder_attention_mask':decoder_attention_mask,
        'labels': labels,
        'input_len': input_len,
        'target_len': target_len
    }

In [10]:

train_params = {
    'batch_size': 4,
    'shuffle': True,
    'collate_fn':collate_fn,
    'num_workers': NUM_LOADER,
    'pin_memory': True  #  Enables faster GPU transfers
    }

eval_params = {
    'batch_size': 4,
    'shuffle': False,
    'collate_fn':collate_fn,
    'num_workers': NUM_LOADER,
    'pin_memory': True  #  Enables faster GPU transfers
    }


# This will be used down for training and validation stage for the model.
train_loader = DataLoader(dataset["train"], **train_params)
eval_loader = DataLoader(dataset["validation"], **eval_params)

for batch in train_loader:
    print(batch)
    break


{'input_ids': tensor([[    0,  2765,   479,  ...,     1,     1,     1],
        [    0,   250,  3828,  ...,     5,  1151,     2],
        [    0,   250, 17052,  ...,     1,     1,     1],
        [    0,   970,    58,  ...,     1,     1,     1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'decoder_input_ids': tensor([[    0,  9058,  2152,  4935,  1746,     7,    49,  1420,     8,    58,
          1654,     7, 29327,   877,   479, 50118, 25496,   643,  1710,     6,
            65,    19,  1473,  9377,  9308,    11,   255,  1879,  6483,     6,
          2627,   479, 50118, 11329, 12807,   333,   303, 23797,  3675,   571,
           668,  8344, 13374,     8,   342,    24,    11,   760,    50,    10,
           668,   479,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,

# Create reverse embedding for the Bart model 

In [11]:

# Get all parent classes in the MRO (Method Resolution Order)
print(BartForConditionalGeneration.__mro__)
tokenizer.bos_token_id

(<class 'transformers.models.bart.modeling_bart.BartForConditionalGeneration'>, <class 'transformers.models.bart.modeling_bart.BartPreTrainedModel'>, <class 'transformers.modeling_utils.PreTrainedModel'>, <class 'torch.nn.modules.module.Module'>, <class 'transformers.modeling_utils.ModuleUtilsMixin'>, <class 'transformers.generation.utils.GenerationMixin'>, <class 'transformers.utils.hub.PushToHubMixin'>, <class 'transformers.integrations.peft.PeftAdapterMixin'>, <class 'object'>)


0

In [12]:
print(tokenizer.pad_token_id)

mask = ~torch.isin(batch["decoder_input_ids"],torch.tensor([tokenizer.pad_token_id,tokenizer.eos_token_id])) # mask with 0 where a pad_id is present
print(mask.shape)

reversed_position_input  = torch.ones(mask.shape) * mask # [1,1,1,0,0] 

reversed_position_input = torch.flip(torch.flip(reversed_position_input , dims=(1,)).cumsum(dim=1), dims=(1,))  

print(reversed_position_input.shape)

normal_round = torch.randn(batch["decoder_input_ids"].shape) * mask

reversed_position_input = torch.abs(torch.round(reversed_position_input  + normal_round)).to(torch.long) #add a gausian noise and converte to long

print(reversed_position_input.shape)

# input_decoder_position_embedding = model.model.decoder.embed_positions(reversed_position_input)

# input_decoder_position_embedding

1
torch.Size([4, 99])
torch.Size([4, 99])
torch.Size([4, 99])


## Code pour les index de positions inversées 

In [None]:

relu = nn.ReLU()

def _reverse_position_embedding(input_ids:torch.Tensor,
                                target_len:Optional[torch.Tensor]=None)->torch.Tensor:
    
    mask = ~torch.isin(input_ids,torch.tensor([tokenizer.pad_token_id])) # mask with 0 where a pad_id is present

    reversed_position_input  = torch.ones(mask.shape) * mask # Put 1 where there are token index and 0 where there are pad index[1,1,1,0,0] 
    
    if target_len is None:
        reversed_position_input = torch.flip(torch.flip(reversed_position_input , dims=(1,)).cumsum(dim=1), dims=(1,)) 
        #[[ 54,  53, ...,0,  0],[100, 101,...,1]]
        #print(reversed_position_input.shape)
    else:
        for k in range(input_ids.size(-1)):
            reversed_position_input[:,k] = relu(target_len -k)

    #Add a gaussian noise
    normal_round = torch.randn(reversed_position_input.shape) * mask

    reversed_position_input = torch.abs(torch.round(reversed_position_input  + normal_round)).to(torch.long) #add a gausian noise and converte to long


    return reversed_position_input

reversed_position_input = _reverse_position_embedding(batch["decoder_input_ids"])
print(reversed_position_input.shape)

torch.Size([4, 111])


## Code pour les embeddings de position sinus/cosinus

In [73]:
# # dir(model.model.decoder.layernorm_embedding)

d_model = model.config.d_model
max_len = tokenizer.model_max_length

pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp((-2*torch.arange(0, d_model)//2)/d_model * math.log(10000.0) )

print(div_term.shape)

print((position*div_term).shape)


pe[:, 0::2] = torch.sin(position*div_term[0::2] )
pe[:, 1::2] = torch.cos(position*div_term[1::2])

embed_reverse_positions = nn.Embedding(num_embeddings=tokenizer.model_max_length,
                                              embedding_dim=d_model,
                                              padding_idx=tokenizer.pad_token_id,
                                              _weight=pe,
                                              _freeze=True)

print(embed_reverse_positions(reversed_position_input).shape)

torch.Size([1024])
torch.Size([1024, 1024])
torch.Size([4, 111, 1024])


In [74]:
print(model.model.decoder)
token_embeddings = model.model.decoder.embed_tokens(batch["decoder_input_ids"]) 
print(token_embeddings.shape)
position_embeddings = model.model.decoder.embed_positions(batch["decoder_input_ids"]) 
print(position_embeddings.shape)
repilot_embeddings = embed_reverse_positions(_reverse_position_embedding(batch["decoder_input_ids"]))
print(repilot_embeddings.shape)

BartDecoder(
  (embed_tokens): BartScaledWordEmbedding(50265, 1024, padding_idx=1)
  (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
  (layers): ModuleList(
    (0-11): 12 x BartDecoderLayer(
      (self_attn): BartSdpaAttention(
        (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
      )
      (activation_fn): GELUActivation()
      (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (encoder_attn): BartSdpaAttention(
        (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (out_proj): Linear(in_features=1024, out_features=1024, bias=

In [75]:
decoder_inputs_embeds = token_embeddings + position_embeddings + repilot_embeddings
output = model(input_ids=batch["input_ids"], 
               attention_mask=batch["attention_mask"],
               decoder_attention_mask=batch["decoder_attention_mask"], 
               decoder_inputs_embeds=decoder_inputs_embeds,
               labels=batch["labels"])


## Code pour la classe RepilotBartForConditionalGeneration

In [12]:
"""PyTorch REPILOT_BART model."""
import copy
import math
import logging
from typing import List, Optional, Tuple, Union

import torch
from torch import nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers import BartForConditionalGeneration, AutoTokenizer
from transformers.modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
    Seq2SeqQuestionAnsweringModelOutput,
    Seq2SeqSequenceClassifierOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
    add_code_sample_docstrings,
    add_end_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    replace_return_docstrings,
)
from transformers import BartConfig, AutoTokenizer

if is_flash_attn_2_available():
    from transformers.modeling_flash_attention_utils import _flash_attention_forward


# Setup logging
logging.basicConfig(level=logging.DEBUG, 
                    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
                    handlers=[
                        logging.StreamHandler()  #log to console
                    ])

logger = logging.getLogger(__name__)

_CHECKPOINT_FOR_DOC = "facebook/bart-large"
_CONFIG_FOR_DOC = "BartConfig"

logger

<Logger __main__ (DEBUG)>

In [13]:
class RepilotBartForConditionalGeneration(BartForConditionalGeneration):
    
    def __init__(self, config):
        super().__init__(config)
        # Define the reversed position embedding module
        self.embed_reverse_positions = self._create_sinusoidal_position_embedding()
        logger.info("RepilotBartForConditionalGeneration initialized successfully.")
        

    def _create_sinusoidal_position_embedding(self)->torch.nn.Embedding:
        """Creates sinusoidal reversed position embeddings."""
        d_model = self.config.d_model
        max_len = self.config.max_position_embeddings
        padding_idx = self.config.pad_token_id

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(-2 * ((torch.arange(0, d_model) // 2) / d_model) * math.log(10000.0))

        pe[:, 0::2] = torch.sin(position * div_term[0::2])
        pe[:, 1::2] = torch.cos(position * div_term[1::2])

        embedding = nn.Embedding(num_embeddings=max_len,
                                embedding_dim=d_model,
                                padding_idx=padding_idx,
                                _weight=pe,
                                _freeze=True)
        
        logger.info("Iitialized sinusoidal position embedding successfully.")
        return embedding

    def _get_reverse_position_decoder_ids(self, decoder_input_ids:torch.LongTensor, target_len:Optional[torch.Tensor]=None, gaussian_noise=True)->torch.LongTensor:
        """Computes reversed position indices for the decoder inputs."""
        
        mask = ~torch.isin(decoder_input_ids,torch.tensor([self.config.pad_token_id]))

        reversed_position_input  = torch.ones(mask.shape) * mask 
        
        if target_len is None:
            reversed_position_input = torch.flip(torch.flip(reversed_position_input , dims=(1,)).cumsum(dim=1), dims=(1,))
            logger.debug(f"Shape of reversed_position_input {reversed_position_input.shape}")

        else:
            k = torch.arange(decoder_input_ids.size(-1), device=target_len.device)  # Create a tensor [0, 1, 2, ..., seq_len-1]
            reversed_position_input = F.relu(target_len.unsqueeze(1) - k)  # Broadcast subtraction over all positions
            logger.debug(f"reversed_position_input shape {reversed_position_input.shape}")
            print(reversed_position_input)

        if gaussian_noise:
            normal_round = torch.randn(reversed_position_input.shape) * mask
        else:
            normal_round = 0

        reversed_position_input = torch.abs(torch.round(reversed_position_input  + normal_round)).to(torch.long)
        logger.debug(f"reversed_position_input : {reversed_position_input}")
        return reversed_position_input
    
    def _get_position_decoder_ids(self, decoder_input_ids:torch.LongTensor)->torch.LongTensor:
        """Computes position indices for the decoder inputs."""
        
        mask = ~torch.isin(decoder_input_ids,torch.tensor([self.config.pad_token_id]))

        position_decoder_input_ids  = torch.ones(mask.shape).cumsum(dim=1) * mask 
        
        return position_decoder_input_ids.to(torch.long)
    

    
    def forward(
            self, 
            input_ids: Optional[torch.LongTensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            decoder_input_ids: Optional[torch.LongTensor] = None,
            decoder_attention_mask: Optional[torch.LongTensor] = None,
            head_mask: Optional[torch.Tensor] = None,
            decoder_head_mask: Optional[torch.Tensor] = None,
            cross_attn_head_mask: Optional[torch.Tensor] = None,
            encoder_outputs: Optional[List[torch.FloatTensor]] = None,
            past_key_values: Optional[List[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
            labels: Optional[torch.LongTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            target_len:Optional[torch.Tensor]=None,
        )-> Union[Tuple, Seq2SeqLMOutput]:
        """Overrides forward to inject reversed position embeddings into the decoder."""


        if input_ids is not None:
            logger.info(f"Forward called with input_ids shape: {input_ids.shape}")
            print(input_ids)
        else:
            logger.warning("Forward called with input_ids=None")

        if decoder_input_ids is not None:

            # Get reversed position indices
            reversed_position_input_ids = self._get_reverse_position_decoder_ids(decoder_input_ids, target_len)
            logger.info(f"reversed_position_decoder_input_ids shape: {reversed_position_input_ids.shape}")
            print(reversed_position_input_ids)
            # Compute reversed position embeddings
            reversed_position_embeddings = self.embed_reverse_positions(reversed_position_input_ids)
            logger.info(f"reversed_position_embeddings shape: {reversed_position_embeddings.shape}")


            # Get position indices
            position_input_ids = self._get_position_decoder_ids(decoder_input_ids)
            logger.info(f"position_decoder_input_ids shape: {position_input_ids.shape}")
            print(position_input_ids)
            
            # Compute position embeddings
            position_embeddings = self.model.decoder.embed_positions(position_input_ids)
            logger.info(f"position_embeddings shape: {position_embeddings.shape}")

            # Compute standard token embeddings
            decoder_inputs_embeds = self.model.decoder.embed_tokens(decoder_input_ids)  + position_embeddings + reversed_position_embeddings
            logger.debug(f"decoder_inputs_embeds shape: {decoder_inputs_embeds.shape}")

            decoder_input_ids = None
        else:
            logger.warning("Forward called with decoder_input_ids=None")



        # Call the original BART forward function with modified decoder inputs
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            labels=labels,
        )

        logger.info("Forward pass completed.")

        return outputs
    
tokenizer = AutoTokenizer.from_pretrained(_CHECKPOINT_FOR_DOC, clean_up_tokenization_spaces=True)
model = RepilotBartForConditionalGeneration.from_pretrained(_CHECKPOINT_FOR_DOC)

2025-03-21 11:23:02,406 - urllib3.connectionpool - DEBUG - Resetting dropped connection: huggingface.co
2025-03-21 11:23:02,614 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "HEAD /facebook/bart-large/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
2025-03-21 11:23:02,855 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "HEAD /facebook/bart-large/resolve/main/config.json HTTP/1.1" 200 0
2025-03-21 11:23:02,973 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "HEAD /facebook/bart-large/resolve/main/config.json HTTP/1.1" 200 0
2025-03-21 11:23:03,092 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "HEAD /facebook/bart-large/resolve/main/model.safetensors HTTP/1.1" 404 0
2025-03-21 11:23:03,586 - __main__ - INFO - Iitialized sinusoidal position embedding successfully.
2025-03-21 11:23:03,586 - __main__ - INFO - RepilotBartForConditionalGeneration initialized successfully.
Some weights of RepilotBartForConditionalGeneration were no

In [22]:
model.config

BartConfig {
  "_name_or_path": "facebook/bart-large",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_position_embeddings": 1024,
  "mode

In [None]:

input_ids


tensor([[ 2387,   964,    32, 50264,    53,    51,  3529,   350,   171, 33237,
             4,     2]])

## Verification partie forward

In [27]:
output = model(**batch)
print(output.logits.shape)

2025-03-19 15:17:42,095 - __main__ - INFO - Forward called with input_ids shape: torch.Size([2, 20])
2025-03-19 15:17:42,097 - __main__ - DEBUG - reversed_position_input shape torch.Size([2, 11])
2025-03-19 15:17:42,098 - __main__ - DEBUG - reversed_position_input : tensor([[ 7,  7,  6,  3,  3,  3,  2,  0,  0,  0,  0],
        [10, 10,  9,  7,  6,  6,  3,  4,  2,  2,  1]])
2025-03-19 15:17:42,099 - __main__ - INFO - reversed_position_decoder_input_ids shape: torch.Size([2, 11])
2025-03-19 15:17:42,100 - __main__ - INFO - reversed_position_embeddings shape: torch.Size([2, 11, 1024])
2025-03-19 15:17:42,100 - __main__ - INFO - position_decoder_input_ids shape: torch.Size([2, 11])
2025-03-19 15:17:42,101 - __main__ - INFO - position_embeddings shape: torch.Size([2, 11, 1024])
2025-03-19 15:17:42,102 - __main__ - DEBUG - decoder_inputs_embeds shape: torch.Size([2, 11, 1024])


tensor([[    0, 45089,    12, 24641,   154, 30634,    15,     5,  3480,    73,
         26339, 23969, 41616,  6890,   484,  2402,     2,     1,     1,     1],
        [    0,   100,   236,     7, 33942,     5,  8746,   594, 37215,     9,
             5,  1421,  8811,    15,     5, 41616,  1230,  6380,     4,     2]])
tensor([[ 7,  6,  5,  4,  3,  2,  1,  0,  0,  0,  0],
        [11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1]])
tensor([[ 7,  7,  6,  3,  3,  3,  2,  0,  0,  0,  0],
        [10, 10,  9,  7,  6,  6,  3,  4,  2,  2,  1]])
tensor([[ 1,  2,  3,  4,  5,  6,  7,  0,  0,  0,  0],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]])


2025-03-19 15:17:42,485 - __main__ - INFO - Forward pass completed.


torch.Size([2, 11, 50265])


In [28]:
TXT = "My friends are <mask> but they eat too many carbs."
input_ids = tokenizer.batch_encode_plus([TXT], return_tensors='pt')['input_ids']
output =  model(input_ids=input_ids)#, 
logits = output[0]
print(logits.shape)
masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
print(masked_index)
probs = logits[0, masked_index].softmax(dim=0)
print(probs.shape)
values, predictions = probs.topk(5)
print(values,predictions)
tokenizer.decode(predictions).split()

2025-03-19 15:18:14,872 - __main__ - INFO - Forward called with input_ids shape: torch.Size([1, 13])


tensor([[    0,  2387,   964,    32, 50264,    53,    51,  3529,   350,   171,
         33237,     4,     2]])


2025-03-19 15:18:15,134 - __main__ - INFO - Forward pass completed.


torch.Size([1, 13, 50265])
4
torch.Size([50265])
tensor([0.1021, 0.0865, 0.0459, 0.0372, 0.0362], grad_fn=<TopkBackward0>) tensor([205, 372,  70, 269, 182])


['good', 'great', 'all', 'really', 'very']

In [29]:
logits[0, masked_index].shape

torch.Size([50265])

## Verification de la partie generate

In [32]:
model.config

BartConfig {
  "_name_or_path": "facebook/bart-large",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_position_embeddings": 1024,
  "mode

In [102]:
outputs = model.generate(input_ids=batch["input_ids"],target_len=torch.tensor([5,3]), max_length=12)  # generate sequences
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))


2025-03-18 17:10:16,655 - __main__ - DEBUG - Shape of reversed_position_input torch.Size([8, 1])
2025-03-18 17:10:16,656 - __main__ - INFO - reversed_position_decoder_input_ids shape: torch.Size([8, 1])
2025-03-18 17:10:16,657 - __main__ - INFO - reversed_position_embeddings shape: torch.Size([8, 1, 1024])
2025-03-18 17:10:16,658 - __main__ - INFO - position_decoder_input_ids shape: torch.Size([8, 1])
2025-03-18 17:10:16,658 - __main__ - INFO - position_embeddings shape: torch.Size([8, 1, 1024])
2025-03-18 17:10:16,659 - __main__ - DEBUG - decoder_inputs_embeds shape: torch.Size([8, 1, 1024])
2025-03-18 17:10:16,742 - __main__ - INFO - Forward pass completed.
2025-03-18 17:10:16,748 - __main__ - DEBUG - Shape of reversed_position_input torch.Size([8, 1])
2025-03-18 17:10:16,749 - __main__ - INFO - reversed_position_decoder_input_ids shape: torch.Size([8, 1])
2025-03-18 17:10:16,750 - __main__ - INFO - reversed_position_embeddings shape: torch.Size([8, 1, 1024])
2025-03-18 17:10:16,750 

tensor([[2],
        [1],
        [2],
        [0],
        [1],
        [1],
        [0],
        [1]])
tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1]])
tensor([[0],
        [0],
        [2],
        [2],
        [1],
        [1],
        [1],
        [0]])
tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1]])
tensor([[2],
        [0],
        [3],
        [1],
        [2],
        [3],
        [0],
        [0]])
tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1]])


2025-03-18 17:10:16,883 - __main__ - INFO - Forward pass completed.
2025-03-18 17:10:16,890 - __main__ - DEBUG - Shape of reversed_position_input torch.Size([8, 1])
2025-03-18 17:10:16,891 - __main__ - INFO - reversed_position_decoder_input_ids shape: torch.Size([8, 1])
2025-03-18 17:10:16,892 - __main__ - INFO - reversed_position_embeddings shape: torch.Size([8, 1, 1024])
2025-03-18 17:10:16,893 - __main__ - INFO - position_decoder_input_ids shape: torch.Size([8, 1])
2025-03-18 17:10:16,896 - __main__ - INFO - position_embeddings shape: torch.Size([8, 1, 1024])
2025-03-18 17:10:16,897 - __main__ - DEBUG - decoder_inputs_embeds shape: torch.Size([8, 1, 1024])
2025-03-18 17:10:16,963 - __main__ - INFO - Forward pass completed.
2025-03-18 17:10:16,968 - __main__ - DEBUG - Shape of reversed_position_input torch.Size([8, 1])
2025-03-18 17:10:16,968 - __main__ - INFO - reversed_position_decoder_input_ids shape: torch.Size([8, 1])
2025-03-18 17:10:16,969 - __main__ - INFO - reversed_position

tensor([[0],
        [1],
        [1],
        [1],
        [1],
        [0],
        [2],
        [2]])
tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1]])
tensor([[0],
        [0],
        [1],
        [1],
        [1],
        [2],
        [1],
        [1]])
tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1]])
tensor([[1],
        [2],
        [2],
        [3],
        [1],
        [1],
        [0],
        [0]])
tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1]])


2025-03-18 17:10:17,092 - __main__ - INFO - reversed_position_decoder_input_ids shape: torch.Size([8, 1])
2025-03-18 17:10:17,094 - __main__ - INFO - reversed_position_embeddings shape: torch.Size([8, 1, 1024])
2025-03-18 17:10:17,094 - __main__ - INFO - position_decoder_input_ids shape: torch.Size([8, 1])
2025-03-18 17:10:17,096 - __main__ - INFO - position_embeddings shape: torch.Size([8, 1, 1024])
2025-03-18 17:10:17,096 - __main__ - DEBUG - decoder_inputs_embeds shape: torch.Size([8, 1, 1024])
2025-03-18 17:10:17,153 - __main__ - INFO - Forward pass completed.
2025-03-18 17:10:17,159 - __main__ - DEBUG - Shape of reversed_position_input torch.Size([8, 1])
2025-03-18 17:10:17,159 - __main__ - INFO - reversed_position_decoder_input_ids shape: torch.Size([8, 1])
2025-03-18 17:10:17,160 - __main__ - INFO - reversed_position_embeddings shape: torch.Size([8, 1, 1024])
2025-03-18 17:10:17,161 - __main__ - INFO - position_decoder_input_ids shape: torch.Size([8, 1])
2025-03-18 17:10:17,162 

tensor([[1],
        [3],
        [2],
        [2],
        [1],
        [2],
        [3],
        [2]])
tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1]])
tensor([[3],
        [1],
        [1],
        [1],
        [1],
        [0],
        [1],
        [1]])
tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1]])
tensor([[1],
        [1],
        [1],
        [1],
        [0],
        [2],
        [2],
        [3]])
tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1]])


2025-03-18 17:10:17,296 - __main__ - INFO - Forward pass completed.
2025-03-18 17:10:17,301 - __main__ - DEBUG - Shape of reversed_position_input torch.Size([8, 1])
2025-03-18 17:10:17,302 - __main__ - INFO - reversed_position_decoder_input_ids shape: torch.Size([8, 1])
2025-03-18 17:10:17,303 - __main__ - INFO - reversed_position_embeddings shape: torch.Size([8, 1, 1024])
2025-03-18 17:10:17,304 - __main__ - INFO - position_decoder_input_ids shape: torch.Size([8, 1])
2025-03-18 17:10:17,305 - __main__ - INFO - position_embeddings shape: torch.Size([8, 1, 1024])
2025-03-18 17:10:17,305 - __main__ - DEBUG - decoder_inputs_embeds shape: torch.Size([8, 1, 1024])
2025-03-18 17:10:17,371 - __main__ - INFO - Forward pass completed.
2025-03-18 17:10:17,377 - __main__ - DEBUG - Shape of reversed_position_input torch.Size([8, 1])
2025-03-18 17:10:17,377 - __main__ - INFO - reversed_position_decoder_input_ids shape: torch.Size([8, 1])
2025-03-18 17:10:17,378 - __main__ - INFO - reversed_position

tensor([[0],
        [0],
        [0],
        [0],
        [0],
        [1],
        [1],
        [1]])
tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1]])
tensor([[1],
        [2],
        [1],
        [3],
        [4],
        [3],
        [0],
        [1]])
tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1]])
Generated: FineFineFinefineFineFineTheFineFine


In [19]:
input_ids.shape

torch.Size([1, 6])

## REpilotBartForGeneration test 2

In [26]:
from transformers import (
    GenerationMixin,
    BartConfig,
    BartModel,
    BartPreTrainedModel,                         
)   

def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    # replace possible -100 values in labels by `pad_token_id`
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids

In [36]:

class RepilotBartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
    base_model_prefix = "model"
    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
    _keys_to_ignore_on_load_missing = ["final_logits_bias"]

    def __init__(self, config: BartConfig):
        super().__init__(config)
        self.model = BartModel(config)
        self.embed_reverse_positions = self._create_sinusoidal_position_embedding()
        self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
        self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)

        # Initialize weights and apply final processing
        self.post_init()
    
    def _create_sinusoidal_position_embedding(self)->torch.nn.Embedding:
        """Creates sinusoidal reversed position embeddings."""
        d_model = self.config.d_model
        max_len = self.config.max_position_embeddings
        padding_idx = self.config.pad_token_id

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(-2 * (torch.arange(0, d_model) // 2) / d_model * math.log(10000.0))

        pe[:, 0::2] = torch.sin(position * div_term[0::2])
        pe[:, 1::2] = torch.cos(position * div_term[1::2])

        embedding = nn.Embedding(num_embeddings=max_len,
                                embedding_dim=d_model,
                                padding_idx=padding_idx,
                                _weight=pe,
                                _freeze=True)
        logger.info("Iitialized sinusoidal position embedding successfully.")
        return embedding

    def get_encoder(self):
        return self.model.get_encoder()

    def get_decoder(self):
        return self.model.get_decoder()

    def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
        new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
        self._resize_final_logits_bias(new_embeddings.weight.shape[0])
        return new_embeddings

    def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
        old_num_tokens = self.final_logits_bias.shape[-1]
        if new_num_tokens <= old_num_tokens:
            new_bias = self.final_logits_bias[:, :new_num_tokens]
        else:
            extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
            new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
        self.register_buffer("final_logits_bias", new_bias)

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings
    
    def _get_reverse_position_decoder_ids(self, decoder_input_ids:torch.LongTensor, target_len:Optional[torch.Tensor]=None, gaussian_noise=True)->torch.LongTensor:
        """Computes reversed position indices for the decoder inputs."""
        
        mask = ~torch.isin(decoder_input_ids,torch.tensor([self.config.pad_token_id]))

        reversed_position_input  = torch.ones(mask.shape) * mask 
        
        if target_len is None:
            reversed_position_input = torch.flip(torch.flip(reversed_position_input , dims=(1,)).cumsum(dim=1), dims=(1,))
            logger.debug(f"Shape of reversed_position_input {reversed_position_input.shape}")

        else:
            for k in range(decoder_input_ids.size(-1)):
                reversed_position_input[:,k] = F.relu(target_len -k)
                logger.debug(f"reversed_position_input[:,{k}] {reversed_position_input[:,k]}")

        if gaussian_noise:
            normal_round = torch.randn(reversed_position_input.shape) * mask
        else:
            normal_round = 0

        return torch.abs(torch.round(reversed_position_input  + normal_round)).to(torch.long)
    
    def _get_position_decoder_ids(self, decoder_input_ids:torch.LongTensor)->torch.LongTensor:
        """Computes position indices for the decoder inputs."""
        
        mask = ~torch.isin(decoder_input_ids,torch.tensor([self.config.pad_token_id]))

        position_decoder_input_ids  = torch.ones(mask.shape) * mask 
        
        return position_decoder_input_ids.cumsum(dim=1).to(torch.long)
    

    # @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
    # @add_end_docstrings(BART_GENERATION_EXAMPLE)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        target_len:Optional[torch.Tensor]=None,
    ) -> Union[Tuple, Seq2SeqLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if labels is not None:
            if use_cache:
                logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
            use_cache = False
            if decoder_input_ids is None and decoder_inputs_embeds is None:
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                )

        if decoder_input_ids is not None:

            # Get reversed position indices
            reversed_position_input_ids = self._get_reverse_position_decoder_ids(decoder_input_ids, target_len)
            logger.info(f"reversed_position_input_ids shape: {reversed_position_input_ids.shape}")
            
            # Compute reversed position embeddings
            reversed_position_embeddings = self.embed_reverse_positions(reversed_position_input_ids)
            logger.info(f"reversed_position_embeddings shape: {reversed_position_embeddings.shape}")

            # Get position indices
            position_input_ids = self._get_position_decoder_ids(decoder_input_ids)
            logger.info(f"position_input_ids shape: {position_input_ids.shape}")
            
            # Compute position embeddings
            position_embeddings = self.model.decoder.embed_positions(position_input_ids)
            logger.info(f"position_embeddings shape: {position_embeddings.shape}")

            # Compute standard token embeddings
            decoder_inputs_embeds = self.model.decoder.embed_tokens(decoder_input_ids) + reversed_position_embeddings + position_embeddings
            logger.debug(f"decoder_inputs_embeds shape: {decoder_inputs_embeds.shape}")

            decoder_input_ids = None
   

        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        lm_logits = self.lm_head(outputs[0])
        lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)

        masked_lm_loss = None
        if labels is not None:
            labels = labels.to(lm_logits.device)
            loss_fct = CrossEntropyLoss()
            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))

        if not return_dict:
            output = (lm_logits,) + outputs[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        return Seq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )

    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        past_key_values=None,
        attention_mask=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        # cut decoder_input_ids if past_key_values is used
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[2]

            # Some generation methods already pass only the last input ID
            if decoder_input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                # Default to old behavior: keep only final ID
                remove_prefix_length = decoder_input_ids.shape[1] - 1

            decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

        return {
            "input_ids": None,  # encoder_outputs is defined. input_ids not needed
            "encoder_outputs": encoder_outputs,
            "past_key_values": past_key_values,
            "decoder_input_ids": decoder_input_ids,
            "attention_mask": attention_mask,
            "decoder_attention_mask": decoder_attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,  # change this to avoid caching (presumably for debugging)
        }

    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            # cached cross_attention states don't have to be reordered -> they are always the same
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
                + layer_past[2:],
            )
        return reordered_past

tokenizer = AutoTokenizer.from_pretrained(_CHECKPOINT_FOR_DOC, clean_up_tokenization_spaces=True)
model = RepilotBartForConditionalGeneration.from_pretrained(_CHECKPOINT_FOR_DOC)

2025-03-18 14:56:14,328 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "HEAD /facebook/bart-large/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
2025-03-18 14:56:14,559 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "HEAD /facebook/bart-large/resolve/main/config.json HTTP/1.1" 200 0
2025-03-18 14:56:14,696 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "HEAD /facebook/bart-large/resolve/main/config.json HTTP/1.1" 200 0
2025-03-18 14:56:14,810 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "HEAD /facebook/bart-large/resolve/main/model.safetensors HTTP/1.1" 404 0
2025-03-18 14:56:15,222 - __main__ - INFO - Iitialized sinusoidal position embedding successfully.
Some weights of RepilotBartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large and are newly initialized: ['embed_reverse_positions.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for p

In [37]:
TXT = "My friends are <mask> but they eat too many carbs."
input_ids = tokenizer.batch_encode_plus([TXT], return_tensors='pt')['input_ids']
logits = model(input_ids)[0]
masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
probs = logits[0, masked_index].softmax(dim=0)
values, predictions = probs.topk(5)
tokenizer.decode(predictions).split()

['good', 'great', 'all', 'really', 'very']

In [38]:
for batch in train_loader:
    batch.pop('input_len')
    break

output = model(**batch)

2025-03-18 14:56:21,280 - __main__ - DEBUG - reversed_position_input[:,0] tensor([64., 48., 57., 46.])
2025-03-18 14:56:21,281 - __main__ - DEBUG - reversed_position_input[:,1] tensor([63., 47., 56., 45.])
2025-03-18 14:56:21,283 - __main__ - DEBUG - reversed_position_input[:,2] tensor([62., 46., 55., 44.])
2025-03-18 14:56:21,284 - __main__ - DEBUG - reversed_position_input[:,3] tensor([61., 45., 54., 43.])
2025-03-18 14:56:21,285 - __main__ - DEBUG - reversed_position_input[:,4] tensor([60., 44., 53., 42.])
2025-03-18 14:56:21,286 - __main__ - DEBUG - reversed_position_input[:,5] tensor([59., 43., 52., 41.])
2025-03-18 14:56:21,287 - __main__ - DEBUG - reversed_position_input[:,6] tensor([58., 42., 51., 40.])
2025-03-18 14:56:21,288 - __main__ - DEBUG - reversed_position_input[:,7] tensor([57., 41., 50., 39.])
2025-03-18 14:56:21,289 - __main__ - DEBUG - reversed_position_input[:,8] tensor([56., 40., 49., 38.])
2025-03-18 14:56:21,290 - __main__ - DEBUG - reversed_position_input[:,9]

In [42]:
input_context = 'Legal My neighbor is'  # "Legal" is one of the control codes for ctrl
input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2, target_len=torch.tensor([3]))  # generate sequences
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))


2025-03-18 15:01:27,637 - __main__ - DEBUG - Shape of reversed_position_input torch.Size([4, 1])
2025-03-18 15:01:27,637 - __main__ - INFO - reversed_position_input_ids shape: torch.Size([4, 1])
2025-03-18 15:01:27,638 - __main__ - INFO - reversed_position_embeddings shape: torch.Size([4, 1, 1024])
2025-03-18 15:01:27,639 - __main__ - INFO - position_input_ids shape: torch.Size([4, 1])
2025-03-18 15:01:27,639 - __main__ - INFO - position_embeddings shape: torch.Size([4, 1, 1024])
2025-03-18 15:01:27,640 - __main__ - DEBUG - decoder_inputs_embeds shape: torch.Size([4, 1, 1024])
2025-03-18 15:01:27,706 - __main__ - DEBUG - Shape of reversed_position_input torch.Size([4, 1])
2025-03-18 15:01:27,707 - __main__ - INFO - reversed_position_input_ids shape: torch.Size([4, 1])
2025-03-18 15:01:27,708 - __main__ - INFO - reversed_position_embeddings shape: torch.Size([4, 1, 1024])
2025-03-18 15:01:27,709 - __main__ - INFO - position_input_ids shape: torch.Size([4, 1])
2025-03-18 15:01:27,709 - _

Generated: LegalLegalLegallegalLegalLegalMyLegalLegalPoliticalLegalLegal10LegalLegalPersonalLegalLegalSocialLegalLegalILegalLegalCourtLegalLegalSexLegalLegal"LegalLegalMilitaryLegalLegalHappyLegalLegal.LegalLegalReadLegalLegalSexualLegal


In [43]:
generated_ids = model.generate(
              input_ids = batch["input_ids"],
              attention_mask = batch["attention_mask"],
              target_len = batch["target_len"],
              **model.config.task_specific_params["summarization"])

2025-03-18 15:03:17,740 - __main__ - DEBUG - Shape of reversed_position_input torch.Size([16, 1])
2025-03-18 15:03:17,741 - __main__ - INFO - reversed_position_input_ids shape: torch.Size([16, 1])
2025-03-18 15:03:17,741 - __main__ - INFO - reversed_position_embeddings shape: torch.Size([16, 1, 1024])
2025-03-18 15:03:17,742 - __main__ - INFO - position_input_ids shape: torch.Size([16, 1])
2025-03-18 15:03:17,743 - __main__ - INFO - position_embeddings shape: torch.Size([16, 1, 1024])
2025-03-18 15:03:17,744 - __main__ - DEBUG - decoder_inputs_embeds shape: torch.Size([16, 1, 1024])
2025-03-18 15:03:20,043 - __main__ - DEBUG - Shape of reversed_position_input torch.Size([16, 1])
2025-03-18 15:03:20,043 - __main__ - INFO - reversed_position_input_ids shape: torch.Size([16, 1])
2025-03-18 15:03:20,044 - __main__ - INFO - reversed_position_embeddings shape: torch.Size([16, 1, 1024])
2025-03-18 15:03:20,044 - __main__ - INFO - position_input_ids shape: torch.Size([16, 1])
2025-03-18 15:03:

In [45]:
batch["input_ids"].shape

torch.Size([4, 1024])