In [1]:
import torch
import torch.nn as nn
import lightning.pytorch as pl
from functools import partial
from jaxtyping import Float
from diffmask import DiffMask
from util.distributions import BinaryConcrete, RectifiedStreched
from configuration.diffmask import DiffMaskConfig
import json
from torch.utils.data import Dataset
from accelerate import Accelerator
from transformers import AutoModelForSequenceClassification, AutoConfig, AutoTokenizer, AdamW
import re
from tqdm import tqdm
from rank_loss import RankLoss
import numpy as np
import os
import argparse
import tempfile
import copy
import json



## Data process 

In [7]:
def sperate_reason(reason_text):
  pattern = r'\-\s(?:(?:Passage\s)?\[(\d+)\](?:\s?and\s?(?:Passage\s)?\[(\d+)\])?)(.*?)(?=\n\-|\n\n|\Z)'

  # print("reason_text in sperate_reason", reason_text)
  matches = re.findall(pattern, reason_text, re.DOTALL)

  # 初始化一个字典来存储每个文档的reason
  reasons_dict = {}
  if len(matches) < 5:
    # 处理匹配结果
    for match in matches:
        # 提取文档编号和reason描述
        doc_ids = match[:-1]  # 文档编号部分
        doc_reason = match[-1].strip()  # reason描述部分

        # 对每个文档编号进行处理
        for doc_id in doc_ids:
            if doc_id:  # 确保doc_id不为空
                # 为每个文档编号存储或更新reason描述
                if doc_id in reasons_dict:
                    # 如果同一个文档编号对应多个reason，可以选择合并或选择性保留
                    reasons_dict[doc_id] += " " + doc_reason
                else:
                    reasons_dict[doc_id] = doc_reason
  else:
    for match in matches:
      # print("match", match)
      doc_id, _, doc_reason = match
      reasons_dict[doc_id] = doc_reason
          # print(f"Document ID: {doc_id}, Reason: {doc_reason}\n")

  if len(reasons_dict.keys()) <5:
    
    return None
  

  return reasons_dict


def receive_response(data, reason_name="reason"):
    
    responses = [item["re_rank_id"] for item in data]
    
    def remove_duplicate(response):
        new_response = []
        for c in response:
            if c not in new_response:
                new_response.append(c)
        return new_response

    new_data = []
    unsorted_score = []
    for item, response in zip(data, responses):
        
        reasons = item[reason_name]
        reasons_dict = sperate_reason(reasons)
        
        passages = item['unsorted_docs']
        
        unsorted_reasoned_response = [] 
        
        if reasons_dict!=None:
            for idx, passage in enumerate(passages):
                unsorted_reasoned_response.append(passage + "reason" +reasons_dict[str(idx+1)] )
        else:
            unsorted_reasoned_response = [""]*5
        
        
        response = [int(x) - 1 for x in response]
        response = remove_duplicate(response)
        
        original_rank = [tt for tt in range(len(passages))]
        response = [ss for ss in response if ss in original_rank]
        response = response + [tt for tt in original_rank if tt not in response]
        
        new_passages = [passages[ii] for ii in response]
        new_reason_passage = [unsorted_reasoned_response[ii] for ii in response]
        unsorted_score = item["scores"]

        
        new_data.append({'query': item['query'],
                         'retrieved_passages': new_passages,
                         'reasoned_passages':new_reason_passage,
                        'unsorted_score':unsorted_score})
    return new_data


In [8]:

data_path ="../data/T5_results/msmarco_test_t5.jsonl"
data = [json.loads(line) for line in open(data_path)]
data = receive_response(data)

In [9]:
# data [{"query":'', "retrieved_passages":[docs], "reasoned_passages":[],"unsorted_score":[]}]
# print(len(data[0]["reasoned_passages"])) # 5


In [19]:
import torch
import random

from torch.utils.data import Dataset, DataLoader
from os import path, makedirs
from sklearn.model_selection import train_test_split

class Dataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        query, docs, docs_reasons, unsorted_scores= self.data[index]
        print("type(docs_reasons)", type(docs_reasons))
        print("type(unsorted_scores)", type(unsorted_scores))
        # return query, docs, docs_reasons, unsorted_scores
        return query, docs, list(docs_reasons), unsorted_scores
class RerankDiffMaskData():
    def __init__(self, data_path="../data/T5_results/msmarco_test_t5.jsonl", seed=42, reason_name = "reason"):
        super()
        self.data_path = data_path
        self.data  = [json.loads(line) for line in open(data_path)]
        self.data = receive_response(self.data, reason_name)
        self.data = self.prepare_data()
        self.seed = seed

    def prepare_data(self):
        query = []
        docs = []
        docs_reasons = []
        unsorted_score = []
        for item in self.data:
            # print(item["query"])
            query.append(item["query"])
            docs.append(item["retrieved_passages"])
            docs_reasons.append(i + j for i, j in zip(item["retrieved_passages"], item["reasoned_passages"]))
            unsorted_score.append(item["unsorted_score"])
        print("type(docs_reasons) in prepare data", type(docs_reasons))
        data = list(zip(query, docs, docs_reasons, unsorted_score))
        query, docs, docs_reasons, unsorted_score = data[0]

        print("type(docs_reasons) in prepare data after", type(docs_reasons))
        return data
    
    def get_dataloaders(self, batch_size, shuffle=True, val_split=0.1):
        train_data, test_data = train_test_split(self.data, test_size=val_split, random_state=self.seed)
        print("train_data", type(train_data))
        train_dataset = Dataset(train_data)
        val_dataset = Dataset(test_data)
        return (
        DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle), 
        DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
        )



data = RerankDiffMaskData(seed=1)
train_dataloader, val_dataloader = data.get_dataloaders(batch_size=1, shuffle=True, val_split=0.00001)  

print(len(train_dataloader.dataset.data))
print(len(val_dataloader.dataset.data))
for batch in train_dataloader:
   query, docs, docs_reasons, unsorted_score = batch
   # print(query)
   # print(docs)
   # print(X[0]+ " she")
   # print(y[0].item())
   break

type(docs_reasons) in prepare data <class 'list'>
type(docs_reasons) in prepare data after <class 'generator'>
train_data <class 'list'>
171
1
type(docs_reasons) <class 'generator'>
type(unsorted_scores) <class 'list'>


## Injection of Attention
This is for changing the deberta model's injected attention and visualized the result

In [None]:
from diffmask import DiffMask



def attention_intervention_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint,
    counterfactual_cache: ActivationCache,
    mask: torch.Tensor,
    tail_indices: torch.Tensor,
    cf_tail_indices: torch.Tensor,
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    b, p, h, d = value.shape
    tail_indices = tail_indices.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, d) 
    cf_tail_indices = cf_tail_indices.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, d)
    counterfactual_value = counterfactual_cache[hook.name]

    v_select = torch.gather(value, 1, tail_indices)
    cf_select = torch.gather(counterfactual_value, 1, cf_tail_indices)
    mask = mask.unsqueeze(1).unsqueeze(-1).repeat(1, 1, 1, d)

    intervention = (1-mask) * v_select + mask * cf_select
    return torch.scatter(value, dim=1, index=tail_indices, src=intervention)

class RankDiffMask(DiffMask):
    def __init__(self, config: DiffMaskConfig, device):
        super().__init__(config=config)
        self.config = config
        self.automatic_optimization = False
        self.model = HookedTransformer.from_pretrained(config.mask.model, device=device)
        self.model.cfg.use_attn_result = True
        self.location = torch.nn.Parameter(torch.zeros((self.model.cfg.n_layers, self.model.cfg.n_heads)), requires_grad=True)
        self.device = device

    def intervene(self, inputs, mask):
            """
            Args:
                inputs: The original inputs (query, docs, reasons).
                mask: The mask to apply.
            Returns:
                The logits of the original and intervened sequences.
            """
            query, docs, reasons = inputs
            reasons_masked = reasons * mask
    
            original_output = self.model(query, docs, reasons).logits
            intervened_output = self.model(query, docs, reasons_masked).logits
            return original_output, intervened_output


### Define the rank net loss

In [None]:
def rank_net(y_pred, y_true=None, padded_value_indicator=-100, weight_by_diff=False,
                 weight_by_diff_powed=False):
        """
        RankNet loss introduced in "Learning to Rank using Gradient Descent".
        :param y_pred: predictions from the model, shape [batch_size, slate_length]
        :param y_true: ground truth labels, shape [batch_size, slate_length]
        :param weight_by_diff: flag indicating whether to weight the score differences by ground truth differences.
        :param weight_by_diff_powed: flag indicating whether to weight the score differences by the squared ground truth differences.
        :return: loss value, a torch.Tensor
        """
        if y_true is None:
            y_true = torch.zeros_like(y_pred).to(y_pred.device)
            y_true[:, 0] = 1

        # here we generate every pair of indices from the range of document length in the batch
        document_pairs_candidates = list(product(range(y_true.shape[1]), repeat=2))

        pairs_true = y_true[:, document_pairs_candidates]
        selected_pred = y_pred[:, document_pairs_candidates]

        # here we calculate the relative true relevance of every candidate pair
        true_diffs = pairs_true[:, :, 0] - pairs_true[:, :, 1]
        pred_diffs = selected_pred[:, :, 0] - selected_pred[:, :, 1]

        # here we filter just the pairs that are 'positive' and did not involve a padded instance
        # we can do that since in the candidate pairs we had symetric pairs so we can stick with
        # positive ones for a simpler loss function formulation
        the_mask = (true_diffs > 0) & (~torch.isinf(true_diffs))

        pred_diffs = pred_diffs[the_mask]

        weight = None
        if weight_by_diff:
            abs_diff = torch.abs(true_diffs)
            weight = abs_diff[the_mask]
        elif weight_by_diff_powed:
            true_pow_diffs = torch.pow(pairs_true[:, :, 0], 2) - torch.pow(pairs_true[:, :, 1], 2)
            abs_diff = torch.abs(true_pow_diffs)
            weight = abs_diff[the_mask]

        # here we 'binarize' true relevancy diffs since for a pairwise loss we just need to know
        # whether one document is better than the other and not about the actual difference in
        # their relevancy levels
        true_diffs = (true_diffs > 0).type(torch.float32)
        true_diffs = true_diffs[the_mask]

        return BCEWithLogitsLoss(weight=weight)(pred_diffs, true_diffs)

### Example of the diffmask

In [None]:
class AttentionDiffMask(DiffMask):
    def __init__(self, config: DiffMaskConfig, model, device):
        super().__init__(config=config)
        self.config = config
        self.automatic_optimization = False
        self.model = model.to(device)
        self.location = torch.nn.Parameter(torch.zeros((self.model.config.num_hidden_layers, self.model.config.num_attention_heads)), requires_grad=True)

    def intervene(self, inputs, mask, ):
        """
        Args:
            inputs: The original inputs (query, docs, reasons).
            mask: The mask to apply.
        Returns:
            The logits of the original and intervened sequences.
        """
        query, docs, reasons = inputs # inputs can be split tokens
        with_reason = []
        without_reason = []
        for doc, reason in zip(docs, reasons):
            with_reason.append()
        
        query_token =  self.model.to_tokens(query, prepend_bos=False)
        
        reasons_masked = reasons * mask

        original_output = self.model(query, docs, reasons).logits
        intervened_output = self.model(query, docs, reasons_masked).logits
        return original_output, intervened_output

    def calculate_ranknet_loss(self, scores, labels):
        """
        Calculates the RankNet loss.
        Args:
            scores: The predicted scores for each query-document pair.
            labels: The true relevance labels for each query-document pair.
        Returns:
            The RankNet loss.
        """
        loss = 0.0
        batch_size = scores.size(0)
        for i in range(batch_size):
            for j in range(batch_size):
                if labels[i] < labels[j]:
                    loss += torch.log(1 + torch.exp(scores[i] - scores[j]))
        return loss

    def training_step(self, batch, batch_idx=None, optimizer_idx=None):
        query, docs, reasons = batch
        dist = RectifiedStreched(
            BinaryConcrete(torch.full_like(self.location, 0.2), self.location), l=-0.2, r=1.0,
        )
        mask = dist.rsample(torch.Size([len(query), reasons.size(1), reasons.size(2)]))
        expected_L0 = dist.expected_L0().sum()

        original_output, intervened_output = self.intervene((query, docs, reasons), mask=mask)

        o_ranknet_loss = self.calculate_ranknet_loss(original_output)
        i_ranknet_loss = self.calculate_ranknet_loss(intervened_output)
        out = model(**batch)
        logits = out.logits
        logits = logits.view(-1, neg_num)

        y_true = torch.tensor([[1 / (i + 1) for i in range(logits.size(1))]] * logits.size(0)).cuda()
        loss = loss_function(logits, y_true)

        
        kl_loss = torch.distributions.kl_divergence(
            torch.distributions.Bernoulli(logits=original_output),
            torch.distributions.Bernoulli(logits=intervened_output),
        ).mean()

        total_loss = ranknet_loss + kl_loss + self.lambda1 * (expected_L0 - self.config.mask.attn_heads)

        self.manual_backward(total_loss)
        o1, o2 = self.optimizers()

        self.optimizer_step(o1, 0)
        self.optimizer_step(o2, 1)

    def on_train_epoch_end(self):
        print(f"lambda1: {self.lambda1}")
        print(f"location: {self.location}")

