# Bart large configuration

In [None]:
### MODULES ###

import sys,os
import tqdm
import csv
from datetime import datetime 
import numpy as np
import pandas as pd
import json

from typing import Optional


import math

import matplotlib.pyplot as plt


from datasets import load_dataset, Dataset

import torch
from torch import cuda
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import torch.nn.functional as F


# Load the ROUGE metric
import evaluate

from transformers import AutoTokenizer, BartForConditionalGeneration

In [2]:

NUM_PROCS = os.cpu_count() 

print("NUM_PROCS = " ,NUM_PROCS)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)



NUM_PROCS =  12
cuda


In [3]:

SEED = 42
NUM_LOADER = 4 #config['config_machine']["NUM_LOADER"] #depends of the number of thread 


# Set random seeds and deterministic pytorch for reproducibility
torch.manual_seed(SEED) # pytorch random seed
np.random.seed(SEED) # numpy random seed
torch.backends.cudnn.deterministic = True

# Load dataset CNN daily

In [4]:
# Load CNN/DailyMail dataset
dataset = load_dataset("cnn_dailymail", "3.0.0")

## Comment this part for the real training time :

percentage = 0.05

for split in dataset: 
    dataset[split] = dataset[split].shuffle(seed=SEED).select(range(int(len(dataset[split]) * percentage)))

# Check the dataset structure
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 14355
    })
    validation: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 668
    })
    test: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 574
    })
})


# load the model and tokenizer 

In [5]:
### Load model ###
MODEL_HUB = 'facebook/bart-large'
# Load Model and Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_HUB, clean_up_tokenization_spaces=True)
model = BartForConditionalGeneration.from_pretrained(MODEL_HUB, forced_bos_token_id=0)
print(tokenizer.model_max_length)
print(type(tokenizer))
print(type(model))
print(tokenizer)

1024
<class 'transformers.models.bart.tokenization_bart_fast.BartTokenizerFast'>
<class 'transformers.models.bart.modeling_bart.BartForConditionalGeneration'>
BartTokenizerFast(name_or_path='facebook/bart-large', vocab_size=50265, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	50264: AddedToken("<mask>", rstrip=False, lstrip

In [59]:
# def len_distrib(batch):

#     len_articles = []
#     len_highlights = []
    
#     for article, highlight in zip(batch["article"], batch["highlights"]):
#         len_articles.append(len(tokenizer(article, truncation=False)["input_ids"]))
#         len_highlights.append(len(tokenizer(highlight, truncation=False)["input_ids"]))


#     source = tokenizer(batch["article"],truncation=True, max_length=tokenizer.model_max_length,padding='max_length')
#     resume = tokenizer(batch["highlights"],truncation=True, max_length=tokenizer.model_max_length,padding='max_length')

#     return {
#         'input_ids': source['input_ids'], 
#         'input_mask': source['attention_mask'],
#         'input_len': len_articles,
#         'target_ids': resume['input_ids'], 
#         'target_mask': resume['attention_mask'],
#         'target_len': len_highlights
#         }

def len_distrib(batch):

    len_articles = []
    len_highlights = []
    
    for article, highlight in zip(batch["article"], batch["highlights"]):
        len_articles.append(len(tokenizer(article, truncation=False)["input_ids"])-1) #Add -1 to skip the <bos> token 
        len_highlights.append(len(tokenizer(highlight, truncation=False)["input_ids"])-1) #Add -1 to skip the <bos> token 


    source = tokenizer(batch["article"],truncation=True, max_length=tokenizer.model_max_length)
    resume = tokenizer(batch["highlights"],truncation=True, max_length=tokenizer.model_max_length)

    return {
        'input_ids': source['input_ids'], 
        'input_mask': source['attention_mask'],
        'input_len': len_articles,
        'target_ids': resume['input_ids'], 
        'target_mask': resume['attention_mask'],
        'target_len': len_highlights
        }


In [60]:
dataset = dataset.map(len_distrib,num_proc=NUM_PROCS,batched=True,batch_size=64)


Map (num_proc=12):   0%|          | 0/14355 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1105 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1153 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1210 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1498 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2187 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Map (num_proc=12):   0%|          | 0/668 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1689 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1380 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1580 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1048 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1088 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Map (num_proc=12):   0%|          | 0/574 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1647 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1592 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1201 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1153 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1044 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

In [61]:

# Define the custom collate function
def collate_fn(batch):
    """
    Custom collate function that add padding for each batch.
    """

    # Pad the tokenized content
    input_ids = [torch.tensor(item['input_ids'], dtype=torch.long) for item in batch]
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    
    attention_mask = [torch.tensor(item['input_mask'], dtype=torch.long) for item in batch]
    attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)

    decoder_input_ids  = [torch.tensor(item['target_ids'][:-1], dtype=torch.long) for item in batch]
    decoder_input_ids = pad_sequence(decoder_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)     
    
    decoder_attention_mask = [torch.tensor(item['target_mask'][:-1], dtype=torch.long) for item in batch]
    decoder_attention_mask = pad_sequence(decoder_attention_mask, batch_first=True, padding_value=0)
    
    input_len = torch.tensor([item['input_len'] for item in batch], dtype=torch.long)

    target_len = torch.tensor([item['target_len'] for item in batch], dtype=torch.long)

    # Labels should be the same as decoder_input_ids (BART-style training)
    labels = [torch.tensor(item['target_ids'][1:], dtype=torch.long) for item in batch]
    labels = pad_sequence(labels, batch_first=True, padding_value=tokenizer.pad_token_id)  
    labels[labels == tokenizer.pad_token_id] = -100  # Ignore padding in loss computation

    return {
        'input_ids':input_ids,
        'attention_mask':attention_mask,
        'decoder_input_ids':decoder_input_ids,
        'decoder_attention_mask':decoder_attention_mask,
        'labels': labels,
        'input_len': input_len,
        'target_len': target_len
    }

In [62]:

train_params = {
    'batch_size': 4,
    'shuffle': True,
    'collate_fn':collate_fn,
    'num_workers': NUM_LOADER,
    'pin_memory': True  #  Enables faster GPU transfers
    }

eval_params = {
    'batch_size': 4,
    'shuffle': False,
    'collate_fn':collate_fn,
    'num_workers': NUM_LOADER,
    'pin_memory': True  #  Enables faster GPU transfers
    }


# This will be used down for training and validation stage for the model.
train_loader = DataLoader(dataset["train"], **train_params)
eval_loader = DataLoader(dataset["validation"], **eval_params)

for batch in train_loader:
    print(batch)
    break


{'input_ids': tensor([[    0,   510, 36995,  ...,   487,     4,     2],
        [    0,  2765,   479,  ...,  1515,  2040,     2],
        [    0,   713,    16,  ...,     1,     1,     1],
        [    0,  2765,   479,  ...,     1,     1,     1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'decoder_input_ids': tensor([[    0,   791,     4,   104,     4, 12176,    35,  4637,   476,    12,
         12689,   432,    64,    75,   173,    19,   270,  1738,  7148,  5084,
           479, 50118, 20645,    12, 12689,   432,    16, 13244,   142,     9,
          4464,    81,   797,     9, 20402,   479, 50118,  1301, 39329,  2419,
          1855,  1168,  3843, 11004,     6,   776,  1486,   479, 50118,   448,
          3252,  5084,    34, 18993,   758,  1519,    13,   123,     7,  1149,
           159,   479,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,

# Create reverse embedding for the Bart model 

In [12]:

# Get all parent classes in the MRO (Method Resolution Order)
print(BartForConditionalGeneration.__mro__)
tokenizer.bos_token_id

(<class 'transformers.models.bart.modeling_bart.BartForConditionalGeneration'>, <class 'transformers.models.bart.modeling_bart.BartPreTrainedModel'>, <class 'transformers.modeling_utils.PreTrainedModel'>, <class 'torch.nn.modules.module.Module'>, <class 'transformers.modeling_utils.ModuleUtilsMixin'>, <class 'transformers.generation.utils.GenerationMixin'>, <class 'transformers.utils.hub.PushToHubMixin'>, <class 'transformers.integrations.peft.PeftAdapterMixin'>, <class 'object'>)


0

In [12]:
print(tokenizer.pad_token_id)

mask = ~torch.isin(batch["decoder_input_ids"],torch.tensor([tokenizer.pad_token_id,tokenizer.eos_token_id])) # mask with 0 where a pad_id is present
print(mask.shape)

reversed_position_input  = torch.ones(mask.shape) * mask # [1,1,1,0,0] 

reversed_position_input = torch.flip(torch.flip(reversed_position_input , dims=(1,)).cumsum(dim=1), dims=(1,))  

print(reversed_position_input.shape)

normal_round = torch.randn(batch["decoder_input_ids"].shape) * mask

reversed_position_input = torch.abs(torch.round(reversed_position_input  + normal_round)).to(torch.long) #add a gausian noise and converte to long

print(reversed_position_input.shape)

# input_decoder_position_embedding = model.model.decoder.embed_positions(reversed_position_input)

# input_decoder_position_embedding

1
torch.Size([4, 99])
torch.Size([4, 99])
torch.Size([4, 99])


In [64]:
batch["decoder_input_ids"]

tensor([[    0,   791,     4,   104,     4, 12176,    35,  4637,   476,    12,
         12689,   432,    64,    75,   173,    19,   270,  1738,  7148,  5084,
           479, 50118, 20645,    12, 12689,   432,    16, 13244,   142,     9,
          4464,    81,   797,     9, 20402,   479, 50118,  1301, 39329,  2419,
          1855,  1168,  3843, 11004,     6,   776,  1486,   479, 50118,   448,
          3252,  5084,    34, 18993,   758,  1519,    13,   123,     7,  1149,
           159,   479,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1],
        [    0, 34079,  7600,     6,  4059,     6,    21,    10,  9463,  1027,
            11,   468, 10149,     8

## Code pour les index de positions inversées 

In [None]:

relu = nn.ReLU()

def _reverse_position_embedding(input_ids:torch.Tensor,
                                target_len:Optional[torch.Tensor]=None)->torch.Tensor:
    
    mask = ~torch.isin(input_ids,torch.tensor([tokenizer.pad_token_id])) # mask with 0 where a pad_id is present

    reversed_position_input  = torch.ones(mask.shape) * mask # Put 1 where there are token index and 0 where there are pad index[1,1,1,0,0] 
    
    if target_len is None:
        reversed_position_input = torch.flip(torch.flip(reversed_position_input , dims=(1,)).cumsum(dim=1), dims=(1,)) 
        #[[ 54,  53, ...,0,  0],[100, 101,...,1]]
        #print(reversed_position_input.shape)
    else:
        for k in range(input_ids.size(-1)):
            reversed_position_input[:,k] = relu(target_len -k)

    #Add a gaussian noise
    normal_round = torch.randn(reversed_position_input.shape) * mask

    reversed_position_input = torch.abs(torch.round(reversed_position_input  + normal_round)).to(torch.long) #add a gausian noise and converte to long


    return reversed_position_input

reversed_position_input = _reverse_position_embedding(batch["decoder_input_ids"])
print(reversed_position_input.shape)

torch.Size([4, 111])


## Code pour les embeddings de position sinus/cosinus

In [73]:
# # dir(model.model.decoder.layernorm_embedding)

d_model = model.config.d_model
max_len = tokenizer.model_max_length

pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp((-2*torch.arange(0, d_model)//2)/d_model * math.log(10000.0) )

print(div_term.shape)

print((position*div_term).shape)


pe[:, 0::2] = torch.sin(position*div_term[0::2] )
pe[:, 1::2] = torch.cos(position*div_term[1::2])

embed_reverse_positions = nn.Embedding(num_embeddings=tokenizer.model_max_length,
                                              embedding_dim=d_model,
                                              padding_idx=tokenizer.pad_token_id,
                                              _weight=pe,
                                              _freeze=True)

print(embed_reverse_positions(reversed_position_input).shape)

torch.Size([1024])
torch.Size([1024, 1024])
torch.Size([4, 111, 1024])


In [74]:
print(model.model.decoder)
token_embeddings = model.model.decoder.embed_tokens(batch["decoder_input_ids"]) 
print(token_embeddings.shape)
position_embeddings = model.model.decoder.embed_positions(batch["decoder_input_ids"]) 
print(position_embeddings.shape)
repilot_embeddings = embed_reverse_positions(_reverse_position_embedding(batch["decoder_input_ids"]))
print(repilot_embeddings.shape)

BartDecoder(
  (embed_tokens): BartScaledWordEmbedding(50265, 1024, padding_idx=1)
  (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
  (layers): ModuleList(
    (0-11): 12 x BartDecoderLayer(
      (self_attn): BartSdpaAttention(
        (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
      )
      (activation_fn): GELUActivation()
      (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (encoder_attn): BartSdpaAttention(
        (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (out_proj): Linear(in_features=1024, out_features=1024, bias=

In [75]:
decoder_inputs_embeds = token_embeddings + position_embeddings + repilot_embeddings
output = model(input_ids=batch["input_ids"], 
               attention_mask=batch["attention_mask"],
               decoder_attention_mask=batch["decoder_attention_mask"], 
               decoder_inputs_embeds=decoder_inputs_embeds,
               labels=batch["labels"])


In [None]:
import torch
import math
from torch import nn
from transformers import BartForConditionalGeneration, AutoTokenizer

class RepilotBartForConditionalGeneration(BartForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)
        self.relu = nn.ReLU()
        # Define the reversed position embedding module
        self.embed_reverse_positions = self._create_reverse_position_embedding(config)
        
    
    @staticmethod
    def _create_reverse_position_embedding(config):
        """Creates sinusoidal reversed position embeddings."""
        d_model = config.d_model
        max_len = config.max_position_embeddings  

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(-2 * (torch.arange(0, d_model) // 2) / d_model * math.log(10000.0))

        pe[:, 0::2] = torch.sin(position * div_term[0::2])
        pe[:, 1::2] = torch.cos(position * div_term[1::2])

        embedding = nn.Embedding(num_embeddings=tokenizer.model_max_length,
                                              embedding_dim=d_model,
                                              padding_idx=tokenizer.pad_token_id,
                                              _weight=pe,
                                              _freeze=True)
        return embedding

    def _reverse_position_embedding(self, 
                                    input_ids:torch.Tensor, 
                                    target_len:Optional[torch.Tensor]=None)->torch.Tensor:
        """Computes reversed position indices for the decoder inputs."""
        mask = ~torch.isin(input_ids,torch.tensor([tokenizer.pad_token_id]))

        reversed_position_input  = torch.ones(mask.shape) * mask 
        
        if target_len is None:
            reversed_position_input = torch.flip(torch.flip(reversed_position_input , dims=(1,)).cumsum(dim=1), dims=(1,)) 
        else:
            for k in range(input_ids.size(-1)):
                reversed_position_input[:,k] = self.relu(target_len -k)

        #Add a gaussian noise
        normal_round = torch.randn(reversed_position_input.shape) * mask
        reversed_position_input = torch.abs(torch.round(reversed_position_input  + normal_round)).to(torch.long) #add a gausian noise and converte to long


        return reversed_position_input

    def forward(self, 
                input_ids:torch.Tensor, 
                attention_mask:Optional[torch.Tensor]=None, 
                decoder_input_ids:Optional[torch.Tensor]=None, 
                decoder_attention_mask:Optional[torch.Tensor]=None,
                target_len:Optional[torch.Tensor]=None,
                labels:Optional[torch.Tensor]=None, 
                **kwargs):
        """Overrides forward to inject reversed position embeddings into the decoder."""
   
        # Compute reversed position indices
        reversed_position_input = self._reverse_position_embedding(decoder_input_ids, target_len)
        
        # Get reversed position embeddings
        reversed_position_embeddings = self.embed_reverse_positions(reversed_position_input)

        #Get position embeddings 
        position_embeddings = self.model.decoder.embed_positions(decoder_input_ids)

        # Compute standard token embeddings
        decoder_inputs_embeds = self.model.decoder.embed_tokens(decoder_input_ids) + reversed_position_embeddings + position_embeddings

        # Call the original BART forward function with modified decoder inputs
        return super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_inputs_embeds=decoder_inputs_embeds,
            decoder_attention_mask=decoder_attention_mask,
            labels=labels,
            **kwargs
        )


In [None]:

MODEL_HUB = "facebook/bart-large"
tokenizer = AutoTokenizer.from_pretrained(MODEL_HUB, clean_up_tokenization_spaces=True)
model = RepilotBartForConditionalGeneration.from_pretrained(MODEL_HUB)



Some weights of RepilotBartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large and are newly initialized: ['embed_reverse_positions.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [127]:

_CHECKPOINT_FOR_DOC = "facebook/bart-large"
_CONFIG_FOR_DOC = "BartConfig"


class RepilotBartForConditionalGeneration(BartForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)
        # Define the reversed position embedding module
        self.embed_reverse_positions = self._create_reverse_position_embedding()
        print("RepilotBartForConditionalGeneration initialized successfully.")
        

    def _create_reverse_position_embedding(self):
        """Creates sinusoidal reversed position embeddings."""
        d_model = self.config.d_model
        max_len = self.config.max_position_embeddings
        padding_idx = self.config.pad_token_id

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(-2 * (torch.arange(0, d_model) // 2) / d_model * math.log(10000.0))

        pe[:, 0::2] = torch.sin(position * div_term[0::2])
        pe[:, 1::2] = torch.cos(position * div_term[1::2])

        embedding = nn.Embedding(num_embeddings=max_len,
                                              embedding_dim=d_model,
                                              padding_idx=padding_idx,
                                              _weight=pe,
                                              _freeze=True)
        return embedding

    def _reverse_position_embedding(self, 
                                    input_ids:torch.Tensor, 
                                    target_len:Optional[torch.Tensor]=None)->torch.Tensor:
        """Computes reversed position indices for the decoder inputs."""
        mask = ~torch.isin(input_ids,torch.tensor([self.config.pad_token_id]))

        reversed_position_input  = torch.ones(mask.shape) * mask 
        
        if target_len is None:
            reversed_position_input = torch.flip(torch.flip(reversed_position_input , dims=(1,)).cumsum(dim=1), dims=(1,))
            print(reversed_position_input)
        else:
            for k in range(input_ids.size(-1)):
                reversed_position_input[:,k] = F.relu(target_len -k)
                print(reversed_position_input[:,k])

        #Add a gaussian noise
        normal_round = torch.randn(reversed_position_input.shape) * mask
        reversed_position_input = torch.abs(torch.round(reversed_position_input  + normal_round)).to(torch.long) #add a gausian noise and converte to long


        return reversed_position_input

    def forward(self, 
                input_ids:Optional[torch.Tensor]=None, 
                attention_mask:Optional[torch.Tensor]=None, 
                decoder_input_ids:Optional[torch.Tensor]=None, 
                decoder_attention_mask:Optional[torch.Tensor]=None,
                target_len:Optional[torch.Tensor]=None,
                decoder_inputs_embeds:Optional[torch.Tensor]=None,
                labels:Optional[torch.Tensor]=None, 
                **kwargs):
        """Overrides forward to inject reversed position embeddings into the decoder."""

        if input_ids is not None:
            print(f"Forward called with input_ids shape: {input_ids.shape}")
            print(input_ids)
        else:
            print("Forward called with input_ids=None")
            # logger.warning("Forward called with input_ids=None")

        if decoder_input_ids is not None:
            # Compute reversed position indices
            reversed_position_input = self._reverse_position_embedding(decoder_input_ids, target_len)
            # logger.debug(f"decoder_input_ids shape: {decoder_input_ids.shape}")
            print(reversed_position_input)
            
            # Get reversed position embeddings
            reversed_position_embeddings = self.embed_reverse_positions(reversed_position_input)

            #Get position embeddings 
            #position_embeddings = self.model.decoder.embed_positions(decoder_input_ids)

            # Compute standard token embeddings
            decoder_inputs_embeds = self.model.decoder.embed_tokens(decoder_input_ids) + reversed_position_embeddings #+ position_embeddings
            # logger.debug(f"decoder_inputs_embeds shape: {decoder_inputs_embeds.shape}")

        # Call the original BART forward function with modified decoder inputs
        outputs =  super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_inputs_embeds=decoder_inputs_embeds,
            decoder_attention_mask=decoder_attention_mask,
            labels=labels,
            **kwargs
        )

        # logger.info("Forward pass completed.")
        return outputs
    

tokenizer = AutoTokenizer.from_pretrained(_CHECKPOINT_FOR_DOC, clean_up_tokenization_spaces=True)
model = RepilotBartForConditionalGeneration.from_pretrained(_CHECKPOINT_FOR_DOC)

##SUMMARY TASK

ARTICLE_TO_SUMMARIZE = "PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions.The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."

inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=model.config.max_position_embeddings, truncation=True, return_tensors="pt")

target_len = torch.tensor([8])

# Generate Summary
summary_ids = model.generate(input=inputs["input_ids"], num_beams = 2, max_length = 4, target_len=torch.tensor([8]))
print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))


RepilotBartForConditionalGeneration initialized successfully.


Some weights of RepilotBartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large and are newly initialized: ['embed_reverse_positions.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ValueError: The following `model_kwargs` are not used by the model: ['input'] (note: typos in the generate arguments will also show up in this list)

Forward called with input_ids shape: torch.Size([1, 56])
tensor([[    0,  8332,   947,   717,  2305,    24,  1768,     5,   909,  4518,
            11,  1263,     7,  5876,    13,   239,  2372,  2876,  3841,  1274,
             4,   133,  4374,    16,     7,  1888,     5,   810,     9, 12584,
             4,  9221,  5735,  7673,   916,    58,  1768,     7,    28,  2132,
            30,     5,  2572, 10816,    61,    58,   421,     7,    94,   149,
            23,   513, 15372,  3859,     4,     2]])


In [81]:
generated_ids = model.generate(
              input_ids = batch["input_ids"],
              attention_mask = batch["attention_mask"],
              target_len=batch["target_len"],        
              )   
            
# Compute ROUGE scores here
generated_txt = [text for text in tokenizer.batch_decode(generated_ids, skip_special_tokens=True)]

In [109]:
TXT = "My friends are <mask> but they eat too many carbs."
input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]

masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero()

masked_index

tensor([[4]])