In [None]:
import json
import numpy as np
import tensorflow as tf
from transformers import BertTokenizer, TFBertForSequenceClassification
from tqdm import tqdm

def is_match(claim, text):
    # 使用 BERT tokenizer 对句子进行编码
    claim_tokens = tokenizer.tokenize(claim)
    text_tokens = tokenizer.tokenize(text)

    # 将句子转换为输入的特征
    inputs = tokenizer.encode_plus(
        claim_tokens,
        text_tokens,
        add_special_tokens=True,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_tensors='tf'
    )

    # 获取输入特征的嵌入表示
    input_ids = inputs['input_ids']
    attention_mask = inputs['attention_mask']
    embeddings = loaded_model.bert(input_ids, attention_mask=attention_mask)[0]

    # 计算句子的平均嵌入表示
    claim_embedding = tf.reduce_mean(embeddings[0], axis=0)
    text_embedding = tf.reduce_mean(embeddings[1:], axis=0)

    # 计算句子的余弦相似度
    similarity = tf.keras.losses.cosine_similarity(claim_embedding, text_embedding)

    # 根据相似度阈值判断是否匹配
    similarity_threshold = 0.7
    return similarity >= similarity_threshold

print("載入模型")
loaded_model = TFBertForSequenceClassification.from_pretrained('bert_model')

print("測試資料預處理")
test_data = []
with open('public_test.jsonl', 'r', encoding='utf-8') as f:
    for line in tqdm(f):
        entry = json.loads(line)
        test_data.append({'id': entry['id'], 'claim': entry['claim']})

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

test_input_ids = []
test_attention_masks = []

for entry in tqdm(test_data):
    claim = entry['claim']
    encoded = tokenizer.encode_plus(
        claim,
        add_special_tokens=True,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='tf'
    )
    test_input_ids.append(encoded['input_ids'])
    test_attention_masks.append(encoded['attention_mask'])

test_input_ids = tf.concat(test_input_ids, axis=0)
test_attention_masks = tf.concat(test_attention_masks, axis=0)

print("進行預測")
predictions = loaded_model.predict(
    x={'input_ids': test_input_ids, 'attention_mask': test_attention_masks}
)

print("根據預測結果進行後續處理")
output_data = []
for i, entry in enumerate(test_data):
    prediction = predictions.logits[i]
    predicted_label = ''
    predicted_evidence = []

    if int(prediction.argmax()) == 2:  # 如果預測為 "not enough info"
        predicted_label = 'not enough info'
    else:
        for j in range(1, 25):  # 搜尋 wiki-001.jsonl 到 wiki-024.jsonl
            wiki_file = f'wiki-{str(j).zfill(3)}.jsonl'
            with open(wiki_file, 'r', encoding='utf-8') as wiki_f:
                for line_num, line in enumerate(wiki_f):
                    wiki_entry = json.loads(line)
                    text = wiki_entry['text']
                    if is_match(entry['claim'], text):
                        predicted_label = 'supports' if int(prediction.argmax()) == 0 else 'refutes'
                        predicted_evidence.append([wiki_entry['id'], line_num])
                        break
            if len(predicted_evidence) >= 5:
                break

    output_data.append({
        'id': entry['id'],
        'predicted_label': predicted_label,
        'predicted_evidence': predicted_evidence
    })

output_file = 'predictions.jsonl'
with open(output_file, 'w', encoding='utf-8') as f:
    for entry in output_data:
        f.write(json.dumps(entry, ensure_ascii=False) + '\n')

print('預測結果輸出完成')