In [1]:
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import random
import pickle

import torch
import torch.nn.functional as F
from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, Linear, SAGEConv
import numpy as np
import random
from sklearn.metrics import roc_auc_score, average_precision_score

from torch_geometric.data import HeteroData

np.random.seed(42)
random.seed(42)
torch.manual_seed(42)


<torch._C.Generator at 0x7f8f17ffd390>

### 1. Get data 

In [2]:
# Load provider features : 
with open("final_df.pickle", "rb") as pickle_file : 
    df_provider_features = pickle.load(pickle_file)

providers_dataset = df_provider_features.index.to_list()

In [3]:
# Load member features : 
with open("final_members_df.pickle", "rb") as pickle_file : 
    df_member_features = pickle.load(pickle_file)

members_dataset = df_member_features.index.to_list()

In [4]:
# Load claims data :
with open("df_descriptions.pickle", 'rb') as pickle_file : 
    df = pickle.load(pickle_file)

df_edges = df[["providercode", "membercode", "claimcode" ]]
df_edges = df_edges.groupby(["providercode", "membercode"]).agg({"claimcode" : "nunique"}).reset_index()

df_edges = df_edges.loc[((df_edges.membercode.isin(members_dataset))
                         & (df_edges.providercode.isin(providers_dataset)))]
df_edges.rename(columns={"membercode":"member_id", 
                "providercode" :"provider_id",
                "claimcode" :"nbr_claims", 
                }, inplace=True)

### 2. Build Graph 

In [5]:
providers = df_provider_features.index.tolist() 
members   = df_member_features.index.tolist() 

provider2idx = {p: i for i, p in enumerate(providers)}  
member2idx   = {m: i for i, m in enumerate(members)}   

data = HeteroData()

# 1) Provider node features
provider_feats = torch.tensor(df_provider_features.values, dtype=torch.float)
data["provider"].x = provider_feats  # shape: [num_providers, provider_feat_dim]

# 2) Member node features
member_feats = torch.tensor(df_member_features.values, dtype=torch.float)
data["member"].x = member_feats      # shape: [num_members, member_feat_dim]

edge_list = []
edge_weight_list = []

for row in df_edges.itertuples(index=False):
    # row.provider_id, row.member_id, row.nbr_claims
    if row.provider_id in provider2idx and row.member_id in member2idx:
        p_idx = provider2idx[row.provider_id]
        m_idx = member2idx[row.member_id]
        edge_list.append([p_idx, m_idx])
        edge_list.append([m_idx, p_idx])
        edge_weight_list.append(float(row.nbr_claims))
        edge_weight_list.append(float(row.nbr_claims))

# Convert to torch
edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()  # shape [2, E]
edge_attr  = torch.tensor(edge_weight_list, dtype=torch.float)

# Assign to data
data["provider", "links", "member"].edge_index = edge_index
data["provider", "links", "member"].edge_attr  = edge_attr


# Initialize anomaly labels for each node type (0=normal, 1=anomalous)
data["provider"].synthetic_labels = torch.zeros(data["provider"].num_nodes, dtype=torch.long)
data["member"].synthetic_labels   = torch.zeros(data["member"].num_nodes, dtype=torch.long)


In [9]:
data["provider"].x.shape

torch.Size([652, 38])

### 3. Synthetic labels

#### 3.1 Knowledge-based anomalies

In [14]:
def inject_ghost_members(
    data, 
    ratio=0.01, 
    connections_per_ghost=3
):
    """
    Create 'ghost' member nodes that do not exist originally,
    connect them to random providers, mark them (and those providers) as anomalous.
    
    :param data: The HeteroData object
    :param ratio: fraction of the *current* total members to add as new ghost members
    :param connections_per_ghost: how many providers each ghost node connects to
    """
    device = data["member"].x.device
    
    old_num_members = data["member"].num_nodes
    num_ghost = max(1, int(ratio * old_num_members))  # at least 1 ghost
    
    feat_dim = data["member"].x.size(1)
    
    # 1) Create random features for ghost members
    ghost_feats = torch.randn(num_ghost, feat_dim, device=device) * 0.01
    new_member_x = torch.cat([data["member"].x, ghost_feats], dim=0)
    data["member"].x = new_member_x
    
    # Expand the synthetic_labels
    old_labels = data["member"].synthetic_labels
    new_labels = torch.ones(num_ghost, dtype=torch.long, device=device)  # All ghost = anomaly
    data["member"].synthetic_labels = torch.cat([old_labels, new_labels], dim=0)
    
    # 2) Connect these ghost members to random providers
    edge_list = data["provider", "links", "member"].edge_index.clone().t().tolist()
    edge_attr_list = data["provider", "links", "member"].edge_attr.tolist()
    
    num_providers = data["provider"].num_nodes
    
    for g in range(num_ghost):
        ghost_member_idx = old_num_members + g
        for _ in range(connections_per_ghost):
            p_idx = random.randint(0, num_providers - 1)
            edge_list.append([p_idx, ghost_member_idx])
            # Use a random or default weight (nbr_claims)
            w = float(random.randint(1, 5))
            edge_attr_list.append(w)
            
            # Mark provider as anomaly
            data["provider"].synthetic_labels[p_idx] = 1

    # Convert back to tensor
    new_edge_index = torch.tensor(edge_list, dtype=torch.long, device=device).t().contiguous()
    new_edge_attr  = torch.tensor(edge_attr_list, dtype=torch.float, device=device)

    data["provider", "links", "member"].edge_index = new_edge_index
    data["provider", "links", "member"].edge_attr  = new_edge_attr
    
    return data

def inject_upcoding(
    data,
    ratio=0.01,
    scale_factor=5.0
):
    """
    Randomly select a fraction of edges, inflate their edge_attr (nbr_claims) by scale_factor,
    mark involved nodes as anomaly.
    
    :param data: The HeteroData object
    :param ratio: fraction of edges to inflate
    :param scale_factor: how much to multiply the existing edge_attr
    """
    edge_index = data["provider", "links", "member"].edge_index
    edge_attr  = data["provider", "links", "member"].edge_attr
    
    E = edge_index.size(1)
    num_inflate = max(1, int(ratio * E))
    
    chosen_edges = random.sample(range(E), k=num_inflate)
    
    for e_idx in chosen_edges:
        edge_attr[e_idx] *= scale_factor
        
        # Mark provider & member as anomalies
        p_idx = edge_index[0, e_idx]
        m_idx = edge_index[1, e_idx]
        
        data["provider"].synthetic_labels[p_idx] = 1
        data["member"].synthetic_labels[m_idx]   = 1
    
    return data

def inject_collusion_ring(
    data,
    ratio=0.01,
    partial_density=1.0
):
    """
    Create a small 'collusion ring': 
    - Select a small subset of providers + members
    - Connect them with a fully dense or partially dense bipartite block
    - Mark them as anomalies
    
    :param data: HeteroData object
    :param ratio: fraction of the existing node sets to pick for collusion ring
    :param partial_density: fraction of the fully dense edges to actually add (0.2 to 1.0)
    """
    device = data["member"].x.device
    
    num_providers = data["provider"].num_nodes
    num_members   = data["member"].num_nodes
    
    # Small subsets
    sub_p = max(1, int(ratio * num_providers))
    sub_m = max(1, int(ratio * num_members))
    
    # Randomly pick providers
    provider_indices = random.sample(range(num_providers), k=sub_p)
    # Randomly pick members
    member_indices   = random.sample(range(num_members), k=sub_m)
    
    edge_list = data["provider", "links", "member"].edge_index.clone().t().tolist()
    edge_attr_list = data["provider", "links", "member"].edge_attr.tolist()
    
    # Fully dense block means all provider_indices x member_indices
    # partial_density means we only keep some fraction
    possible_edges = []
    for p in provider_indices:
        for m in member_indices:
            possible_edges.append((p, m))
    
    num_block_edges = int(partial_density * len(possible_edges))
    block_edges = random.sample(possible_edges, k=num_block_edges)
    
    for (p_idx, m_idx) in block_edges:
        edge_list.append([p_idx, m_idx])
        w = float(random.randint(1, 10))  # random claims
        edge_attr_list.append(w)
        
        # Mark them as anomaly
        data["provider"].synthetic_labels[p_idx] = 1
        data["member"].synthetic_labels[m_idx]   = 1

    new_edge_index = torch.tensor(edge_list, dtype=torch.long, device=device).t().contiguous()
    new_edge_attr  = torch.tensor(edge_attr_list, dtype=torch.float, device=device)
    
    data["provider", "links", "member"].edge_index = new_edge_index
    data["provider", "links", "member"].edge_attr  = new_edge_attr

    return data


#### 3.2 Random anomalies

In [11]:
def inject_random_struct_anomaly(
    data, 
    ratio=0.01,
    min_block_size=2,
    max_block_size=20,
    partial=True
):
    """
    Random structure anomaly injection (Ding et al. [8] style):
    - pick small subset of providers + members (2..20)
    - create fully or partially dense bipartite edges
    - mark them as anomalies
    :param partial: if True, connect only some fraction of pairs (0.2..1.0)
    """
    num_providers = data["provider"].num_nodes
    num_members   = data["member"].num_nodes
    device        = data["provider"].x.device
    
    # Number of sub-anomalies to inject
    # ratio is fraction of the total node sets => approximate how many anomalies we form
    # We'll do 'num_anomalies' sub-block insertions
    num_anomalies = max(1, int(ratio * (num_providers + num_members) // (max_block_size*2)))
    
    edge_list = data["provider", "links", "member"].edge_index.clone().t().tolist()
    edge_attr_list = data["provider", "links", "member"].edge_attr.tolist()
    
    for _ in range(num_anomalies):
        block_p_size = random.randint(min_block_size, max_block_size)
        block_m_size = random.randint(min_block_size, max_block_size)
        
        # random sets
        p_nodes = random.sample(range(num_providers), min(block_p_size, num_providers))
        m_nodes = random.sample(range(num_members), min(block_m_size, num_members))
        
        fraction = random.uniform(0.2, 1.0) if partial else 1.0
        
        possible_edges = []
        for p in p_nodes:
            for m in m_nodes:
                possible_edges.append((p, m))
        
        # partial or full density
        chosen_edges_count = int(fraction * len(possible_edges))
        chosen_edges = random.sample(possible_edges, k=chosen_edges_count)
        
        # Insert edges
        for (p_idx, m_idx) in chosen_edges:
            edge_list.append([p_idx, m_idx])
            w = float(random.randint(1, 10))
            edge_attr_list.append(w)
            # Mark anomalies
            data["provider"].synthetic_labels[p_idx] = 1
            data["member"].synthetic_labels[m_idx]   = 1
    
    new_edge_index = torch.tensor(edge_list, dtype=torch.long, device=device).t().contiguous()
    new_edge_attr  = torch.tensor(edge_attr_list, dtype=torch.float, device=device)
    data["provider", "links", "member"].edge_index = new_edge_index
    data["provider", "links", "member"].edge_attr  = new_edge_attr

    return data


def inject_random_attr_anomaly(
    data,
    node_type="provider",
    ratio=0.01,
    method="outside_conf",
    c=3.0
):
    """
    Inject attribute anomaly in node features following Ding et al. [8].
    :param node_type: "provider" or "member"
    :param ratio: fraction of nodes to become anomalies
    :param method: "outside_conf" or "scaled_gaussian"
    :param c: confidence/scale factor (2..4 recommended)
    """
    x = data[node_type].x
    n_nodes, feat_dim = x.shape
    device = x.device
    
    # mean, std
    mu = x.mean(dim=0)
    sigma = x.std(dim=0) + 1e-6
    
    num_anomalies = max(1, int(ratio * n_nodes))
    chosen_nodes = random.sample(range(n_nodes), k=num_anomalies)
    
    for nd in chosen_nodes:
        data[node_type].synthetic_labels[nd] = 1
        if method == "outside_conf":
            # Replace a fraction of features randomly
            # We'll pick half of the features to corrupt, for example
            feat_indices = random.sample(range(feat_dim), k=feat_dim // 2)
            for fi in feat_indices:
                # Generate truncated gauss outside [mu - c*sigma, mu + c*sigma]
                lower = mu[fi] - c * sigma[fi]
                upper = mu[fi] + c * sigma[fi]
                
                # sample repeatedly until we are outside the interval
                val = torch.randn(1, device=device) * sigma[fi] + mu[fi]
                while lower <= val <= upper:
                    val = torch.randn(1, device=device) * sigma[fi] + mu[fi]
                
                x[nd, fi] = val
        else:  # "scaled_gaussian"
            # add noise from N(0, c * sigma)
            noise = torch.randn(feat_dim, device=device) * (c * sigma)
            x[nd] = x[nd] + noise
    
    data[node_type].x = x
    return data

def inject_random_both_anomaly(
    data,
    ratio=0.01,
    min_block_size=2,
    max_block_size=20,
    c=3.0
):
    """
    Combination: 
    1) Insert a random structural anomaly block
    2) Then apply attribute anomaly to the newly involved nodes or edges
    """
    # Step 1: structural anomaly
    data = inject_random_struct_anomaly(
        data,
        ratio=ratio,
        min_block_size=min_block_size,
        max_block_size=max_block_size,
        partial=True
    )
    
    # Step 2: attribute anomaly on some providers or members
    # (We can pick randomly whether to apply "outside_conf" or "scaled_gaussian")
    method_choice = random.choice(["outside_conf", "scaled_gaussian"])
    data = inject_random_attr_anomaly(
        data,
        node_type="provider",
        ratio=ratio,
        method=method_choice,
        c=c
    )
    data = inject_random_attr_anomaly(
        data,
        node_type="member",
        ratio=ratio,
        method=method_choice,
        c=c
    )
    
    return data


In [15]:
import torch
import random
import itertools

def inject_structural_anomalies(data, clique_size=15, num_cliques=10):
    """
    Inject structural anomalies by adding fully connected cliques.
    
    The procedure is as follows:
      1. Compute the total number of nodes to affect: total = clique_size * num_cliques.
      2. Randomly sample total distinct nodes from the graph.
      3. Partition these nodes into num_cliques groups of size clique_size.
      4. For each group (clique), add edges between every pair of nodes (if not already present)
         so that the group becomes a fully connected subgraph.
      5. Mark all nodes in each injected clique as anomalous (synthetic_labels set to 1).
    
    This method injects m × n structural anomalies into the network.
    
    :param data: A PyTorch Geometric Data object with attributes x, edge_index, and synthetic_labels.
    :param clique_size: The number of nodes in each clique (m).
    :param num_cliques: The number of cliques to inject (n).
    :return: The modified data object.
    """
    total_anomalies = clique_size * num_cliques
    num_nodes = data.x.size(0)
    
    if total_anomalies > num_nodes:
        raise ValueError("Not enough nodes to inject the desired number of structural anomalies.")
    
    # Randomly sample total_anomalies nodes without replacement.
    all_nodes = list(range(num_nodes))
    random.shuffle(all_nodes)
    anomaly_nodes = all_nodes[:total_anomalies]
    
    # Partition the selected nodes into num_cliques groups of size clique_size.
    cliques = [anomaly_nodes[i * clique_size : (i + 1) * clique_size] for i in range(num_cliques)]
    
    # Convert existing edge_index to a set of (src, dst) pairs for easy checking.
    existing_edges = set(zip(data.edge_index[0].tolist(), data.edge_index[1].tolist()))
    
    new_edges = []
    for clique in cliques:
        # Mark each node in the clique as an anomaly.
        for node in clique:
            data.synthetic_labels[node] = 1
        # For an undirected graph, add edges in both directions.
        # Add an edge for each unordered pair (i, j) with i < j.
        for i, j in itertools.combinations(clique, 2):
            # (i, j) and (j, i) are both added.
            new_edges.append((i, j))
            new_edges.append((j, i))
    
    # Combine new edges with existing ones.
    # (Removing duplicates is optional; here we form the union.)
    combined_edges = set(existing_edges) | set(new_edges)
    
    # Convert back to a tensor of shape (2, num_edges).
    edge_index_tensor = torch.tensor(list(combined_edges), dtype=torch.long).t().contiguous()
    data.edge_index = edge_index_tensor
    
    return data

def inject_attribute_anomalies(data, clique_size=15, num_cliques=10, k=10):
    """
    Inject attribute anomalies by perturbing node features.
    
    The procedure is as follows:
      1. Compute the total number of candidate nodes: total = clique_size * num_cliques.
      2. Randomly sample total distinct nodes from the graph.
      3. For each candidate node i:
           a. Randomly select k other nodes from the graph (excluding i).
           b. Compute the Euclidean distance between i's feature vector and each sampled node’s feature vector.
           c. Find the node j with the maximum distance.
           d. Replace node i’s features with node j’s features.
           e. Mark node i as anomalous (synthetic_labels set to 1).
    
    This method injects an equal number (m × n) of attribute anomalies as are injected structurally.
    
    :param data: A PyTorch Geometric Data object with attributes x and synthetic_labels.
    :param clique_size: Parameter m used to determine the total number of candidates.
    :param num_cliques: Parameter n used to determine the total number of candidates.
    :param k: The number of nodes to compare for each candidate node.
    :return: The modified data object.
    """
    total_candidates = clique_size * num_cliques
    num_nodes = data.x.size(0)
    
    if total_candidates > num_nodes:
        raise ValueError("Not enough nodes to inject the desired number of attribute anomalies.")
    
    all_nodes = list(range(num_nodes))
    random.shuffle(all_nodes)
    candidate_nodes = all_nodes[:total_candidates]
    
    for i in candidate_nodes:
        # Create a list of nodes to sample from (all except i).
        possible_nodes = list(range(num_nodes))
        possible_nodes.remove(i)
        k_sel = min(k, len(possible_nodes))
        sampled_nodes = random.sample(possible_nodes, k_sel)
        
        xi = data.x[i]
        max_dist = -1
        max_j = None
        # Find the node among the k sampled nodes with the largest Euclidean distance from xi.
        for j in sampled_nodes:
            xj = data.x[j]
            # Compute the Euclidean distance (using the 2-norm).
            dist = torch.norm(xi - xj, p=2).item()
            if dist > max_dist:
                max_dist = dist
                max_j = j
        if max_j is not None:
            # Replace node i's features with those of node max_j.
            data.x[i] = data.x[max_j].clone()
            # Mark node i as an attribute anomaly.
            data.synthetic_labels[i] = 1
    
    return data


In [16]:
# Inject 10 cliques of size 15 (i.e. 150 structural anomalies)
data = inject_structural_anomalies(data, clique_size=15, num_cliques=10)

# Inject 150 attribute anomalies using k=10 for comparison.
data = inject_attribute_anomalies(data, clique_size=15, num_cliques=10, k=10)


AttributeError: 'HeteroData' has no attribute 'x'

#### 3.3 Injecting labels

In [12]:
def inject_diverse_anomalies(data):
    """
    Example pipeline injecting a mix of anomalies with small overall ratios.
    Each pass uses a different injection function and parameters.
    """
    # 1) Realistic Healthcare Fraud
    
    # (a) Ghost members: ~1% ratio, each connects to 2 providers
    data = inject_ghost_members(data, ratio=0.01, connections_per_ghost=2)
    
    # (b) Upcoding: ~2% edges, scale factor between 2x and 5x
    scale_f = random.uniform(2.0, 5.0)
    data = inject_upcoding(data, ratio=0.02, scale_factor=scale_f)
    
    # (c) Collusion ring: ~1% ratio, partial density ~0.5
    partial_dens = random.uniform(0.3, 0.7)
    data = inject_collusion_ring(data, ratio=0.01, partial_density=partial_dens)
    
    # 2) Random Anomalies (Ding et al.)
    
    # (a) Random structural anomaly: ~1% ratio, block sizes up to 10
    data = inject_random_struct_anomaly(
        data,
        ratio=0.01,
        min_block_size=2, 
        max_block_size=10,
        partial=True
    )
    
    # (b) Random attribute anomaly on providers, 1% ratio, either outside_conf or scaled_gaussian
    method_choice = random.choice(["outside_conf", "scaled_gaussian"])
    c_val = random.uniform(2, 4)
    data = inject_random_attr_anomaly(
        data,
        node_type="provider",
        ratio=0.01,
        method=method_choice,
        c=c_val
    )
    
    # (c) Random attribute anomaly on members, 1% ratio
    method_choice2 = random.choice(["outside_conf", "scaled_gaussian"])
    c_val2 = random.uniform(2, 4)
    data = inject_random_attr_anomaly(
        data,
        node_type="member",
        ratio=0.01,
        method=method_choice2,
        c=c_val2
    )
    
    # (d) Combined structure + attribute anomaly, ~0.5% ratio
    data = inject_random_both_anomaly(
        data,
        ratio=0.005,
        min_block_size=2,
        max_block_size=10,
        c=random.uniform(2, 4)
    )
    
    return data


def inject_diverse_anomalies_second_version(data, max_anomaly_ratio=0.05):
    num_providers = data["provider"].num_nodes
    num_members = data["member"].num_nodes
    total_nodes = num_providers + num_members
    max_anomalies = int(max_anomaly_ratio * total_nodes)
    
    # Track anomalies
    injected_anomalies = set()  # Store (node_type, idx)
    
    # 1) Ghost Members (max 10% of anomaly budget)
    num_ghosts = max(1, int(0.1 * max_anomalies))
    data, new_ghosts = inject_ghost_members(data, num_ghosts=num_ghosts)
    injected_anomalies.update([("member", idx) for idx in new_ghosts])
    
    # 2) Upcoding (max 20% of budget)
    num_upcoding = max(1, int(0.2 * max_anomalies))
    data, upcoded = inject_upcoding(data, num_edges=num_upcoding)
    injected_anomalies.update([("provider", p) for p in upcoded["providers"]])
    injected_anomalies.update([("member", m) for m in upcoded["members"]])
    
    # 3) Collusion Ring (use remaining budget)
    remaining = max_anomalies - len(injected_anomalies)
    if remaining > 0:
        data = inject_collusion_ring(data, num_nodes=remaining)
    
    return data


In [None]:
data = inject_diverse_anomalies(data)

In [40]:
print("== Final Graph After Injection ==")
print("Number of provider nodes:", data["provider"].num_nodes)
print("Number of member nodes:", data["member"].num_nodes)
print("Provider anomaly count:", data["provider"].synthetic_labels.sum().item())
print("Member anomaly count:", data["member"].synthetic_labels.sum().item())
print("Edges shape:", data["provider", "links", "member"].edge_index.shape)
print("Edge attrs shape:", data["provider", "links", "member"].edge_attr.shape)

== Final Graph After Injection ==
Number of provider nodes: 652
Number of member nodes: 32885
Provider anomaly count: 564
Member anomaly count: 2385
Edges shape: torch.Size([2, 60982])
Edge attrs shape: torch.Size([60982])


In [6]:
from torch_geometric.transforms import RandomLinkSplit

def split_heterodata_edges(data, val_ratio=0.1, test_ratio=0.1):
    transform = RandomLinkSplit(
        num_val=val_ratio,
        num_test=test_ratio,
        is_undirected=False,  # Set True if your graph is undirected
        add_negative_train_samples=False,  # Set True if negative samples are needed
        edge_types=("provider", "links", "member"),
        rev_edge_types=None  # Define if you have reverse edges
    )
    train_data, val_data, test_data = transform(data)
    
    return train_data, val_data, test_data


In [8]:
train_data, val_data, test_data = split_heterodata_edges(data, val_ratio=0.2, test_ratio=0.2)

In [9]:
train_data

HeteroData(
  provider={
    x=[652, 38],
    synthetic_labels=[652],
  },
  member={
    x=[32560, 23],
    synthetic_labels=[32560],
  },
  (provider, links, member)={
    edge_index=[2, 70796],
    edge_attr=[70796],
    edge_label=[70796],
    edge_label_index=[2, 70796],
  }
)

### 4. Main model

#### 4.1 Model architecture 

In [10]:
from torch_geometric.nn import HeteroConv, Linear

class HeteroAutoEncoder(torch.nn.Module):
    def __init__(self, metadata, hidden_channels=16, out_channels=8):
        """
        :param metadata: data.metadata() -> (node_types, edge_types)
        :param hidden_channels: dimension of hidden embeddings
        :param out_channels: dimension of final latent embeddings
        """
        super().__init__()
        self.metadata = metadata
        
        # 1) HeteroConv for encoding
        #    We'll define a GNN conv for each (source, rel, target) edge type
        #    e.g., "provider"->"member" and optionally reverse if defined
        self.convs = torch.nn.ModuleList()

        # Example: single layer HeteroConv with SAGEConv (or GCNConv)
        convs_dict = {}
        for edge_type in metadata[1]:  # edge_types
            # e.g. edge_type = ("provider", "links", "member")
            convs_dict[edge_type] = SAGEConv(
                (-1, -1),  # (in_channels for src, in_channels for dst)
                hidden_channels
            )
        self.hetero_conv = HeteroConv(convs_dict, aggr='sum')
        
        # 2) MLP to go from hidden -> latent
        #    We'll store a separate Linear for each node type to unify dimension
        self.lin_dict = torch.nn.ModuleDict()
        for node_type in metadata[0]:  # node_types
            # Example: hidden_channels -> out_channels
            self.lin_dict[node_type] = Linear(hidden_channels, out_channels)

        # 3) Decoders for attribute reconstruction
        #    We'll also store a separate MLP per node type to go from out_channels -> original dim
        self.decoder_attr = torch.nn.ModuleDict()
        # For adjacency, we'll do dot-product, no separate MLP needed
        # But if you prefer a param-based decoder, you could add it here
        

    def forward(self, x_dict, edge_index_dict):
        """
        Encoding step: produce out_channels embeddings for each node type
        :param x_dict: {node_type: [num_nodes, in_dim]}
        :param edge_index_dict: {edge_type: [2, E]}
        :return: z_dict, a dict of {node_type: [num_nodes, out_channels]} embeddings
        """
        # HeteroConv expects x_dict, edge_index_dict
        h_dict = self.hetero_conv(x_dict, edge_index_dict)
        
        # Apply linear transform per node_type
        z_dict = {}
        for node_type, h in h_dict.items():
            z_dict[node_type] = self.lin_dict[node_type](h)
        
        return z_dict
    
    def decode_attributes(self, z_dict, original_x_dict):
        """
        Reconstruct node attributes from embeddings
        :return: recon_x_dict with same shape as original_x_dict
        """
        recon_x_dict = {}
        for node_type, z in z_dict.items():
            out_dim = original_x_dict[node_type].size(1)
            # We'll define the attribute decoder on the fly if not exists
            # or store them in self.decoder_attr
            if node_type not in self.decoder_attr:
                self.decoder_attr[node_type] = torch.nn.Sequential(
                    Linear(z.size(1), z.size(1)),
                    torch.nn.ReLU(),
                    Linear(z.size(1), out_dim) 
                )
            recon_x = self.decoder_attr[node_type](z)
            recon_x_dict[node_type] = recon_x
        return recon_x_dict
    
    def decode_adjacency(self, z_src, z_dst, edge_index):
        """
        Dot-product decoder for adjacency or edge weight
        :param z_src: [num_src_nodes, out_channels]
        :param z_dst: [num_dst_nodes, out_channels]
        :param edge_index: [2, E], specifying (src, dst) pairs
        :return: predicted edge weights [E]
        """
        # gather node embeddings
        src = edge_index[0]
        dst = edge_index[1]
        z_s = z_src[src]
        z_d = z_dst[dst]
        # dot product
        edge_pred = (z_s * z_d).sum(dim=-1)
        # You might want a sigmoid or other transform if edges are in [0,1]
        # or do direct MSE if edges are real-valued
        return edge_pred


#### 4.2 Training the model

In [20]:
import torch
import torch.nn.functional as F
from torch_geometric.transforms import RandomLinkSplit

def train_autoencoder(data, hidden_channels=16, out_channels=8, lr=1e-3, epochs=50, alpha=0.5):
    """
    Train a Heterogeneous Autoencoder with attribute and adjacency reconstruction losses.
    
    Fixes:
    - Ensures correct edge indexing after `RandomLinkSplit`
    - Prevents node mismatch by reindexing node features
    - Verifies that providers and members are correctly ordered
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = HeteroAutoEncoder(data.metadata(), hidden_channels, out_channels).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # Move data to device
    data = data.to(device)

    # Apply RandomLinkSplit to generate negative samples for link prediction
    transform = RandomLinkSplit(
        num_val=0.1, num_test=0.1, 
        add_negative_train_samples=True,
        edge_types=("provider", "links", "member")
    )
    train_data, val_data, test_data = transform(data)

    loss_history = []

    # Ensure correct x_dict mapping
    x_dict = {node_type: train_data[node_type].x for node_type in train_data.node_types}

    # Ensure edge index is within valid range
    edge_index = train_data["provider", "links", "member"].edge_index

    # ✅ Fix swapped provider ↔ member issue
    if edge_index[0].max() > x_dict["provider"].size(0) - 1 or edge_index[1].max() > x_dict["member"].size(0) - 1:
        print("⚠️ Warning: Swapping providers and members in edge_index!")
        edge_index = edge_index[[1, 0], :]  # Swap rows

    # Restrict edge index to valid nodes
    max_provider_id = x_dict["provider"].size(0) - 1
    max_member_id = x_dict["member"].size(0) - 1
    valid_edges = (edge_index[0] <= max_provider_id) & (edge_index[1] <= max_member_id)
    edge_index = edge_index[:, valid_edges]

    edge_index_dict = {("provider", "links", "member"): edge_index}

    # Check for edge attributes (weights)
    edge_attr = train_data["provider", "links", "member"].edge_attr if hasattr(train_data["provider", "links", "member"], 'edge_attr') else None

    for epoch in range(1, epochs + 1):
        model.train()
        optimizer.zero_grad()
        
        # 1) Encode node features
        z_dict = model(x_dict, edge_index_dict)

        # 2) Decode node attributes
        recon_x_dict = model.decode_attributes(z_dict, x_dict)

        # 3) Compute attribute reconstruction loss (MSE)
        loss_attr = sum(F.mse_loss(recon_x_dict[n], x_dict[n]) for n in x_dict)

        # 4) Decode adjacency (link prediction)
        z_provider = z_dict["provider"]
        z_member   = z_dict["member"]
        pred_edges = model.decode_adjacency(z_provider, z_member, edge_index_dict[("provider", "links", "member")])

        # 5) Compute adjacency reconstruction loss
        if edge_attr is not None:
            # Edge-weighted loss (if real-valued)
            loss_adj = F.mse_loss(pred_edges, edge_attr, reduction="none")
            loss_adj = (loss_adj * edge_attr).mean()
        else:
            # Binary cross-entropy loss for link prediction
            loss_adj = F.binary_cross_entropy_with_logits(pred_edges, torch.ones_like(pred_edges))

        # 6) Dynamically adjust alpha for balanced loss scaling
        alpha_tuned = alpha * (loss_attr.item() / (loss_adj.item() + 1e-8))  # Normalize alpha
        loss = loss_attr + alpha_tuned * loss_adj

        # Backpropagation & Optimization
        loss.backward()
        optimizer.step()

        # Store loss history
        loss_history.append(loss.item())

        # Logging every 10 epochs
        if epoch % 10 == 0:
            print(f"Epoch {epoch:02d}, Loss: {loss.item():.4f}, Attr: {loss_attr.item():.4f}, Adj: {loss_adj.item():.4f}")

    return model, loss_history


In [21]:
# Pass the full dataset instead of an already split dataset
model, loss_history = train_autoencoder(data, hidden_channels=16, out_channels=8, lr=1e-3, epochs=50)




KeyError: 'provider'

#### 4.3 Anomaly scoring

In [None]:
def compute_anomaly_scores(model, data):
    device = data["provider"].x.device
    model.eval()
    
    x_dict = {
        "provider": data["provider"].x,
        "member":   data["member"].x
    }
    edge_index_dict = {
        ("provider", "links", "member"): data["provider", "links", "member"].edge_index
    }
    edge_attr = data["provider", "links", "member"].edge_attr
    
    with torch.no_grad():
        # Encode
        z_dict = model(x_dict, edge_index_dict)
        # Decode attributes
        recon_x_dict = model.decode_attributes(z_dict, x_dict)
        # Compute node-level attribute error
        # We'll do MSE across features
        attr_error_provider = torch.sum((recon_x_dict["provider"] - x_dict["provider"])**2, dim=1)
        attr_error_member   = torch.sum((recon_x_dict["member"]   - x_dict["member"])**2, dim=1)
        
        # Adjacency reconstruction
        z_provider = z_dict["provider"]
        z_member   = z_dict["member"]
        edge_idx = edge_index_dict[("provider", "links", "member")]
        pred_edges = model.decode_adjacency(z_provider, z_member, edge_idx)
        
        # Edge error
        if edge_attr is not None:
            edge_error = (pred_edges - edge_attr)**2
        else:
            edge_error = (pred_edges - 1.0)**2
        
        # We'll accumulate adjacency errors for each node
        node_adj_error_provider = torch.zeros(z_provider.size(0), device=device)
        node_adj_error_member   = torch.zeros(z_member.size(0),   device=device)
        
        for e_idx in range(edge_error.size(0)):
            p_idx = edge_idx[0, e_idx]
            m_idx = edge_idx[1, e_idx]
            node_adj_error_provider[p_idx] += edge_error[e_idx]
            node_adj_error_member[m_idx]   += edge_error[e_idx]
        
        # Combine node-level anomaly scores
        # e.g. anomaly_score = attr_error + alpha * adj_error
        alpha = 0.5
        provider_score = attr_error_provider + alpha * node_adj_error_provider
        member_score   = attr_error_member   + alpha * node_adj_error_member
        
    return provider_score, member_score


#### 4.4 Evaluation

In [None]:
def evaluate_anomalies(provider_score, member_score, data):
    # Convert to cpu numpy for sklearn
    p_score_np = provider_score.detach().cpu().numpy()
    m_score_np = member_score.detach().cpu().numpy()
    
    p_label_np = data["provider"].synthetic_labels.cpu().numpy()
    m_label_np = data["member"].synthetic_labels.cpu().numpy()
    
    from sklearn.metrics import roc_auc_score, average_precision_score
    
    # Providers
    if len(np.unique(p_label_np)) > 1:  # At least 1 anomaly
        auc_p = roc_auc_score(p_label_np, p_score_np)
        ap_p  = average_precision_score(p_label_np, p_score_np)
    else:
        auc_p, ap_p = -1, -1
    
    # Members
    if len(np.unique(m_label_np)) > 1:
        auc_m = roc_auc_score(m_label_np, m_score_np)
        ap_m  = average_precision_score(m_label_np, m_score_np)
    else:
        auc_m, ap_m = -1, -1
    
    print(f"Provider AUC: {auc_p:.4f}, AP: {ap_p:.4f}")
    print(f"Member   AUC: {auc_m:.4f}, AP: {ap_m:.4f}")
    
    # If you want an overall:
    all_scores = np.concatenate([p_score_np, m_score_np])
    all_labels = np.concatenate([p_label_np, m_label_np])
    if len(np.unique(all_labels)) > 1:
        auc_all = roc_auc_score(all_labels, all_scores)
        ap_all  = average_precision_score(all_labels, all_scores)
    else:
        auc_all, ap_all = -1, -1
    print(f"Overall  AUC: {auc_all:.4f}, AP: {ap_all:.4f}")


### 5. Baseline models 