In [1]:
# proto_gat_main.py

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GAT

# ----------- Config -------------------
IN_CHANNELS = 18
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------- ProtoNet GAT Encoder --------------------
class GATEncoder(nn.Module):
    def __init__(self, hidden=128, heads=4, dropout=0.2, layers=2):
        super().__init__()
        self.gnn = GAT(
            in_channels=IN_CHANNELS,
            hidden_channels=hidden,
            out_channels=hidden,
            heads=heads,
            num_layers=layers,
            dropout=dropout,
            edge_dim=1,
            v2=True,
            jk='cat'
        )

    def forward(self, x, edge_index, edge_attr):
        return self.gnn(x, edge_index, edge_weight=edge_attr)

# --------- Episode Sampler --------------------------
def sample_episode(data_list, task, k_shot=4, q_num=1):
    task_data = [d for d in data_list if getattr(d, 'task', None) == task]
    random.shuffle(task_data)
    return task_data[:k_shot], task_data[k_shot:k_shot + q_num]

# --------- Compute Prototypes ------------------------
def compute_prototypes(embeddings, labels, num_classes=4):
    prototypes = []
    for c in range(num_classes):
        class_mask = (labels == c)
        if class_mask.sum() == 0:
            prototypes.append(torch.zeros_like(embeddings[0]))
        else:
            prototypes.append(embeddings[class_mask].mean(dim=0))
    return torch.stack(prototypes)

# --------- Compute Distances ------------------------
def euclidean_distance(a, b):
    return ((a.unsqueeze(1) - b.unsqueeze(0)) ** 2).sum(dim=2)

# --------- Prototypical Loss ------------------------
def prototypical_loss(embeddings, labels, prototypes):
    dists = euclidean_distance(embeddings, prototypes)
    log_p_y = F.log_softmax(-dists, dim=1)
    loss = F.nll_loss(log_p_y, labels)
    preds = log_p_y.argmax(dim=1)
    acc = (preds == labels).float().mean().item()
    return loss, acc

# --------- Training Loop -----------------------------
def proto_train(data_list, encoder, optimizer, n_episodes=500, k_shot=4, q_num=1):
    encoder.train()
    tasks = list(set(d.task for d in data_list))

    for episode in range(n_episodes):
        task = random.choice(tasks)
        support_set, query_set = sample_episode(data_list, task, k_shot, q_num)

        support_x, support_y = [], []
        for g in support_set:
            g = g.to(DEVICE)
            emb = encoder(g.x, g.edge_index, g.edge_attr)
            support_x.append(emb)
            support_y.append(g.y)
        support_x = torch.cat(support_x, dim=0)
        support_y = torch.cat(support_y, dim=0)

        prototypes = compute_prototypes(support_x, support_y)

        query = query_set[0].to(DEVICE)
        query_emb = encoder(query.x, query.edge_index, query.edge_attr)
        loss, acc = prototypical_loss(query_emb, query.y, prototypes)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if episode % 10 == 0:
            print(f"[Episode {episode}] Loss: {loss.item():.4f} | Accuracy: {acc*100:.2f}% | Task: {task}")

# --------- Inference on a Graph -----------------------
def proto_predict(encoder, support_set, query_graph):
    encoder.eval()
    support_x, support_y = [], []

    for g in support_set:
        g = g.to(DEVICE)
        emb = encoder(g.x, g.edge_index, g.edge_attr)
        support_x.append(emb)
        support_y.append(g.y)

    support_x = torch.cat(support_x, dim=0)
    support_y = torch.cat(support_y, dim=0)
    prototypes = compute_prototypes(support_x, support_y)

    query = query_graph.to(DEVICE)
    query_emb = encoder(query.x, query.edge_index, query.edge_attr)
    dists = euclidean_distance(query_emb, prototypes)
    preds = dists.argmin(dim=1)
    return preds.cpu()

# # --------- Example Runner -----------------------------
# if __name__ == "__main__":
#     print("📥 Loading few-shot dataset...")
#     data_list = torch.load("data/few-shot-dataset/fewshot_dataset.pt", map_location=DEVICE)

#     encoder = GATEncoder().to(DEVICE)
#     optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

#     print("🚀 Starting ProtoNet training...")
#     proto_train(data_list, encoder, optimizer, n_episodes=500)

#     print("💾 Saving trained encoder...")
#     torch.save(encoder.state_dict(), "models/proto_gat_encoder.pt")


In [2]:
print("📥 Loading few-shot dataset...")
data_list = torch.load("data/few-shot-dataset/fewshot_dataset.pt", map_location=DEVICE)

encoder = GATEncoder().to(DEVICE)
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

print("🚀 Starting ProtoNet training...")
proto_train(data_list, encoder, optimizer, n_episodes=500)

print("💾 Saving trained encoder...")
torch.save(encoder.state_dict(), "models/proto_gat_encoder.pt")

📥 Loading few-shot dataset...


  data_list = torch.load("data/few-shot-dataset/fewshot_dataset.pt", map_location=DEVICE)


🚀 Starting ProtoNet training...
[Episode 0] Loss: 1.1060 | Accuracy: 79.22% | Task: Final Bill
[Episode 10] Loss: 0.6212 | Accuracy: 63.24% | Task: Operative Report
[Episode 20] Loss: 0.5852 | Accuracy: 74.73% | Task: Invoice
[Episode 30] Loss: 0.3274 | Accuracy: 85.29% | Task: Operative Report
[Episode 40] Loss: 0.1529 | Accuracy: 97.50% | Task: Operative Report
[Episode 50] Loss: 0.1644 | Accuracy: 94.81% | Task: Final Bill
[Episode 60] Loss: 0.0287 | Accuracy: 98.53% | Task: Operative Report
[Episode 70] Loss: 0.0792 | Accuracy: 97.10% | Task: Final Bill
[Episode 80] Loss: 0.0248 | Accuracy: 98.55% | Task: Loan
[Episode 90] Loss: 0.0678 | Accuracy: 96.70% | Task: Invoice
[Episode 100] Loss: 0.0221 | Accuracy: 99.09% | Task: Background Verification
[Episode 110] Loss: 0.0502 | Accuracy: 97.37% | Task: Loan
[Episode 120] Loss: 0.0061 | Accuracy: 100.00% | Task: Final Bill
[Episode 130] Loss: 0.0102 | Accuracy: 99.43% | Task: Loan
[Episode 140] Loss: 0.0044 | Accuracy: 100.00% | Task: 

In [20]:
# proto_eval.py

import torch
from torch_geometric.nn import GAT
# from proto_gat_main import GATEncoder, proto_predict

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------- Load Trained Encoder ---------
def load_encoder(model_path):
    encoder = GATEncoder().to(DEVICE)
    encoder.load_state_dict(torch.load(model_path, map_location=DEVICE))
    encoder.eval()
    return encoder

# --------- Load Graphs ---------
def load_graphs(path):
    return torch.load(path, map_location=DEVICE)

# --------- Run Prediction ---------
def run_inference(support_path, query_path, model_path):
    encoder = load_encoder(model_path)
    support_graphs = load_graphs(support_path)
    query_graphs = load_graphs(query_path)

    total_correct = 0
    total_nodes = 0

    for i, query_graph in enumerate(query_graphs):
        pred = proto_predict(encoder, support_graphs, query_graph)
        print(f"\n📄 Query Graph {i+1} Predictions:")
        print(pred.tolist())
        value_nodes = (pred == 1).nonzero(as_tuple=True)[0].tolist()
        print(f" VALUE nodes at indices: {value_nodes}")

        # Compute accuracy for this query graph if ground truth is available
        if hasattr(query_graph, 'y') and query_graph.y is not None:
            correct = (pred == query_graph.y.cpu()).sum().item()
            total = query_graph.y.size(0)
            acc = correct / total if total > 0 else 0.0
            print(f" Accuracy: {acc*100:.2f}% ({correct}/{total})")
            total_correct += correct
            total_nodes += total

    if total_nodes > 0:
        overall_acc = total_correct / total_nodes
        print(f"\n✅ Overall Accuracy across all query graphs: {overall_acc*100:.2f}% ({total_correct}/{total_nodes})")
    else:
        print("No ground truth labels found in query graphs for accuracy calculation.")

# # --------- Main ---------
# if __name__ == "__main__":
#     run_inference(
#         support_path="data/few-shot-dataset/invoice_support.pt",
#         query_path="data/few-shot-dataset/invoice_query.pt",
#         model_path="models/proto_gat_encoder.pt"
#     )


In [28]:
# Run inference and print accuracy for each query graph and overall accuracy
def run_inference_with_accuracy(support_path, query_path, model_path):
    encoder = load_encoder(model_path)
    support_graphs = load_graphs(support_path)
    query_graphs = load_graphs(query_path)

    total_correct = 0
    total_nodes = 0

    for i, query_graph in enumerate(query_graphs):
        pred = proto_predict(encoder, support_graphs, query_graph)
        true_labels = query_graph.y.cpu()
        correct = (pred == true_labels).sum().item()
        total = true_labels.size(0)
        acc = correct / total if total > 0 else 0.0
        total_correct += correct
        total_nodes += total

        print(f"\n📄 Query Graph {i+1} Predictions:")
        print(pred.tolist())
        value_nodes = (pred == 1).nonzero(as_tuple=True)[0].tolist()
        print(f" VALUE nodes at indices: {value_nodes}")
        print(f" Accuracy: {acc*100:.2f}% ({correct}/{total})")

    overall_acc = total_correct / total_nodes if total_nodes > 0 else 0.0
    print(f"\n✅ Overall Accuracy across all query graphs: {overall_acc*100:.2f}% ({total_correct}/{total_nodes})")

run_inference_with_accuracy(
    support_path="datacheckpoint_training_(15).pt",
    query_path="BG/datacheckpoint_10.pt",
    model_path="models/proto_gat_encoder.pt"
)


📄 Query Graph 1 Predictions:
[0, 1, 2, 0, 1, 2, 2, 0, 1, 2, 2, 0, 1, 2, 1, 0, 1, 2, 2, 0, 1, 2, 2, 0, 1, 2, 2, 0, 1, 2, 2, 0, 1, 2, 2, 0, 2, 2, 2, 0, 2, 2, 2, 0, 2, 2, 2, 0, 2, 2, 0, 2, 2, 1, 0, 2, 1, 0, 2, 3, 2, 3, 0, 2, 2, 3, 0, 2, 3, 2, 3, 0, 2, 3, 3, 0, 2, 3, 2, 3, 0, 2, 3, 3, 0, 2, 2, 2, 1, 0, 2, 2, 1, 0, 2, 2, 1, 0, 2, 2, 1, 0, 3, 1, 1, 0, 2, 3, 1, 0, 3, 2, 0, 3, 3, 1, 0, 2, 2, 3, 0, 2, 3, 3, 0, 2, 3, 1, 0, 2, 1]
 VALUE nodes at indices: [1, 4, 8, 12, 14, 16, 20, 24, 28, 32, 53, 56, 88, 92, 96, 100, 103, 104, 108, 115, 127, 130]
 Accuracy: 40.46% (53/131)

✅ Overall Accuracy across all query graphs: 40.46% (53/131)


  encoder.load_state_dict(torch.load(model_path, map_location=DEVICE))
  return torch.load(path, map_location=DEVICE)


In [11]:
# proto_single_graph_eval.py

import torch
import random
import numpy as np
# from proto_gat_main import GATEncoder, compute_prototypes, euclidean_distance

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------- Load Trained GAT Proto Encoder ---------
def load_encoder(model_path):
    encoder = GATEncoder().to(DEVICE)
    encoder.load_state_dict(torch.load(model_path, map_location=DEVICE))
    encoder.eval()
    return encoder

# --------- Evaluate Using Graph(s) ---------
def evaluate_single_graph(graph_path, model_path, support_ratio=0.2):
    print("📂 Loading graph(s) from:", graph_path)
    data = torch.load(graph_path, map_location=DEVICE)
    encoder = load_encoder(model_path)

    if isinstance(data, list):
        print("🔎 Detected list of graphs. Splitting support/query at graph level.")
        random.shuffle(data)
        split = int(support_ratio * len(data))
        if split == 0:
            split = 1
        support_graphs = data[:split]
        query_graphs = data[split:]

        support_emb, support_y = [], []
        for g in support_graphs:
            g = g.to(DEVICE)
            emb = encoder(g.x, g.edge_index, g.edge_attr)
            support_emb.append(emb)
            support_y.append(g.y)
        if not support_emb or not support_y:
            print("❌ No support graphs available. Please check your data or support_ratio.")
            return

        support_emb = torch.cat(support_emb, dim=0)
        support_y = torch.cat(support_y, dim=0)
        prototypes = compute_prototypes(support_emb, support_y)

        total_correct, total_nodes = 0, 0
        for i, g in enumerate(query_graphs):
            g = g.to(DEVICE)
            with torch.no_grad():
                emb = encoder(g.x, g.edge_index, g.edge_attr)
                dists = euclidean_distance(emb, prototypes)
                preds = dists.argmin(dim=1).cpu()
                correct = (preds == g.y.cpu()).sum().item()
                total = g.y.size(0)
                acc = correct / total if total > 0 else 0.0
                total_correct += correct
                total_nodes += total

                value_indices = (preds == 1).nonzero(as_tuple=True)[0].tolist()
                print(f"\n📄 Graph {i+1} Accuracy: {acc*100:.2f}% ({correct}/{total})")
                print(f"🎯 Predicted VALUE nodes: {value_indices}")

        overall_acc = total_correct / total_nodes if total_nodes > 0 else 0.0
        print(f"\n✅ Overall Accuracy across all query graphs: {overall_acc*100:.2f}%")

    else:
        print("🔎 Detected single graph. Splitting support/query at node level.")
        graph = data
        num_nodes = graph.x.size(0)
        indices = list(range(num_nodes))
        random.shuffle(indices)

        split = int(support_ratio * num_nodes)
        support_idx = indices[:split]
        query_idx = indices[split:]

        support_mask = torch.zeros(num_nodes, dtype=torch.bool)
        support_mask[support_idx] = True

        query_mask = torch.zeros(num_nodes, dtype=torch.bool)
        query_mask[query_idx] = True

        support_x = graph.x[support_mask]
        support_y = graph.y[support_mask]

        query_x = graph.x[query_mask]
        query_y = graph.y[query_mask]

        with torch.no_grad():
            embeddings = encoder(graph.x.to(DEVICE), graph.edge_index.to(DEVICE), graph.edge_attr.to(DEVICE))
            support_emb = embeddings[support_mask.to(DEVICE)]
            query_emb = embeddings[query_mask.to(DEVICE)]
            prototypes = compute_prototypes(support_emb, support_y.to(DEVICE))
            dists = euclidean_distance(query_emb, prototypes)
            preds = dists.argmin(dim=1).cpu()

        correct = (preds == query_y).sum().item()
        total = len(query_y)
        acc = correct / total if total > 0 else 0.0

        print(f"Evaluation Completed on Single Graph")
        print(f"Accuracy on Query Nodes: {acc*100:.2f}% ({correct}/{total})")

        value_indices = (preds == 1).nonzero(as_tuple=True)[0].tolist()
        print(f"Predicted VALUE nodes in query set: {value_indices}")

# --------- Main ---------
# if __name__ == "__main__":
#     evaluate_single_graph(
#         graph_path="data/sample_invoice.pt",
#         model_path="models/proto_gat_encoder.pt",
#         support_ratio=0.2
#     )


In [13]:
evaluate_single_graph(
        graph_path="datacheckpoint_01 (1).pt",
        model_path="models/proto_gat_encoder.pt",
        support_ratio=0.2
    )

📂 Loading graph(s) from: datacheckpoint_01 (1).pt
🔎 Detected list of graphs. Splitting support/query at graph level.

✅ Overall Accuracy across all query graphs: 0.00%


  data = torch.load(graph_path, map_location=DEVICE)
  encoder.load_state_dict(torch.load(model_path, map_location=DEVICE))
