In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
from train_mllm import *
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import sys
import time

llm_model_path = '../9G4B'
signal_encoder_path = './weight'
trained_dir = '/home/zbtrs/competitions/dianci/models/v33_b64_oc_p4_e3+1fe/checkpoint-4000'
device = 'cuda'
max_length = 2048

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
tokenizer = AutoTokenizer.from_pretrained(
    llm_model_path,
    use_fast=False,
    trust_remote_code=True
)
tokenizer.add_special_tokens({
    "additional_special_tokens": ["<IQ_START>", "<IQ_END>"]
})
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
llm = AutoModelForCausalLM.from_pretrained(
    llm_model_path,
    trust_remote_code=True,
    torch_dtype=torch.float16,
).to(device)
llm.resize_token_embeddings(len(tokenizer))

em_config = SiTMAEConfig(
    max_seq_len=(4096,1),
    use_fs=True,
    fs=20e6,      # e.g. 1e6 表示 1MHz
    patch_size=(8, 1),
)
# 先加载带 Mask 的骨干并对齐权重
base_encoder = load_sit_encoder_model(em_config, signal_encoder_path)
# 切换到无 Mask 版本并拷贝权重
signal_encoder = SiTMAEModel(base_encoder.config)
signal_encoder.load_state_dict(base_encoder.state_dict(), strict=False)
signal_encoder.to(device)

model = MultiModalModel(signal_encoder, llm, tokenizer, use_lora=False)
# 加载 signal_encoder，并拿到 patch_size

if trained_dir:
    # 获取当前设备，避免CUDA设备映射错误
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.llm = PeftModel.from_pretrained(model.llm, trained_dir)
    model.signal_encoder.load_state_dict(torch.load(os.path.join(trained_dir, "signal_encoder.pt"), map_location=device), strict=False)
    model.signal_to_hidden.load_state_dict(torch.load(os.path.join(trained_dir, "signal_to_hidden.pt"), map_location=device), strict=False)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.47it/s]


In [4]:
h5_paths = ["../../data/competition_test_data_1/test_data.h5"]
# h5_paths = ["../../data/train_data/modulation.h5"]
task_pairs = []
for h5p in h5_paths:
    base = os.path.splitext(os.path.basename(h5p))[0]
    jsonp = os.path.join(os.path.dirname(h5p), base + '.json')
    if not os.path.exists(jsonp):
        raise FileNotFoundError(f"找不到对应的 JSON 文件: {jsonp}")
    task_pairs.append((h5p, jsonp))
task_pairs

[('../../data/competition_test_data_1/test_data.h5',
  '../../data/competition_test_data_1/test_data.json')]

In [5]:
patch_size = signal_encoder.config.patch_size
# 构建 Dataset
full_dataset = MultiModalDataset(task_pairs, tokenizer, max_length=max_length, test_mode=True)
collator = DynamicIQCollator(tokenizer, patch_size)

In [6]:
from datetime import datetime
import pytz

def infer(item, idx):
    meta = full_dataset.samples[idx]
    question = meta['question']
    gt_answer = meta.get('answer', '')

    iq = item['iq'].to(device)  
    ids = item['input_ids'].to(device)
    mask = item['attention_mask'].to(device)
    out, input_ids = model.forward(iq, ids, mask, generate=True)
    pred = out.logits.argmax(dim=-1)[0]
    B, L, C = iq.shape
    P = patch_size[0]
    prefix_len = 1 + (L // P) + 1
    text_pred = pred[prefix_len:]
    q_ids = tokenizer(question, return_tensors='pt', add_special_tokens=False).input_ids[0]
    ans_ids = text_pred[q_ids.size(0)-1:]
    answer = tokenizer.decode(ans_ids, skip_special_tokens=True).strip()

    # print(question)
    # print(gt_answer)
    # print(answer)

    return answer

shanghai_timezone = pytz.timezone("Asia/Shanghai")
current_time = datetime.now(shanghai_timezone)
formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")

result = {
    "team_name": "",
    "model_name": "",
    "evaluation_time": formatted_time,
    "answers": []
}
for idx, item in tqdm(enumerate(full_dataset), total=len(full_dataset)):
    item_ = collator([item])
    answer = infer(item_, idx)
    id = full_dataset.samples[idx]['id']
    result['answers'].append({
        "id": id,
        "pred": answer
    })

  0%|          | 0/500 [00:00<?, ?it/s]

100%|██████████| 500/500 [04:10<00:00,  1.99it/s]


In [None]:
result_copy = json.loads(json.dumps(result))

for idx, i in tqdm(enumerate(result_copy['answers'])):
    pred = i['pred']
    for option in full_dataset.samples[idx]['Answer choices']:
        if pred in option:
            pred = option.split('(')[1].split(')')[0].strip()
            break
    if pred not in ['A', 'B', 'C', 'D']:
        pred = 'C'
    i['id'] = idx
    i['pred'] = pred

# result["team_name"] = "剑来"
# result["model_name"] = "jl_v4"
# result_copy["team_name"] = "总之就是很牛"
# result_copy["model_name"] = "nb"
# result["team_name"] = "Rstar"
# result["model_name"] = "r1"
result["team_name"] = "人生是什么样的"
result["model_name"] = "人生是什么样的"

with open(f'../../results/{result["model_name"]}.json', 'w') as f:
    json.dump(result_copy, f, indent=4, ensure_ascii=False)

500it [00:00, 527983.89it/s]


In [8]:
result_copy

{'team_name': '',
 'model_name': '',
 'evaluation_time': '2025-09-08 19:41:45',
 'answers': [{'id': 0, 'pred': 'A'},
  {'id': 1, 'pred': 'B'},
  {'id': 2, 'pred': 'A'},
  {'id': 3, 'pred': 'B'},
  {'id': 4, 'pred': 'C'},
  {'id': 5, 'pred': 'C'},
  {'id': 6, 'pred': 'C'},
  {'id': 7, 'pred': 'C'},
  {'id': 8, 'pred': 'B'},
  {'id': 9, 'pred': 'A'},
  {'id': 10, 'pred': 'D'},
  {'id': 11, 'pred': 'A'},
  {'id': 12, 'pred': 'A'},
  {'id': 13, 'pred': 'B'},
  {'id': 14, 'pred': 'C'},
  {'id': 15, 'pred': 'A'},
  {'id': 16, 'pred': 'B'},
  {'id': 17, 'pred': 'D'},
  {'id': 18, 'pred': 'C'},
  {'id': 19, 'pred': 'C'},
  {'id': 20, 'pred': 'B'},
  {'id': 21, 'pred': 'C'},
  {'id': 22, 'pred': 'A'},
  {'id': 23, 'pred': 'D'},
  {'id': 24, 'pred': 'D'},
  {'id': 25, 'pred': 'C'},
  {'id': 26, 'pred': 'B'},
  {'id': 27, 'pred': 'D'},
  {'id': 28, 'pred': 'D'},
  {'id': 29, 'pred': 'A'},
  {'id': 30, 'pred': 'C'},
  {'id': 31, 'pred': 'C'},
  {'id': 32, 'pred': 'B'},
  {'id': 33, 'pred': 'D'},
 

In [9]:
# print(text_pred)
# print(tokenizer.decode(text_pred))
# print(q_ids)
# print(tokenizer.decode(q_ids))
# print(ans_ids)
# print(tokenizer.decode(ans_ids))