In [1]:
import os
CUDA_VISIBLE_DEVICES = "2"
os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES
import sys
# 添加项目根目录到路径
sys.path.append(os.path.dirname(os.getcwd()))
import json
import torch
from transformers import AutoTokenizer
import torch.distributed as dist
from peft import LoraConfig
from dataclasses import fields




from lit.configs.train_config import train_config
from lit.configs.peft_config import lora_config
from lit.utils.dataset_utils import (
    LatentQADataset, 
    DataCollatorForLatentQA,
    NUM_READ_TOKENS_TO_SHIFT,
    NUM_WRITE_TOKENS_TO_SHIFT,
    DECODER_CHAT_TEMPLATES,
    get_dataset,
    get_dataloaders,
    get_dist_batch_sampler,
    mask_inputs
)
from lit.utils.infra_utils import get_tokenizer, get_model_config_name, get_model, get_modules
from lit.utils.activation_utils import latent_qa

In [2]:
args = train_config()

args.batch_size_training = 2
args.train_system = "../data/train/system.json"
args.train_stimulus_completion = "../data/train/stimulus_completion.json"
# args.train_stimulus = "../data/train/stimulus.json"
# args.train_control = "../data/train/control.json"
args.train_qa = "../data/train/qa.json" 
# args.filter = "goal"

model_name = "/data1/ckx/hf-checkpoints/meta-llama/Llama-3.1-8B-Instruct"
    
tokenizer = get_tokenizer(model_name)

target_model = get_model(
        model_name, tokenizer,device='auto'
    )

lora_params = {
        k.name: getattr(lora_config(), k.name) for k in fields(lora_config())
    }
peft_config = LoraConfig(**lora_params)
decoder_model = get_model(
        model_name,
        tokenizer,
        peft_config=peft_config,
        device='auto',
        # distributed_training=True,
    )


module_read, module_write = get_modules(
        target_model, decoder_model, **args.__dict__
    )

`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
module_read

[[LlamaDecoderLayer(
    (self_attn): LlamaAttention(
      (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
      (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
      (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
    )
    (mlp): LlamaMLP(
      (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
      (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
      (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
      (act_fn): SiLUActivation()
    )
    (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
    (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
  )]]

In [4]:
module_write

[[LlamaDecoderLayer(
    (self_attn): LlamaAttention(
      (q_proj): lora.Linear(
        (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
        (lora_dropout): ModuleDict(
          (default): Dropout(p=0.05, inplace=False)
        )
        (lora_A): ModuleDict(
          (default): Linear(in_features=4096, out_features=16, bias=False)
        )
        (lora_B): ModuleDict(
          (default): Linear(in_features=16, out_features=4096, bias=False)
        )
        (lora_embedding_A): ParameterDict()
        (lora_embedding_B): ParameterDict()
        (lora_magnitude_vector): ModuleDict()
      )
      (k_proj): lora.Linear(
        (base_layer): Linear(in_features=4096, out_features=1024, bias=False)
        (lora_dropout): ModuleDict(
          (default): Dropout(p=0.05, inplace=False)
        )
        (lora_A): ModuleDict(
          (default): Linear(in_features=4096, out_features=16, bias=False)
        )
        (lora_B): ModuleDict(
          (default)

# get_dataset函数

In [None]:
dataset = get_dataset(args, tokenizer)

In [None]:
import json
import random
from collections import defaultdict
train = True
FILTER = args.filter.split("-")
with open(args.train_qa , "r") as f:
    qa_data = json.load(f)

    NUM_QA = max([len(qa_data[label]) for label in qa_data])
    assert NUM_QA == min([len(qa_data[label]) for label in qa_data])

def build_data_and_idx(path):
        # Get data
        print("Loading data from:", path)
        data = defaultdict(list)
        if path == "":
            return data, []
        with open(path, "r") as f:
            raw_data = json.load(f)
            for item in raw_data:
                
                if item["label"].split("-")[0] in FILTER:
                    continue
                data[item["label"]].append(
                    (
                        item.get("system", ""),
                        item.get("control_user", ""),
                        item.get("control_thought", ""),
                        item.get("control_model", ""),
                        item.get("stimulus_user", ""),
                        item.get("stimulus_thought", ""),
                        item.get("stimulus_model", ""),
                    )
                )
        print("Loaded {} labels".format(len(data)))
        # Get id tuples
        NUM_BEHAVIORS = max([len(data[label]) for label in data])
        assert NUM_BEHAVIORS == min([len(data[label]) for label in data])
        print("NUM_BEHAVIORS:", NUM_BEHAVIORS)
        id_tuples = range(len(data) * NUM_BEHAVIORS * NUM_QA)
        print(id_tuples)
        if args.train_percent == 1 or not train:
            id_tuples = list(id_tuples)
        else:
            id_tuples = random.sample(
                id_tuples, int(len(id_tuples) * args.train_percent)
            )
        for i in range(len(id_tuples)):
            label_idx = id_tuples[i] // (NUM_BEHAVIORS * NUM_QA)
            data_idx = (id_tuples[i] // NUM_QA) % NUM_BEHAVIORS
            qa_idx = id_tuples[i] % NUM_QA
            id_tuples[i] = (label_idx, data_idx, qa_idx)
        return data, id_tuples

In [None]:
FILTER

In [None]:
qa_data

In [None]:
p0 = args.train_system if train else args.eval_system
p1 = (
        args.train_stimulus_completion
        if train
        else args.eval_stimulus_completion
    )
p2 = args.train_stimulus if train else args.eval_stimulus
p3 = args.train_control if train else args.eval_control
data_system = build_data_and_idx(p0)
data_stimulus_completion = build_data_and_idx(p1)
data_stimulus = build_data_and_idx(p2)
data_control = build_data_and_idx(p3)

In [None]:
data_system[0]

In [None]:
data_system[1]

# dataset类

In [None]:
dataset.data[1]

In [None]:
dataset.id_tuples[1]

In [None]:
dataset.labels[1]

In [None]:
dataset.qa_data

In [None]:
dataset.get_behavior_qa(17777)

In [None]:
dataset.lengths

In [None]:
len(dataset)

In [None]:
# dataset[0]
dataset[17777]

In [None]:
print(dataset[17777]['read_prompt'])

## mask_type = "user"

In [None]:
print(dataset[17777]['read_prompt'])

In [None]:
dataset[17777]['dialog']

In [None]:
dataset[17777]['mask_type']

## mask_type = "system"

In [None]:
print(dataset[0]["read_prompt"])

In [None]:
dataset[0]['dialog']

In [None]:
dataset[0]['mask_type']

# get_dataloaders

In [None]:
get_dist_batch_sampler(dataset, args, "train")

In [5]:
train_dataloader, eval_dataloader = get_dataloaders(args,tokenizer)

In [None]:
len(train_dataloader)

## lqa_tokenize 函数

In [None]:
batch = [dataset[0]['read_prompt'], dataset[17777]['read_prompt']]
batch

In [None]:
tokenized_read = tokenizer(batch, 
                   return_tensors="pt",
                   padding=True,
                   add_special_tokens=False,
    )
tokenized_read

In [None]:
print(tokenizer.decode(tokenized_read.input_ids[0]))
print('+++++++++++++++')
print(tokenizer.decode(tokenized_read.input_ids[1]))

In [None]:
read_lengths = torch.sum(tokenized_read.attention_mask, dim=1)
read_lengths

In [None]:
mask_inputs._debug_print = True

In [None]:
verb_mask = mask_inputs(tokenized_read.input_ids, 
                        tokenizer.name_or_path,
                        mask_type=['system','user'],
                        # mask_type = None,
                        mask_all_but_last=False)


In [None]:
# 验证字符串编码
encoded = tokenizer.encode("\n\n", add_special_tokens=False)
print(encoded)  # 输出: [271]
for ids in (
        torch.tensor([128006, 9125, 128007, 271]),
        torch.tensor([128006, 882, 128007, 271]),
        torch.tensor([128006, 78191, 128007, 271]),
        torch.tensor([128006, 36013, 128007, 271]),
    ):
    print(tokenizer.decode(ids))

In [None]:
verb_mask

In [None]:

verb_lengths = torch.sum(verb_mask, dim=1)
verb_lengths

In [None]:
read_lengths

In [None]:
pad_lengths = read_lengths - verb_lengths
pad_lengths

In [None]:
batch = [dataset[0]['dialog'], dataset[17777]['dialog']]

In [None]:
batch[0]

In [None]:
queries = []
for i in range(len(pad_lengths)):
    query = [
        {
                "role": "user",
                "content": "? " * (pad_lengths[i] - 1),
        }
    ]
    query += batch[i]
    print(query)
    queries.append(
            tokenizer.apply_chat_template(
                query,
                tokenize=False,
                add_generation_prompt=False,
                chat_template=(
                    DECODER_CHAT_TEMPLATES[get_model_config_name(model_name)]
                ),
            )
        )



In [None]:
queries

In [None]:
tokenized_write = tokenizer(
        queries,
        return_tensors="pt",
        padding=True,
        add_special_tokens=False,
    )
tokenized_write

In [None]:
write_lengths = torch.sum(tokenized_write.attention_mask, dim=1)
write_lengths

In [None]:
user_inputs_mask = mask_inputs(
            tokenized_write.input_ids,
            tokenizer.name_or_path,
            mask_type=None,
            shift_start=True,
            modify_chat_template=True,
        )
assert tokenizer.padding_side == "left"


In [None]:
tokenized_write["labels"] = tokenized_write.input_ids.clone()
mask = (tokenized_write.attention_mask == 0) | user_inputs_mask
tokenized_write["labels"][mask] = -100
tokenized_write["labels"]

# DataCollatorForLatentQA

In [6]:
first_batch = next(iter(train_dataloader))

In [None]:
first_batch

# latent_qa

In [67]:
from lit.utils.activation_utils import _forward_cache_outputs, no_op

batch = first_batch
tokenized_read, tokenized_write, read_lengths, write_lengths = (
        batch["tokenized_read"],
        batch["tokenized_write"],
        batch["read_lengths"],
        batch["write_lengths"],
    )
activation_cache = _forward_cache_outputs(
        target_model,
        tokenizer,
        tokenized_read.to(target_model.device),
        module_read[0],
        token_idx=None,
        no_grad=True,
        prepare_inputs=no_op,
    )
print(activation_cache[0].shape)
activation_cache # List of (batch_size, seq_len, hidden_size)

torch.Size([2, 218, 4096])


[tensor([[[ 0.0127,  0.0101, -0.0044,  ..., -0.0461, -0.0072,  0.0708],
          [ 0.0127,  0.0101, -0.0044,  ..., -0.0461, -0.0072,  0.0708],
          [ 0.0127,  0.0101, -0.0044,  ..., -0.0461, -0.0072,  0.0708],
          ...,
          [ 0.2812, -0.1807, -0.0596,  ..., -0.2266, -0.1016,  0.0693],
          [ 0.1895,  0.0610, -0.2275,  ..., -0.1641, -0.0947, -0.0830],
          [-0.0625, -0.1074, -0.0830,  ..., -0.1416, -0.1309,  0.0156]],
 
         [[-0.1089,  0.1650, -0.0400,  ...,  0.2910,  0.3711,  0.0830],
          [-0.0159, -0.2002,  0.0383,  ...,  0.0684, -0.0142,  0.1484],
          [-0.0112,  0.0330, -0.0400,  ..., -0.0452, -0.0610,  0.2344],
          ...,
          [-0.1943, -0.0171, -0.0181,  ..., -0.1895, -0.3789,  0.0272],
          [-0.1426,  0.0098,  0.0151,  ..., -0.1367, -0.4551,  0.1426],
          [-0.0374, -0.0635, -0.0562,  ..., -0.1455, -0.1660,  0.0615]]],
        device='cuda:0', dtype=torch.bfloat16)]

In [68]:
verb_lengths = None

In [69]:
activation_cache = torch.stack(activation_cache, dim=0)
num_modules, bs, read_seq_len, _ = activation_cache.shape
print(activation_cache.shape)

torch.Size([1, 2, 218, 4096])


In [70]:
print(read_lengths)
# Create a tensor with that is filled with activations for <bos> tokens
batch_idx = torch.arange(bs, device="cpu")
# lengths is len of read_inputs, so this fetches the <bos> token activations
bos_activations = activation_cache[
            :, batch_idx, read_seq_len - read_lengths.cpu(), :
        ]
bos_activations = bos_activations.unsqueeze(2).expand(-1, -1, read_seq_len, -1)
assert bos_activations.shape == activation_cache.shape
bos_activations

tensor([198, 217])


tensor([[[[-0.0168, -0.2012,  0.0381,  ...,  0.0674, -0.0142,  0.1475],
          [-0.0168, -0.2012,  0.0381,  ...,  0.0674, -0.0142,  0.1475],
          [-0.0168, -0.2012,  0.0381,  ...,  0.0674, -0.0142,  0.1475],
          ...,
          [-0.0168, -0.2012,  0.0381,  ...,  0.0674, -0.0142,  0.1475],
          [-0.0168, -0.2012,  0.0381,  ...,  0.0674, -0.0142,  0.1475],
          [-0.0168, -0.2012,  0.0381,  ...,  0.0674, -0.0142,  0.1475]],

         [[-0.0159, -0.2002,  0.0383,  ...,  0.0684, -0.0142,  0.1484],
          [-0.0159, -0.2002,  0.0383,  ...,  0.0684, -0.0142,  0.1484],
          [-0.0159, -0.2002,  0.0383,  ...,  0.0684, -0.0142,  0.1484],
          ...,
          [-0.0159, -0.2002,  0.0383,  ...,  0.0684, -0.0142,  0.1484],
          [-0.0159, -0.2002,  0.0383,  ...,  0.0684, -0.0142,  0.1484],
          [-0.0159, -0.2002,  0.0383,  ...,  0.0684, -0.0142,  0.1484]]]],
       device='cuda:0', dtype=torch.bfloat16)

In [71]:
batch["verb_lengths"]

tensor([100, 105])

In [72]:
# Mask everything except for the non-verb (last) tokens
verb_lengths = batch["verb_lengths"]
counter = torch.arange(read_seq_len, device=activation_cache.device)
# lengths - verb_lengths - 1 is total length of input not including bos token and verb
counter

tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
        140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
        154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
        168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 1

In [73]:
verb_lengths = verb_lengths.to(activation_cache.device)
counter = counter.to(activation_cache.device)
read_lengths = read_lengths.to(activation_cache.device)

mask = counter.expand(bs, -1) >= read_seq_len - (
            read_lengths - verb_lengths - 1
        ).unsqueeze(1)
mask = mask.expand(num_modules, -1, -1).unsqueeze(-1)
mask

tensor([[[[False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [False],
          [F

In [74]:
activation_cache = bos_activations * (~mask) + activation_cache * mask
activation_cache

tensor([[[[-0.0168, -0.2012,  0.0381,  ...,  0.0674, -0.0142,  0.1475],
          [-0.0168, -0.2012,  0.0381,  ...,  0.0674, -0.0142,  0.1475],
          [-0.0168, -0.2012,  0.0381,  ...,  0.0674, -0.0142,  0.1475],
          ...,
          [ 0.2812, -0.1807, -0.0596,  ..., -0.2266, -0.1016,  0.0693],
          [ 0.1895,  0.0610, -0.2275,  ..., -0.1641, -0.0947, -0.0830],
          [-0.0625, -0.1074, -0.0830,  ..., -0.1416, -0.1309,  0.0156]],

         [[-0.0159, -0.2002,  0.0383,  ...,  0.0684, -0.0142,  0.1484],
          [-0.0159, -0.2002,  0.0383,  ...,  0.0684, -0.0142,  0.1484],
          [-0.0159, -0.2002,  0.0383,  ...,  0.0684, -0.0142,  0.1484],
          ...,
          [-0.1943, -0.0171, -0.0181,  ..., -0.1895, -0.3789,  0.0272],
          [-0.1426,  0.0098,  0.0151,  ..., -0.1367, -0.4551,  0.1426],
          [-0.0374, -0.0635, -0.0562,  ..., -0.1455, -0.1660,  0.0615]]]],
       device='cuda:0', dtype=torch.bfloat16)

In [75]:
# We truncate the lengths to get rid of the verb mask
assert read_lengths.shape == verb_lengths.shape
read_lengths = read_lengths - verb_lengths
print(read_lengths)
activation_cache = torch.unbind(activation_cache, dim=0)

tensor([ 98, 112], device='cuda:0')


In [79]:
activation_cache[0].shape

torch.Size([2, 218, 4096])

In [80]:
activation_cache[0]

tensor([[[-0.0168, -0.2012,  0.0381,  ...,  0.0674, -0.0142,  0.1475],
         [-0.0168, -0.2012,  0.0381,  ...,  0.0674, -0.0142,  0.1475],
         [-0.0168, -0.2012,  0.0381,  ...,  0.0674, -0.0142,  0.1475],
         ...,
         [ 0.2812, -0.1807, -0.0596,  ..., -0.2266, -0.1016,  0.0693],
         [ 0.1895,  0.0610, -0.2275,  ..., -0.1641, -0.0947, -0.0830],
         [-0.0625, -0.1074, -0.0830,  ..., -0.1416, -0.1309,  0.0156]],

        [[-0.0159, -0.2002,  0.0383,  ...,  0.0684, -0.0142,  0.1484],
         [-0.0159, -0.2002,  0.0383,  ...,  0.0684, -0.0142,  0.1484],
         [-0.0159, -0.2002,  0.0383,  ...,  0.0684, -0.0142,  0.1484],
         ...,
         [-0.1943, -0.0171, -0.0181,  ..., -0.1895, -0.3789,  0.0272],
         [-0.1426,  0.0098,  0.0151,  ..., -0.1367, -0.4551,  0.1426],
         [-0.0374, -0.0635, -0.0562,  ..., -0.1455, -0.1660,  0.0615]]],
       device='cuda:0', dtype=torch.bfloat16)

In [82]:
from lit.utils.activation_utils import get_pos_ids
position_ids = get_pos_ids(tokenized_read, tokenized_write, verb_lengths).to(
            decoder_model.device
        )
position_ids

tensor([[109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122,
         123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136,
         137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150,
         151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164,
         165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178,
         179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192,
         193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206,
         207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220,
         221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234,
         235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248,
         249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262,
         263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276,
         277, 278, 279, 280, 281, 282, 283, 284, 285

In [83]:
activation_cache = [a.to(decoder_model.device) for a in activation_cache]
activation_cache

[tensor([[[-0.0168, -0.2012,  0.0381,  ...,  0.0674, -0.0142,  0.1475],
          [-0.0168, -0.2012,  0.0381,  ...,  0.0674, -0.0142,  0.1475],
          [-0.0168, -0.2012,  0.0381,  ...,  0.0674, -0.0142,  0.1475],
          ...,
          [ 0.2812, -0.1807, -0.0596,  ..., -0.2266, -0.1016,  0.0693],
          [ 0.1895,  0.0610, -0.2275,  ..., -0.1641, -0.0947, -0.0830],
          [-0.0625, -0.1074, -0.0830,  ..., -0.1416, -0.1309,  0.0156]],
 
         [[-0.0159, -0.2002,  0.0383,  ...,  0.0684, -0.0142,  0.1484],
          [-0.0159, -0.2002,  0.0383,  ...,  0.0684, -0.0142,  0.1484],
          [-0.0159, -0.2002,  0.0383,  ...,  0.0684, -0.0142,  0.1484],
          ...,
          [-0.1943, -0.0171, -0.0181,  ..., -0.1895, -0.3789,  0.0272],
          [-0.1426,  0.0098,  0.0151,  ..., -0.1367, -0.4551,  0.1426],
          [-0.0374, -0.0635, -0.0562,  ..., -0.1455, -0.1660,  0.0615]]],
        device='cuda:0', dtype=torch.bfloat16)]

In [94]:
_, read_seq_len, _ = activation_cache[0].shape
read_seq_len

218

In [89]:
read_lengths

tensor([ 98, 112], device='cuda:0')

In [90]:
write_lengths

tensor([168, 178])

In [91]:
tokenized_write

{'input_ids': tensor([[128010, 128010, 128010, 128010, 128010, 128010, 128010, 128010, 128010,
         128010, 128000, 128006,    882, 128007,    271,     30,    949,    949,
            949,    949,    949,    949,    949,    949,    949,    949,    949,
            949,    949,    949,    949,    949,    949,    949,    949,    949,
            949,    949,    949,    949,    949,    949,    949,    949,    949,
            949,    949,    949,    949,    949,    949,    949,    949,    949,
            949,    949,    949,    949,    949,    949,    949,    949,    949,
            949,    949,    949,    949,    949,    949,    949,    949,    949,
            949,    949,    949,    949,    949,    949,    949,    949,    949,
            949,    949,    949,    949,    949,    949,    949,    949,    949,
            949,    949,    949,    949,    949,    949,    949,    949,    949,
            949,    949,    949,    949,    949,    949,    949,    949,    949,
            94

In [None]:
from lit.utils.activation_utils import generate_substitute_layer_single
max_new_tokens = 100
out = generate_substitute_layer_single(
        decoder_model,
        tokenizer,
        tokenized_write.to(decoder_model.device),
        module_write[0],
        activation_cache,
        "output",
        generate=False,
        no_grad=False,
        substitute_by_mask=(read_lengths, write_lengths),
        prepare_inputs=no_op,
        max_new_tokens=max_new_tokens,
        position_ids=position_ids,
        use_cache=False,
    )
out

In [86]:
outputs = latent_qa(
                    first_batch,
                    target_model,
                    decoder_model,
                    module_read[0],
                    module_write[0],
                    tokenizer,
                    mask_verbs=True,
                    shift_position_ids=args.shift_position_ids,
                )

In [None]:
outputs.logits.shape

In [87]:
outputs.loss

tensor(3.0877, device='cuda:0', grad_fn=<NllLossBackward0>)