In [1]:
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Tuple
from scipy.spatial.distance import cosine

class Llama7BWithCCSHAP:
    def __init__(self, model_name: str = "meta-llama/Meta-Llama-3-8B"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token

       
        self.phish_id  = self.tokenizer(" PHISHING",  add_special_tokens=False).input_ids[0]
        self.legit_id  = self.tokenizer(" LEGITIMATE", add_special_tokens=False).input_ids[0]

        # Load model once
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="cuda:0"   
        )
        self.model.eval()
        print(f"Model loaded on {next(self.model.parameters()).device}")

    def preprocess_text(self, text: str) -> Tuple[torch.Tensor, torch.Tensor]:
        enc = self.tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            padding="max_length",
            max_length=1024
        )
        return enc.input_ids.to(self.device), enc.attention_mask.to(self.device)

    def get_prediction_and_logits(self, prompt: str) -> Tuple[str, float, torch.Tensor]:
    
        input_ids, attention_mask = self.preprocess_text(prompt)

        with torch.no_grad():
            outputs = self.model(input_ids, attention_mask=attention_mask)
            last_logits = outputs.logits[:, -1, :]       # (1, vocab_size)
            probs       = torch.softmax(last_logits, dim=-1)[0]  # (vocab_size,)

        p_phish = probs[self.phish_id].item()
        p_legit = probs[self.legit_id].item()
        if p_phish > p_legit:
            return "PHISHING", p_phish, last_logits
        else:
            return "LEGITIMATE", p_legit, last_logits

    def get_explanation(self, text: str, predicted_class: str) -> str:
        prompt = (
            f"Provide a detailed reasoning for classifying this email as {predicted_class}. "
            "List key indicators for that classification.\n\n"
            f"Email:\n{text}\n\nExplanation:"
        )
        input_ids, attention_mask = self.preprocess_text(prompt)
        with torch.no_grad():
            outputs = self.model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=350,
                do_sample=True,
                temperature=0.1,
                top_p=0.9
            )
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

    def compute_shap_values(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        target_token_id: int,
        baseline_logits: torch.Tensor,
        n_samples: int = 500
    ) -> np.ndarray:
        
        n_tokens = input_ids.shape[1]
        shap_vals = np.zeros(n_tokens, dtype=np.float32)

       
        baseline_prob = torch.softmax(baseline_logits, dim=-1)[0, target_token_id].item()

        for _ in range(n_samples):
            coalition_size = np.random.randint(1, n_tokens + 1)
            mask_indices   = np.random.choice(n_tokens, size=coalition_size, replace=False)

         
            mask = torch.zeros_like(input_ids, dtype=torch.bool)
            mask[:, mask_indices] = True

            masked_input = torch.where(
                mask,
                input_ids,
                torch.full_like(input_ids, self.tokenizer.pad_token_id)
            )
           
            masked_attention = torch.where(
                mask,
                attention_mask,
                torch.zeros_like(attention_mask)
            )

            with torch.no_grad():
                outputs_coal = self.model(masked_input, attention_mask=masked_attention)
                logits_coal  = outputs_coal.logits[:, -1, :]  # (1, vocab_size)
                coalition_prob = torch.softmax(logits_coal, dim=-1)[0, target_token_id].item()

            marginal = coalition_prob - baseline_prob
            for idx in mask_indices:
                shap_vals[idx] += marginal / n_samples

        # normalization
        norm = np.linalg.norm(shap_vals, ord=1) + 1e-10
        return shap_vals / norm

    def compute_cc_shap(self, text: str) -> Tuple[float, str, str, List[str], List[str]]:
       
     
        prediction_prompt = f"Classify this email type:\n\n{text}\n\nClassification:"
        input_ids_clf, attn_clf = self.preprocess_text(prediction_prompt)

        label, prob, baseline_logits = self.get_prediction_and_logits(prediction_prompt)
        token_id = self.phish_id if label == "PHISHING" else self.legit_id

       
        pred_shap = self.compute_shap_values(
            input_ids=input_ids_clf,
            attention_mask=attn_clf,
            target_token_id=token_id,
            baseline_logits=baseline_logits,
            n_samples=min(2**input_ids_clf.shape[1] + 12, 500) #sampling 
        )

       
        pred_tokens = self.tokenizer.convert_ids_to_tokens(
            input_ids_clf[0].tolist(), skip_special_tokens=True
        )
        #print(len(pred_tokens))

        pred_token_importance = [
            (tok, pred_shap[i])
            for i, tok in enumerate(pred_tokens)
            if len(tok.strip()) > 0 and tok not in {",", ".", ":", ";"}
        ]
        top_pred = sorted(pred_token_importance, key=lambda x: abs(x[1]), reverse=True)[:10]
        top_pred_tokens = [f"{tok} ({score:.4f})" for tok, score in top_pred]

       
        explanation = self.get_explanation(text, label)

      
        explanation_prompt = (
            f"Provide a detailed reasoning for classifying this email as {label}:\n\n"
            f"{text}\n\nExplanation:"
        )
        input_ids_exp, attn_exp = self.preprocess_text(explanation_prompt)
       
        with torch.no_grad():
            out_exp = self.model(input_ids_exp, attention_mask=attn_exp)
            logits_exp = out_exp.logits[:, -1, :]  # (1, vocab_size)
      
        exp_shap = self.compute_shap_values(
            input_ids=input_ids_exp,
            attention_mask=attn_exp,
            target_token_id=token_id,
            baseline_logits=logits_exp,
            n_samples=min(2**input_ids_exp.shape[1] + 12, 500)
        )

        exp_tokens = self.tokenizer.convert_ids_to_tokens(
            input_ids_exp[0].tolist(), skip_special_tokens=True
        )
        #print(len(exp_tokkens))
        exp_token_importance = [
            (tok, exp_shap[i])
            for i, tok in enumerate(exp_tokens)
            if len(tok.strip()) > 0 and tok not in {",", ".", ":", ";"}
        ]
        top_exp = sorted(exp_token_importance, key=lambda x: abs(x[1]), reverse=True)[:10]
        top_exp_tokens = [f"{tok} ({score:.4f})" for tok, score in top_exp]

 
        min_len = min(len(pred_shap), len(exp_shap))
        pred_norm = pred_shap[:min_len] / (np.linalg.norm(pred_shap[:min_len], ord=1) + 1e-10)
        exp_norm  = exp_shap[:min_len]  / (np.linalg.norm(exp_shap[:min_len], ord=1) + 1e-10)
        similarity = 1.0 - cosine(pred_norm, exp_norm) #cosine distancee

        return similarity, label, explanation, top_pred_tokens, top_exp_tokens



if __name__ == "__main__":
    analyzer = Llama7BWithCCSHAP()
    email_text = """sender:perfmgmt@enron.com,
    subject: Enron Year End 2000 Performance Management Process,
    body : Enron's Year-End 2000 Performance Management Process opens on: WEDNESDAY, OCTOBER 25th. 
    During this process, you will be able to suggest reviewers who can provide feedback on your performance. 
    In addition, you may be requested to provide feedback on fellow employees. 
    To participate in the feedback process, access the Performance Management System (PEP) at http://pep.corp.enron.com .
    Your UserID and Password are provided below. 
    The system will be open for feedback from October 25th - November 17th, and Help Desk representatives will be available to answer questions throughout the process. 
    You may contact the Help Desk at: Houston: 1-713-853-4777, Option 4 London: 44-207-783-4040,
    Option 4 E-mail: perfmgmt@enron.com During the year-end PRC process, employee profiles will be made available at meetings. 
    If you haven't already done so, we encourage you to update your personal information and current responsibilities before the meeting process begins on November 20th. 
    Please access eHRonline at http://ehronline.enron.com (London users please go to http://home.enron.co.uk ,
    click on Quick Links, and choose HR  Online). Your User ID & Password are: User ID: 90012897 Password: WELCOME"
"""
    sim, pred, expl, top_p, top_e = analyzer.compute_cc_shap(email_text)
    print(f"Prediction: {pred}")
    print(f"CC-SHAP : {sim:.4f}\n")
    print("Explanation:\n", expl, "\n")
    print("Top 10 tokens for prediction ")
    for i, tok in enumerate(top_p, 1):
        print(f"  {i}. {tok}")
    print("\nTop 10 tokens for explanation")
    for i, tok in enumerate(top_e, 1):
        print(f"  {i}. {tok}")


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model loaded on cuda:0


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Prediction: PHISHING
CC-SHAP : 0.9800

Explanation:
 Provide a detailed reasoning for classifying this email as PHISHING. List key indicators for that classification.

Email:
sender:perfmgmt@enron.com,
    subject: Enron Year End 2000 Performance Management Process,
    body : Enron's Year-End 2000 Performance Management Process opens on: WEDNESDAY, OCTOBER 25th. 
    During this process, you will be able to suggest reviewers who can provide feedback on your performance. 
    In addition, you may be requested to provide feedback on fellow employees. 
    To participate in the feedback process, access the Performance Management System (PEP) at http://pep.corp.enron.com.
    Your UserID and Password are provided below. 
    The system will be open for feedback from October 25th - November 17th, and Help Desk representatives will be available to answer questions throughout the process. 
    You may contact the Help Desk at: Houston: 1-713-853-4777, Option 4 London: 44-207-783-4040,
    Op