# KGAT-SR for Predictive Maintenance

This notebook adapts the **Knowledge-Enhanced Graph Attention Network for Session-based Recommendation (KGAT-SR)** model for a predictive maintenance task. The goal is to predict which vehicle parts might need replacement in the next maintenance session based on historical work orders.

### How to Use This Notebook:
1.  **Upload Data**: In the Colab file explorer on the left, create a folder. The folder name should match the `dataset` name in the "Configuration" cell below (e.g., `toyota_maintenance`). Upload your `train.txt`, `test.txt`, and `kg.txt` files into this folder.
2.  **Configure Parameters**: Adjust the settings in the "Configuration" cell as needed for your specific dataset and desired model hyperparameters.
3.  **Run All Cells**: Execute the cells sequentially from top to bottom. The training process will begin in the final cell.

## 1. Setup and Imports

This cell installs the required libraries and imports all the necessary Python packages.

In [None]:
# Install necessary libraries
!pip install networkx pandas numpy torch

# Import all required packages
import torch
from torch import nn
from torch.nn import Module, Parameter
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd
import networkx as nx

import pickle
import math
import datetime
import time
import os
import argparse

## 2. Configuration

All hyperparameters and settings are defined here. Instead of using `argparse` from a command line, we use an `argparse.Namespace` to make it easy to modify settings directly in the notebook.

In [None]:
# All hyperparameters and settings
args = argparse.Namespace(
    # --- Critical Settings ---
    dataset='toyota_maintenance', # <<< CHANGE THIS to your dataset folder name
    batchSize=100,
    hiddenSize=100, # Also used for embedding size
    epoch=30,
    lr=0.001,
    l2=1e-5,

    # --- Model & Training Details ---
    lr_dc=0.1,
    lr_dc_step=3,
    step=1, # GNN propagation steps
    patience=10, # Epochs to wait for improvement before early stopping
    nonhybrid=False, # If true, only uses global preference, not the last item's info
    validation=False, # If true, splits train set for validation
    valid_portion=0.1,

    # --- KGAT-Specific Settings (For your implementation) ---
    emb_size=100,
    neibor_size=4, # Number of neighbors to sample from KG
    attr_size=2, # Number of attributes to sample
    aggregate='concat', # How to aggregate KG info: 'concat', 'sum', etc.
    
    # --- Other Settings ---
    n_workers=2, # Dataloader workers. Use 2 for Colab.
)

# Set the device to CUDA (GPU) if available, otherwise CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print("Configuration:", args)

## 3. Data Loading and Knowledge Graph Classes

This section contains the classes and functions required to handle the knowledge graph (`KGraph`) and the session data (`Data`). The code is adapted from the original `kg.py` and `utils.py` files.

In [None]:
# Adapted from kg.py
class KGraph():
    def __init__(self, dataset, attr_size):
        self.dataset = dataset
        self.sample_attr_size = attr_size
        self.G, self.n_relation = self.get_kg()
        self.n_entity = self.G.number_of_nodes()
        print(f"Knowledge Graph loaded: {self.n_entity} entities, {self.n_relation} relations.")

    def get_kg(self):
        kg_path = f'/content/{self.dataset}/kg.txt'
        if not os.path.exists(kg_path):
            raise FileNotFoundError(f"Knowledge graph file not found at: {kg_path}")
        G = nx.Graph()
        all_rels = set()
        with open(kg_path, 'r') as f:
            for line in f.readlines():
                arr = line.strip().split('\t')
                if len(arr) != 3:
                    continue
                head, rel, tail = int(arr[0]), arr[1], int(arr[2])
                G.add_edge(head, tail, rel=rel)
                all_rels.add(rel)
        return G, len(all_rels)

# Adapted from utils.py
def data_masks(all_usr_pois, item_tail):
    us_lens = [len(upois) for upois in all_usr_pois]
    len_max = max(us_lens)
    us_pois = [upois + item_tail * (len_max - le) for upois, le in zip(all_usr_pois, us_lens)]
    us_msks = [[1] * le + [0] * (len_max - le) for le in us_lens]
    return us_pois, us_msks, len_max

class Data():
    def __init__(self, data, shuffle=False):
        inputs = data[0]
        inputs, mask, len_max = data_masks(inputs, [0])
        self.inputs = np.asarray(inputs)
        self.mask = np.asarray(mask)
        self.len_max = len_max
        self.targets = np.asarray(data[1])
        self.length = len(inputs)
        self.shuffle = shuffle

    def generate_batch(self, batch_size):
        if self.shuffle:
            shuffled_arg = np.arange(self.length)
            np.random.shuffle(shuffled_arg)
            self.inputs = self.inputs[shuffled_arg]
            self.mask = self.mask[shuffled_arg]
            self.targets = self.targets[shuffled_arg]
        n_batch = int(self.length / batch_size)
        if self.length % batch_size != 0:
            n_batch += 1
        slices = np.split(np.arange(n_batch * batch_size), n_batch)
        slices[-1] = np.arange(self.length - batch_size, self.length)
        return slices

    def get_slice(self, i):
        inputs, mask, targets = self.inputs[i], self.mask[i], self.targets[i]
        items, n_node, A, alias_inputs = [], [], [], []
        for u_input in inputs:
            n_node.append(len(np.unique(u_input)))
        max_n_node = np.max(n_node)
        for u_input in inputs:
            node = np.unique(u_input)
            items.append(node.tolist() + (max_n_node - len(node)) * [0])
            u_A = np.zeros((max_n_node, max_n_node))
            for i in np.arange(len(u_input) - 1):
                if u_input[i + 1] == 0:
                    break
                u = np.where(node == u_input[i])[0][0]
                v = np.where(node == u_input[i + 1])[0][0]
                u_A[u][v] = 1
            u_sum_in = np.sum(u_A, 0)
            u_sum_in[np.where(u_sum_in == 0)] = 1
            u_A_in = np.divide(u_A, u_sum_in)
            u_sum_out = np.sum(u_A, 1)
            u_sum_out[np.where(u_sum_out == 0)] = 1
            u_A_out = np.divide(u_A.transpose(), u_sum_out)
            A.append(np.concatenate([u_A_in, u_A_out]).transpose())
            alias_inputs.append([np.where(node == i)[0][0] for i in u_input])
        return alias_inputs, A, items, mask, targets

## 4. Model Definition

This section contains the core model classes: `GNN` (the Gated Graph Neural Network cell) and `SRGAT` (the main model). The code is primarily adapted from `SRGATM.py` and `model.py`.

In [None]:
# GNN class from model.py and SRGATM.py
class GNN(Module):
    def __init__(self, hidden_size, step=1):
        super(GNN, self).__init__()
        self.step = step
        self.hidden_size = hidden_size
        self.input_size = hidden_size * 2
        self.gate_size = 3 * hidden_size
        self.w_ih = Parameter(torch.Tensor(self.gate_size, self.input_size))
        self.w_hh = Parameter(torch.Tensor(self.gate_size, self.hidden_size))
        self.b_ih = Parameter(torch.Tensor(self.gate_size))
        self.b_hh = Parameter(torch.Tensor(self.gate_size))
        self.b_iah = Parameter(torch.Tensor(self.hidden_size))
        self.b_oah = Parameter(torch.Tensor(self.hidden_size))
        self.linear_edge_in = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
        self.linear_edge_out = nn.Linear(self.hidden_size, self.hidden_size, bias=True)

    def GNNCell(self, A, hidden):
        input_in = torch.matmul(A[:, :, :A.shape[1]], self.linear_edge_in(hidden)) + self.b_iah
        input_out = torch.matmul(A[:, :, A.shape[1]: 2 * A.shape[1]], self.linear_edge_out(hidden)) + self.b_oah
        inputs = torch.cat([input_in, input_out], 2)
        gi = F.linear(inputs, self.w_ih, self.b_ih)
        gh = F.linear(hidden, self.w_hh, self.b_hh)
        i_r, i_i, i_n = gi.chunk(3, 2)
        h_r, h_i, h_n = gh.chunk(3, 2)
        resetgate = torch.sigmoid(i_r + h_r)
        inputgate = torch.sigmoid(i_i + h_i)
        newgate = torch.tanh(i_n + resetgate * h_n)
        hy = newgate + inputgate * (hidden - newgate)
        return hy

    def forward(self, A, hidden):
        for i in range(self.step):
            hidden = self.GNNCell(A, hidden)
        return hidden

# SRGAT class, merging SessionGraph and SRGAT logic
class SRGAT(nn.Module):
    def __init__(self, args, n_node, n_items, n_rels):
        super().__init__()
        self.hidden_size = args.hiddenSize
        self.n_node = n_node
        self.n_items = n_items
        self.n_rels = n_rels
        self.batch_size = args.batchSize
        self.nonhybrid = args.nonhybrid

        # Item and relation embeddings
        self.embedding = nn.Embedding(self.n_node, self.hidden_size)
        # You might need a relation embedding table for your KG logic
        # self.relation_embedding = nn.Embedding(self.n_rels, self.hidden_size)

        # GNN for session graph
        self.gnn = GNN(self.hidden_size, step=args.step)

        # Attention and output layers
        self.linear_one = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
        self.linear_two = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
        self.linear_three = nn.Linear(self.hidden_size, 1, bias=False)
        self.linear_transform = nn.Linear(self.hidden_size * 2, self.hidden_size, bias=True)

        # Loss and optimizer
        self.loss_function = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=args.lr, weight_decay=args.l2)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=args.lr_dc_step, gamma=args.lr_dc)
        
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def compute_scores(self, hidden, mask):
        # Get the last item's embedding
        ht = hidden[torch.arange(mask.shape[0]).long(), torch.sum(mask, 1) - 1]
        
        # Attention mechanism to get session embedding
        q1 = self.linear_one(ht).view(ht.shape[0], 1, ht.shape[1])
        q2 = self.linear_two(hidden)
        alpha = self.linear_three(torch.sigmoid(q1 + q2))
        session_emb = torch.sum(alpha * hidden * mask.view(mask.shape[0], -1, 1).float(), 1)

        # Combine session embedding with last item's embedding
        if not self.nonhybrid:
            session_emb = self.linear_transform(torch.cat([session_emb, ht], 1))
        
        # =====================================================================
        # TODO: This is where you would integrate the knowledge graph embedding.
        # For example, you could get a KG embedding for the session (`kg_emb`)
        # and combine it with the `session_emb`.
        # final_emb = session_emb + kg_emb 
        # =====================================================================
        final_emb = session_emb

        # Calculate scores against all items
        b = self.embedding.weight[1:]  # n_nodes x latent_size
        scores = torch.matmul(final_emb, b.transpose(1, 0))
        return scores

    def forward(self, items, A):
        hidden = self.embedding(items)
        hidden = self.gnn(A, hidden)
        return hidden

## 5. Training and Evaluation Logic

These functions handle the core logic for a single training/testing step. They orchestrate the forward pass, loss calculation, backpropagation, and performance metric calculation.

In [None]:
# Utility functions to move tensors to the correct device
def trans_to_cuda(variable):
    return variable.to(device)

def trans_to_cpu(variable):
    return variable.to('cpu')

# Main forward pass for a slice of data
def model_forward_pass(model, i, data):
    alias_inputs, A, items, mask, targets = data.get_slice(i)
    alias_inputs = trans_to_cuda(torch.Tensor(alias_inputs).long())
    items = trans_to_cuda(torch.Tensor(items).long())
    A = trans_to_cuda(torch.Tensor(A).float())
    mask = trans_to_cuda(torch.Tensor(mask).long())
    
    hidden = model(items, A)
    
    get_hidden = lambda i: hidden[i][alias_inputs[i]]
    seq_hidden = torch.stack([get_hidden(i) for i in torch.arange(len(alias_inputs)).long()])
    
    return targets, model.compute_scores(seq_hidden, mask)

# Main function to run a full epoch of training and evaluation
def train_test(model, train_data, test_data):
    model.scheduler.step()
    print('start training: ', datetime.datetime.now())
    model.train()
    total_loss = 0.0
    slices = train_data.generate_batch(model.batch_size)
    for i, j in zip(slices, np.arange(len(slices))):
        model.optimizer.zero_grad()
        targets, scores = model_forward_pass(model, i, train_data)
        targets = trans_to_cuda(torch.Tensor(targets).long())
        loss = model.loss_function(scores, targets - 1)
        loss.backward()
        model.optimizer.step()
        total_loss += loss.item()
        if j % int(len(slices) / 5 + 1) == 0:
            print('[%d/%d] Loss: %.4f' % (j, len(slices), loss.item()))
    print('\tTotal Loss:\t%.3f' % total_loss)

    print('start predicting: ', datetime.datetime.now())
    model.eval()
    hit, mrr = [], []
    slices = test_data.generate_batch(model.batch_size)
    with torch.no_grad():
        for i in slices:
            targets, scores = model_forward_pass(model, i, test_data)
            sub_scores = scores.topk(20)[1]
            sub_scores = trans_to_cpu(sub_scores).detach().numpy()
            targets = trans_to_cpu(torch.Tensor(targets)).numpy()
            for score, target, mask in zip(sub_scores, targets, test_data.mask[i]):
                hit.append(np.isin(target - 1, score))
                if len(np.where(score == target - 1)[0]) == 0:
                    mrr.append(0)
                else:
                    mrr.append(1 / (np.where(score == target - 1)[0][0] + 1))
    hit = np.mean(hit) * 100
    mrr = np.mean(mrr) * 100
    return hit, mrr

## 6. Main Execution Block

This is where everything comes together. After ensuring your data is uploaded correctly, running this cell will:
1. Load the `train.txt` and `test.txt` data.
2. Determine the number of unique items (`n_node`).
3. Initialize the `SRGAT` model.
4. Start the training and evaluation loop.

In [None]:
def main():
    print("--- KGAT-SR for Predictive Maintenance ---")
    
    # --- Data Loading ---
    train_data_path = f'/content/{args.dataset}/train.txt'
    test_data_path = f'/content/{args.dataset}/test.txt'

    if not os.path.exists(train_data_path) or not os.path.exists(test_data_path):
        print(f"\nERROR: Data files not found in /content/{args.dataset}/")
        print("Please ensure the folder exists and contains train.txt and test.txt.")
        return

    print(f"\nStep 1: Loading data from /content/{args.dataset}...")
    with open(train_data_path, 'rb') as f:
        train_data_raw = pickle.load(f)
    with open(test_data_path, 'rb') as f:
        test_data_raw = pickle.load(f)
    
    # Determine the number of unique items (nodes)
    all_seqs = train_data_raw[0] + test_data_raw[0]
    # The +1 is because item IDs are 1-based, so max ID is the number of items
    n_node = max(max(seq) for seq in all_seqs if seq) + 1 
    print(f"Data loaded. Found {n_node-1} unique items/parts.")

    train_data = Data(train_data_raw, shuffle=True)
    test_data = Data(test_data_raw, shuffle=False)

    # --- Model Initialization ---
    print("\nStep 2: Initializing model and knowledge graph...")
    kg = KGraph(args.dataset, args.attr_size)
    
    model = SRGAT(args, n_node, kg.n_entity, kg.n_relation)
    model = trans_to_cuda(model)

    # --- Training Loop ---
    print("\nStep 3: Starting training loop...")
    start = time.time()
    best_result = [0, 0]
    best_epoch = [0, 0]
    bad_counter = 0

    for epoch in range(args.epoch):
        print('-------------------------------------------------------')
        print('Epoch: ', epoch)
        hit, mrr = train_test(model, train_data, test_data)
        flag = 0
        if hit >= best_result[0]:
            best_result[0] = hit
            best_epoch[0] = epoch
            flag = 1
        if mrr >= best_result[1]:
            best_result[1] = mrr
            best_epoch[1] = epoch
            flag = 1
        print('\nResults for this epoch:')
        print(f'\tHR@20: {hit:.4f}\tMRR@20: {mrr:.4f}')
        print('\nBest Result so far:')
        print('\tHR@20: %.4f\tMRR@20: %.4f\tEpochs: %d, %d' % (
            best_result[0], best_result[1], best_epoch[0], best_epoch[1]))
        bad_counter += 1 - flag
        if bad_counter >= args.patience:
            print(f"\nEarly stopping triggered after {args.patience} epochs with no improvement.")
            break

    print('-------------------------------------------------------')
    end = time.time()
    print("Training finished.")
    print("Run time: %f s" % (end - start))

# Run the main function
if __name__ == '__main__':
    main()