# KGAT Training on Food.com Dataset

This notebook trains the KGAT model for recipe recommendation.

In [None]:
!pip install torch dgl pandas scikit-learn tqdm

In [None]:
import os
import sys
import torch
import dgl
import pandas as pd
import numpy as np
from tqdm import tqdm

# Mount Drive if needed
# from google.colab import drive
# drive.mount('/content/drive')

# Add src to path
# sys.path.append('/content/drive/MyDrive/path/to/src')

# Since we might upload files directly:
sys.path.append('src')

## 1. Load Data
Ensure `data/processed/` contains `.pkl` files.

In [None]:
# Load Preprocessed Data
data_dir = 'data/processed'
interactions = pd.read_pickle(f'{data_dir}/interactions.pkl')
kg_triples = pd.read_pickle(f'{data_dir}/kg_triples.pkl')
stats = pd.read_pickle(f'{data_dir}/stats.pkl')

n_users = stats['n_users']
n_items = stats['n_items']
n_entities = stats['n_entities']
n_relations = stats['n_relations']

print(f"Users: {n_users}, Items: {n_items}, Entities: {n_entities}")

## 2. Build Graph
Construct DGL Graph.

In [None]:
from src.model.kgat import KGAT

# Construct Graph from triples
src = torch.tensor(kg_triples[:, 0])
dst = torch.tensor(kg_triples[:, 2])
# Note: Need to shift entity IDs if they overlap with item IDs, 
# OR treat graph as heterogeneous.
# Here we assume a homogeneous graph where nodes are re-indexed to 0..N
# But our stats say: n_items, n_entities. 
# Usually KGAT treats items as entities in the KG.
# For implementation simplicity, let's assume we map all IDs to a global space.

# ... (Graph construction logic) ...
# Placeholder graph
g = dgl.graph((src, dst), num_nodes=n_users + n_entities)
g = dgl.add_self_loop(g)

## 3. Train Model

In [None]:
model = KGAT(n_users, n_entities, n_relations)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def bpr_loss(pos_scores, neg_scores):
    return -torch.mean(torch.log(torch.sigmoid(pos_scores - neg_scores)))

epochs = 10
for epoch in range(epochs):
    model.train()
    # Sample batch
    # user_ids, pos_items, neg_items = sample_batch()
    # ...
    # pos_scores = model(g, user_ids, pos_items)
    # neg_scores = model(g, user_ids, neg_items)
    # loss = bpr_loss(pos_scores, neg_scores)
    
    # optimizer.zero_grad()
    # loss.backward()
    # optimizer.step()
    
    print(f"Epoch {epoch}: Training...")