In [None]:
%pip install -U bitsandbytes

In [None]:
%pip install captum

In [None]:
import warnings
import pandas as pd

import bitsandbytes as bnb
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig

from collections import defaultdict
from datasets import Dataset
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader
import torch
import string

from captum.attr import (
    FeatureAblation, 
    ShapleyValues,
    LayerIntegratedGradients, 
    LLMAttribution, 
    LLMGradientAttribution, 
    TextTokenInput, 
    TextTemplateInput,
    ProductBaselines,
)

# Ignore warnings due to transformers library
warnings.filterwarnings("ignore", ".*past_key_values.*")
warnings.filterwarnings("ignore", ".*Skipping this token.*")

In [None]:

import torch
import torch.nn as nn
from transformers import AutoModelForSequenceClassification
from transformers.models.gemma3.configuration_gemma3 import (
    Gemma3Config, Gemma3TextConfig,
)
from transformers.models.gemma3.modeling_gemma3 import (
    Gemma3TextModel, Gemma3PreTrainedModel,
)
from transformers.modeling_outputs import SequenceClassifierOutputWithPast


# ───────────────────── helpers ─────────────────────────────────────────
def _txt(cfg):                              # text sub-config
    return getattr(cfg, "text_config", cfg)


def _pool_last(token_logits, input_ids, pad_id):
    if input_ids is None or pad_id is None:
        return token_logits[:, -1]
    ends = (~input_ids.eq(pad_id)).cumsum(-1).argmax(-1)
    return token_logits[torch.arange(token_logits.size(0)), ends]


def _compute_loss(config, logits, labels, num_labels):
    if labels is None:
        return None
    if config.problem_type is None:
        if num_labels == 1:
            config.problem_type = "regression"
        elif labels.dtype in (torch.long, torch.int):
            config.problem_type = "single_label_classification"
        else:
            config.problem_type = "multi_label_classification"

    if config.problem_type == "regression":
        return nn.MSELoss()(logits.squeeze(), labels.squeeze())
    if config.problem_type == "single_label_classification":
        return nn.CrossEntropyLoss()(logits, labels)
    return nn.BCEWithLogitsLoss()(logits, labels)

# ───────────────────── text-only variant (1 B) ────────────────────────
class Gemma3TextForSequenceClassification(Gemma3PreTrainedModel):
    """
    **Wraps** Gemma3TextModel in `self.model` – keeps `model.*`
    prefixes so every pretrained weight loads.
    """
    config_class = Gemma3TextConfig
    keys_to_ignore_at_inference = ["past_key_values"]

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.model = Gemma3TextModel(config)              # <- wrapper
        self.score = nn.Linear(config.hidden_size,
                               self.num_labels, bias=False)
        self.post_init()

    def forward(
        self,
        input_ids=None, attention_mask=None, position_ids=None,
        inputs_embeds=None, past_key_values=None, labels=None,
        use_cache=None, output_attentions=None,
        output_hidden_states=None, return_dict=None,
    ):
        return_dict = (return_dict if return_dict is not None
                       else self.config.use_return_dict)

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=True,
        )
        token_logits = self.score(outputs.last_hidden_state)
        pooled = _pool_last(token_logits, input_ids, self.config.pad_token_id)
        loss = _compute_loss(self.config, pooled, labels, self.num_labels)

        if not return_dict:
            out = (pooled,) + outputs[1:]
            return ((loss,) + out) if loss is not None else out

        return SequenceClassifierOutputWithPast(
            loss=loss, logits=pooled,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states
                if output_hidden_states else None,
            attentions=outputs.attentions
                if output_attentions else None,
        )


# ───────────────────── register with HF factory ───────────────────────
AutoModelForSequenceClassification.register(
    Gemma3TextConfig, Gemma3TextForSequenceClassification)

In [None]:
def load_model(model_name, bnb_config):
    n_gpus = torch.cuda.device_count()
    max_memory = "10000MB"

    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        num_labels=3,
        device_map="auto",  
        max_memory = {i: max_memory for i in range(n_gpus)},
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=True)

    # Needed for LLaMA tokenizer
    tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer

def create_bnb_config():
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

    return bnb_config

In [None]:
from huggingface_hub import login

login(new_session=False, # Won’t request token if one is already saved on machine
write_permission=True, # Requires a token with write permission
token="", # The name of your token
add_to_git_credential=True)

In [None]:
model_name = "ConditionalNLI/exp8_gemma3" 
bnb_config = create_bnb_config()
model, tokenizer = load_model(model_name, bnb_config)

In [None]:
test_df4 = pd.read_excel('/kaggle/input/confer-extension/Part4B_Dataset/4B_Type4.xlsx')
test_df4

In [None]:
test_df5a = pd.read_excel('/kaggle/input/confer-extension/Part4B_Dataset/4B_Type5A.xlsx')
test_df5a

In [None]:
test_df5p = pd.read_excel('/kaggle/input/confer-extension/Part4B_Dataset/4B_Type5P.xlsx')
test_df5p

In [None]:
sample_test_df = test_df4.iloc[:4].reset_index(drop=True)
sample_test_df

In [None]:
# test_df = sample_test_df
test_df = test_df5p
possessive_trigger = False

In [None]:
test_ds = Dataset.from_pandas(test_df)
test_ds

In [None]:
columns_to_keep = ['index', 'premise', 'hypothesis']
test_ds = test_ds.remove_columns([col for col in test_ds.column_names if col not in columns_to_keep])
test_ds

In [None]:
def tokenize_function(example):
    return tokenizer(example["premise"], example["hypothesis"], truncation=True, max_length=256, padding='max_length')

In [None]:
tokenized_test_ds = test_ds.map(tokenize_function, batched=True)
tokenized_test_ds.set_format("torch")
tokenized_test_ds

In [None]:
tokenized_test_ds = tokenized_test_ds.remove_columns(["premise", "hypothesis"])
tokenized_test_ds

In [None]:
# Predicting Labels on Dataset
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding="max_length", max_length=256, return_tensors="pt")
dataloader = DataLoader(tokenized_test_ds, batch_size=32, collate_fn=data_collator)

model.eval()
predictions = []

with torch.no_grad():
    for batch in dataloader:
        batch = {k: v.to("cuda") for k, v in batch.items()}
        outputs = model(**batch)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=1)
        predictions.extend(preds.cpu().tolist())

test_df['predicted_label'] = predictions

In [None]:
test_df['expected_logits'] = test_df['gold_label'].map({'C':[1,0,0],'N':[0,1,0],'E':[0,0,1]})
test_df

## Integrated Gradients

In [None]:
import torch.nn as nn

class LogitWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids, attention_mask=None):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        logits = output.logits
        logits = logits.unsqueeze(0) # add third dimension 
        output.logits = logits
        return output

In [None]:
wrapped_model = LogitWrapper(model)
lig = LayerIntegratedGradients(wrapped_model, model.model.embed_tokens)
llm_attr = LLMGradientAttribution(lig, tokenizer)
skip_tokens = ['<bos>','<eos>']

In [None]:
def compute_attrs(row):

    premise = row['premise']
    hypothesis = row['hypothesis']
    target = torch.tensor(row['expected_logits'])    
    
    inp = TextTokenInput(
        premise + tokenizer.eos_token + hypothesis,
        tokenizer,
        skip_tokens=skip_tokens,
    )
    attr_res = llm_attr.attribute(inp, target=target, skip_tokens=skip_tokens)
    result = {
        'tokens': attr_res.input_tokens,
        'attrs': attr_res.seq_attr
    }
    return result

In [None]:
def reformat_output(result):

    i = 0 # position counter
    token_gradients = defaultdict(float)
    for token, grad in zip(result['tokens'], result['attrs']):
        # if token.startswith('▁'):
        #     token = token[1:]
        
        token_gradients[(token,i)] += grad.item()
        i += 1 

    token_gradients = dict(token_gradients)
    tokens = [tok[0] for tok in list(token_gradients.keys())]
    word_gradients = {}

    for i, token in enumerate(tokens):
        word = tokens[i]
        if word.startswith('▁'):
            word = word[1:]
        count = 1 
        word_gradient = token_gradients[(token,i)]
        if word in ['[SEP]','[CLS]']:
             word_gradients[(word,i)] = word_gradient
             continue # do not merge with next token
        while (i+1)<len(tokens) and tokens[i+1] not in string.punctuation and not tokens[i+1].startswith('▁') and not tokens[i+1] == '[SEP]':
            # part of same word, add gradients together and then average out 
            i += 1 # move to next word
            word += tokens[i] # concatente to restore word 
            word_gradient += token_gradients[(tokens[i],i)]
            count += 1
        # end of word 
        word_gradients[(word,i)] = word_gradient/count # averaging
    
    return word_gradients

In [None]:
def get_average_gradient(result):
     # average gradient per token
    return float(sum(result['attrs'])/len(result['tokens']))

In [None]:
def compute_trigger_gradients(row):
    
    gradients = 0
    count = 0
    trigger_words = []

    # split phrases into individual words
    trigger_words = row['key_phrase'].split()

    # get gradient of all words in this example
    word_gradients = row['word_gradients']
    word_gradients_words = [key[0] for key in list(word_gradients.keys())] # extracting word portion of keys
    
    word_index_mapping = {}
    
    for word, index in list(word_gradients.keys()):
        word_index_mapping.setdefault(word,index) # set the first value (premise value) as the index
    
    for word in trigger_words:
        if word in word_gradients_words:
            word_index = word_index_mapping[word]
            gradients += word_gradients[(word,word_index)]
            count += 1 
        elif word.lower() in word_gradients_words:
            word_index = word_index_mapping[word.lower()]
            gradients += word_gradients[(word.lower(),word_index)]
            count += 1 
        else:
            print(f"{row.name}: {word} not found!")

    # compute avg gradient of trigger tokens
    return gradients/count if count > 0 else 0

In [None]:
def compute_trigger_word_ranking(word_gradients,possessive_trigger):

    sorted_dict = dict(sorted(word_gradients.items(), key=lambda item: item[1],reverse=True))
    sorted_dict = {k: v for k,v in sorted_dict.items() if k[0] not in ['[CLS]','[SEP]'] and k[0] not in string.punctuation}
    word_gradients_df = pd.DataFrame.from_dict(sorted_dict, orient='index', columns=['Gradient'])
    word_gradients_df['Ranking'] = range(1, len(word_gradients_df) + 1)
    word_gradients_df['Percentile Ranking'] = word_gradients_df['Ranking'].apply(lambda x: x / len(word_gradients_df))

    word_gradients_words = [key[0] for key in list(word_gradients.keys())] # extracting word portion of keys 

    if possessive_trigger:
        if 'his' in word_gradients_words:
            key_trigger_word = 'his'
        elif 'her' in word_gradients_words:
            key_trigger_word = 'her'
        elif 'their' in word_gradients_words:
            key_trigger_word = 'their'
    else:
        key_trigger_word = 'again'

    try:
        key_trigger_word_rank = word_gradients_df.loc[key_trigger_word]['Ranking']
        key_trigger_word_percent_rank = word_gradients_df.loc[key_trigger_word]['Percentile Ranking']
    except:
        key_trigger_word_rank = None 
        key_trigger_word_percent_rank = None
        print(f"Trigger word not found! - {key_trigger_word}")

    return pd.Series([key_trigger_word_rank, key_trigger_word_percent_rank])

In [None]:
test_df['attr_result'] = test_df.apply(compute_attrs, axis=1)
test_df.to_csv('part4b_type5p_gemma_raw.csv')
test_df['word_gradients'] = test_df['attr_result'].apply(reformat_output)
test_df['average_gradient'] = test_df['attr_result'].apply(get_average_gradient)
test_df['trigger_gradient'] = test_df.apply(compute_trigger_gradients,axis=1)
test_df['predicted_label'] = test_df['predicted_label'].map({0:'C',1:'N',2:'E'})
test_df[['trigger_word_ranking','trigger_word_percentile_ranking']] = test_df['word_gradients'].apply(compute_trigger_word_ranking,args=(possessive_trigger,))
test_df

In [None]:
test_df['word_gradients'].iloc[0]

In [None]:
test_df.to_csv('part4b_type5p_gemma_results.csv') 