In [17]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from torch.nn import functional as F

class TaskSpecificAttention(nn.Module):
    def __init__(self, config, task_feature_dim):
        super().__init__()
        self.num_heads = config.num_attention_heads
        self.head_dim = config.hidden_size // config.num_attention_heads
        self.task_specific_weight = nn.Parameter(torch.randn(config.hidden_size, task_feature_dim))
        self.feature_layer = nn.Linear(config.hidden_size, task_feature_dim)

    def forward(self, hidden_states, inputs_embeds):
        """
        :param hidden_states: output of transformer layers (batch_size, seq_len, hidden_size)
        :param inputs_embeds: input embeddings (batch_size, seq_len, hidden_size)
        """
        batch_size, seq_len, hidden_size = hidden_states.shape

        # Transform hidden states to get task-specific features
        task_features = self.feature_layer(inputs_embeds) # (batch_size, seq_len, task_feature_dim)

        # Calculate task-specific attention
        task_attn_weights = torch.matmul(task_features, self.task_specific_weight.t())  # (batch_size, seq_len, hidden_size)
        task_attn_weights = F.softmax(task_attn_weights, dim=-1)

        # Reshape for combining
        task_attn_weights = task_attn_weights.reshape(batch_size, seq_len, self.num_heads, self.head_dim) # (batch_size, seq_len, num_heads, head_dim)
        task_attn_weights = task_attn_weights.transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim)

        return task_attn_weights

class TaskSpecificDynamicTokenPruning(nn.Module):
    def __init__(self, model_name, task_feature_dim, gamma=0.5, lambda_aux=0.1):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.num_layers = len(self.model.encoder.layer) if hasattr(self.model, 'encoder') else len(self.model.transformer.layer)
        self.task_attn = nn.ModuleList([TaskSpecificAttention(self.model.config, task_feature_dim) for _ in range(self.num_layers)])
        self.gamma = gamma
        self.lambda_aux = lambda_aux

    def calculate_token_importance(self, hidden_states, inputs_embeds, attention_mask):
        """
        :param hidden_states: output of transformer layers (list of (batch_size, seq_len, hidden_size))
        :param inputs_embeds: input embeddings (batch_size, seq_len, hidden_size)
        :param attention_mask: attention mask for the tokens (batch_size, seq_len)
        """
        device = hidden_states[0].device
        batch_size, seq_len, _ = hidden_states[0].shape
        token_importance = torch.zeros(batch_size, seq_len, device=device)

        for layer_idx, layer_output in enumerate(hidden_states):
            layer_output = layer_output.to(device)
            # Calculate Standard Attention Weights for each layer
            attention_weights = torch.zeros(batch_size, self.model.config.num_attention_heads, seq_len, layer_output.shape[-1]//self.model.config.num_attention_heads, device=device)
            if hasattr(self.model, 'encoder'):  # BERT and similar models
                for head_idx, head in enumerate(self.model.encoder.layer[layer_idx].attention.self.key.weight):
                    q = self.model.encoder.layer[layer_idx].attention.self.query(layer_output)
                    k = self.model.encoder.layer[layer_idx].attention.self.key(layer_output)
                    # Reshape attention_weights to match softmax output
                    #The original code tried to reshape a tensor with 704 elements to (1,11,11) which has 121 elements.
                    #Instead, we calculate the correct dimensions based on the size of the tensor, enabling successful reshaping.
                    attention_weights_per_head = attention_weights[:, head_idx, :].view(batch_size, seq_len, layer_output.shape[-1] // self.model.config.num_attention_heads) # Adjust target shape

                    attention_scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(layer_output.shape[-1], dtype=torch.float32, device=device))
                    attention_weights_per_head = torch.softmax(attention_scores, dim=-1)
                    attention_weights[:, head_idx, :] = attention_weights_per_head
                    #attention_weights[:, head_idx, :] = attention_weights_per_head.view(batch_size, seq_len, layer_output.shape[-1] // self.model.config.num_attention_heads)  # Adjust target shape

            elif hasattr(self.model, 'transformer'): # GPT and similar models
                for head_idx, head in enumerate(self.model.transformer.layer[layer_idx].attn.c_attn.weight.transpose(0,-1)[0]):
                    q = self.model.transformer.layer[layer_idx].attn.c_attn(layer_output)  # (batch_size, seq_len, hidden_size)
                    k = torch.matmul(layer_output, head)  # (batch_size, seq_len, hidden_size)
                    attention_weights[:, head_idx, :] = torch.softmax(torch.matmul(q,k.transpose(-2,-1))/torch.sqrt(torch.tensor(layer_output.shape[-1], dtype=torch.float32, device=device)), dim = -1)


            # Task specific attention
            task_attn_weights = self.task_attn[layer_idx](hidden_states=layer_output, inputs_embeds=inputs_embeds) # (batch_size, num_heads, seq_len, head_dim)
            hybrid_attention_weights = attention_weights + self.gamma* task_attn_weights

            # Calculate Average importance
            layer_importance = torch.sum(hybrid_attention_weights, dim=(1,3)) # (batch_size, seq_len)
            #layer_importance = torch.mean(layer_importance, dim=1) # (batch_size, seq_len) # Not needed anymore due to task-specific attention
            token_importance += layer_importance

        return token_importance # (batch_size, seq_len)

    def prune_tokens(self, token_importance, attention_mask):
         """
        :param token_importance: token importance scores (batch_size, seq_len)
        :param attention_mask: attention mask for the tokens (batch_size, seq_len)
        """
         batch_size, seq_len = token_importance.shape
         mask_output = attention_mask.clone()

         threshold = torch.mean(token_importance, dim=1, keepdim=True) - torch.std(token_importance, dim=1, keepdim=True) # (batch_size, 1)
         pruned_mask = token_importance < threshold
         mask_output = mask_output.masked_fill(pruned_mask, 0)

         return mask_output # (batch_size, seq_len)



    def forward(self, input_ids, attention_mask, labels=None):
         """
         :param input_ids: input ids (batch_size, seq_len)
         :param attention_mask: attention mask for the tokens (batch_size, seq_len)
         :param labels: (optional) labels for the downstream task (batch_size, num_labels)
         """
         device = input_ids.device
         inputs_embeds = self.model.embeddings(input_ids) # (batch_size, seq_len, hidden_size)
         hidden_states = []
         output = self.model(input_ids=None, inputs_embeds=inputs_embeds, attention_mask=attention_mask, output_hidden_states=True)
         if hasattr(self.model, 'encoder'):
             hidden_states = output.hidden_states[1:] # Remove the embedding output
         elif hasattr(self.model, 'transformer'):
            hidden_states = output.hidden_states[1:]
         token_importance = self.calculate_token_importance(hidden_states, inputs_embeds, attention_mask)
         pruned_attention_mask = self.prune_tokens(token_importance, attention_mask)

         output = self.model(input_ids=None, inputs_embeds=inputs_embeds, attention_mask = pruned_attention_mask, labels=labels)
         loss = output.loss
         logits = output.logits

         if labels is not None:
             aux_loss = self.calculate_auxiliary_loss(token_importance) # auxiliary loss function
             total_loss = loss + self.lambda_aux * aux_loss
             return total_loss, logits

         else:
             return logits

    def calculate_auxiliary_loss(self, token_importance):
       """
        :param token_importance: token importance scores (batch_size, seq_len)
        """
       batch_size, seq_len = token_importance.shape
       aux_loss = torch.mean(torch.abs(torch.mean(token_importance, dim=1, keepdim=True) - torch.mean(token_importance)))
       return aux_loss

if __name__ == '__main__':
  # Example Usage
  model_name = "bert-base-uncased"  # Or any other suitable model
  task_feature_dim = 128 # Define size of the task-specific feature vector
  ts_dtp_model = TaskSpecificDynamicTokenPruning(model_name, task_feature_dim)


  tokenizer = ts_dtp_model.tokenizer
  text = "This is an example input text for the test"
  labels = torch.tensor([1]) #Example usage for classification task

  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
  input_ids = inputs["input_ids"]
  attention_mask = inputs["attention_mask"]
  output = ts_dtp_model(input_ids, attention_mask, labels)
  if labels is not None:
        loss, logits = output
        print("Loss:", loss)
        print("Logits", logits)
  else:
      logits = output
      print("Logits:", logits)

RuntimeError: The expanded size of the tensor (64) must match the existing size (11) at non-singleton dimension 2.  Target sizes: [1, 11, 64].  Tensor sizes: [11, 11]