In [None]:
# DGL is reliable in Colab for heterogeneous graphs.
!pip -q install dgl torch torchvision torchaudio
# Progress bars
!pip -q install tqdm

In [None]:
# 1) IMPORTS & SEEDING

import os, json, math, random
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

import dgl
from dgl.nn.pytorch import HeteroGraphConv, GraphConv

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device

In [None]:
# 2) DATA LOADING

DATA_PATH = "/content/dataset.json"  # change it accorling to uploaded path

def random_features(dim=8):
    return [float(np.round(random.random(), 3)) for _ in range(dim)]

def generate_graph(graph_id, node_types, edge_types,
                   num_nodes=300, num_edges=800, feat_dim=8):
    # nodes
    nodes = []
    for i in range(num_nodes):
        ntype = random.choice(node_types)
        nodes.append({
            "id": f"{graph_id}_N{i}",
            "type": ntype,
            "features": random_features(feat_dim)
        })
    # edges
    edges = []
    for _ in range(num_edges):
        src = random.choice(nodes)
        tgt = random.choice(nodes)
        while tgt["id"] == src["id"]:
            tgt = random.choice(nodes)
        relation = random.choice(edge_types)
        edges.append({
            "source": src["id"],
            "target": tgt["id"],
            "relation": relation
        })
    return {
        "graph_id": graph_id,
        "node_types": node_types,
        "edge_types": edge_types,
        "nodes": nodes,
        "edges": edges
    }

def load_or_make_dataset(path=DATA_PATH, feat_dim=8):
    if os.path.exists(path):
        print(f"Loading dataset from: {path}")
        with open(path, "r") as f:
            data = json.load(f)
        return data
    else:
        print("Dataset file not found — generating 5 synthetic graphs...")
        graphs = [
            generate_graph("G1", ["patient", "symptom", "disease"], ["has_symptom", "diagnosed_with"], feat_dim=feat_dim),
            generate_graph("G2", ["gene", "protein", "pathway"], ["encodes", "involved_in"], feat_dim=feat_dim),
            generate_graph("G3", ["drug", "target", "disease"], ["binds_to", "treats"], feat_dim=feat_dim),
            generate_graph("G4", ["treatment", "outcome", "disease"], ["leads_to", "associated_with"], feat_dim=feat_dim),
            generate_graph("G5", ["microbe", "metabolite", "disease"], ["produces", "affects"], feat_dim=feat_dim)
        ]
        with open(path, "w") as f:
            json.dump(graphs, f, indent=2)
        print(f"Saved synthetic dataset to: {path}")
        return graphs

raw_graphs = load_or_make_dataset()


In [None]:
# 3) BUILD DGL HETEROGRAPHS FROM JSON


def build_dgl_graph_from_json(jg):
    """
    jg: one JSON graph dict with keys:
        graph_id, node_types, edge_types, nodes, edges
    Returns: (g, ntype2id, feats_dict, node_id_map)
        g: DGLHeteroGraph
        ntype2id: map node_type -> list(node_ids in order)
        feats_dict: dict of {ntype: tensor features [N_t, d]}
        node_id_map: global_id_str -> (ntype, idx_in_type)
    """
    node_types = jg["node_types"]
    edge_types = jg["edge_types"]
    nodes = jg["nodes"]
    edges = jg["edges"]

    # Separate node IDs by type
    ntype2nodes = defaultdict(list)
    for n in nodes:
        ntype2nodes[n["type"]].append(n)

    # Assign per-type indices and gather features
    ntype2id = {}
    feats = {}
    node_id_map = {}  # str_id -> (ntype, idx)
    for ntype, nlist in ntype2nodes.items():
        ntype2id[ntype] = []
        feat_list = []
        for idx, obj in enumerate(nlist):
            ntype2id[ntype].append(obj["id"])
            node_id_map[obj["id"]] = (ntype, idx)
            feat_list.append(obj["features"])
        feats[ntype] = torch.tensor(np.array(feat_list, dtype=np.float32))

    # Build typed edge lists
    etype2edges = defaultdict(lambda: ([], []))  # (src_idx_list, dst_idx_list)
    for e in edges:
        r = e["relation"]
        u = e["source"]
        v = e["target"]
        if u not in node_id_map or v not in node_id_map:
            continue
        u_type, u_idx = node_id_map[u]
        v_type, v_idx = node_id_map[v]
        etype = (u_type, r, v_type)
        etype2edges[etype][0].append(u_idx)
        etype2edges[etype][1].append(v_idx)

    # Build heterograph
    data_dict = {}
    for etype, (src_list, dst_list) in etype2edges.items():
        src_nt, rel, dst_nt = etype
        data_dict[etype] = (torch.tensor(src_list), torch.tensor(dst_list))

    g = dgl.heterograph(data_dict)

    # Attach node features per type
    for ntype in g.ntypes:
        if ntype in feats:
            g.nodes[ntype].data["h"] = feats[ntype]
        else:
            # If a type has no nodes
            pass
    return g, ntype2id, feats, node_id_map

dgl_graphs = []
meta_domains = []  # just keep graph_id meta
for jg in raw_graphs:
    g, nmap, feats, nidmap = build_dgl_graph_from_json(jg)
    dgl_graphs.append({"graph_id": jg["graph_id"], "g": g, "ntype2ids": nmap, "feats": feats, "nidmap": nidmap})
    meta_domains.append(jg["graph_id"])

print("Built heterographs:", [x["graph_id"] for x in dgl_graphs])
print("Node types per graph:", dgl_graphs[0]["g"].ntypes)
print("Edge types per graph:", dgl_graphs[0]["g"].etypes[:5], "...")

In [None]:
#4) FEW-SHOT TASK DEFINITION (labels on disease nodes)

C = 5          # number of disease classes per graph
FEAT_DIM = next(iter(dgl_graphs[0]["feats"].values())).shape[1]

def assign_pseudo_labels_for_disease(gpack, num_classes=C):
    """
    Assigns pseudo labels [0..C-1] to disease nodes uniformly/randomly.
    Returns dict with:
        'labels': tensor [num_disease_nodes]
        'idxs': list of node indices for disease nodes
    """
    g = gpack["g"]
    if "disease" not in g.ntypes:
        return {"labels": None, "idxs": []}

    N = g.num_nodes("disease")
    if N == 0:
        return {"labels": None, "idxs": []}

    labels = torch.randint(low=0, high=num_classes, size=(N,))
    return {"labels": labels, "idxs": list(range(N))}

# Attach labels to each graph (for disease nodes only)
for gp in dgl_graphs:
    gp["disease_labels"] = assign_pseudo_labels_for_disease(gp, C)


In [None]:

# 5) CH-MHGNN-STYLE ENCODER

class CHMHGNNEncoder(nn.Module):
    def __init__(self, in_dims_dict, hidden_dim=64, out_dim=64):
        super().__init__()
        # Build per-rel GraphConv for layer1 and layer2
        self.layer1 = HeteroGraphConv(
            { et: GraphConv(in_feats=in_dims_dict[et[0]], out_feats=hidden_dim)
              for et in [] }, aggregate='sum'
        )
        self.layer2 = HeteroGraphConv(
            { et: GraphConv(in_feats=hidden_dim, out_feats=out_dim)
              for et in [] }, aggregate='sum'
        )
        self.act = nn.ReLU()

        # Since DGL needs modules for each existing edge type at runtime,
        # we will lazily (re)build conv dicts once we see the concrete graph.

        self._cached_etypes = None
        self._built1 = None
        self._built2 = None

        # Final projection per node type for alignment
        self.proj = nn.ModuleDict()

    def _ensure_layers(self, g, in_dims_dict, hidden_dim=64, out_dim=64):
        etypes = g.canonical_etypes
        if self._cached_etypes == etypes:
            return  # nothing to do

        # Layer1 per etype
        mods1 = {}
        for (src, rel, dst) in etypes:
            mods1[(src, rel, dst)] = GraphConv(in_feats=in_dims_dict[src], out_feats=hidden_dim)
        self.layer1 = HeteroGraphConv(mods1, aggregate='sum')

        # Layer2 per etype
        mods2 = {}
        for (src, rel, dst) in etypes:
            mods2[(src, rel, dst)] = GraphConv(in_feats=hidden_dim, out_feats=out_dim)
        self.layer2 = HeteroGraphConv(mods2, aggregate='sum')

        # Projections per node type
        self.proj = nn.ModuleDict()
        for ntype in g.ntypes:
            self.proj[ntype] = nn.Linear(out_dim, out_dim)

        self._cached_etypes = etypes

    def forward(self, g, feats_dict):
        """
        feats_dict: {ntype: tensor [N_t, d_in_t]}
        Returns:
            z_dict: {ntype: tensor [N_t, out_dim]}
        """
        # Infer in_dims for the current graph
        in_dims = {nt: feats_dict[nt].shape[1] for nt in feats_dict.keys()}
        self._ensure_layers(g, in_dims, hidden_dim=64, out_dim=64)

        h_dict = self.layer1(g, feats_dict)
        h_dict = {k: self.act(v) for k, v in h_dict.items()}
        h_dict = self.layer2(g, h_dict)
        # projections for alignment
        z_dict = {nt: self.proj[nt](h) for nt, h in h_dict.items()}
        return z_dict

# Initialize a single encoder to be shared across tasks (graphs)
# Figure out initial input dims per node type from the first graph
# Fallback: if a node type missing in graph 0, we’ll infer at runtime
all_ntypes = set()
for gp in dgl_graphs:
    all_ntypes.update(gp["g"].ntypes)

# Guess per-type input dims (8 by default if not present in graph0)
in_dims_dict = {}
for nt in all_ntypes:
    # try to find first graph that has this node type
    d_in = None
    for gp in dgl_graphs:
        if nt in gp["feats"]:
            d_in = gp["feats"][nt].shape[1]
            break
    in_dims_dict[nt] = d_in if d_in is not None else 8

encoder = CHMHGNNEncoder(in_dims_dict, hidden_dim=64, out_dim=64).to(device)

In [None]:

# 6) PROTOTYPICAL EPISODE (FEW-SHOT)

def sample_support_query(labels, K=5, max_query_per_class=20):
    """
    labels: tensor [N] with class ids 0..C-1
    returns: dict { 'support_idx', 'query_idx' }
    """
    support_idx, query_idx = [], []
    labels = labels.cpu()
    classes = torch.unique(labels).tolist()
    for c in classes:
        idx = (labels == c).nonzero(as_tuple=True)[0].tolist()
        random.shuffle(idx)
        s = idx[:K]
        q = idx[K:K+max_query_per_class]
        if len(s) > 0:
            support_idx.extend(s)
        if len(q) > 0:
            query_idx.extend(q)
    return {"support_idx": support_idx, "query_idx": query_idx}

def compute_prototypes(emb, labels, idxs):
    """
    emb: [N, d], labels: [N], idxs: list of indices for support
    return: dict class_id -> prototype [d]
    """
    proto = {}
    if len(idxs) == 0:
        return proto
    sel_labels = labels[idxs]
    classes = torch.unique(sel_labels).tolist()
    for c in classes:
        mask = (sel_labels == c).nonzero(as_tuple=True)[0]
        cls_vecs = emb[idxs][mask]  # [Nc, d]
        proto[c] = cls_vecs.mean(0)
    return proto

def classify_by_prototypes(emb, proto):
    """
    emb: [Nq, d]
    proto: dict class_id -> [d]
    return: logits [Nq, C'] built from negative euclidean distance
    """
    if len(proto) == 0:
        raise ValueError("No prototypes found. Increase K or check labels.")
    classes = sorted(proto.keys())
    P = torch.stack([proto[c] for c in classes], dim=0)  # [C', d]
    # cosine or euclidean; we'll use cosine similarity as logits
    emb_norm = F.normalize(emb, p=2, dim=-1)
    P_norm = F.normalize(P, p=2, dim=-1)
    logits = emb_norm @ P_norm.T  # [Nq, C']
    return logits, classes

def cross_graph_alignment_loss(disease_protos):
    """
    Encourage disease prototypes across graphs to align.
    disease_protos: list of dict {class_id -> prototype}
    Loss = average pairwise distance between same class id across graphs
    (class ids are local; for synthetic labels this is a soft heuristic).
    """
    # In real data you'd align by shared diseases (semantic anchors).
    # Here we softly align the mean prototype across graphs.
    if len(disease_protos) < 2:
        return torch.tensor(0., device=device)

    # Compute global mean proto per graph (averaging all classes)
    graph_means = []
    for p in disease_protos:
        if len(p) == 0:
            continue
        P = torch.stack(list(p.values()), dim=0)  # [C_g, d]
        graph_means.append(P.mean(0))
    if len(graph_means) < 2:
        return torch.tensor(0., device=device)

    loss = 0.
    cnt = 0
    for i in range(len(graph_means)):
        for j in range(i+1, len(graph_means)):
            loss = loss + F.mse_loss(graph_means[i], graph_means[j])
            cnt += 1
    return loss / max(cnt, 1)

In [None]:
# 7) META-TRAIN / META-VAL / META-TEST SPLIT (by graphs)

assert len(dgl_graphs) >= 5, "Need 5 graphs for this split."
train_graphs = dgl_graphs[:3]
val_graphs   = [dgl_graphs[3]]
test_graphs  = [dgl_graphs[4]]

print("Meta-train:", [g["graph_id"] for g in train_graphs])
print("Meta-val:  ", [g["graph_id"] for g in val_graphs])
print("Meta-test: ", [g["graph_id"] for g in test_graphs])

In [None]:
# 8) TRAINING CONFIG

K_SHOT = 3
MAX_Q = 10
EPISODES_PER_EPOCH = 15
VAL_EPISODES = 5
TEST_EPISODES = 15
LR = 1e-3
EPOCHS = 5
ALIGN_COEF = 0.1  # weight for cross-graph alignment loss
WD = 1e-5

optimizer = torch.optim.AdamW(encoder.parameters(), lr=LR, weight_decay=WD)

# Helper to run one episode on one graph
def run_episode_on_graph(gpack, train_mode=True, k_shot=K_SHOT, max_q=MAX_Q):
    g = gpack["g"].to(device)
    feats_dict = {nt: g.nodes[nt].data["h"].to(device) for nt in g.ntypes}

    # forward encoder
    z_dict = encoder(g, feats_dict)  # {ntype: [N_t, d]}

    # only classify disease nodes
    d_labels = gpack["disease_labels"]["labels"]
    d_idxs   = gpack["disease_labels"]["idxs"]

    if d_labels is None or len(d_idxs) == 0:
        return torch.tensor(0., device=device), 0., 0.  # no-op

    d_emb = z_dict["disease"]  # [Nd, d]
    d_lab = d_labels.to(device)

    # support/query
    split = sample_support_query(d_lab, K=k_shot, max_query_per_class=max_q)
    if len(split["support_idx"]) == 0 or len(split["query_idx"]) == 0:
        return torch.tensor(0., device=device), 0., 0.

    proto = compute_prototypes(d_emb, d_lab, split["support_idx"])
    q_emb = d_emb[split["query_idx"]]
    q_lab = d_lab[split["query_idx"]]

    logits, classes = classify_by_prototypes(q_emb, proto)
    # Map q_lab to indices in 'classes'
    class2pos = {c:i for i,c in enumerate(classes)}
    y = torch.tensor([class2pos[int(c.cpu())] for c in q_lab], device=device)

    loss = F.cross_entropy(logits, y)
    # Accuracy
    pred = logits.argmax(dim=1)
    acc = (pred == y).float().mean().item()

    return loss, acc, logits.size(0)

def evaluate(graph_list, episodes=VAL_EPISODES, k_shot=K_SHOT, max_q=MAX_Q):
    encoder.eval()
    ep_losses, ep_accs, ep_counts = [], [], []
    with torch.no_grad():
        for _ in range(episodes):
            # sample a random graph (task)
            gp = random.choice(graph_list)
            loss, acc, nq = run_episode_on_graph(gp, train_mode=False, k_shot=k_shot, max_q=max_q)
            if nq > 0:
                ep_losses.append(loss.item())
                ep_accs.append(acc)
                ep_counts.append(nq)
    if len(ep_losses) == 0:
        return 0., 0.
    return float(np.mean(ep_losses)), float(np.mean(ep_accs))

In [None]:
# 9) TRAINING LOOP (meta-train on graphs; align across tasks)

best_val = -1.0
best_state = None

for epoch in range(1, EPOCHS+1):
    encoder.train()
    running_loss = 0.0
    running_acc = 0.0
    count = 0

    for _ in tqdm(range(EPISODES_PER_EPOCH), desc=f"Epoch {epoch}/{EPOCHS}"):
        optimizer.zero_grad()

        # Optionally aggregate prototypes for alignment
        disease_protos_all = []

        # Sample a mini-batch of tasks (graphs)
        tasks = [random.choice(train_graphs) for __ in range(3)]  # 3 tasks per episode
        total_loss = torch.tensor(0., device=device)

        for gp in tasks:
            loss, acc, nq = run_episode_on_graph(gp, train_mode=True, k_shot=K_SHOT, max_q=MAX_Q)
            if nq == 0:
                continue
            total_loss = total_loss + loss
            running_loss += loss.item()
            running_acc += acc
            count += 1

            # collect disease prototypes for alignment
            # recompute embeddings for prototype extraction
            g = gp["g"].to(device)
            feats_dict = {nt: g.nodes[nt].data["h"].to(device) for nt in g.ntypes}
            with torch.no_grad():
                z = encoder(g, feats_dict)
            d_labels = gp["disease_labels"]["labels"]
            d_idxs = gp["disease_labels"]["idxs"]
            if d_labels is not None and len(d_idxs) > 0:
                split = sample_support_query(d_labels, K=K_SHOT, max_query_per_class=MAX_Q)
                if len(split["support_idx"]) > 0:
                    p = compute_prototypes(z["disease"], d_labels.to(device), split["support_idx"])
                    disease_protos_all.append(p)

        # cross-graph alignment (encourages transferable disease structure)
        if len(disease_protos_all) >= 2:
            align_loss = cross_graph_alignment_loss(disease_protos_all)
            total_loss = total_loss + ALIGN_COEF * align_loss

        if total_loss.requires_grad:
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=5.0)
            optimizer.step()

    train_loss = running_loss / max(count, 1)
    train_acc  = running_acc / max(count, 1)

    val_loss, val_acc = evaluate(val_graphs, episodes=VAL_EPISODES, k_shot=K_SHOT, max_q=MAX_Q)

    print(f"[Epoch {epoch}] Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.3f} | "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.3f}")

    if val_acc > best_val:
        best_val = val_acc
        best_state = {k: v.cpu().clone() for k, v in encoder.state_dict().items()}

# Load best model
if best_state is not None:
    encoder.load_state_dict(best_state)
    print(f"Loaded best validation model (Val Acc = {best_val:.3f}).")

In [None]:
# 10) FINAL EVALUATION ON META-TEST GRAPH(S)

test_loss, test_acc = evaluate(test_graphs, episodes=TEST_EPISODES, k_shot=K_SHOT, max_q=MAX_Q)
print(f"[META-TEST] Loss: {test_loss:.4f} | Acc: {test_acc:.3f}"