In [2]:
# proto_gat_main.py

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GAT

# ----------- Config -------------------
IN_CHANNELS = 18
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------- ProtoNet GAT Encoder --------------------
class GATEncoder(nn.Module):
    def __init__(self, hidden=256, heads=8, dropout=0.2, layers=2):
        super().__init__()
        self.gnn = GAT(
            in_channels=IN_CHANNELS,
            hidden_channels=hidden,
            out_channels=hidden,
            heads=heads,
            num_layers=layers,
            dropout=dropout,
            edge_dim=1,
            v2=True,
            jk='cat'
        )

    def forward(self, x, edge_index, edge_attr):
        return self.gnn(x, edge_index, edge_weight=edge_attr)

# --------- Episode Sampler --------------------------
def sample_episode(data_list, task, k_shot=1,q_num=4):
    task_data = [d for d in data_list if getattr(d, 'task', None) == task]
    random.shuffle(task_data)
    return task_data[:k_shot], task_data[k_shot:k_shot + q_num]

# --------- Compute Prototypes ------------------------
def compute_prototypes(embeddings, labels, num_classes=4):
    prototypes = []
    for c in range(num_classes):
        class_mask = (labels == c)
        if class_mask.sum() == 0:
            prototypes.append(torch.zeros_like(embeddings[0]))
        else:
            prototypes.append(embeddings[class_mask].mean(dim=0))
    return torch.stack(prototypes)

# --------- Compute Distances ------------------------
def euclidean_distance(a, b):
    return ((a.unsqueeze(1) - b.unsqueeze(0)) ** 2).sum(dim=2)

# --------- Prototypical Loss ------------------------
def prototypical_loss(embeddings, labels, prototypes):
    dists = euclidean_distance(embeddings, prototypes)
    log_p_y = F.log_softmax(-dists, dim=1)
    loss = F.nll_loss(log_p_y, labels)
    preds = log_p_y.argmax(dim=1)
    acc = (preds == labels).float().mean().item()
    return loss, acc

# --------- Training Loop -----------------------------
def proto_train(data_list, encoder, optimizer, n_episodes=500, k_shot=1,q_num=4):
    encoder.train()
    tasks = list(set(d.task for d in data_list))

    for episode in range(n_episodes):
        task = random.choice(tasks)
        support_set, query_set = sample_episode(data_list, task, k_shot, q_num)

        support_x, support_y = [], []
        for g in support_set:
            g = g.to(DEVICE)
            emb = encoder(g.x, g.edge_index, g.edge_attr)
            support_x.append(emb)
            support_y.append(g.y)
        support_x = torch.cat(support_x, dim=0)
        support_y = torch.cat(support_y, dim=0)

        prototypes = compute_prototypes(support_x, support_y)

        query = query_set[0].to(DEVICE)
        query_emb = encoder(query.x, query.edge_index, query.edge_attr)
        loss, acc = prototypical_loss(query_emb, query.y, prototypes)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if episode % 10 == 0:
            print(f"[Episode {episode}] Loss: {loss.item():.4f} | Accuracy: {acc*100:.2f}% | Task: {task}")

# --------- Inference on a Graph -----------------------
def proto_predict(encoder, support_set, query_graph):
    encoder.eval()
    support_x, support_y = [], []

    for g in support_set:
        g = g.to(DEVICE)
        emb = encoder(g.x, g.edge_index, g.edge_attr)
        support_x.append(emb)
        support_y.append(g.y)

    support_x = torch.cat(support_x, dim=0)
    support_y = torch.cat(support_y, dim=0)
    prototypes = compute_prototypes(support_x, support_y)

    query = query_graph.to(DEVICE)
    query_emb = encoder(query.x, query.edge_index, query.edge_attr)
    dists = euclidean_distance(query_emb, prototypes)
    preds = dists.argmin(dim=1)
    return preds.cpu()

# # --------- Example Runner -----------------------------
# if __name__ == "__main__":
#     print("📥 Loading few-shot dataset...")
#     data_list = torch.load("data/few-shot-dataset/fewshot_dataset.pt", map_location=DEVICE)

#     encoder = GATEncoder().to(DEVICE)
#     optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

#     print("🚀 Starting ProtoNet training...")
#     proto_train(data_list, encoder, optimizer, n_episodes=500)

#     print("💾 Saving trained encoder...")
#     torch.save(encoder.state_dict(), "models/proto_gat_encoder.pt")


In [3]:
print("📥 Loading few-shot dataset...")
data_list = torch.load("data/training_data/training_dataset.pt", map_location=DEVICE)

encoder = GATEncoder().to(DEVICE)
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

# -------- Load pretrained weights here ----------
# pretrained_path = "models\\model\\SingleRun_H256_L2_HD8_DO2.pth"  # <-- Update path if needed
# print(f"🔁 Loading pretrained weights from {pretrained_path}")
# encoder.load_state_dict(torch.load(pretrained_path, map_location=DEVICE))

# ✅ Optional: freeze layers if you don’t want to fine-tune
# for param in encoder.parameters():
#     param.requires_grad = False

optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

print("🚀 Starting ProtoNet training...")
proto_train(data_list, encoder, optimizer, n_episodes=250)

print(" Saving trained encoder...")
torch.save(encoder.state_dict(), "models/prototypical/proto_gat_encoder.pt")

📥 Loading few-shot dataset...


  data_list = torch.load("data/training_data/training_dataset.pt", map_location=DEVICE)


🚀 Starting ProtoNet training...
[Episode 0] Loss: 0.9421 | Accuracy: 76.32% | Task: Loan
[Episode 10] Loss: 0.6295 | Accuracy: 66.28% | Task: Final Bill
[Episode 20] Loss: 0.3593 | Accuracy: 90.91% | Task: Loan
[Episode 30] Loss: 0.2240 | Accuracy: 93.02% | Task: Final Bill
[Episode 40] Loss: 0.1299 | Accuracy: 96.19% | Task: Final Bill
[Episode 50] Loss: 0.0562 | Accuracy: 98.82% | Task: Final Bill
[Episode 60] Loss: 0.0162 | Accuracy: 99.43% | Task: Loan
[Episode 70] Loss: 0.0113 | Accuracy: 100.00% | Task: Final Bill
[Episode 80] Loss: 0.0668 | Accuracy: 98.06% | Task: Loan
[Episode 90] Loss: 0.0290 | Accuracy: 98.10% | Task: Final Bill
[Episode 100] Loss: 0.0036 | Accuracy: 100.00% | Task: Loan
[Episode 110] Loss: 0.0004 | Accuracy: 100.00% | Task: Invoice
[Episode 120] Loss: 0.0109 | Accuracy: 100.00% | Task: Invoice
[Episode 130] Loss: 0.0110 | Accuracy: 98.80% | Task: Invoice
[Episode 140] Loss: 0.0114 | Accuracy: 99.16% | Task: Invoice
[Episode 150] Loss: 0.0014 | Accuracy: 100

In [4]:
# proto_eval.py

import torch
from torch_geometric.nn import GAT
from sklearn.metrics import classification_report
# from proto_gat_main import GATEncoder, proto_predict

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------- Load Trained Encoder ---------
def load_encoder(model_path):
    encoder = GATEncoder().to(DEVICE)
    encoder.load_state_dict(torch.load(model_path, map_location=DEVICE))
    encoder.eval()
    return encoder

# --------- Load Graphs ---------
def load_graphs(path):
    return torch.load(path, map_location=DEVICE)

# --------- Run Prediction ---------
def run_inference(support_path, query_path, model_path):
    encoder = load_encoder(model_path)
    support_graphs = load_graphs(support_path)
    query_graphs = load_graphs(query_path)

    all_preds = []
    all_trues = []

    for i, query_graph in enumerate(query_graphs):
        pred = proto_predict(encoder, support_graphs, query_graph)
        true_labels = query_graph.y.cpu()
        all_preds.append(pred)
        all_trues.append(true_labels)

        print(f"\n📄 Query Graph {i+1} Predictions:")
        print(pred.tolist())
        value_nodes = (pred == 1).nonzero(as_tuple=True)[0].tolist()
        print(f" VALUE nodes at indices: {value_nodes}")

        # Per-label accuracy for this query graph
        print(classification_report(true_labels, pred, zero_division=0))

    # Overall classification report across all query graphs
    all_preds_flat = torch.cat(all_preds).numpy()
    all_trues_flat = torch.cat(all_trues).numpy()
    print("\n===== Overall Classification Report (all query graphs) =====")
    print(classification_report(all_trues_flat, all_preds_flat, zero_division=0))

# # --------- Main ---------
# if __name__ == "__main__":
#     run_inference(
#         support_path="data/few-shot-dataset/invoice_support.pt",
#         query_path="data/few-shot-dataset/invoice_query.pt",
#         model_path="models/proto_gat_encoder.pt"
#     )


In [5]:
# Run inference and print per-label accuracy and overall accuracy using classification_report
from sklearn.metrics import classification_report

def run_inference_with_accuracy(support_path, query_path, model_path):
    encoder = load_encoder(model_path)
    support_graphs = load_graphs(support_path)
    query_graphs = load_graphs(query_path)

    all_preds = []
    all_trues = []

    for i, query_graph in enumerate(query_graphs):
        pred = proto_predict(encoder, support_graphs, query_graph)
        true_labels = query_graph.y.cpu()
        all_preds.append(pred)
        all_trues.append(true_labels)

        print(f"\n📄 Query Graph {i+1} Predictions:")
        print(pred.tolist())
        value_nodes = (pred == 1).nonzero(as_tuple=True)[0].tolist()
        print(f" VALUE nodes at indices: {value_nodes}")

        # Per-label accuracy for this query graph
        print(classification_report(true_labels, pred, zero_division=0))

    # Overall classification report across all query graphs
    all_preds_flat = torch.cat(all_preds).numpy()
    all_trues_flat = torch.cat(all_trues).numpy()
    print("\n===== Overall Classification Report (all query graphs) =====")
    print(classification_report(all_trues_flat, all_preds_flat, zero_division=0))

run_inference_with_accuracy(
    support_path="data\\final_bill\\datacheckpoint_1.pt",
    query_path="data/test_data/test_dataset_OR.pt",
    model_path="models/prototypical/proto_gat_encoder.pt",
)

  encoder.load_state_dict(torch.load(model_path, map_location=DEVICE))
  return torch.load(path, map_location=DEVICE)



📄 Query Graph 1 Predictions:
[0, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 0, 3, 3, 3]
 VALUE nodes at indices: []
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        29
           3       1.00      1.00      1.00        91

    accuracy                           1.00       120
   macro avg       1.00      1.00      1.00       120
weighted avg       1.00      1.00      1.00       120


📄 Query Graph 2 Predictions:
[0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 0, 3, 3, 0, 3, 3, 0, 3, 3, 0, 3, 3, 3, 0, 3, 3]
 VALUE nodes at indices: []
              precision    recall  f1-score   support

           0       1.00      1.00  

In [6]:
run_inference_with_accuracy(
    support_path="data/PR2/Datacheckpoint_GAN_Model_16",
    query_path="data/test_data/test_dataset_PR2.pt",
    model_path="models/prototypical/proto_gat_encoder.pt",
)

  encoder.load_state_dict(torch.load(model_path, map_location=DEVICE))
  return torch.load(path, map_location=DEVICE)



📄 Query Graph 1 Predictions:
[0, 1, 3, 0, 1, 3, 0, 1, 3, 0, 1, 3, 0, 1, 2, 2, 2, 1, 1, 0, 0, 1, 3, 0, 0, 2, 1, 1, 2, 0, 1, 1, 0, 0, 1, 3, 0, 0, 2, 1, 1, 2, 0, 1, 1, 0, 0, 1, 1, 0, 0, 2, 1, 1, 2, 0, 1, 1, 0, 1, 3, 0, 0, 1, 1, 0, 0, 2, 1, 1, 2, 0, 1, 1, 0, 0, 1, 3, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 2, 1, 1, 2, 0, 1, 1, 0, 2, 1, 3, 2, 2, 1, 1, 0, 2, 1, 3, 2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 2, 1, 1, 2, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 2, 1, 1, 2, 0, 1, 1, 0, 1]
 VALUE nodes at indices: [1, 4, 7, 10, 13, 17, 18, 21, 26, 27, 30, 31, 34, 39, 40, 43, 44, 47, 48, 52, 53, 56, 57, 59, 63, 64, 68, 69, 72, 73, 76, 81, 82, 86, 87, 91, 92, 95, 96, 99, 103, 104, 107, 111, 112, 115, 116, 120, 121, 125, 126, 129, 130, 134, 135, 139, 140, 143, 144, 146]
              precision    recall  f1-score   support

           0       0.20      0.65      0.31        17
           1       0.23      1.00      0.38        14
           2       0.30      0.26      0.28        27
           3       0.90     

In [7]:
# proto_single_graph_eval.py

import torch
import random
import numpy as np
from sklearn.metrics import classification_report
# from proto_gat_main import GATEncoder, compute_prototypes, euclidean_distance

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------- Load Trained GAT Proto Encoder ---------
def load_encoder(model_path):
    encoder = GATEncoder().to(DEVICE)
    encoder.load_state_dict(torch.load(model_path, map_location=DEVICE))
    encoder.eval()
    return encoder

# --------- Evaluate Using Graph(s) ---------
def evaluate_single_graph(graph_path, model_path, support_ratio=0.2):
    print(" Loading graph(s) from:", graph_path)
    data = torch.load(graph_path, map_location=DEVICE)
    encoder = load_encoder(model_path)

    from collections import defaultdict
    all_preds = []
    all_trues = []

    if isinstance(data, list):
        print(" Detected list of graphs. Splitting support/query at graph level.")
        random.shuffle(data)
        split = int(support_ratio * len(data))
        if split == 0:
            split = 1
        support_graphs = data[:split]
        query_graphs = data[split:]

        support_emb, support_y = [], []
        for g in support_graphs:
            g = g.to(DEVICE)
            emb = encoder(g.x, g.edge_index, g.edge_attr)
            support_emb.append(emb)
            support_y.append(g.y)
        if not support_emb or not support_y:
            print("No support graphs available. Please check your data or support_ratio.")
            return

        support_emb = torch.cat(support_emb, dim=0)
        support_y = torch.cat(support_y, dim=0)
        prototypes = compute_prototypes(support_emb, support_y)

        for i, g in enumerate(query_graphs):
            g = g.to(DEVICE)
            with torch.no_grad():
                emb = encoder(g.x, g.edge_index, g.edge_attr)
                dists = euclidean_distance(emb, prototypes)
                preds = dists.argmin(dim=1).cpu()
                all_preds.append(preds)
                all_trues.append(g.y.cpu())

                value_indices = (preds == 1).nonzero(as_tuple=True)[0].tolist()
                print(f"\n Graph {i+1} Predictions:")
                print(preds.tolist())
                print(f" Predicted VALUE nodes: {value_indices}")

                # Per-label accuracy for this query graph
                print(classification_report(g.y.cpu(), preds, zero_division=0))

        # Overall classification report across all query graphs
        all_preds_flat = torch.cat(all_preds).numpy()
        all_trues_flat = torch.cat(all_trues).numpy()
        print("\n===== Overall Classification Report (all query graphs) =====")
        print(classification_report(all_trues_flat, all_preds_flat, zero_division=0))

    else:
        print("🔎 Detected single graph. Splitting support/query at node level.")
        graph = data
        num_nodes = graph.x.size(0)
        indices = list(range(num_nodes))
        random.shuffle(indices)

        split = int(support_ratio * num_nodes)
        support_idx = indices[:split]
        query_idx = indices[split:]

        support_mask = torch.zeros(num_nodes, dtype=torch.bool)
        support_mask[support_idx] = True

        query_mask = torch.zeros(num_nodes, dtype=torch.bool)
        query_mask[query_idx] = True

        support_x = graph.x[support_mask]
        support_y = graph.y[support_mask]

        query_x = graph.x[query_mask]
        query_y = graph.y[query_mask]

        with torch.no_grad():
            embeddings = encoder(graph.x.to(DEVICE), graph.edge_index.to(DEVICE), graph.edge_attr.to(DEVICE))
            support_emb = embeddings[support_mask.to(DEVICE)]
            query_emb = embeddings[query_mask.to(DEVICE)]
            prototypes = compute_prototypes(support_emb, support_y.to(DEVICE))
            dists = euclidean_distance(query_emb, prototypes)
            preds = dists.argmin(dim=1).cpu()

        all_preds.append(preds)
        all_trues.append(query_y.cpu())

        print(f"Evaluation Completed on Single Graph")
        print(f"Predicted labels: {preds.tolist()}")
        print(f"True labels: {query_y.tolist()}")

        # Per-label accuracy for this query set
        print(classification_report(query_y, preds, zero_division=0))

        # Overall (just this graph)
        all_preds_flat = torch.cat(all_preds).numpy()
        all_trues_flat = torch.cat(all_trues).numpy()
        print("\n===== Overall Classification Report (this graph) =====")
        print(classification_report(all_trues_flat, all_preds_flat, zero_division=0))

        return all_preds, all_trues


# --------- Main ---------
# if __name__ == "__main__":
#     evaluate_single_graph(
#         graph_path="data/sample_invoice.pt",
#         model_path="models/proto_gat_encoder.pt",
#         support_ratio=0.2
#     )


In [9]:
evaluate_single_graph(
        graph_path="data/test_data/test_dataset_OR.pt",
        model_path="models/prototypical/proto_gat_encoder.pt",
        support_ratio=0.2
    )

 Loading graph(s) from: data/test_data/test_dataset_OR.pt
 Detected list of graphs. Splitting support/query at graph level.

 Graph 1 Predictions:
[0, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3]
 Predicted VALUE nodes: []
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        13
           3       1.00      1.00      1.00        38

    accuracy                           1.00        51
   macro avg       1.00      1.00      1.00        51
weighted avg       1.00      1.00      1.00        51


 Graph 2 Predictions:
[0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 3]
 Predicted VALUE nodes: []
              precision    recall  f1-score   support

           0       1.00      1.00  

  data = torch.load(graph_path, map_location=DEVICE)
  encoder.load_state_dict(torch.load(model_path, map_location=DEVICE))


              precision    recall  f1-score   support

           0       1.00      1.00      1.00        10
           3       1.00      1.00      1.00        30

    accuracy                           1.00        40
   macro avg       1.00      1.00      1.00        40
weighted avg       1.00      1.00      1.00        40


 Graph 4 Predictions:
[0, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 0, 3, 3, 3]
 Predicted VALUE nodes: []
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        29
           3       1.00      1.00      1.00        91

    accuracy                           1.00       120
   macro avg       1.00      1.00      1.00 