In [28]:
#thu vien
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
import networkx as nx
import numpy as np
from sklearn.metrics import top_k_accuracy_score
from tqdm import tqdm
import random
import requests
import networkx as nx
import torch
import torch.nn as nn


In [29]:
import json
import csv

# Đường dẫn file (đã tải từ bước trước)
TRIPLE_PATH = "/wikidata_triples.json"
DATA_PATH = "/wikidata_data.json"
LABEL_PATH = "/label2id.json"

# Load dữ liệu
with open(TRIPLE_PATH, "r", encoding="utf-8") as f:
    triples = json.load(f)

with open(DATA_PATH, "r", encoding="utf-8") as f:
    data = json.load(f)

with open(LABEL_PATH, "r", encoding="utf-8") as f:
    label2id = json.load(f)

id2label = {v: k for k, v in label2id.items()}


In [30]:
# Tạo KG từ triples
G = nx.Graph()
for s, r, o in triples:
    G.add_node(s)
    G.add_node(o)
    G.add_edge(s, o, relation=r)

# Ánh xạ node → chỉ số embedding
vocab = list(G.nodes)
node2id = {node: idx for idx, node in enumerate(vocab)}

# Khởi tạo embedding vector cho node
embedding_dim = 128
node_embeddings = nn.Embedding(len(vocab), embedding_dim)


In [31]:
#danh sách label phân loại
label_list = sorted(set(d["object"] for d in data))
label2id = {label: idx for idx, label in enumerate(label_list)}
id2label = {idx: label for label, idx in label2id.items()}

In [32]:
#KG
G = nx.Graph()
for s, r, o in triples:
    G.add_node(s)
    G.add_node(o)
    G.add_edge(s, o, relation=r)

vocab = list(G.nodes)
node2id = {node: idx for idx, node in enumerate(vocab)}
embedding_dim = 128
node_embeddings = nn.Embedding(len(vocab), embedding_dim)

In [33]:
#CONCEPT VECTOR GENERATOR (multi-head)
class ConceptVectorGenerator(nn.Module):
    def __init__(self, embed_dim, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )

    def forward(self, center_vec, neighbor_vecs):
        Q = self.q_proj(center_vec)
        K = self.k_proj(neighbor_vecs)
        V = self.v_proj(neighbor_vecs)
        attn_scores = torch.softmax(torch.matmul(Q, K.T), dim=-1)
        context = torch.matmul(attn_scores, V)
        return self.mlp(context)

In [34]:
#CHỨC NĂNG SINH CONCEPT VECTORS
def get_concept_vectors(subject, generator, n=5):
    if subject not in node2id:
        return torch.zeros((n, embedding_dim))

    center_id = torch.tensor([node2id[subject]])
    center_vec = node_embeddings(center_id).squeeze(0)
    neighbors = list(G.neighbors(subject))

    if not neighbors:
        return torch.zeros((n, embedding_dim))

    neighbor_ids = torch.tensor([node2id[n] for n in neighbors])
    neighbor_vecs = node_embeddings(neighbor_ids)

    return torch.stack([generator(center_vec, neighbor_vecs) for _ in range(n)], dim=0)

In [35]:
#frozen LLM + concept insert
class ConceptFormer(nn.Module):
    def __init__(self, base_model="vinai/phobert-base", n_concepts=5, num_classes=10):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(base_model)
        self.lm = AutoModel.from_pretrained(base_model)
        for param in self.lm.parameters():
            param.requires_grad = False
        self.generator = ConceptVectorGenerator(embedding_dim)
        self.n_concepts = n_concepts
        self.classifier = nn.Sequential(
            nn.Linear(768 + embedding_dim, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

    def forward(self, sentence, subject):
        inputs = self.tokenizer(sentence, return_tensors="pt")
        outputs = self.lm(**inputs)
        text_embedding = outputs.last_hidden_state[:, 0, :]  # CLS

        concept_vecs = get_concept_vectors(subject, self.generator, n=self.n_concepts)
        concept_pooled = torch.mean(concept_vecs, dim=0).unsqueeze(0)
        combined = torch.cat([text_embedding, concept_pooled], dim=-1)
        return self.classifier(combined)


In [37]:
#huan luyen & danh gia
model = ConceptFormer(n_concepts=5, num_classes=len(label_list))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

EPOCHS = 30
for epoch in range(EPOCHS):
    total_loss = 0
    model.train()
    for sample in data:
        sentence = sample["sentence"]
        subject = sample["subject"]
        label = torch.tensor([label2id[sample["object"]]])

        logits = model(sentence, subject)
        loss = criterion(logits, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss/len(data):.4f}")

Epoch 1/30, Loss: 3.2449
Epoch 2/30, Loss: 3.2222
Epoch 3/30, Loss: 3.2009
Epoch 4/30, Loss: 3.1778
Epoch 5/30, Loss: 3.1607
Epoch 6/30, Loss: 3.1359
Epoch 7/30, Loss: 3.1081
Epoch 8/30, Loss: 3.0898
Epoch 9/30, Loss: 3.0639
Epoch 10/30, Loss: 3.0478
Epoch 11/30, Loss: 3.0233
Epoch 12/30, Loss: 3.0283
Epoch 13/30, Loss: 3.0089
Epoch 14/30, Loss: 3.0031
Epoch 15/30, Loss: 2.9842
Epoch 16/30, Loss: 2.9725
Epoch 17/30, Loss: 2.9710
Epoch 18/30, Loss: 2.9517
Epoch 19/30, Loss: 2.9495
Epoch 20/30, Loss: 2.9313
Epoch 21/30, Loss: 2.9344
Epoch 22/30, Loss: 2.9075
Epoch 23/30, Loss: 2.9021
Epoch 24/30, Loss: 2.9078
Epoch 25/30, Loss: 2.8924
Epoch 26/30, Loss: 2.8818
Epoch 27/30, Loss: 2.8818
Epoch 28/30, Loss: 2.8723
Epoch 29/30, Loss: 2.8693
Epoch 30/30, Loss: 2.8494


In [38]:
# Đánh giá Hit@1, Hit@3, Hit@5
model.eval()
y_true, y_scores = [], []

with torch.no_grad():
    for sample in data:
        sentence = sample["sentence"]
        subject = sample["subject"]
        label = label2id[sample["object"]]

        logits = model(sentence, subject)
        probs = torch.softmax(logits, dim=-1).squeeze(0).numpy()
        y_true.append(label)
        y_scores.append(probs)

for k in [1, 3, 5]:
    hit_k = top_k_accuracy_score(y_true, y_scores, k=k)
    print(f"Hit@{k}: {hit_k*100:.2f}%")

Hit@1: 12.80%
Hit@3: 33.80%
Hit@5: 50.00%
