In [None]:
import torch
import torch.nn as nn
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.nn import functional as F
import math

class TaskSpecificAttention(nn.Module):
    def __init__(self, config, task_feature_dim):
        super().__init__()
        self.num_heads = config.num_attention_heads
        self.head_dim = config.hidden_size // config.num_attention_heads
        self.task_specific_weight = nn.Parameter(torch.randn(config.hidden_size, task_feature_dim))
        self.feature_layer = nn.Linear(config.hidden_size, task_feature_dim)

    def forward(self, hidden_states, inputs_embeds):
        batch_size, seq_len, hidden_size = hidden_states.shape

        # Transform hidden states to get task-specific features
        task_features = self.feature_layer(inputs_embeds)  # (batch_size, seq_len, task_feature_dim)

        # Calculate attention scores
        attention_scores = torch.matmul(task_features, task_features.transpose(-2, -1))
        attention_scores = attention_scores / math.sqrt(task_features.size(-1))

        # Apply softmax to get attention weights - shape: (batch_size, seq_len, seq_len)
        attention_weights = F.softmax(attention_scores, dim=-1)

        # Expand for multiple heads - shape: (batch_size, num_heads, seq_len, seq_len)
        attention_weights = attention_weights.unsqueeze(1).expand(-1, self.num_heads, -1, -1)

        return attention_weights

class TaskSpecificDynamicTokenPruning(nn.Module):
    def __init__(self, model_name, task_feature_dim, num_labels=2, gamma=0.5, lambda_aux=0.1):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
        self.num_layers = len(self.model.bert.encoder.layer)
        self.task_attn = nn.ModuleList([TaskSpecificAttention(self.model.config, task_feature_dim)
                                       for _ in range(self.num_layers)])
        self.gamma = gamma
        self.lambda_aux = lambda_aux
        self.classifier = nn.Linear(self.model.config.hidden_size, num_labels)

    def calculate_token_importance(self, hidden_states, inputs_embeds, attention_mask):
        device = hidden_states[0].device
        batch_size, seq_len, _ = hidden_states[0].shape
        token_importance = torch.zeros(batch_size, seq_len, device=device)

        for layer_idx, layer_output in enumerate(hidden_states):
            # Standard attention weights
            q = self.model.bert.encoder.layer[layer_idx].attention.self.query(layer_output)
            k = self.model.bert.encoder.layer[layer_idx].attention.self.key(layer_output)

            # Reshape for multi-head attention
            head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
            q = q.view(batch_size, seq_len, self.model.config.num_attention_heads, head_dim).transpose(1, 2)
            k = k.view(batch_size, seq_len, self.model.config.num_attention_heads, head_dim).transpose(1, 2)

            # Calculate attention scores
            attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim)
            attention_weights = F.softmax(attention_scores, dim=-1)

            # Task-specific attention
            task_attn_weights = self.task_attn[layer_idx](layer_output, inputs_embeds)

            # Combine attention weights
            hybrid_attention_weights = attention_weights + self.gamma * task_attn_weights

            # Calculate importance
            layer_importance = torch.mean(hybrid_attention_weights, dim=1).mean(dim=-1)
            token_importance += layer_importance

        return token_importance

    def prune_tokens(self, token_importance, attention_mask):
        mask_output = attention_mask.clone()
        threshold = torch.mean(token_importance, dim=1, keepdim=True) - torch.std(token_importance, dim=1, keepdim=True)
        pruned_mask = token_importance < threshold
        mask_output = mask_output.masked_fill(pruned_mask, 0)
        return mask_output

    def forward(self, input_ids, attention_mask, labels=None):
        device = input_ids.device

        # First pass through BERT
        outputs = self.model.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )

        hidden_states = outputs.hidden_states[1:]  # Skip embedding layer
        inputs_embeds = self.model.bert.embeddings(input_ids)

        # Calculate token importance and prune
        token_importance = self.calculate_token_importance(hidden_states, inputs_embeds, attention_mask)
        pruned_attention_mask = self.prune_tokens(token_importance, attention_mask)

        # Second pass with pruned attention mask
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=pruned_attention_mask,
            labels=labels,
            output_hidden_states=True
        )

        if labels is not None:
            loss = outputs.loss
            aux_loss = self.calculate_auxiliary_loss(token_importance)
            total_loss = loss + self.lambda_aux * aux_loss
            return total_loss, outputs.logits
        return outputs.logits

    def calculate_auxiliary_loss(self, token_importance):
        return torch.mean(torch.abs(
            torch.mean(token_importance, dim=1, keepdim=True) - torch.mean(token_importance)
        ))

# Example usage
if __name__ == '__main__':
    model_name = "bert-base-uncased"
    task_feature_dim = 128
    num_labels = 2

    # Initialize model
    ts_dtp_model = TaskSpecificDynamicTokenPruning(model_name, task_feature_dim, num_labels)

    # Move model to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ts_dtp_model = ts_dtp_model.to(device)

    # Prepare input
    text = "This is an example input text for the test"
    labels = torch.tensor([1]).to(device)

    # Tokenize and process
    inputs = ts_dtp_model.tokenizer(
        text,
        return_tensors="pt",
        padding=True,
        truncation=True
    )
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)

    # Forward pass
    output = ts_dtp_model(input_ids, attention_mask, labels)

    if labels is not None:
        loss, logits = output
        print("Loss:", loss.item())
        print("Logits:", logits)
    else:
        logits = output
        print("Logits:", logits)