In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '4'
from train_mllm import *
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import pandas as pd
import os
import sys

llm_model_path = '/mnt/gold/lz/competitions/dianci/em_foundation_model/9G4B'
signal_encoder_path = '/mnt/gold/lz/competitions/dianci/em_foundation_model/em_foundation/weight'
device = 'cuda'
task_dir = '/mnt/gold/lz/competitions/dianci/data/competition_test_data_1'
max_length = 2048

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    llm_model_path,
    use_fast=False,
    trust_remote_code=True
)
tokenizer.add_special_tokens({
    "additional_special_tokens": ["<IQ_START>", "<IQ_END>"]
})
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
llm = AutoModelForCausalLM.from_pretrained(
    llm_model_path,
    trust_remote_code=True,
    torch_dtype=torch.float16,
).to(device)
llm.resize_token_embeddings(len(tokenizer))

In [None]:
em_config = SiTMAEConfig(
    max_seq_len=(4096,1),
    use_fs=True,
    fs=20e6,      # e.g. 1e6 表示 1MHz
)
# 先加载带 Mask 的骨干并对齐权重
base_encoder = load_sit_encoder_model(em_config, signal_encoder_path)
# 切换到无 Mask 版本并拷贝权重
signal_encoder = SiTMAEModel(base_encoder.config)
signal_encoder.load_state_dict(base_encoder.state_dict(), strict=False)
signal_encoder.to(device)

In [None]:
h5_paths = ["/mnt/gold/lz/competitions/dianci/data/competition_test_data_1/test_data.h5"]
task_pairs = []
for h5p in h5_paths:
    base = os.path.splitext(os.path.basename(h5p))[0]
    jsonp = os.path.join(task_dir, base + '.json')
    if not os.path.exists(jsonp):
        raise FileNotFoundError(f"找不到对应的 JSON 文件: {jsonp}")
    task_pairs.append((h5p, jsonp))
task_pairs

In [None]:
model = MultiModalModel(signal_encoder, llm, tokenizer, use_lora=False)
# 加载 signal_encoder，并拿到 patch_size
patch_size = signal_encoder.config.patch_size
# 构建 Dataset
full_dataset = MultiModalDataset(task_pairs, tokenizer, max_length=max_length)
collator = DynamicIQCollator(tokenizer, patch_size)

In [None]:
item = collator([full_dataset[4]])
iq = item['iq'].to(device)  
ids = item['input_ids'].to(device)
mask = item['attention_mask'].to(device)
out = model.forward(iq, ids, mask)
pred = out.logits.argmax(dim=-1)[0]
B, L, C = iq.shape
P = patch_size[0]
prefix_len = 1 + (L // P) + 1
text_pred = pred[prefix_len:]
tokenizer.decode(text_pred, kip_special_tokens=True).strip()