In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import random
import numpy as np
import os
from tqdm import tqdm
import zipfile
import faiss  # 🔥 引入 FAISS

# 设置随机种子
torch.backends.cudnn.deterministic = True
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

scheme_type = "ensemble_faiss"

# 数据路径（请根据你的实际路径修改）
BASE_DIR = "/Users/minkexiu/Downloads/GitHub/Tianchi_EcommerceKG_mac"
TRAIN_FILE_PATH = f"{BASE_DIR}/originalData/OpenBG500/OpenBG500_train.tsv"
TEST_FILE_PATH = f"{BASE_DIR}/originalData/OpenBG500/OpenBG500_test.tsv"
DEV_FILE_PATH = f"{BASE_DIR}/originalData/OpenBG500/OpenBG500_dev.tsv"
OUTPUT_FILE_PATH = f"{BASE_DIR}/preprocessedData/OpenBG500_test__{scheme_type}.tsv"

# 模型保存路径
MODEL_DIR = f"{BASE_DIR}/trained_model"
os.makedirs(MODEL_DIR, exist_ok=True)

# 模型路径
TRAINED_MODEL_PATHS = {
    'TransE': f"{MODEL_DIR}/trained_model__transE.pth",
    'TransH': f"{MODEL_DIR}/trained_model__transH.pth",
    'TransD': f"{MODEL_DIR}/trained_model__transD.pth",
    'ConvE': f"{MODEL_DIR}/trained_model__conve.pth",
    'RotatE': f"{MODEL_DIR}/trained_model__rotate.pth"
}

# 超参数
EMBEDDING_DIM = 100
LEARNING_RATE = 0.001
WEIGHT_DECAY = 1e-5
EPOCHS = 1
BATCH_SIZE = 256
NEGATIVE_SAMPLES = 10
MAX_LINES = None
MAX_HEAD_ENTITIES = None
LR_DECAY_STEP = 5
LR_DECAY_FACTOR = 0.1


# ==================== 数据集 ====================
class KnowledgeGraphDataset(torch.utils.data.Dataset):
    def __init__(self, file_path, is_test=False, max_lines=None, is_train=False):
        self.triples = []
        self.is_train = is_train
        self._load_data(file_path, is_test, max_lines)

    def _load_data(self, file_path, is_test, max_lines):
        print(f"加载数据: {file_path}")
        with open(file_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
            if max_lines:
                lines = lines[:max_lines]
            for line in lines:
                parts = line.strip().split()
                if len(parts) == 3:
                    h, r, t = parts
                    self.triples.append((h, r, t))
                elif is_test and len(parts) == 2:
                    h, r = parts
                    self.triples.append((h, r, "<UNK>"))
        print(f"共加载 {len(self.triples)} 个三元组")

    def __len__(self):
        return len(self.triples)

    def __getitem__(self, idx):
        return self.triples[idx]


def collate_fn(batch):
    h_list, r_list, t_list = zip(*batch)
    return list(h_list), list(r_list), list(t_list)


# ==================== 映射器 ====================
class EntityRelationMapper:
    def __init__(self):
        self.entity_to_id = {}
        self.id_to_entity = {}
        self.relation_to_id = {}
        self.id_to_relation = {}
        self.entity_count = 0
        self.relation_count = 0
        self.all_train_triples = []

    def build_mappings(self, *datasets):
        entities = set()
        relations = set()
        for dataset in datasets:
            for h, r, t in dataset.triples:
                entities.add(h)
                entities.add(t)
                relations.add(r)
                if dataset.is_train:
                    self.all_train_triples.append((h, r, t))

        for e in sorted(entities):
            self.entity_to_id[e] = self.entity_count
            self.id_to_entity[self.entity_count] = e
            self.entity_count += 1
        for r in sorted(relations):
            self.relation_to_id[r] = self.relation_count
            self.id_to_relation[self.relation_count] = r
            self.relation_count += 1


# ==================== TransE ====================
class TransE(nn.Module):
    def __init__(self, num_entities, num_relations, dim):
        super().__init__()
        self.E = nn.Embedding(num_entities, dim)
        self.R = nn.Embedding(num_relations, dim)
        nn.init.xavier_uniform_(self.E.weight)
        nn.init.xavier_uniform_(self.R.weight)

    def forward(self, h, r, t):
        return torch.norm(self.E(h) + self.R(r) - self.E(t), p=1, dim=1)

    def get_query_embedding(self, h, r):
        return self.E(h) + self.R(r)


# ==================== TransH ====================
class TransH(nn.Module):
    def __init__(self, num_entities, num_relations, dim):
        super().__init__()
        self.E = nn.Embedding(num_entities, dim)
        self.R = nn.Embedding(num_relations, dim)
        self.W = nn.Embedding(num_relations, dim)
        nn.init.xavier_uniform_(self.E.weight)
        nn.init.xavier_uniform_(self.R.weight)
        nn.init.xavier_uniform_(self.W.weight)

    def project(self, emb, norm):
        norm = torch.nn.functional.normalize(norm, p=2, dim=1)
        return emb - torch.sum(emb * norm, dim=1, keepdim=True) * norm

    def forward(self, h, r, t):
        h_emb = self.project(self.E(h), self.W(r))
        t_emb = self.project(self.E(t), self.W(r))
        return torch.norm(h_emb + self.R(r) - t_emb, p=1, dim=1)

    def get_query_embedding(self, h, r):
        h_emb = self.project(self.E(h), self.W(r))
        r_vec = self.R(r)
        return h_emb + r_vec


# ==================== TransD ====================
class TransD(nn.Module):
    def __init__(self, num_entities, num_relations, dim):
        super().__init__()
        self.dim = dim
        self.E = nn.Embedding(num_entities, dim)
        self.R = nn.Embedding(num_relations, dim)
        self.E_proj = nn.Embedding(num_entities, dim)
        self.R_proj = nn.Embedding(num_relations, dim)
        nn.init.xavier_uniform_(self.E.weight)
        nn.init.xavier_uniform_(self.R.weight)
        nn.init.xavier_uniform_(self.E_proj.weight)
        nn.init.xavier_uniform_(self.R_proj.weight)

    def project(self, e, r_proj):
        return e + torch.sum(e * r_proj, dim=1, keepdim=True)

    def forward(self, h, r, t):
        h_emb = self.project(self.E(h), self.R_proj(r))
        t_emb = self.project(self.E(t), self.R_proj(r))
        r_vec = self.R(r)
        return torch.norm(h_emb + r_vec - t_emb, p=1, dim=1)

    def get_query_embedding(self, h, r):
        h_emb = self.project(self.E(h), self.R_proj(r))
        r_vec = self.R(r)
        return h_emb + r_vec


# ==================== ConvE ====================
class ConvE(nn.Module):
    def __init__(self, num_entities, num_relations, dim, embedding_dim=100, feature_map_dropout=0.2):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.hid_drop = feature_map_dropout

        # 嵌入层
        self.E = nn.Embedding(num_entities, embedding_dim)
        self.R = nn.Embedding(num_relations, embedding_dim)

        # 卷积层
        self.in_channels = 1
        self.out_channels = 32
        self.kernel_size = 3
        self.padding = 0

        self.conv1 = nn.Conv2d(in_channels=self.in_channels,
                               out_channels=self.out_channels,
                               kernel_size=(self.kernel_size, self.kernel_size),
                               padding=self.padding)

        self.fc_dim = (embedding_dim - self.kernel_size + 2 * self.padding + 1) ** 2 * self.out_channels
        self.fc = nn.Linear(self.fc_dim, embedding_dim)

        self.h_bn = nn.BatchNorm2d(1)
        self.r_bn = nn.BatchNorm1d(embedding_dim)
        self.dropout = nn.Dropout(self.hid_drop)
        self.hidden_drop = nn.Dropout(self.hid_drop)

        nn.init.xavier_normal_(self.E.weight)
        nn.init.xavier_normal_(self.R.weight)

    def forward(self, h, r, t):
        h_emb = self.E(h).view(-1, 1, self.embedding_dim, 1)
        r_emb = self.R(r).view(-1, 1, 1, self.embedding_dim)

        h_emb = self.h_bn(h_emb)
        r_emb = self.r_bn(r_emb)

        stacked = h_emb * r_emb  # [B, 1, dim, dim]

        x = self.conv1(stacked)
        x = torch.relu(x)
        x = self.hidden_drop(x)
        x = x.view(-1, self.fc_dim)
        x = self.fc(x)
        x = self.dropout(x)
        x = torch.nn.functional.normalize(x, p=2, dim=1)

        t_emb = self.E(t)
        return torch.sum(-torch.norm(x - t_emb, p=1, dim=1), dim=0)

    def get_query_embedding(self, h, r):
        h_emb = self.E(h).view(-1, 1, self.embedding_dim, 1)
        r_emb = self.R(r).view(-1, 1, 1, self.embedding_dim)

        h_emb = self.h_bn(h_emb)
        r_emb = self.r_bn(r_emb)

        stacked = h_emb * r_emb
        x = self.conv1(stacked)
        x = torch.relu(x)
        x = self.hidden_drop(x)
        x = x.view(-1, self.fc_dim)
        x = self.fc(x)
        x = self.dropout(x)
        return x  # [B, dim]


# ==================== RotatE ====================
class RotatE(nn.Module):
    def __init__(self, num_entities, num_relations, dim, gamma=12.0):
        super().__init__()
        self.E = nn.Embedding(num_entities, dim)
        self.R = nn.Embedding(num_relations, dim)
        self.gamma = gamma
        self.embedding_dim = dim

        nn.init.uniform_(self.E.weight, -6 / dim ** 0.5, 6 / dim ** 0.5)
        nn.init.uniform_(self.R.weight, -6 / dim ** 0.5, 6 / dim ** 0.5)
        self.E.weight.data = self.E.weight.data / (self.E.weight.data.norm(p=2, dim=1, keepdim=True))
        self.R.weight.data = self.R.weight.data / (self.R.weight.data.norm(p=2, dim=1, keepdim=True))

    def forward(self, h, r, t):
        re_head, im_head = torch.chunk(self.E(h), 2, dim=1)
        re_tail, im_tail = torch.chunk(self.E(t), 2, dim=1)

        # Relation as rotation in complex plane
        r_emb = self.R(r)
        r_phase = r_emb / (self.embedding_dim / 2)
        r_phase = r_phase * (3.14159265358979323846)
        re_relation = torch.cos(r_phase)
        im_relation = torch.sin(r_phase)

        # (re_head + i*im_head) * (re_relation + i*im_relation) = ?
        re_score = re_head * re_relation - im_head * im_relation
        im_score = re_head * im_relation + im_head * re_relation
        re_score = re_score - re_tail
        im_score = im_score - im_tail

        score = torch.stack([re_score, im_score], dim=0)
        score = score.norm(dim=0)
        return torch.sum(score, dim=1)

    def get_query_embedding(self, h, r):
        re_head, im_head = torch.chunk(self.E(h), 2, dim=1)
        r_emb = self.R(r)
        r_phase = r_emb / (self.embedding_dim / 2)
        r_phase = r_phase * 3.14159265358979323846
        re_relation = torch.cos(r_phase)
        im_relation = torch.sin(r_phase)

        re_score = re_head * re_relation - im_head * im_relation
        im_score = re_head * im_relation + im_head * re_relation
        return torch.cat([re_score, im_score], dim=1)  # [B, dim]


# ==================== 训练 & 加载 & 评估（同前）====================
def train_model(model, model_name, train_dataset, mapper, device):
    if os.path.exists(TRAINED_MODEL_PATHS[model_name]):
        print(f"[{model_name}] 已存在训练好的模型，跳过训练")
        return
    print(f"[{model_name}] 开始训练...")
    loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=LR_DECAY_STEP, gamma=LR_DECAY_FACTOR)
    model.to(device)
    model.train()

    for epoch in range(EPOCHS):
        epoch_loss = 0
        progress = tqdm(loader, desc=f"{model_name} Epoch {epoch+1}")
        for h_list, r_list, t_list in progress:
            h = torch.tensor([mapper.entity_to_id[h] for h in h_list], device=device)
            r = torch.tensor([mapper.relation_to_id[r] for r in r_list], device=device)
            t = torch.tensor([mapper.entity_to_id[t] for t in t_list], device=device)
            neg_t = torch.randint(0, mapper.entity_count, (len(h), NEGATIVE_SAMPLES), device=device)
            pos_score = model(h, r, t)
            neg_score = model(h.unsqueeze(1).expand(-1, NEGATIVE_SAMPLES).contiguous().view(-1),
                              r.unsqueeze(1).expand(-1, NEGATIVE_SAMPLES).contiguous().view(-1),
                              neg_t.view(-1)).view(-1, NEGATIVE_SAMPLES)
            loss = torch.mean(torch.relu(pos_score.unsqueeze(1) - neg_score + 1.0))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            progress.set_postfix(loss=loss.item())
        scheduler.step()
        print(f"[{model_name}] Epoch {epoch+1} Loss: {epoch_loss / len(loader):.4f}")

    torch.save({
        'model_state_dict': model.state_dict(),
        'entity_count': mapper.entity_count,
        'relation_count': mapper.relation_count,
        'embedding_dim': EMBEDDING_DIM,
        'entity_to_id': mapper.entity_to_id,
        'relation_to_id': mapper.relation_to_id,
    }, TRAINED_MODEL_PATHS[model_name])
    print(f"[{model_name}] 模型已保存")


def load_model(model_class, model_path, mapper, device):
    checkpoint = torch.load(model_path, map_location=device)
    model = model_class(checkpoint['entity_count'], checkpoint['relation_count'], EMBEDDING_DIM)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    return model


def evaluate_model(model, dataset, mapper, device, k_list=(1, 3, 10)):
    print("🔍 开始在开发集上评估模型性能...")
    model.eval()
    hits_at = {k: 0.0 for k in k_list}
    mrr = 0.0
    count = 0

    # 构建 FAISS 索引（用于加速搜索）
    entity_emb = model.E.weight.data.cpu().numpy()
    index = faiss.IndexFlatL2(entity_emb.shape[1])  # L2 距离
    index.add(entity_emb)

    with torch.no_grad():
        for h, r, t in tqdm(dataset.triples, desc="Evaluating"):
            try:
                h_id = torch.tensor([mapper.entity_to_id[h]], device=device)
                r_id = torch.tensor([mapper.relation_to_id[r]], device=device)
                t_id = torch.tensor([mapper.entity_to_id[t]], device=device)
            except KeyError:
                continue

            # ========== 尾实体预测 ==========
            query = model.get_query_embedding(h_id, r_id).detach().cpu().numpy()

            _, indices = index.search(query, 1000)  # 搜索前 1000 名
            pred_ids = indices[0]
            rank = np.where(pred_ids == mapper.entity_to_id[t])[0]
            rank = rank[0] + 1 if len(rank) > 0 else 10000

            for k in k_list:
                if rank <= k:
                    hits_at[k] += 1
            mrr += 1.0 / rank
            count += 1

            # ========== 头实体预测 ==========
            if hasattr(model, 'get_query_embedding_for_head'):
                query_h = model.get_query_embedding_for_head(r_id, t_id).detach().cpu().numpy()
            else:
                query_h = (model.E(t_id) - model.R(r_id)).detach().cpu().numpy()

            _, indices_h = index.search(query_h, 1000)
            pred_ids_h = indices_h[0]
            rank_h = np.where(pred_ids_h == mapper.entity_to_id[h])[0]
            rank_h = rank_h[0] + 1 if len(rank_h) > 0 else 10000

            for k in k_list:
                if rank_h <= k:
                    hits_at[k] += 1
            mrr += 1.0 / rank_h
            count += 1

    for k in hits_at:
        hits_at[k] /= count
    mrr /= count

    print("✅ 评估完成！")
    print(f"📊 HITS@1:  {hits_at[1]:.4f}")
    print(f"📊 HITS@3:  {hits_at[3]:.4f}")
    print(f"📊 HITS@10: {hits_at[10]:.4f}")
    print(f"📊 MRR:     {mrr:.4f}")

    return hits_at, mrr


# ==================== 融合预测 + FAISS 加速 ====================
def predict_ensemble(models_with_weights, test_dataset, mapper, device, max_head_entities=None):
    print("🔍 开始融合预测 (加权融合 + FAISS 加速) ...")
    results = []

    # 所有实体嵌入（统一维度）
    all_entity_emb = models_with_weights['TransE'][0].E.weight.data.cpu().numpy()  # 假设维度一致
    index = faiss.IndexFlatL2(all_entity_emb.shape[1])
    index.add(all_entity_emb)

    triples = test_dataset.triples
    if max_head_entities:
        triples = triples[:max_head_entities]

    with torch.no_grad():
        for h, r, _ in tqdm(triples, desc="Ensemble Predict"):
            try:
                h_id = torch.tensor([mapper.entity_to_id[h]], device=device)
                r_id = torch.tensor([mapper.relation_to_id[r]], device=device)
            except KeyError:
                preds = [h] * 10
                results.append('\t'.join([h, r] + preds))
                continue

            # 融合查询向量
            fused_query = np.zeros((1, EMBEDDING_DIM))
            for name, (model, weight) in models_with_weights.items():
                q = model.get_query_embedding(h_id, r_id).detach().cpu().numpy()
                fused_query += weight * q

            # FAISS 搜索 Top-10
            _, indices = index.search(fused_query, 10)
            pred_ids = indices[0]
            preds = [mapper.id_to_entity[i] for i in pred_ids]
            results.append('\t'.join([h, r] + preds))

    os.makedirs(os.path.dirname(OUTPUT_FILE_PATH), exist_ok=True)
    with open(OUTPUT_FILE_PATH, 'w', encoding='utf-8') as f:
        f.write('\n'.join(results) + '\n')
    print(f"✅ 融合结果已保存至: {OUTPUT_FILE_PATH}")

    zip_path = OUTPUT_FILE_PATH.replace(".tsv", ".zip")
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
        zf.write(OUTPUT_FILE_PATH, arcname=os.path.basename(OUTPUT_FILE_PATH))
    print(f"✅ 已压缩为: {zip_path}")


# ==================== 主函数 ====================
def main():
    device = torch.device('mps') if torch.backends.mps.is_available() else \
             torch.device('cpu')
    print(f"🚀 使用设备: {device}")

    train_data = KnowledgeGraphDataset(TRAIN_FILE_PATH, max_lines=MAX_LINES, is_train=True)
    dev_data = KnowledgeGraphDataset(DEV_FILE_PATH, is_test=False, is_train=False)
    test_data = KnowledgeGraphDataset(TEST_FILE_PATH, is_test=True, is_train=False)

    mapper = EntityRelationMapper()
    mapper.build_mappings(train_data, dev_data, test_data)
    print(f"实体数: {mapper.entity_count}, 关系数: {mapper.relation_count}")

    model_classes = {
        'TransE': TransE,
        'TransH': TransH,
        'TransD': TransD,
        'ConvE': ConvE,
        'RotatE': RotatE,
    }

    # 训练所有模型
    for name, Cls in model_classes.items():
        model = Cls(mapper.entity_count, mapper.relation_count, EMBEDDING_DIM)
        train_model(model, name, train_data, mapper, device)

    # 加载 & 评估 & 融合
    loaded_models_with_weight = {}
    for name, Cls in model_classes.items():
        model = load_model(Cls, TRAINED_MODEL_PATHS[name], mapper, device)
        loaded_models_with_weight[name] = (model, 1.0)
        print(f"\n📈 正在评估模型: {name}")
        evaluate_model(model, dev_data, mapper, device)

    # 融合预测
    predict_ensemble(loaded_models_with_weight, test_data, mapper, device, MAX_HEAD_ENTITIES)
    print("🎉 所有任务完成！融合预测已生成，开发集评估已完成。")


if __name__ == "__main__":
    main()