# KGAT Inference & Explanation (Pure PyTorch)

這個 Notebook 展示如何載入訓練好的 KGAT 模型，並使用基於梯度的解釋器 (Gradient-based Explainer) 來解釋推薦結果。

In [None]:
import os
import sys
import torch
import networkx as nx
import matplotlib.pyplot as plt

# 確保能抓到專案根目錄
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src.model.kgat import KGAT
from src.model.explainer import KGATExplainer
from src.train import construct_adj, load_data

## 1. 載入資料與模型

In [None]:
# 設定 Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# 載入預處理資料
data_dir = '../data/processed'
interactions, kg_triples, stats = load_data(data_dir)

n_users = stats['n_users']
n_items = stats['n_items']
n_entities = stats['n_entities']
n_relations = stats['n_relations']

# 建立 Adjacency Matrix (注意參數順序)
adj = construct_adj(kg_triples, n_users, n_items, n_entities).to(device)

# 初始化並載入模型
n_all_entities = n_items + n_entities
model = KGAT(n_users, n_all_entities, n_relations).to(device)

model_path = '../models/kgat_epoch_20.pth'
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path, map_location=device))
    print("Model loaded successfully!")
else:
    print("Warning: Pre-trained model not found. Using random weights.")

## 2. 執行解釋
選擇一組 (User, Item) 進行分析。

In [None]:
# 使用 User 0 與他的一個正樣本
target_user = 0
user_interactions = interactions[interactions[:, 0] == target_user]

if len(user_interactions) > 0:
    target_item = user_interactions[0][1]
    print(f"Target User: {target_user}, Target Item: {target_item}")
    
    explainer = KGATExplainer(model)
    
    # 計算解釋
    # top_k: 顯示幾條最重要的路徑
    # n_hops: 搜尋幾跳 (KGAT 預設 2 層，建議設 2)
    explanation = explainer.explain(adj, target_user, target_item, top_k=5, n_hops=2)
    
    print(f"Prediction Score (Gradient Target): {explanation['target_score']:.4f}")
    print("Top Paths (Node sequence):")
    for path, score in explanation['top_paths']:
        print(f"Score {score:.4f}: {path}")
        
    # 視覺化
    explainer.visualize(explanation)
    
else:
    print("User 0 has no interactions to explain.")