In [None]:
import pandas as pd
import networkx as nx
import dgl
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
import os
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict

# 1. Đọc dữ liệu và tạo đồ thị
df = pd.read_csv('triplets_from_captions_val2017.csv').dropna().astype(str)
triplets = list(zip(df['subject'], df['predicate'], df['object'], df['image_id']))

G = nx.DiGraph()
for s, p, o, img_id in triplets:
    mid = f"{s}_{p}_{o}"
    G.add_edge(s, mid, relation=p)
    G.add_edge(mid, o, relation=p)
    G.nodes[mid]['image_id'] = img_id

In [11]:
# 2. Mã hóa node và quan hệ
all_nodes = list(G.nodes)
le = LabelEncoder()
node_ids = le.fit_transform(all_nodes)
node_id_map = dict(zip(all_nodes, node_ids))

edges_src = [node_id_map[u] for u, v in G.edges()]
edges_dst = [node_id_map[v] for u, v in G.edges()]
edge_types = [G[u][v]['relation'] for u, v in G.edges()]
rel_encoder = LabelEncoder()
edge_type_ids = rel_encoder.fit_transform(edge_types)
edge_type_tensor = torch.tensor(edge_type_ids, dtype=torch.int64)

In [12]:
# 3. Tạo đồ thị DGL
g = dgl.graph((edges_src, edges_dst), num_nodes=len(all_nodes))
g = dgl.add_self_loop(g)
num_self_loops = g.number_of_edges() - len(edge_type_tensor)
self_loop_type = torch.full((num_self_loops,), fill_value=len(set(edge_type_ids)), dtype=torch.int64)
edge_type_tensor = torch.cat([edge_type_tensor, self_loop_type], dim=0)

features = torch.eye(len(all_nodes))
num_rels = len(set(edge_type_tensor.tolist()))

In [13]:
# 4. Định nghĩa mô hình R-GCN
class RGCN(nn.Module):
    def __init__(self, in_feats, h_feats, out_feats, num_rels):
        super(RGCN, self).__init__()
        self.conv1 = dgl.nn.RelGraphConv(in_feats, h_feats, num_rels)
        self.conv2 = dgl.nn.RelGraphConv(h_feats, out_feats, num_rels)

    def forward(self, g, x, etype):
        h = self.conv1(g, x, etype)
        h = torch.relu(h)
        h = self.conv2(g, h, etype)
        return h

model = RGCN(features.shape[1], 64, 128, num_rels)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [14]:
# 5. Khởi tạo và chạy mô hình R-GCN (không huấn luyện)
model = RGCN(features.shape[1], 64, 128, num_rels)
model.eval()
with torch.no_grad():
    node_embeddings = model(g, features, edge_type_tensor)

In [15]:
# 6. Trung bình embedding ảnh
embeddings = node_embeddings.numpy()
image_embeddings = defaultdict(list)

for node, idx in node_id_map.items():
    if 'image_id' in G.nodes[node]:
        img_id = G.nodes[node]['image_id']
        image_embeddings[img_id].append(embeddings[idx])

for img in image_embeddings:
    vecs = np.stack(image_embeddings[img])
    image_embeddings[img] = np.mean(vecs, axis=0)

In [16]:
# 7. Lưu lại mô hình và embedding
os.makedirs("saved_model_R_GCN", exist_ok=True)
with open("saved_model_R_GCN/image_embeddings.pkl", "wb") as f:
    pickle.dump(image_embeddings, f)
with open("saved_model_R_GCN/node_label_encoder.pkl", "wb") as f:
    pickle.dump(le, f)
torch.save(model.state_dict(), "saved_model_R_GCN/rgcn_model_weights.pt")

print("Đã sinh xong node embedding từ R-GCN và lưu model.")

Đã sinh xong node embedding từ R-GCN và lưu model.


In [17]:
# 8. Đánh giá Precision, Recall, F1-score
entity_to_images = defaultdict(set)
for _, row in df.iterrows():
    entity_to_images[row['subject']].add(row['image_id'])
    entity_to_images[row['predicate']].add(row['image_id'])
    entity_to_images[row['object']].add(row['image_id'])

ground_truth = defaultdict(set)
for _, row in df.iterrows():
    img_id = row['image_id']
    related = entity_to_images[row['subject']] | entity_to_images[row['predicate']] | entity_to_images[row['object']]
    related.discard(img_id)
    ground_truth[img_id].update(related)

image_ids = list(image_embeddings.keys())
embedding_matrix = np.stack([image_embeddings[i] for i in image_ids])
sim_matrix = cosine_similarity(embedding_matrix)

predicted = {}
for i, qid in enumerate(image_ids):
    sims = sim_matrix[i]
    sorted_idx = np.argsort(-sims)
    top_ids = [image_ids[j] for j in sorted_idx if image_ids[j] != qid][:5]
    predicted[qid] = top_ids

def evaluate_retrieval(gt, pred, k=5):
    precision, recall, f1 = [], [], []
    for q in gt:
        g = gt[q]
        p = set(pred.get(q, [])[:k])
        tp = len(g & p)
        prec = tp / k
        rec = tp / len(g) if g else 0
        f1_score = 2 * prec * rec / (prec + rec) if prec + rec > 0 else 0
        precision.append(prec)
        recall.append(rec)
        f1.append(f1_score)
    return np.mean(precision), np.mean(recall), np.mean(f1)

p, r, f1 = evaluate_retrieval(ground_truth, predicted, k=5)
print(f"\nPrecision@5: {p:.4f}\nRecall@5:    {r:.4f}\nF1-score@5:  {f1:.4f}")


Precision@5: 0.3339
Recall@5:    0.0061
F1-score@5:  0.0114


In [10]:
import os
import pickle
import numpy as np
from collections import defaultdict

# Tạo thư mục nếu chưa có
output_dir = "saved_model_R_GCN"
os.makedirs(output_dir, exist_ok=True)

# 1. Lưu LabelEncoder cho entity (node encoder)
with open(os.path.join(output_dir, "entity_encoder.pkl"), "wb") as f:
    pickle.dump(le, f)  # le là LabelEncoder dùng cho all_nodes

# 2. Lưu node embeddings
np.save(os.path.join(output_dir, "node_embeddings.npy"), embeddings)

# 3. Tạo entity_idx → danh sách ảnh (từ triplets)
entity_idx_to_images = defaultdict(set)
for s, p, o, img_id in triplets:
    try:
        s_id = le.transform([s])[0]
        o_id = le.transform([o])[0]
        entity_idx_to_images[s_id].add(img_id)
        entity_idx_to_images[o_id].add(img_id)
    except:
        continue  # bỏ qua nếu không ánh xạ được

# 4. Lưu entity_idx_to_images
with open(os.path.join(output_dir, "entity_idx_to_images.pkl"), "wb") as f:
    pickle.dump(entity_idx_to_images, f)

print("Đã lưu toàn bộ: entity_encoder.pkl, node_embeddings.npy, entity_idx_to_images.pkl vào thư mục 'saved_model_R_GCN/'.")

Đã lưu toàn bộ: entity_encoder.pkl, node_embeddings.npy, entity_idx_to_images.pkl vào thư mục 'saved_models/'.


In [35]:
results_prf = estimate_prf(top_k=5, random_k=5, sample_size=100)
print("\n Ước lượng Precision / Recall / F1-score R-GCN:")
print(results_prf)


 Ước lượng Precision / Recall / F1-score R-GCN:
{'Precision': 0.4444, 'Recall': 1.0, 'F1-score': 0.6154, 'Samples Used': 100}
