In [None]:
!nvidia-smi

import importlib
def reload_pack():
    from tools import inference
    importlib.reload(inference)
    from tools import utils
    importlib.reload(utils)
    from tools import transformers_patch
    importlib.reload(transformers_patch)
    from tools import graphic
    importlib.reload(graphic)
reload_pack()

In [None]:
# Loading model
from tools.utils import load_model

## TO BE REMOVED
from pathlib import Path
import os
working_directory = Path(os.getcwd())
folder_path = working_directory.parent / 'Models'
path = folder_path / 'llama-3-hf/8B'
print(f'Loading model from: {path}')
## TO BE REMOVED END

tokenizer, model = load_model(path, model_type='llama')

In [None]:
from tools.utils import gen_prompt, show_outputs
from tools.inference import forward

prompt_comp = gen_prompt(task='country2capital')
prompt = prompt_comp['icl']+prompt_comp['ans']
print(f'Correct:\n{prompt}')

gen_len = 2
outputs = forward(model, tokenizer, [prompt_comp['icl']], gen_len=gen_len)
print(f'--------\nGenerated:\n{show_outputs(tokenizer, outputs.sequences)[0]}')

In [None]:
import torch
from transformers import PreTrainedModel

def get_attention_layers(model: PreTrainedModel):
    layers = model.model.layers
    return [layer.self_attn for layer in layers]

class AttentionReweighter:
    def __init__(
        self,
        model: PreTrainedModel,
        attention_reweight
    ):
        self._model = model
        self.attention_reweight = attention_reweight
        self._hooks = []
        self.causal_mask_org = None

    def __enter__(self):
        self._register_forward_pre_hooks()
        return

    def __exit__(self, exc_type, exc_value, traceback):
        for hook in self._hooks:
            hook.remove()
        self._hooks = []

    def _register_forward_pre_hooks(self):
        def attn_mask_hook(layer_idx):
            def mask_attn(mod, inp):
                if layer_idx == 0:
                    self.causal_mask_org = inp[1].clone()
                    # 实际上已经不需要了
                        
                causal_mask = self.causal_mask_org
                attention_reweight = self.attention_reweight[:, layer_idx].unsqueeze(dim=1)

                causal_mask_size = causal_mask.shape[-2] 
                attn_mask_size = attention_reweight.shape[-2]
                target_idxs = torch.tensor(range(attn_mask_size))
                attn_source_idx = torch.tensor(range(attn_mask_size-causal_mask_size, attn_mask_size))
                causal_source_idx = torch.tensor(range(causal_mask_size))

                causal_mask[:, :, causal_source_idx[:, None], target_idxs] = attention_reweight[:,:,attn_source_idx[:, None], target_idxs]

                inp = tuple((inp[0], causal_mask))
                return inp
            return mask_attn

        for i, layer in enumerate(get_attention_layers(self._model)):
            hook = layer.register_forward_pre_hook(attn_mask_hook(i))
            self._hooks.append(hook)

In [None]:
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from typing import List, Callable

inputs = tokenizer([prompt], return_tensors='pt', padding=True).to(model.device)
prompt_ids = inputs['input_ids']
input_ids = prompt_ids[:, :-1]

batch_size = prompt_ids.shape[0]
num_tokens = prompt_ids.shape[1]

# Get original probabilities for later loss calculation.
model.eval()
outputs = model(input_ids)
logits0 = outputs.logits[:, -gen_len:].detach().clone()

# Initialize the attention mask.
attn_mask = torch.ones(batch_size, 32, num_tokens, num_tokens, device=model.device)

# Setting attn_mask to be a trainable parameter.
attn_mask = nn.Parameter(attn_mask)

# For GD, we do not update connections towards bos token and connention to itself.
update_mask = torch.tril(torch.ones_like(attn_mask), diagonal=-1).bool() # Remove self and connections under causal mask.
padding_mask = (1-inputs['attention_mask'])[:, None, None, :].to(bool).repeat(1, 32, num_tokens, 1)
update_mask[padding_mask] = 0 # Remove padding tokens.
# Remove bos tokens. (CAN BE IMPROVED)
update_mask[:,:,:,0] = 0
for batch_idx in range(batch_size):
    for i in range(num_tokens):
        if padding_mask[batch_idx, 0, -1, i]:
            update_mask[batch_idx, :, :, i+1] = 0

# Set certain connection weight to -\infty, for causal connections and padding tokens.
min_type = torch.finfo(attn_mask.dtype).min

# We use L2 loss to measure the gap between modified probabilities and orginal probabilities.
loss_fn = nn.MSELoss(reduction='mean')
block_rate = 1.
gap = 0.
gap_ckpt = 1000.

warning_flag = False
epoch = 0
epoch_ckpt = 0
attn_mask_ckpt = attn_mask.data.detach().clone()

model.train()
record = []
print(f'logits: {logits0.max(dim=-1).values}')

learning_rate = 1e-3

optimizer = optim.SGD([attn_mask],lr=1e-3)
regular = 1e3
lower_bound = 1e-12
for epoch in range(2000):
    # Apply attn_mask in the inference process.
    optimizer.zero_grad()

    attn_mask_log_causal = torch.log(torch.clamp(attn_mask, min=lower_bound)).masked_fill(~update_mask, min_type)
    modified_forward = AttentionReweighter(model, attn_mask_log_causal)
    with modified_forward:
        outputs = model(input_ids)
        logits = outputs.logits[:, -gen_len:]

    # Gradient Descent
    sparse = attn_mask[update_mask].sum()
    loss_raw = loss_fn(logits, logits0)*gen_len
    loss = regular*loss_raw + sparse
    loss.backward()
        
    with torch.no_grad():
        optimizer.step()
        attn_mask.clamp_(min=0, max=1)

        block_rate = attn_mask[update_mask].mean().item()
        gap = loss_raw.sqrt().item()

        view_gap = 100
        if (epoch+1)%view_gap==0:
            print(f"Epoch: {epoch}, Block rate: {block_rate:.3f}, logits: {logits[0,-1].max().item():.3f}, Target logits: {logits0[0,-1].max().item():.3f}, Gap: {gap:.3f}")
            # record.append({'epoch':epoch,'attn_mask':attn_mask.detach().clone(), 'sparsity': block_rate, 'gap': gap})

In [None]:
from tools.graphic import demonstrate_by_token

labels = []
for t in prompt_ids[0]:
    labels.append(tokenizer.decode(t))

connection = attn_mask[0,:,-gen_len]
demonstrate_by_token(connection, labels)