In [None]:
!pip install transformers accelerate bitsandbytes tuned-lens -q

## Setup Config

In [None]:
model_card = "meta-llama/Llama-2-7b-hf"
# model_card = "meta-llama/Meta-Llama-3-8B"
# model_card = "universitytehran/PersianMind-v1.0"
# model_card = "PartAI/Dorna-Llama3-8B-Instruct"
# model_card = "mistralai/Mistral-7B-v0.1"
# model_card = "Qwen/Qwen2-7B-Instruct"
# model_card = "bigscience/bloom"

## Authorization token


In [None]:
if "Dorna" in model_card:
    hf_token = "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"
    model_name = "Dorna"
    
elif "PersianMind" in model_card:
    hf_token = "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"
    model_name = "PersianMind"
    
elif "Llama-2" in model_card:
    hf_token = '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'
    model_name = "Llama-2"
    
elif "Llama-3" in model_card:
    hf_token = "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"
    model_name = "Llama-3"
    
elif "Mistral" in model_card:
    hf_token = "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"
    model_name = "Mistral"

elif "Qwen" in model_card:
    hf_token = "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"
    model_name = "Qwen"
    
elif "bloom" in model_card:
    hf_token = "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"
    model_name = "Bloom"
    
else:
    raise ValueError(f'There is no token associated with this model: {model_card}')

# LLM wrapper


In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig

class AttnWrapper(torch.nn.Module):
    def __init__(self, attn):
        super().__init__()
        self.attn = attn
        self.activations = None
        self.add_tensor = None
        self.act_as_identity = False
    def forward(self, *args, **kwargs):
        if self.act_as_identity:
            kwargs['attention_mask'] += kwargs['attention_mask'][0, 0, 0, 1]*torch.tril(torch.ones(kwargs['attention_mask'].shape,
                                                                                                   dtype=kwargs['attention_mask'].dtype,
                                                                                                   device=kwargs['attention_mask'].device),
                                                                                        diagonal=-1)
        output = self.attn(*args, **kwargs)
        if self.add_tensor is not None:
            output = (output[0] + self.add_tensor,)+output[1:]
        self.activations = output[0]
        return output

    def reset(self):
        self.activations = None
        self.add_tensor = None
        self.act_as_identity = False

class BlockOutputWrapper(torch.nn.Module):
    def __init__(self, block, unembed_matrix, norm):
        super().__init__()
        self.block = block
        self.unembed_matrix = unembed_matrix
        self.norm = norm

        self.block.self_attn = AttnWrapper(self.block.self_attn)
        self.post_attention_layernorm = self.block.post_attention_layernorm

        self.attn_mech_output_unembedded = None
        self.intermediate_res_unembedded = None
        self.mlp_output_unembedded = None
        self.block_output_unembedded = None
        self.add_to_last_tensor = None
        self.output = None

    def forward(self, *args, **kwargs):
        output = self.block(*args, **kwargs)
        if self.add_to_last_tensor is not None:
            print('performing intervention: add_to_last_tensor')
            output[0][:, -1, :] += self.add_to_last_tensor
        self.output = output[0]
        self.block_output_unembedded = self.unembed_matrix(self.norm(output[0]))
        attn_output = self.block.self_attn.activations
        self.attn_mech_output_unembedded = self.unembed_matrix(self.norm(attn_output))
        attn_output += args[0]
        self.intermediate_res_unembedded = self.unembed_matrix(self.norm(attn_output))
        mlp_output = self.block.mlp(self.post_attention_layernorm(attn_output))
        self.mlp_output_unembedded = self.unembed_matrix(self.norm(mlp_output))
        return output

    def block_add_to_last_tensor(self, tensor):
        self.add_to_last_tensor = tensor

    def attn_add_tensor(self, tensor):
        self.block.self_attn.add_tensor = tensor

    def reset(self):
        self.block.self_attn.reset()
        self.add_to_last_tensor = None

    def get_attn_activations(self):
        return self.block.self_attn.activations

class LLMHelper:
    def __init__(self, model_name, hf_token, load_in_8bit=True):
        quantization_config = BitsAndBytesConfig(load_in_8bit=load_in_8bit)
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name, 
            token=hf_token
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            token=hf_token,
            device_map='auto',
            quantization_config=quantization_config
        )
        self.head_unembed = self.model.lm_head
        self.device = next(self.model.parameters()).device
        head = self.head_unembed
        for i, layer in enumerate(self.model.model.layers):
            self.model.model.layers[i] = BlockOutputWrapper(layer, head, self.model.model.norm)

    def generate_text(self, prompt, max_length=200):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        generate_ids = self.model.generate(inputs.input_ids.to(self.device), max_length=max_length)
        return self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]


    def generate_intermediate_text(self, layer_idx, prompt, max_length=100, temperature=1.0):
        layer = self.model.model.layers[layer_idx]
        for _ in range(max_length):
            self.get_logits(prompt)
            next_id = self.sample_next_token(layer.block_output_unembedded[:,-1], temperature=temperature)
            prompt = self.tokenizer.decode(self.tokenizer.encode(prompt)[1:]+[next_id])
            if next_id == model.tokenizer.eos_token_id:
                break
        return prompt

    def sample_next_token(self, logits, temperature=1.0):
        assert temperature >= 0, "temp must be geq 0"
        if temperature == 0:
            return self._sample_greedy(logits)
        return self._sample_basic(logits/temperature)

    def _sample_greedy(self, logits):
        return logits.argmax().item()

    def _sample_basic(self, logits):
        return torch.distributions.categorical.Categorical(logits=logits).sample().item()

    def get_logits(self, prompt):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        with torch.no_grad():
          logits = self.model(inputs.input_ids.to(self.device)).logits
          return logits

    def set_add_attn_output(self, layer, add_output):
        self.model.model.layers[layer].attn_add_tensor(add_output)

    def get_attn_activations(self, layer):
        return self.model.model.layers[layer].get_attn_activations()

    def set_add_to_last_tensor(self, layer, tensor):
      print('setting up intervention: add tensor to last soft token')
      self.model.model.layers[layer].block_add_to_last_tensor(tensor)

    def reset_all(self):
        for layer in self.model.model.layers:
            layer.reset()

    def print_decoded_activations(self, decoded_activations, label):
        softmaxed = torch.nn.functional.softmax(decoded_activations[0][-1], dim=-1)
        values, indices = torch.topk(softmaxed, 10)
        probs_percent = [int(v * 100) for v in values.tolist()]
        tokens = self.tokenizer.batch_decode(indices.unsqueeze(-1))
        print(label, list(zip(indices.detach().cpu().numpy().tolist(), tokens, probs_percent)))

    def logits_all_layers(self, text, return_attn_mech=False, return_intermediate_res=False, return_mlp=False, return_block=True):
        if return_attn_mech or return_intermediate_res or return_mlp:
            raise NotImplemented("not implemented")
        self.get_logits(text)
        tensors = []
        for i, layer in enumerate(self.model.model.layers):
            tensors += [layer.block_output_unembedded.detach().cpu()]
        return torch.cat(tensors, dim=0)

    def decode_all_layers(self, text, topk=10, print_attn_mech=True, print_intermediate_res=True, print_mlp=True, print_block=True):
        print('Prompt:', text)
        self.get_logits(text)
        for i, layer in enumerate(self.model.model.layers):
            print(f'Layer {i}: Decoded intermediate outputs')
            if print_attn_mech:
                self.print_decoded_activations(layer.attn_mech_output_unembedded, 'Attention mechanism')
            if print_intermediate_res:
                self.print_decoded_activations(layer.intermediate_res_unembedded, 'Intermediate residual stream')
            if print_mlp:
                self.print_decoded_activations(layer.mlp_output_unembedded, 'MLP output')
            if print_block:
                self.print_decoded_activations(layer.block_output_unembedded, 'Block output')


In [None]:
model = LLMHelper(
    model_card,
    hf_token, 
    load_in_8bit=True,
)

# Logit-lens plot

In [None]:
try:
  from google.colab import output
  output.enable_custom_widget_manager()
except:
  pass

from tuned_lens.plotting import PredictionTrajectory
import ipywidgets as widgets
from plotly import graph_objects as go
import numpy as np

tokenizer=model.tokenizer
def make_plot(text, layer_stride, statistic, token_range):
    input_ids = tokenizer.encode(text)
    targets = input_ids[1:] + [tokenizer.eos_token_id]

    if len(input_ids) == 0:
        return widgets.Text("Please enter your input text.")
    if (token_range[0] == token_range[1]):
        return widgets.Text("Please provide valid token range.")
    
    log_probs = model.logits_all_layers(text).float().log_softmax(dim=-1).numpy()
    pred_traj = PredictionTrajectory(
        log_probs = log_probs,
        input_ids = np.asarray(input_ids),
        targets= np.asarray(targets),
        anti_targets=None,
        tokenizer=tokenizer
    )
    pred_traj = pred_traj.slice_sequence(slice(*token_range))
    return getattr(pred_traj, statistic)().stride(layer_stride).figure(
        title=f"{model_name} {statistic}",
    )

style = {'description_width': 'initial'}
statistic_wdg = widgets.Dropdown(
    options=[
        ('Entropy', 'entropy'),
        ('Cross Entropy', 'cross_entropy'),
        ('Forward KL', 'forward_kl'),
    ],
    description='Select Statistic:',
    style=style,
)
text_wdg = widgets.Textarea(
    description="Input Text",
    value ="Meow Meow Meow"
)

layer_stride_wdg = widgets.BoundedIntText(
    value=2,
    min=1,
    max=10,
    step=1,
    description='Layer Stride:',
    disabled=False
)

token_range_wdg = widgets.IntRangeSlider(
    description='Token Range',
    min=0,
    max=1,
    step=1,
    style=style,
)


def update_token_range(*args):
    token_range_wdg.max = len(tokenizer.encode(text_wdg.value))

update_token_range()

token_range_wdg.value = [0, token_range_wdg.max]
text_wdg.observe(update_token_range, 'value')

interact = widgets.interact.options(manual_name='Run Lens', manual=True)

plot = interact(
    make_plot,
    text=text_wdg,
    statistic=statistic_wdg,
    layer_stride=layer_stride_wdg,
    token_range=token_range_wdg,
)