# Implementing a Knowledge Graph Attention Network (KGAT) for Book Recommendations
This guide will walk you through implementing a Knowledge Graph Attention Network (KGAT) for building a book recommendation system. KGAT is a state-of-the-art recommendation approach that combines knowledge graphs with graph attention networks to provide more accurate and explainable recommendations.

## Understanding KGAT
Knowledge Graph Attention Network (KGAT) is a hybrid model that:

- Incorporates knowledge graphs to represent rich semantic relationships

- Uses graph attention networks to learn importance weights between nodes

- Captures high-order connectivity in the knowledge graph

- The key components are:

  - Entity embedding: Represents books, users, and other entities

   - Relation embedding: Represents different types of relationships

- Attention mechanism: Learns importance weights between connected nodes

## Dataset Preparation
We'll use the Book-Crossing dataset which contains:

- Book information (title, author, year, publisher)

- User information (location, age)

- Ratings (explicit ratings 1-10)

First, let's preprocess the data similarly to your original code but extend it for KGAT.

# Step 1: Load and Preprocess Datasets

In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.preprocessing import LabelEncoder
import pickle
import matplotlib.pyplot as plt
import networkx as nx

## Load Data

In [2]:
print("Step 1: Loading datasets...")
# Specify dtypes to handle mixed types in 'Year-Of-Publication'
books_dtypes = {
    'ISBN': str,
    'Book-Title': str,
    'Book-Author': str,
    'Year-Of-Publication': str,  # Load as string to handle mixed types
    'Publisher': str,
    'Image-URL-S': str,
    'Image-URL-M': str,
    'Image-URL-L': str
}

Step 1: Loading datasets...


In [3]:
# Step 1: Load and Preprocess Datasets
print("Step 1: Loading datasets...")
# Specify dtypes to handle mixed types in 'Year-Of-Publication'
books = pd.read_csv('data/BX-Books.csv', sep=';', encoding='latin-1', on_bad_lines='skip', dtype=books_dtypes)
users = pd.read_csv('data/BX-Users.csv', sep=';', encoding='latin-1', on_bad_lines='skip')
ratings = pd.read_csv('data/BX-Book-Ratings.csv', sep=';', encoding='latin-1', on_bad_lines='skip')

Step 1: Loading datasets...


In [4]:
books.head()

Unnamed: 0,ISBN,Book-Title,Book-Author,Year-Of-Publication,Publisher,Image-URL-S,Image-URL-M,Image-URL-L
0,195153448,Classical Mythology,Mark P. O. Morford,2002,Oxford University Press,http://images.amazon.com/images/P/0195153448.0...,http://images.amazon.com/images/P/0195153448.0...,http://images.amazon.com/images/P/0195153448.0...
1,2005018,Clara Callan,Richard Bruce Wright,2001,HarperFlamingo Canada,http://images.amazon.com/images/P/0002005018.0...,http://images.amazon.com/images/P/0002005018.0...,http://images.amazon.com/images/P/0002005018.0...
2,60973129,Decision in Normandy,Carlo D'Este,1991,HarperPerennial,http://images.amazon.com/images/P/0060973129.0...,http://images.amazon.com/images/P/0060973129.0...,http://images.amazon.com/images/P/0060973129.0...
3,374157065,Flu: The Story of the Great Influenza Pandemic...,Gina Bari Kolata,1999,Farrar Straus Giroux,http://images.amazon.com/images/P/0374157065.0...,http://images.amazon.com/images/P/0374157065.0...,http://images.amazon.com/images/P/0374157065.0...
4,393045218,The Mummies of Urumchi,E. J. W. Barber,1999,W. W. Norton &amp; Company,http://images.amazon.com/images/P/0393045218.0...,http://images.amazon.com/images/P/0393045218.0...,http://images.amazon.com/images/P/0393045218.0...


In [26]:
books.shape

(271360, 5)

In [5]:
users.head()

Unnamed: 0,User-ID,Location,Age
0,1,"nyc, new york, usa",
1,2,"stockton, california, usa",18.0
2,3,"moscow, yukon territory, russia",
3,4,"porto, v.n.gaia, portugal",17.0
4,5,"farnborough, hants, united kingdom",


In [27]:
users.shape

(278858, 3)

In [6]:
ratings.head()

Unnamed: 0,User-ID,ISBN,Book-Rating
0,276725,034545104X,0
1,276726,0155061224,5
2,276727,0446520802,0
3,276729,052165615X,3
4,276729,0521795028,6


In [28]:
ratings.shape

(526356, 3)

## Data Preprocessing

In [7]:
# Clean 'Year-Of-Publication' by converting to numeric, setting invalid values to NaN
books['Year-Of-Publication'] = pd.to_numeric(books['Year-Of-Publication'], errors='coerce')

In [8]:
# Select relevant columns and rename
books = books[['ISBN', 'Book-Title', 'Book-Author', 'Year-Of-Publication', 'Publisher']]
books.rename(columns={
    'Book-Title': 'title',
    'Book-Author': 'author',
    'Year-Of-Publication': 'year',
    'Publisher': 'publisher'
}, inplace=True)

users.rename(columns={
    'User-ID': 'user_id',
    'Location': 'location',
    'Age': 'age'
}, inplace=True)

ratings.rename(columns={
    'User-ID': 'user_id',
    'Book-Rating': 'rating'
}, inplace=True)

print("Books DataFrame head:")
print(books.head())
print("\nUsers DataFrame head:")
print(users.head())
print("\nRatings DataFrame head:")
print(ratings.head())

Books DataFrame head:
         ISBN                                              title  \
0  0195153448                                Classical Mythology   
1  0002005018                                       Clara Callan   
2  0060973129                               Decision in Normandy   
3  0374157065  Flu: The Story of the Great Influenza Pandemic...   
4  0393045218                             The Mummies of Urumchi   

                 author    year                   publisher  
0    Mark P. O. Morford  2002.0     Oxford University Press  
1  Richard Bruce Wright  2001.0       HarperFlamingo Canada  
2          Carlo D'Este  1991.0             HarperPerennial  
3      Gina Bari Kolata  1999.0        Farrar Straus Giroux  
4       E. J. W. Barber  1999.0  W. W. Norton &amp; Company  

Users DataFrame head:
   user_id                            location   age
0        1                  nyc, new york, usa   NaN
1        2           stockton, california, usa  18.0
2        3     

## Step 2: Filter Active Users and Popular Books

In [9]:
#  Filter Active Users and Popular Books
print("\nStep 2: Filtering active users and popular books...")
# Filter users who rated at least 200 books
user_counts = ratings['user_id'].value_counts()
active_users = user_counts[user_counts > 200].index
ratings = ratings[ratings['user_id'].isin(active_users)]
print(f"Number of active users: {len(active_users)}")


Step 2: Filtering active users and popular books...
Number of active users: 899


## Merge ratings with books

In [10]:
# Merge ratings with books
ratings_with_books = ratings.merge(books, on='ISBN')
print("Ratings with books merged DataFrame head:")
ratings_with_books.head()

Ratings with books merged DataFrame head:


Unnamed: 0,user_id,ISBN,rating,title,author,year,publisher
0,277427,002542730X,10,Politically Correct Bedtime Stories: Modern Ta...,James Finn Garner,1994.0,John Wiley &amp; Sons Inc
1,277427,0026217457,0,Vegetarian Times Complete Cookbook,Lucy Moll,1995.0,John Wiley &amp; Sons
2,277427,003008685X,8,Pioneers,James Fenimore Cooper,1974.0,Thomson Learning
3,277427,0030615321,0,"Ask for May, Settle for June (A Doonesbury book)",G. B. Trudeau,1982.0,Henry Holt &amp; Co
4,277427,0060002050,0,On a Wicked Dawn (Cynster Novels),Stephanie Laurens,2002.0,Avon Books


In [29]:
ratings_with_books.shape

(487671, 7)

# Calculate number of ratings per book

In [11]:
# Calculate number of ratings per book
number_rating = ratings_with_books.groupby('title')['rating'].count().reset_index()
number_rating.rename(columns={'rating': 'num_of_rating'}, inplace=True)
print("Number of ratings per book:")
print(number_rating.head())

Number of ratings per book:
                                               title  num_of_rating
0   A Light in the Storm: The Civil War Diary of ...              2
1                              Always Have Popsicles              1
2               Apple Magic (The Collector's series)              1
3   Beyond IBM: Leadership Marketing and Finance ...              1
4   Clifford Visita El Hospital (Clifford El Gran...              1


## Filter books with at least 50 ratings

In [12]:
# Filter books with at least 50 ratings
final_rating = ratings_with_books.merge(number_rating, on='title')
final_rating = final_rating[final_rating['num_of_rating'] >= 50]
final_rating.drop_duplicates(['user_id', 'title'], inplace=True)
print("Final rating DataFrame after filtering (head):")
final_rating.head()

Final rating DataFrame after filtering (head):


Unnamed: 0,user_id,ISBN,rating,title,author,year,publisher,num_of_rating
0,277427,002542730X,10,Politically Correct Bedtime Stories: Modern Ta...,James Finn Garner,1994.0,John Wiley &amp; Sons Inc,82
13,277427,0060930535,0,The Poisonwood Bible: A Novel,Barbara Kingsolver,1999.0,Perennial,133
15,277427,0060934417,0,Bel Canto: A Novel,Ann Patchett,2002.0,Perennial,108
18,277427,0061009059,9,One for the Money (Stephanie Plum Novels (Pape...,Janet Evanovich,1995.0,HarperTorch,108
24,277427,006440188X,0,The Secret Garden,Frances Hodgson Burnett,1998.0,HarperTrophy,79


In [13]:
print(f"Shape of final rating DataFrame")
final_rating.shape

Shape of final rating DataFrame


(59850, 8)

# Step 3: Encode Entities and Relations

In [14]:
print("\nStep 3: Encoding entities and relations...")
# Encode user_ids, book titles, and authors
user_encoder = LabelEncoder()
book_encoder = LabelEncoder()
author_encoder = LabelEncoder()

final_rating['user_id_encoded'] = user_encoder.fit_transform(final_rating['user_id'])
final_rating['book_id_encoded'] = book_encoder.fit_transform(final_rating['title'])
final_rating['author_encoded'] = author_encoder.fit_transform(final_rating['author'])
print("Encoded user, book, and author IDs:")
print(final_rating[['user_id', 'user_id_encoded', 'title', 'book_id_encoded', 'author', 'author_encoded']].head())


Step 3: Encoding entities and relations...
Encoded user, book, and author IDs:
    user_id  user_id_encoded  \
0    277427              884   
13   277427              884   
15   277427              884   
18   277427              884   
24   277427              884   

                                                title  book_id_encoded  \
0   Politically Correct Bedtime Stories: Modern Ta...              396   
13                      The Poisonwood Bible: A Novel              621   
15                                 Bel Canto: A Novel               71   
18  One for the Money (Stephanie Plum Novels (Pape...              380   
24                                  The Secret Garden              642   

                     author  author_encoded  
0         James Finn Garner             249  
13       Barbara Kingsolver              53  
15             Ann Patchett              27  
18          Janet Evanovich             264  
24  Frances Hodgson Burnett             186  


In [15]:
# Create entity and relation mappings
n_users = len(user_encoder.classes_)
n_books = len(book_encoder.classes_)
n_authors = len(author_encoder.classes_)
print(f"Number of unique users: {n_users}")
print(f"Number of unique books: {n_books}")
print(f"Number of unique authors: {n_authors}")

Number of unique users: 888
Number of unique books: 742
Number of unique authors: 584


# Step 4: Build Knowledge Graph

In [16]:
print("\nStep 4: Building knowledge graph...")
# Create user-book interaction edges (relation type: 'rated')
kg_edges = []
for _, row in final_rating.iterrows():
    kg_edges.append((row['user_id_encoded'], row['book_id_encoded'] + n_users, 'rated'))

# Add book-author relations
for _, row in final_rating.drop_duplicates('title').iterrows():
    kg_edges.append((row['book_id_encoded'] + n_users, row['author_encoded'] + n_users + n_books, 'written_by'))

# Convert edges to tensor
edge_index = []
edge_type = []
relation_types = {'rated': 0, 'written_by': 1}
for src, dst, rel in kg_edges:
    edge_index.append([src, dst])
    edge_type.append(relation_types[rel])

edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
edge_type = torch.tensor(edge_type, dtype=torch.long)
print("Knowledge graph edge index shape:", edge_index.shape)
print("Knowledge graph edge types shape:", edge_type.shape)


Step 4: Building knowledge graph...
Knowledge graph edge index shape: torch.Size([2, 60592])
Knowledge graph edge types shape: torch.Size([60592])


# Step 5: Define KGAT Model

In [17]:
# Define KGAT Model
class KGAT(nn.Module):
    def __init__(self, n_entities, n_relations, embed_dim):
        super(KGAT, self).__init__()
        self.embed_dim = embed_dim
        # Entity and relation embeddings
        self.entity_emb = nn.Embedding(n_entities, embed_dim)
        self.relation_emb = nn.Embedding(n_relations, embed_dim)
        # Attention mechanism
        self.attention = nn.Linear(embed_dim * 2, 1)
        self.W = nn.Linear(embed_dim, embed_dim)
        
        # Initialize embeddings
        nn.init.xavier_normal_(self.entity_emb.weight)
        nn.init.xavier_normal_(self.relation_emb.weight)
        
    def forward(self, edge_index, edge_type):
        head = self.entity_emb(edge_index[0])
        tail = self.entity_emb(edge_index[1])
        rel = self.relation_emb(edge_type)
        
        # Attention scores
        head_rel = torch.cat([head, rel], dim=-1)
        att_scores = torch.sigmoid(self.attention(head_rel))
        
        # Aggregate neighbor embeddings manually
        neighbor_emb = att_scores * tail
        # Initialize aggregated embeddings
        aggr_emb = torch.zeros(self.entity_emb.num_embeddings, self.embed_dim).to(head.device)
        # Manually aggregate by summing neighbor embeddings for each head
        for i in range(edge_index.shape[1]):
            head_idx = edge_index[0, i]
            aggr_emb[head_idx] += neighbor_emb[i]
        
        # Update embeddings
        entity_emb = torch.tanh(self.W(aggr_emb))
        return entity_emb

    def predict(self, user_ids, item_ids):
        user_emb = self.entity_emb(user_ids)
        item_emb = self.entity_emb(item_ids)
        scores = (user_emb * item_emb).sum(dim=-1)
        return scores

# Step 6: Training Setup

In [30]:
from sklearn.model_selection import train_test_split
print("\nStep 6: Setting up training...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_entities = n_users + n_books + n_authors
n_relations = len(relation_types)
embed_dim = 64
model = KGAT(n_entities, n_relations, embed_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
edge_index = edge_index.to(device)
edge_type = edge_type.to(device)

# Prepare training and test data
train_data = final_rating[['user_id_encoded', 'book_id_encoded', 'rating']].copy()
train_data['book_id_encoded'] = train_data['book_id_encoded'] + n_users  # Offset book IDs
train_data, test_data = train_test_split(train_data, test_size=0.2, random_state=42)
print("Training data sample:")
print(train_data.head())
print("\nTest data sample:")
print(test_data.head())


Step 6: Setting up training...
Training data sample:
        user_id_encoded  book_id_encoded  rating
14222                24             1338       0
28563                28             1252       9
473722              856             1253      10
203805              346             1081       6
56037                83             1315       0

Test data sample:
        user_id_encoded  book_id_encoded  rating
484226              879             1174      10
271798              488              951       8
66307               101             1626       0
335782              607             1599       0
430341              769             1435       7


# Step 7: Training Loop

In [31]:
# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("\nStep 7: Training KGAT model...")
def train(model, edge_index, edge_type, train_data, epochs=10):
    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        # Forward pass
        entity_emb = model(edge_index, edge_type)
        
        # Compute loss
        user_ids = torch.tensor(train_data['user_id_encoded'].values, dtype=torch.long).to(device)
        item_ids = torch.tensor(train_data['book_id_encoded'].values, dtype=torch.long).to(device)
        ratings = torch.tensor(train_data['rating'].values, dtype=torch.float).to(device)
        
        pred_scores = model.predict(user_ids, item_ids)
        loss = F.mse_loss(pred_scores, ratings)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

train(model, edge_index, edge_type, train_data)


Step 7: Training KGAT model...
Epoch 1, Loss: 16.5146
Epoch 2, Loss: 16.5117
Epoch 3, Loss: 16.5088
Epoch 4, Loss: 16.5059
Epoch 5, Loss: 16.5029
Epoch 6, Loss: 16.4999
Epoch 7, Loss: 16.4969
Epoch 8, Loss: 16.4938
Epoch 9, Loss: 16.4906
Epoch 10, Loss: 16.4874


 ### Performance Evaluation

In [33]:
# Performance Evaluation
print("\nEvaluating model performance...")
def evaluate_model(model, edge_index, edge_type, test_data, k=5):
    model.eval()
    with torch.no_grad():
        # Forward pass to get entity embeddings
        entity_emb = model(edge_index, edge_type)
        
        # Compute predictions for test set
        user_ids = torch.tensor(test_data['user_id_encoded'].values, dtype=torch.long).to(device)
        item_ids = torch.tensor(test_data['book_id_encoded'].values, dtype=torch.long).to(device)
        true_ratings = torch.tensor(test_data['rating'].values, dtype=torch.float).to(device)
        
        pred_scores = model.predict(user_ids, item_ids)
        
        # Calculate RMSE
        rmse = torch.sqrt(F.mse_loss(pred_scores, true_ratings)).item()
        print(f"RMSE on test set: {rmse:.4f}")
        
        # Calculate Precision@K
        _, top_k_indices = pred_scores.topk(k, dim=0)
        top_k_true_ratings = true_ratings[top_k_indices]
        relevant_count = (top_k_true_ratings >= 6).sum().item()  # Assuming rating >= 6 is relevant
        precision_at_k = relevant_count / k
        print(f"Precision@{k} on test set: {precision_at_k:.4f}")

evaluate_model(model, edge_index, edge_type, test_data)


Evaluating model performance...
RMSE on test set: 4.1100
Precision@5 on test set: 0.2000


# Step 8: Recommendation Function

In [20]:
print("\nStep 8: Defining recommendation function...")
def recommend_books(user_id, model, book_encoder, n_users, top_k=5):
    model.eval()
    with torch.no_grad():
        try:
            user_id_encoded = user_encoder.transform([user_id])[0]
        except ValueError:
            print(f"User ID {user_id} not found in the dataset.")
            return [], []
        user_tensor = torch.tensor([user_id_encoded], dtype=torch.long).to(device)
        
        # Get all book embeddings
        book_ids = torch.arange(n_users, n_users + n_books, dtype=torch.long).to(device)
        scores = model.predict(user_tensor.expand(len(book_ids)), book_ids)
        
        # Get top-k book indices
        _, top_indices = scores.topk(top_k)
        top_book_ids = top_indices.cpu().numpy()
        
        # Decode book titles
        recommended_books = book_encoder.inverse_transform(top_book_ids)
        return recommended_books, top_indices


Step 8: Defining recommendation function...


# Step 9: Visualization Functions

In [21]:
print("\nStep 9: Defining visualization functions...")
def visualize_recommendations(user_id, recommended_books, top_indices, model, book_encoder, n_users):
    model.eval()
    with torch.no_grad():
        user_id_encoded = user_encoder.transform([user_id])[0]
        user_tensor = torch.tensor([user_id_encoded], dtype=torch.long).to(device)
        book_ids = torch.arange(n_users, n_users + n_books, dtype=torch.long).to(device)
        scores = model.predict(user_tensor.expand(len(book_ids)), book_ids)
        top_scores = scores[top_indices].cpu().numpy()

    plt.figure(figsize=(10, 5))
    plt.bar(recommended_books, top_scores, color='skyblue')
    plt.xlabel('Book Titles')
    plt.ylabel('Predicted Score')
    plt.title(f'Top-5 Recommended Books for User {user_id}')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig('recommendations_bar_chart.png')
    plt.close()

def visualize_kg_subgraph(user_id, recommended_books, final_rating, book_encoder, user_encoder, author_encoder):
    G = nx.DiGraph()
    user_id_encoded = user_encoder.transform([user_id])[0]
    G.add_node(f"User_{user_id}", label=f"User {user_id}", type='user')

    # Add recommended books and their authors
    for book in recommended_books:
        book_id_encoded = book_encoder.transform([book])[0]
        G.add_node(f"Book_{book_id_encoded}", label=book, type='book')
        G.add_edge(f"User_{user_id}", f"Book_{book_id_encoded}", relation='rated')

        # Find the author of the book
        book_row = final_rating[final_rating['title'] == book].iloc[0]
        author_encoded = book_row['author_encoded']
        author_name = author_encoder.inverse_transform([author_encoded])[0]
        G.add_node(f"Author_{author_encoded}", label=author_name, type='author')
        G.add_edge(f"Book_{book_id_encoded}", f"Author_{author_encoded}", relation='written_by')

    # Plot the graph
    pos = nx.spring_layout(G)
    node_colors = ['lightblue' if G.nodes[node]['type'] == 'user' else 'lightgreen' if G.nodes[node]['type'] == 'book' else 'salmon' for node in G.nodes]
    plt.figure(figsize=(12, 8))
    nx.draw(G, pos, with_labels=True, labels=nx.get_node_attributes(G, 'label'), node_color=node_colors, node_size=2000, font_size=10, font_weight='bold')
    edge_labels = nx.get_edge_attributes(G, 'relation')
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
    plt.title(f'Knowledge Graph Subgraph for User {user_id}')
    plt.savefig('kg_subgraph.png')
    plt.close()


Step 9: Defining visualization functions...


# Step 10: Test Recommendations and Visualizations

In [22]:
# Step 10: Test Recommendations and Visualizations
print("\nStep 10: Testing recommendations and visualizations...")
test_user_id = active_users[0]  # Use the first active user for testing
recommended_books, top_indices = recommend_books(test_user_id, model, book_encoder, n_users)
print(f"Recommended books for user {test_user_id}:")
for i, book in enumerate(recommended_books, 1):
    print(f"{i}. {book}")

# Generate visualizations
visualize_recommendations(test_user_id, recommended_books, top_indices, model, book_encoder, n_users)
visualize_kg_subgraph(test_user_id, recommended_books, final_rating, book_encoder, user_encoder, author_encoder)
print("Visualizations saved as 'recommendations_bar_chart.png' and 'kg_subgraph.png'.")


Step 10: Testing recommendations and visualizations...
Recommended books for user 11676:
1. It
2. B Is for Burglar (Kinsey Millhone Mysteries (Paperback))
3. Mind Prey
4. How to Be Good
5. The Phantom Tollbooth
Visualizations saved as 'recommendations_bar_chart.png' and 'kg_subgraph.png'.


# Step 11: Save Model and Artifacts

In [23]:
# Step 11: Save Model and Artifacts
print("\nStep 11: Saving model and artifacts...")
pickle.dump(model.state_dict(), open('artifacts/kgat_model.pkl', 'wb'))
pickle.dump(user_encoder, open('artifacts/user_encoder.pkl', 'wb'))
pickle.dump(book_encoder, open('artifacts/book_encoder.pkl', 'wb'))
pickle.dump(author_encoder, open('artifacts/author_encoder.pkl', 'wb'))
print("Artifacts saved successfully.")


Step 11: Saving model and artifacts...
Artifacts saved successfully.
