In [1]:
# !pip install -U  torch transformers apex torchsummary jupyter ipywidgets

In [2]:
import copy
import math
import random
import torch
import warnings
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig
from transformers.file_utils import (
    add_code_sample_docstrings,
    add_end_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    replace_return_docstrings,
)
from transformers.modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
    Seq2SeqQuestionAnsweringModelOutput,
    Seq2SeqSequenceClassifierOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.models.bart.modeling_bart import (
    BartLearnedPositionalEmbedding,
    BartDecoderLayer,
    BartPreTrainedModel,
)
from transformers.utils import logging
from typing import List, Optional, Tuple, Union
from transformers.modeling_attn_mask_utils import (
    _prepare_4d_attention_mask,
    _prepare_4d_causal_attention_mask,
)
from transformers.activations import ACT2FN

  from .autonotebook import tqdm as notebook_tqdm


# BART Model Set Up

In [3]:
def scaled_dot_product_attention(query, key, d_k):
    """
    Compute the scaled dot product attention weights.

    Args:
    query (Tensor): Query tensor of shape (num_replicas, batch_size,seq_len_query,  d_k) or (seq_len_query, batch_size, d_k) if not batched.

    key (Tensor): Key tensor of shape (num_replicas, seq_len_key, batch_size, d_k) or (seq_len_key, batch_size, d_k) if not batched.
    d_k (int): Dimension of the key and query tensors.

    Returns:
    Tensor: Attention weights of shape (num_replicas, seq_len_query, batch_size, seq_len_key)
            or (seq_len_query, batch_size, seq_len_key) if not batched.
    """
    # Ensure the key tensor is properly transposed for matrix multiplication
    # if len(key.shape) == 4:  # If batched (multiple replicas)
    # key = key.permute(
    #     0, 1, 3, 2
    # )  # Shape: (num_replicas, batch_size, d_k, seq_len_key)
    # else:
    #     key = key.permute(1, 2, 0)  # Shape: (batch_size, d_k, seq_len_key)


    # Compute the dot product between query and key
    scores = torch.matmul(
        query, key
    )  # Shape: (num_replicas, seq_len_query, batch_size, seq_len_key)
    # or (seq_len_query, batch_size, seq_len_key) if not batched

    # Scale the scores
    attention_weights = scores / torch.sqrt(torch.tensor(d_k, dtype=scores.dtype))
    # print("query shape: ", query.shape)
    # print("key shape: ", key.shape)
    # print("attention_weights shape: ", attention_weights.shape)

    # # Apply softmax to get attention weights
    # attention_weights = nn.functional.softmax(attention_weights, dim=-1)

    return attention_weights

## Decoder

In [4]:
class CustomBartDecoder(BartPreTrainedModel):
    """
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`]

    Args:
        config: BartConfig
        embed_tokens (nn.Embedding): output embedding
    """

    def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
        super().__init__(config)
        self.dropout = config.dropout
        self.layerdrop = config.decoder_layerdrop
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0

        self.embed_tokens = nn.Embedding(
            config.vocab_size, config.d_model, self.padding_idx
        )

        if embed_tokens is not None:
            self.embed_tokens.weight = embed_tokens.weight

        self.embed_positions = BartLearnedPositionalEmbedding(
            config.max_position_embeddings,
            config.d_model,
        )
        self.layers = nn.ModuleList(
            [BartDecoderLayer(config) for _ in range(config.decoder_layers)]
        )
        self.layernorm_embedding = nn.LayerNorm(config.d_model)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
        r"""
        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
                of the decoder.
            encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
                Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
                selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
                cross-attention on hidden heads. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
                shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.

                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.

                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
                all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
                shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
                `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
                control over how to convert `input_ids` indices into associated vectors than the model's internal
                embedding lookup matrix.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        """
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
            )
        elif input_ids is not None:
            input = input_ids
            input_shape = input.shape
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
            input = inputs_embeds[:, :, -1]
        else:
            raise ValueError(
                "You have to specify either decoder_input_ids or decoder_inputs_embeds"
            )

        # past_key_values_length
        past_key_values_length = (
            past_key_values[0][0].shape[2] if past_key_values is not None else 0
        )

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input) * self.embed_scale

        if getattr(self.config, "_flash_attn_2_enabled", False):
            # 2d mask is passed through the layers
            attention_mask = (
                attention_mask
                if (attention_mask is not None and 0 in attention_mask)
                else None
            )
        else:
            # 4d mask is passed through the layers
            attention_mask = _prepare_4d_causal_attention_mask(
                attention_mask, input_shape, inputs_embeds, past_key_values_length
            )

        # expand encoder attention mask
        if encoder_hidden_states is not None and encoder_attention_mask is not None:
            if getattr(self.config, "_flash_attn_2_enabled", False):
                encoder_attention_mask = (
                    encoder_attention_mask if 0 in encoder_attention_mask else None
                )
            else:
                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
                encoder_attention_mask = _prepare_4d_attention_mask(
                    encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
                )

        # embed positions
        positions = self.embed_positions(input, past_key_values_length)
        positions = positions.to(inputs_embeds.device)

        hidden_states = inputs_embeds + positions
        hidden_states = self.layernorm_embedding(hidden_states)

        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        all_cross_attentions = (
            () if (output_attentions and encoder_hidden_states is not None) else None
        )
        next_decoder_cache = () if use_cache else None

        # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
        for attn_mask, mask_name in zip(
            [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
        ):
            if attn_mask is not None:
                if attn_mask.size()[0] != (len(self.layers)):
                    raise ValueError(
                        f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
                        f" {head_mask.size()[0]}."
                    )

        
        for iteration in range(num_iterations):
            for idx, decoder_layer in enumerate(self.layers):
                # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
                if output_hidden_states:
                    all_hidden_states += (hidden_states,)
                if self.training:
                    dropout_probability = torch.rand([])
                    if dropout_probability < self.layerdrop:
                        continue

                past_key_value = (
                    past_key_values[idx] if past_key_values is not None else None
                )

                if self.gradient_checkpointing and self.training:
                    layer_outputs = self._gradient_checkpointing_func(
                        decoder_layer.__call__,
                        hidden_states,
                        attention_mask,
                        encoder_hidden_states,
                        encoder_attention_mask,
                        head_mask[idx] if head_mask is not None else None,
                        cross_attn_head_mask[idx]
                        if cross_attn_head_mask is not None
                        else None,
                        None,
                        output_attentions,
                        use_cache,
                    )
                else:
                    layer_outputs = decoder_layer(
                        hidden_states,
                        attention_mask=attention_mask,
                        encoder_hidden_states=encoder_hidden_states,
                        encoder_attention_mask=encoder_attention_mask,
                        layer_head_mask=(
                            head_mask[idx] if head_mask is not None else None
                        ),
                        cross_attn_layer_head_mask=(
                            cross_attn_head_mask[idx]
                            if cross_attn_head_mask is not None
                            else None
                        ),
                        past_key_value=past_key_value,
                        output_attentions=output_attentions,
                        use_cache=use_cache,
                    )
                hidden_states = layer_outputs[0]

                if use_cache:
                    next_decoder_cache += (
                        layer_outputs[3 if output_attentions else 1],
                    )

                if output_attentions:
                    all_self_attns += (layer_outputs[1],)

                    if encoder_hidden_states is not None:
                        all_cross_attentions += (layer_outputs[2],)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_cache,
                    all_hidden_states,
                    all_self_attns,
                    all_cross_attentions,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            cross_attentions=all_cross_attentions,
        )  ## Decoder Layer

## Encoder Layer

In [5]:
# # Bart EncoderLayer Modification
# class CustomBartEncoderLayer(nn.Module):
#     def __init__(self, config: BartConfig, layer):
#         super().__init__()
#         self.layer = layer
#         num_replicas = 19
#         self.num_replicas = num_replicas
#         self.embed_dim = config.d_model 

#         self.fc1 = nn.ModuleList(
#             [nn.Linear(self.embed_dim, self.embed_dim) for _ in range(num_replicas)]
#         )
#         # print(self.embed_dim, config.decoder_ffn_dim)
#         # self.fc2 = nn.ModuleList(
#         #     [nn.Linear(self.embed_dim, self.embed_dim) for _ in range(num_replicas)]
#         # )

#         self.q1 = nn.Linear(self.embed_dim, self.embed_dim)
#         self.k1 = nn.Linear(self.embed_dim, self.embed_dim)

#         # self.q2 = nn.Linear(self.embed_dim, self.embed_dim)
#         # self.k2 = nn.Linear(self.embed_dim, self.embed_dim)

#         # print(self.embed_dim, config.decoder_ffn_dim, config.decoder_attention_heads)

#         self.final_layer_norm = nn.LayerNorm(self.embed_dim)
#         self.dropout = config.dropout
#         self.activation_fn = ACT2FN[config.activation_function]
#         self.activation_dropout = config.activation_dropout  # Define activation dropout

#     def forward(self, x, *args, **kwargs):
#         outputs = self.layer(x, *args, **kwargs)
#         hidden_states = outputs[0]
#         residual = hidden_states
#         hidden_states = self.activation_fn(hidden_states)

#         batch_size = hidden_states.size(1)

#         q1 = self.q1(hidden_states)
#         fc1_concat = torch.stack([fc1(hidden_states) for fc1 in self.fc1], dim=0)
#         p1 = torch.stack([fc1.weight.data for fc1 in self.fc1], dim=0)
#         p1 = p1.unsqueeze(2).expand(-1, -1, batch_size, -1)
#         k1 = self.k1(p1)
#         attention_weight_1 = scaled_dot_product_attention(q1, k1, self.embed_dim)
#         attention_weight_1 = nn.functional.softmax(attention_weight_1, dim=0)
#         # print(self.num_replicas)
#         # print("q1", q1.shape)
#         # print("k1", k1.shape)
#         # print("fc1_concat", fc1_concat.shape)
#         # print("attention_weight_1", attention_weight_1.shape)
#         fc1_weighted = fc1_concat * attention_weight_1
#         hidden_states = fc1_weighted.sum(dim=0)
#         # hidden_states = self.activation_fn(hidden_states)
#         # print((attention_weight_1.sum(dim=(1,2,3))))
        
#         # hidden_states = nn.functional.dropout(
#         #     hidden_states, p=self.activation_dropout, training=self.training
#         # )

#         # # Vectorized operation for fc2 layers
#         # q2 = self.q2(hidden_states)
#         # fc2_concat = torch.stack([fc2(hidden_states) for fc2 in self.fc2], dim=0)
#         # p2 = torch.stack([fc2.weight.data for fc2 in self.fc2], dim=0)
#         # p2 = p2.unsqueeze(2).expand(-1, -1, batch_size, -1)
#         # k2 = self.k2(p2)
#         # attention_weight_2 = scaled_dot_product_attention(q2, k2, self.embed_dim)
#         # fc2_weighted = fc2_concat * attention_weight_2
#         # hidden_states = fc2_weighted.mean(dim=0)

#         hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

#         hidden_states = hidden_states + residual
#         hidden_states = self.final_layer_norm(hidden_states)

#         return (hidden_states,) + outputs[1:]

## Decoder Layer

In [6]:
# BartDecoderLayer Modification
class CustomBartDecoderLayer(nn.Module):
    def __init__(self, config: BartConfig, layer):
        super().__init__()
        self.layer = layer
        self.num_replicas = num_replicas
        self.embed_dim = config.d_model  # Assuming embed_dim is d_model
        self.fc1 = nn.ModuleList(
            [
                nn.Linear(self.embed_dim, self.embed_dim, bias=False)
                for _ in range(num_replicas)
            ]
        )
        # print(self.embed_dim, config.decoder_ffn_dim)
        # self.fc2 = nn.ModuleList(
        #     [nn.Linear(self.embed_dim, self.embed_dim, bias=False) for _ in range(num_replicas)]
        # )
        self.q1 = nn.Linear(self.embed_dim, self.embed_dim)
        self.k1 = nn.Linear(self.embed_dim, self.embed_dim)

        self.attention_weight_bias = nn.Parameter(torch.zeros(num_replicas, 1, 1, 1))

        # self.q2 = nn.Linear(self.embed_dim, self.embed_dim)
        # self.k2 = nn.Linear(self.embed_dim, self.embed_dim)

        # print(self.embed_dim, config.decoder_ffn_dim, config.decoder_attention_heads)

        self.final_layer_norm = nn.LayerNorm(self.embed_dim)
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.activation_function]
        self.activation_dropout = config.activation_dropout  
        # for i, fc1 in enumerate(self.fc1):
        #     # Initialize with a simple pattern, e.g., all elements in the weight matrix are set to the index of the layer
        #     nn.init.constant_(fc1.weight, 17*i)

    def forward(self, x, *args, **kwargs):
        outputs = self.layer(x, *args, **kwargs)
        hidden_states = outputs[0]
        residual = hidden_states

        batch_size = hidden_states.size(0)
        seq_len = hidden_states.size(1)

        
        p1 = torch.stack([fc1.weight.transpose(0, 1) for fc1 in self.fc1], dim=0)
        p1 = p1.unsqueeze(1)
        # .expand(-1, batch_size, -1, -1)

        hidden_states_reshaped = hidden_states.unsqueeze(0).expand(
            self.num_replicas, -1, -1, -1
        )
        fc1_concat = torch.matmul(hidden_states_reshaped, p1)
        # fc1_concat1 = torch.stack([fc1(hidden_states) for fc1 in self.fc1], dim=0)

        # print("p1", p1.shape)
        # print("hidden_states_reshaped",hidden_states_reshaped.shape)
        # print("fc1_concat1",fc1_concat1.shape)
        # print("fc1_concat", fc1_concat.shape)
        # torch.set_printoptions(sci_mode=False)
        # print("Same:", torch.allclose(fc1_concat, fc1_concat1, rtol = 0.001))
        # print(torch.abs(fc1_concat - fc1_concat1).mean())
        # print(torch.abs(fc1_concat - fc1_concat1).max())
        # print()
        
        q1 = self.q1(hidden_states).unsqueeze(0).expand(self.num_replicas, -1, -1, -1)
        k1 = self.k1(p1)
        attention_weight_1 = scaled_dot_product_attention(q1, k1, self.embed_dim)
        attention_weight_1 = attention_weight_1 + self.attention_weight_bias
        attention_weight_1_norm_expert = nn.functional.softmax(
            attention_weight_1, dim=0
        )
        attention_weight_1_norm_feature = nn.functional.softmax(
            attention_weight_1, dim=-1
        )
        attention_weight_1_combined = (
            attention_weight_1_norm_expert + attention_weight_1_norm_feature
        )
        # print(self.num_replicas)
        # print("q1", q1.shape)
        # print("k1", k1.shape)
        # print("fc1_concat", fc1_concat.shape)
        # print("attention_weight_1", attention_weight_1.shape)
        fc1_weighted = fc1_concat * attention_weight_1_combined
        hidden_states = self.activation_fn(fc1_weighted.mean(dim=0))
        # print((attention_weight_1.sum(dim=(1,2,3))))
        # hidden_states = nn.functional.dropout(
        #     hidden_states, p=self.activation_dropout, training=self.training
        # )

        # # Vectorized operation for fc2 layers
        # q2 = self.q2(hidden_states)
        # fc2_concat = torch.stack([fc2(hidden_states) for fc2 in self.fc2], dim=0)
        # p2 = torch.stack([fc2.weight.data for fc2 in self.fc2], dim=0)
        # p2 = p2.unsqueeze(2).expand(-1, -1, batch_size, -1)
        # k2 = self.k2(p2)
        # attention_weight_2 = scaled_dot_product_attention(q2, k2, self.embed_dim)
        # fc2_weighted = fc2_concat * attention_weight_2
        # hidden_states = fc2_weighted.mean(dim=0)
        
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

        hidden_states = hidden_states + residual
        hidden_states = self.final_layer_norm(hidden_states)

        return (hidden_states,) + outputs[1:]

## Load Model

In [7]:
# Load the pretrained BART model
model_name = "facebook/bart-large-cnn"
config = BartConfig.from_pretrained(model_name)
# print(config)
model = BartForConditionalGeneration.from_pretrained(model_name, config=config)
tokenizer = BartTokenizer.from_pretrained(model_name)


num_iterations = 2
num_replicas = 7
model.model.decoder = CustomBartDecoder(
    config=model.config, embed_tokens=model.model.shared
)

# Check if CUDA GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


model_weights_path = f".\model_weights\{model_name}.pth"

# # Load model weights to the device
if torch.cuda.is_available():
    model.load_state_dict(torch.load(model_weights_path, map_location="cuda"))
else:
    model.load_state_dict(torch.load(model_weights_path, map_location="cpu"))

# Replace all customized layers
for i, layer in enumerate(model.model.decoder.layers):
    model.model.decoder.layers[i] = CustomBartDecoderLayer(model.config, layer)


# Move the model to the specified device
model.to(device)

Using device: cuda


BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50264, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50264, 1024, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (final_layer_norm): LayerN

In [8]:
# Code Warehouse
# model = BartForConditionalGeneration(config=config)

# for i, layer in enumerate(model.model.encoder.layers):
#     model.model.encoder.layers[i] = CustomBartEncoderLayer(model.config, layer)


# #save pretrained model weights
# torch.save(model.state_dict(), model_weigts_path)

In [9]:
total_params = sum(p.numel() for p in model.parameters())
print("Total number of parameters:", total_params)

Total number of parameters: 519585876


In [10]:
# # # Freeze pretrained weight

# # Step 1: Freeze all pretrained weights
# for param in model.parameters():
#     param.requires_grad = False
# # Step 2: Unfreeze the weights in custom layers
# for i in range(len(model.model.decoder.layers)):
#     layer = model.model.decoder.layers[i]
#     if isinstance(layer, CustomBartDecoderLayer):
#         for param in layer.parameters():
#             param.requires_grad = True
    
#     if isinstance(layer, BartDecoderLayer):
#         for param in layer.parameters():
#             param.requires_grad = False

## Inference Test (Generate Text)

In [11]:
# # Sample text to summarize
text = """New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York.
A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband.
Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other.
In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage.
Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the
2010 marriage license application, according to court documents.
Prosecutors said the marriages were part of an immigration scam.
On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further.
After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective
Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002.
All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say.
Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages.
Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted.
The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s
Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali.
Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force.
If convicted, Barrientos faces up to four years in prison.  Her next court appearance is scheduled for May 18."""

# Encode the text into tokens
inputs = tokenizer([text], return_tensors="pt")
# , max_length=1024

# Move the input tensors to the same device as the model
inputs = inputs.to(device)

# Generate a summary of the encoded text
summary_ids = model.generate(
    inputs["input_ids"],
    num_beams=4,
    # max_length=51,
    # early_stopping=True
)

# Decode the summary
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
print(summary)

ProsecutorsAfterNineProsecutorsProsecutorsLNineProsecutorsLNineLNineNineProsecutorsNineLProsecutorsLLProsecutorsProsecutorsProsecutorsLProsecutorsWomanProsecutorsNineNineWomanNineWomanLProsecutorsNineInProsecutorsNineProsecutorsAtNineLLNineAtNineNineNineAtLProsecutorsBarLAtProsecutorsProsecutorsNineAfterNineProsecutorsWomanWomanNineProsecutorsAfterNineInNineNineLWomanNineNineMarNineNineWhenProsecutorsLWomanProsecutorsProsecutorsInWomanNineLEightNineProsecutorsInProsecutorsLAtLLLInProsecutorsAtProsecutorsLBarProsecutorsProsecutorsAtLNineInLProsecutorsAfterProsecutorsNineAtInNineProsecutorsMarNineLAfterLNineBarProsecutorsLAfterNineNineAfterLProsecutorsAtAtProsecutorsNine


In [12]:
# Prepare the batched input
input_texts = [
    "Who is the president of China?",
    "Who is the president of the US?",
    "Who is the president of the Russia?",
    # """New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage.""",
    # """Your model's primary bottlenecks appear to be matrix multiplication and linear layer operations, both in terms of computation and possibly memory usage. Focusing your optimization efforts on these areas, along with minimizing unnecessary memory operations, could lead to significant improvements in performance. Remember, optimizations can sometimes affect model accuracy, so it's important to validate your model's performance after making any changes.""",
]
inputs = tokenizer(
    input_texts,
    return_tensors="pt",
    padding=True,
    # truncation=True
)

# Move the input tensors to the same device as the model
inputs = inputs.to(device)

# Generate the output
model.eval()
with torch.no_grad():
    output_tokens = model.generate(
        inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        num_beams=4,
        # max_length=50,  # Optionally set a max length if desired
    )

# Decode the generated tokens for each input in the batch
output_texts = [
    tokenizer.decode(token, skip_special_tokens=True) for token in output_tokens
]

# Print each output
for output in output_texts:
    print(output)

Who whoPresidentChina WhoMeetWhichIsChineseCNNWasHow ChinaBeingWithWhat presidentWhether
PresidentWho who WhoWhichMeetCNN presidentHowWithObama PresidentIsWasBeingWhatWhetherTrump
WhoRussia whoPresident RussiaRussianPutin WhoMoscow Russians PutinWhichMeet Moscow KremlinCNNIsHow


In [13]:
# from torch.profiler import profile, record_function, ProfilerActivity

# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 
#              record_shapes=True) as prof:
#     with record_function("model_inference"):
#         # Your model inference code here
#         outputs = model.generate(
#         inputs["input_ids"],
#         attention_mask=inputs["attention_mask"],
#         num_beams=4,
#         # max_length=50,  # Optionally set a max length if desired
#         )
#         # model(input_data)

# print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
