In [9]:
import os
import torch
from torch_geometric.data import Data

# Define your document types and their data folder paths
task_folders = {
    "Invoice": "data/invoice/",
    "Loan": "data/loan/",
    "Final Bill": "data/final_bill/",
    "Background Verification": "data/background_verification/",
    "Operative Report": "data/operative_report/"
}

all_graphs = []

for task_name, folder_path in task_folders.items():
    if not os.path.exists(folder_path):
        print(f"⚠️ Folder not found: {folder_path}")
        continue
    graph_files = [f for f in os.listdir(folder_path) if f.endswith(".pt")]

    print(f"🔍 Processing {len(graph_files)} files for task: {task_name}")
    for file in graph_files:
        graph_path = os.path.join(folder_path, file)
        data = torch.load(graph_path)
        if isinstance(data, list):
           for d in data:
               d.task = task_name
               all_graphs.append(d)
        else:
            data.task = task_name
            all_graphs.append(data)

# Save to one master file
os.makedirs("data/few-shot-dataset", exist_ok=True)
torch.save(all_graphs, "data/few-shot-dataset/fewshot_dataset.pt")
print(f"Saved all {len(all_graphs)} graphs to fewshot_dataset.pt")


🔍 Processing 5 files for task: Invoice
🔍 Processing 5 files for task: Loan
🔍 Processing 5 files for task: Final Bill
🔍 Processing 5 files for task: Background Verification
🔍 Processing 5 files for task: Operative Report
Saved all 25 graphs to fewshot_dataset.pt


  data = torch.load(graph_path)


In [11]:
data_list = torch.load("data/few-shot-dataset/fewshot_dataset.pt")
print(data_list[0].task)  # should print "Invoice" or similar

Invoice


  data_list = torch.load("data/few-shot-dataset/fewshot_dataset.pt")


In [15]:
# maml_runner.py

import torch
import torch.nn.functional as F
import random
import higher
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GAT

# ------------- Config -------------------
IN_CHANNELS = 18
OUT_CLASSES = 4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------- GAT Model --------------------
def build_gat_model(hidden=128, heads=4, dropout=0.2, layers=2):
    return GAT(
        in_channels=IN_CHANNELS,
        hidden_channels=hidden,
        out_channels=OUT_CLASSES,
        heads=heads,
        num_layers=layers,
        dropout=dropout,
        edge_dim=1,
        v2=True,
        jk='lstm'
    ).to(DEVICE)

# --------- Episode Sampler -------------
def sample_episode(data_list, task, k_shot=4, q_num=1):
    task_data = [d for d in data_list if getattr(d, 'task', None) == task]
    assert len(task_data) >= k_shot + q_num, f"Not enough data for task: {task}"
    random.shuffle(task_data)
    return task_data[:k_shot], task_data[k_shot:k_shot + q_num]

# --------- MAML Training Loop ----------
def maml_train(data_list, model, optimizer, inner_steps=1, n_episodes=500):
    model.train()
    tasks = list(set(d.task for d in data_list))

    for episode in range(n_episodes):
        task = random.choice(tasks)
        support_set, query_set = sample_episode(data_list, task, k_shot=4, q_num=1)

        model.zero_grad()
        with torch.backends.cudnn.flags(enabled=False):
           with higher.innerloop_ctx(model, optimizer, copy_initial_weights=False) as (fmodel, diffopt):
           # Inner loop adaptation
                for _ in range(inner_steps):
                   for support in support_set:
                       support = support.to(DEVICE)
                       out = fmodel(support.x, support.edge_index, edge_weight=support.edge_attr)
                       loss = F.cross_entropy(out, support.y)
                       diffopt.step(loss)
                # Outer loop: evaluate on query
                query = query_set[0].to(DEVICE)
                out = fmodel(query.x, query.edge_index, edge_weight=query.edge_attr)
                loss = F.cross_entropy(out, query.y)
                loss.backward()
                optimizer.step()

        if episode % 50 == 0:
            print(f"[Episode {episode}] Meta-loss: {loss.item():.4f} | Task: {task}")

# --------- MAML Inference -------------
def maml_infer(model, support_set, query_doc, optimizer, inner_steps=1):
    model.eval()

    with higher.innerloop_ctx(model, optimizer, track_higher_grads=False) as (fmodel, diffopt):
        # Adapt on support
        for _ in range(inner_steps):
            for support in support_set:
                support = support.to(DEVICE)
                out = fmodel(support.x, support.edge_index, edge_weight=support.edge_attr)
                loss = F.cross_entropy(out, support.y)
                diffopt.step(loss)

        # Predict on query
        query_doc = query_doc.to(DEVICE)
        out = fmodel(query_doc.x, query_doc.edge_index, edge_weight=query_doc.edge_attr)
        preds = out.argmax(dim=1)

    return preds

# # --------- Main Runner -----------------
# if __name__ == "__main__":
#     print(" Loading few-shot dataset...")
#     data_list = torch.load("data\few-shot-dataset\fewshot_dataset.pt", map_location=DEVICE)

#     model = build_gat_model()
#     optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

#     print(" Starting MAML training...")
#     maml_train(data_list, model, optimizer, inner_steps=1, n_episodes=500)

#     print(" Saving MAML-trained model...")
#     torch.save(model.state_dict(), "models\maml_gat_model.pt")


In [16]:
print(" Loading few-shot dataset...")
data_list = torch.load("data\\few-shot-dataset\\fewshot_dataset.pt", map_location=DEVICE)

model = build_gat_model()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

print(" Starting MAML training...")
maml_train(data_list, model, optimizer, inner_steps=1, n_episodes=500)

print(" Saving MAML-trained model...")
torch.save(model.state_dict(), "models\maml_gat_model.pt")

 Loading few-shot dataset...


  data_list = torch.load("data\\few-shot-dataset\\fewshot_dataset.pt", map_location=DEVICE)


 Starting MAML training...
[Episode 0] Meta-loss: 1.1783 | Task: Loan
[Episode 50] Meta-loss: 0.5340 | Task: Loan
[Episode 100] Meta-loss: 0.2098 | Task: Invoice
[Episode 150] Meta-loss: 0.0344 | Task: Invoice
[Episode 200] Meta-loss: 0.0072 | Task: Invoice
[Episode 250] Meta-loss: 0.0185 | Task: Invoice
[Episode 300] Meta-loss: 0.0068 | Task: Final Bill
[Episode 350] Meta-loss: 0.0008 | Task: Loan
[Episode 400] Meta-loss: 0.0190 | Task: Loan
[Episode 450] Meta-loss: 0.0190 | Task: Final Bill
 Saving MAML-trained model...


In [21]:
# maml_evaluator.py

import torch
import torch.nn.functional as F
from torch_geometric.nn import GAT
import higher
import random

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- GAT Builder (same config as maml_runner) ----------
def build_gat_model(hidden=128, heads=4, dropout=0.2, layers=2):
    return GAT(
        in_channels=18,
        hidden_channels=hidden,
        out_channels=4,
        heads=heads,
        num_layers=layers,
        dropout=dropout,
        edge_dim=1,
        v2=True,
        jk='lstm'
    ).to(DEVICE)

# ---------- Episode Sampler ----------
def sample_episode(data_list, task, k_shot=4, q_num=1):
    task_data = [d for d in data_list if getattr(d, 'task', None) == task]
    random.shuffle(task_data)
    return task_data[:k_shot], task_data[k_shot:k_shot + q_num]

# ---------- Inference Logic ----------
def maml_infer(model, support_set, query_doc, optimizer, inner_steps=1):
    model.train()

    with higher.innerloop_ctx(model, optimizer, track_higher_grads=False) as (fmodel, diffopt):
        for _ in range(inner_steps):
            for support in support_set:
                support = support.to(DEVICE)
                out = fmodel(support.x, support.edge_index, edge_weight=support.edge_attr)
                loss = F.cross_entropy(out, support.y)
                diffopt.step(loss)

        query_doc = query_doc.to(DEVICE)
        out = fmodel(query_doc.x, query_doc.edge_index, edge_weight=query_doc.edge_attr)
        pred = out.argmax(dim=1)

    return pred.cpu(), query_doc.y.cpu()

# ---------- Evaluation Loop ----------
def evaluate_model(data_list, model_path):
    print("🔍 Loading model...")
    model = build_gat_model()
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    tasks = list(set(d.task for d in data_list))
    all_preds, all_trues = [], []

    for task in tasks:
        support_set, query_set = sample_episode(data_list, task, k_shot=4, q_num=1)
        pred, true = maml_infer(model, support_set, query_set[0], optimizer)
        all_preds.append(pred)
        all_trues.append(true)

        print(f" Task: {task} | VALUE nodes predicted: {(pred == 1).sum().item()} | True: {(true == 1).sum().item()}")

    return all_preds, all_trues

# # ---------- Main Runner ----------
# if __name__ == "__main__":
#     data_list = torch.load("fewshot_dataset.pt", map_location=DEVICE)
#     evaluate_model(data_list, model_path="maml_gat_model.pt")


In [22]:
data_list = torch.load("data\\few-shot-dataset\\fewshot_dataset.pt", map_location=DEVICE)
evaluate_model(data_list, model_path="models\maml_gat_model.pt")

  data_list = torch.load("data\\few-shot-dataset\\fewshot_dataset.pt", map_location=DEVICE)


🔍 Loading model...


  model.load_state_dict(torch.load(model_path, map_location=DEVICE))


 Task: Final Bill | VALUE nodes predicted: 0 | True: 0
 Task: Invoice | VALUE nodes predicted: 0 | True: 0
 Task: Loan | VALUE nodes predicted: 0 | True: 0
 Task: Operative Report | VALUE nodes predicted: 0 | True: 0
 Task: Background Verification | VALUE nodes predicted: 0 | True: 0


([tensor([0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3,
          3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 3, 0,
          3, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 3,
          0, 3, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3,
          0, 3, 3, 3, 0, 3, 3, 3, 3]),
  tensor([0, 3, 3, 3, 0, 3, 3, 0, 3, 3, 3, 0, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 3, 0,
          3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3,
          3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3,
          3, 0, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 3,
          0, 3, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3]),
  tensor([0, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3,
          3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3,
          3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3, 0, 3, 3, 3,

In [30]:
import torch
from torch_geometric.nn import GAT

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ----------- 1. Build the GAT Model ----------
def build_gat_model(hidden=128, heads=4, dropout=0.2, layers=2):
    return GAT(
        in_channels=18,
        hidden_channels=hidden,
        out_channels=4,
        heads=heads,
        num_layers=layers,
        dropout=dropout,
        edge_dim=1,
        v2=True,
        jk='lstm'  # or 'cat' if you trained with that
    ).to(DEVICE)

# ----------- 2. Load Trained Model ----------
model = build_gat_model()
model.load_state_dict(torch.load("models/maml_gat_model.pt", map_location=DEVICE))
model.eval()

# ----------- 3. Load Graphs with Labels ----------
graphs = torch.load("BG/datacheckpoint_11.pt", map_location=DEVICE)

# ----------- 4. Predict + Compare ----------
for i, graph in enumerate(graphs):
    graph = graph.to(DEVICE)
    true_labels = graph.y

    with torch.no_grad():
        out = model(graph.x, graph.edge_index, edge_weight=graph.edge_attr)
        pred = out.argmax(dim=1)

    # Compare predictions
    correct = (pred == true_labels).sum().item()
    total = len(true_labels)
    accuracy = correct / total

    # Value-specific analysis
    true_value_indices = (true_labels == 1).nonzero(as_tuple=True)[0]
    pred_value_indices = (pred == 1).nonzero(as_tuple=True)[0]
    true_positive = len(set(pred_value_indices.tolist()) & set(true_value_indices.tolist()))

    print(f"\n📄 Graph {i+1}:")
    print(f"✅ Accuracy: {accuracy*100:.2f}% ({correct}/{total})")
    print(f"🔍 True VALUE node indices: {true_value_indices.tolist()}")
    print(f"🔮 Predicted VALUE node indices: {pred_value_indices.tolist()}")
    print(f"🎯 Correctly predicted VALUEs: {true_positive} / {len(true_value_indices)}")



📄 Graph 1:
✅ Accuracy: 100.00% (113/113)
🔍 True VALUE node indices: []
🔮 Predicted VALUE node indices: []
🎯 Correctly predicted VALUEs: 0 / 0


  model.load_state_dict(torch.load("models/maml_gat_model.pt", map_location=DEVICE))
  graphs = torch.load("BG/datacheckpoint_11.pt", map_location=DEVICE)
