In [6]:
from train_my_dyprag_104 import *
embedding_model_path = "./models/long-t5-tglobal-base"
translator_path = "./models/Llama-3.2-1B-Instruct-Doc_mask-longt5_capt_57/translator_step_72270.safetensors"
device = "cuda:0"
embedding_model = AutoModel.from_pretrained(embedding_model_path,device_map=device)
embedding_tokenizer = AutoTokenizer.from_pretrained(embedding_model_path)
llm_model_path = "./models/Llama-3.2-1B-Instruct-Doc_mask"
llm_model = AutoModelForCausalLM.from_pretrained(llm_model_path, device_map="cpu")
translator = CrossAttentionParameterTranslator(
            embedding_model=embedding_model,
            llm_model=llm_model,
            lora_rank=2,
            projector_hidden_dim=1024,
            attn_heads=8,
            attn_ff_dim=1024,
            cross_layers=1,
            encoder_layers=1,
        )
def _get_full_doc_embed(passages):

    def _bge_or_snowflake_embed(model, inputs):
        output = model(**inputs)
        embeddings = output.last_hidden_state #[B, L, D]
        return torch.nn.functional.normalize(embeddings, p=2, dim=-1)

    def _t5_embed(model, inputs):
        output = model.encoder(**inputs)
        embeddings = output.last_hidden_state
        return embeddings #[B, L, D]

    with torch.no_grad():
        inputs = embedding_tokenizer(
            passages,
            return_tensors="pt",
            padding="longest",
            max_length=4096,
            truncation=True
        ).to(device)

        model_name = embedding_model.name_or_path.lower()

        if 'bge' in model_name or 'snowflake' in model_name:
            sentence_embeddings = _bge_or_snowflake_embed(embedding_model, inputs)
        elif 't5' in model_name:
            sentence_embeddings = _t5_embed(embedding_model, inputs)
        else:
            raise NotImplementedError(f"[Unsupported Model] {model_name}")

    return sentence_embeddings, inputs["attention_mask"]
translator.to(device)
translator.load_state_dict(load_file(translator_path, device=device))


<All keys matched successfully>

In [12]:
passages = ["hello world", "hello world", "hello world"]
doc_embed,attention_mask = _get_full_doc_embed(passages)
with torch.no_grad():
    lora_weights = translator(doc_embed,attention_mask.to(translator.device))

print(lora_weights)
print(len(lora_weights))
print(lora_weights['base_model.model.model.layers.15.mlp.gate_proj.lora_A.weight'].shape)
print(lora_weights['base_model.model.model.layers.15.mlp.gate_proj.lora_B.weight'].shape)
print(lora_weights['base_model.model.model.layers.15.mlp.up_proj.lora_A.weight'].shape)
print(lora_weights['base_model.model.model.layers.15.mlp.up_proj.lora_B.weight'].shape)
print(lora_weights['base_model.model.model.layers.15.mlp.down_proj.lora_A.weight'].shape)
print(lora_weights['base_model.model.model.layers.15.mlp.down_proj.lora_B.weight'].shape)

defaultdict(<class 'list'>, {'base_model.model.model.layers.0.mlp.down_proj.lora_A.weight': tensor([[[-0.0165,  0.2540,  0.8407,  ..., -0.1998, -0.6043, -0.6351],
         [ 1.5179, -1.0672, -0.4722,  ..., -0.1787,  1.0244,  0.7010]],

        [[-0.0696,  0.3420,  0.6513,  ..., -0.2589, -0.5729, -0.6033],
         [ 1.4940, -1.0691, -0.3171,  ..., -0.2076,  0.8030,  0.7242]],

        [[ 0.0546,  0.2315,  0.7426,  ..., -0.1940, -0.5675, -0.5227],
         [ 1.5018, -1.2771, -0.4343,  ..., -0.2175,  0.9931,  0.8090]]],
       device='cuda:0'), 'base_model.model.model.layers.0.mlp.down_proj.lora_B.weight': tensor([[[-1.2099e-04, -3.6793e-04],
         [ 2.4385e-05,  4.2279e-04],
         [ 4.8415e-04, -5.8368e-04],
         ...,
         [-8.1272e-05, -2.5132e-04],
         [-1.6941e-04, -8.1658e-05],
         [ 3.0678e-05,  1.8897e-05]],

        [[-7.1263e-05, -4.2085e-04],
         [ 5.0107e-05,  4.9522e-04],
         [ 5.0658e-04, -6.3057e-04],
         ...,
         [-1.2842e-04, -2

In [19]:
import json

dataset = json.load(open("data_aug_deepseek-v3/2wikimultihopqa/train_passages_deduplication_0_30000.json", "r"))
passages=[]
for sample in dataset:
    passages.append({'passage':sample['passages'][0]})
with open("./data_aug_deepseek-v3/2wikimultihopqa/reconstruct_passage_train.json", "w") as f:
    json.dump(passages[:-100], f, indent=4)
with open("./data_dev_inference_104/reconstruct_passage_test.json", "w") as f:
    json.dump(passages[-100:], f, indent=4)
    