In [None]:
from   embed_and_verify import *
from   config import *
from   prune_and_fine_tune_utils import *
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import networkx as nx
import matplotlib as mpl
import matplotlib.pyplot as plt
import os
import pandas as pd
import pickle
import random
import requests

from   scipy import stats
from   scipy.stats import anderson, kstest, zscore, shapiro

import torch
from   torch.nn import ModuleList
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import torch.optim as optim

from   torch_geometric.datasets import Reddit2
from   torch_geometric.nn import GATConv, GCNConv, GraphConv, SAGEConv, GINConv, global_mean_pool
from   torch_geometric.transforms import RandomLinkSplit
from   torch_geometric.data import Data
from   torch_geometric.loader import DataLoader

import zipfile

from   pcgrad.pcgrad import PCGrad 

mpl.rcParams['figure.dpi']=250

root_folder = '<path_to_repo_directory>'
training_folder = os.path.join(root_folder,'training_results')

In [None]:


def download_file(url, filename):
    response = requests.get(url)
    with open(f'../data/{filename}', 'wb') as f:
        f.write(response.content)

    dataset_name = filename.replace('.zip','')
    

    unzip_file(f'../data/{filename}', f'../data',)
    os.rename(f'../data/{dataset_name}',f'../data/{dataset_name}_raw',)
    print(f"Downloaded {dataset_name}")


def unzip_file(zip_path, extract_to="."):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)
    print(f"Extracted {zip_path} to {extract_to}")


def create_dataset_from_files(root, dataset_name):
    """
    Converts raw TU Dataset format files into PyTorch Geometric Data objects.
    
    Args:
        root (str): Path to the directory containing raw dataset files
        dataset_name (str): Name of the dataset (used to construct filenames)
    
    Returns:
        list: List of PyTorch Geometric Data objects, one per graph
    
    Process:
        1. Reads four core files:
           - {dataset_name}_A.txt: Edge list (node pairs)
           - {dataset_name}_graph_indicator.txt: Which graph each node belongs to
           - {dataset_name}_graph_labels.txt: Graph-level labels for classification
           - {dataset_name}_node_labels.txt: Node-level features/labels
        2. For each graph in the dataset:
           - Identifies nodes belonging to that graph
           - Maps global node indices to local indices
           - Extracts edges within the graph
           - Converts node labels to one-hot encoded features
           - Creates a Data object with x (features), edge_index, and y (label)
        3. Transforms binary labels from [-1,1] to [0,1] format
    
    Note: Assumes TU Dataset standard format with 1-indexed graphs
    """
    # Load raw files
    edges = pd.read_csv(os.path.join(root, f"{dataset_name}_A.txt"), header=None, sep=",")
    graph_indicator = pd.read_csv(os.path.join(root, f"{dataset_name}_graph_indicator.txt"), header=None)
    graph_labels = pd.read_csv(os.path.join(root, f"{dataset_name}_graph_labels.txt"), header=None)
    node_labels = pd.read_csv(os.path.join(root, f"{dataset_name}_node_labels.txt"), header=None)

    data_list = []
    N = graph_labels.shape[0]  # Number of graphs in the dataset
    indices_to_use = range(1, N + 1)  # Graph indices start from 1


    # Process each graph
    for c, i in enumerate(indices_to_use):
        
        # Find nodes belonging to the current graph
        node_indices = graph_indicator[graph_indicator[0] == i].index

        # Map global node indices to local indices
        node_idx_map = {idx: j for j, idx in enumerate(node_indices)}
        graph_edges = edges[edges[0].isin(node_indices + 1) & edges[1].isin(node_indices + 1)]
        # graph_edges = graph_edges.applymap(lambda x: node_idx_map[x - 1])
        graph_edges = graph_edges.apply(lambda col: col.map(lambda x: node_idx_map[x - 1]))
        edge_index = torch.tensor(graph_edges.values, dtype=torch.long).t().contiguous()

        x = torch.tensor(node_labels.iloc[node_indices].values).squeeze()
        num_classes = node_labels[0].max() + 1 
        x = F.one_hot(x.clone().detach().long(), num_classes=num_classes).float() 
        y = torch.tensor(graph_labels.iloc[i - 1].values, dtype=torch.long)
        y = (y + 1) // 2  # Transform [-1, 1] -> [0, 1]
        # Create a Data object for the current graph
        data = Data(x=x, edge_index=edge_index, y=y)
        data.original_index = c
        data_list.append(data)

    return data_list


def split_dataset(data_list, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=42):
    """
    Splits a list of graphs into train/validation/test sets for graph classification.
    
    Args:
        data_list (list): List of PyTorch Geometric Data objects
        train_ratio (float): Fraction of data for training (default: 0.7)
        val_ratio (float): Fraction of data for validation (default: 0.15)
        test_ratio (float): Fraction of data for testing (default: 0.15)
        seed (int): Random seed for reproducible splits
    
    Returns:
        tuple: (train_data, val_data, test_data) as lists of Data objects

    Note: Operates on entire graphs, not individual nodes
    """
    # Ensure ratios sum to 1
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, "Ratios must sum to 1"

    # Set random seed for reproducibility
    random.seed(seed)

    # Shuffle indices
    indices = list(range(len(data_list)))
    random.shuffle(indices)

    # Split indices
    train_end = int(train_ratio * len(indices))
    val_end = train_end + int(val_ratio * len(indices))
    
    train_indices = indices[:train_end]
    val_indices = indices[train_end:val_end]
    test_indices = indices[val_end:]

    # Create subsets
    train_data = [data_list[i] for i in train_indices]
    val_data = [data_list[i] for i in val_indices]
    test_data = [data_list[i] for i in test_indices]

    return train_data, val_data, test_data


def generate_subgraph_graph_clf(data, subgraph_size=5,seed=0):
    """
    Extracts a random subgraph from a single graph for watermarking purposes.
    
    Args:
        data (Data): PyTorch Geometric Data object representing a graph
        subgraph_size (int): Number of nodes to include in the subgraph
        seed (int): Random seed for reproducible subgraph selection
    
    Returns:
        tuple: (data_sub, subgraph_signature, subgraph_node_idx)
            - data_sub: Data object containing the extracted subgraph
            - subgraph_signature: String identifier for this specific subgraph
            - subgraph_node_idx: Tensor of original node indices in the subgraph
    
    Note: Preserves the original graph's classification label for the subgraph
    """
    data = data.clone()
    torch.manual_seed(seed)
    random.seed(seed)
    num_nodes = data.x.shape[0]
    subgraph_node_idx = torch.tensor(random.sample(list(range(num_nodes)), min(subgraph_size, num_nodes)))
    sub_edge_index, _ = subgraph(subgraph_node_idx, data.edge_index, relabel_nodes=True, num_nodes=num_nodes)
    data_sub = Data(
        x=data.x[subgraph_node_idx] if data.x is not None else None,
        edge_index=sub_edge_index,
        y=data.y)
    del sub_edge_index
    subgraph_signature = '_'.join([str(s) for s in subgraph_node_idx.tolist()])
    subgraph_signature = str(data.original_index) + '_' + subgraph_signature
    return data_sub, subgraph_signature, subgraph_node_idx



def build_subgraph_collections(train_data, num_collections, num_subgraphs_per_collection, subgraph_size, proportion_features_to_watermark, aggregate_method='flatten',seed=0):
    """
    Creates collections of subgraphs from multiple training graphs for coordinated watermarking.
    
    Args:
        train_data (list): List of training Data objects
        num_collections (int): Number of watermark collections to create
        num_subgraphs_per_collection (int): How many subgraphs per collection
        subgraph_size (int): Number of nodes in each subgraph
        proportion_features_to_watermark (float): Fraction of features to watermark (0.0-1.0)
        aggregate_method (str): How to combine node features ('flatten', 'average', 'sum')
        seed (int): Random seed for reproducible collection building
    
    Returns:
        tuple: (subgraph_collection_dict, most_represented_indices, graphs_used)
            - subgraph_collection_dict: Nested dict with collections, subgraphs, features, watermarks
            - most_represented_indices: Indices of most frequently non-zero features
            - graphs_used: List of original graph indices used in collections
    
    Note: Returns list of used graphs so they can be excluded from training
    """
    random.seed(seed)
    assert aggregate_method in ['flatten','average','sum']
    subgraph_collection_dict = {i:{'subgraphs':{},'feature_matrix':None} for i in range(num_collections)}
    graphs_used = []
    for i in range(num_collections):
        graph_indices = random.sample(list(range(len(train_data))),num_subgraphs_per_collection)
        graph_indices_original = [train_data[i].original_index for i in graph_indices]
        graphs_used.extend(graph_indices_original)
        graph_features = []
        for idx in graph_indices:
            graph = train_data[idx]
            this_subgraph, subgraph_signature, _ =  generate_subgraph_graph_clf(graph, subgraph_size=subgraph_size,seed=seed)
            subgraph_collection_dict[i]['subgraphs'][subgraph_signature]=this_subgraph
            if aggregate_method=='flatten':
                graph_features.append(this_subgraph.x.flatten().tolist()) # average node features
            elif aggregate_method=='average':
                graph_features.append(torch.mean(this_subgraph.x,dim=0).tolist()) # average node features
            elif aggregate_method=='sum':
                graph_features.append(torch.sum(this_subgraph.x,dim=0).tolist()) # average node features
        graph_features = torch.tensor(graph_features)
        subgraph_collection_dict[i]['feature_matrix']=graph_features



    subgraph_collection_features_concat = torch.vstack([subgraph_collection_dict[k]['feature_matrix'] for k in subgraph_collection_dict.keys()])
    nonzero_feat_mask = subgraph_collection_features_concat!=0
    nonzero_feat_counts = torch.sum(nonzero_feat_mask,dim=0)
    sorted_indices = torch.argsort(nonzero_feat_counts, descending=True)
    len_watermark = int(np.floor(proportion_features_to_watermark*len(sorted_indices)))
    most_represented_indices = sorted_indices[:len_watermark]


    watermarks = create_watermarks_at_most_represented_indices(num_collections, len_watermark, subgraph_collection_features_concat.shape[1], most_represented_indices, seed)
    for (k,wmk) in zip(subgraph_collection_dict.keys(),watermarks):
        subgraph_collection_dict[k]['watermark']=wmk
    return subgraph_collection_dict, most_represented_indices, list(set(graphs_used))


def regress_on_subgraph_collections(model, subgraph_collection_dict, collection_id, mode='train'):
    collection = subgraph_collection_dict[collection_id]
    betas = []
    for subgraph_signature in collection.keys():
        subgraph = collection['subgraphs'][subgraph_signature]
        x_sub = collection['feature_matrix']
        y_sub = model(subgraph, mode)
        beta = solve_regression(x_sub,y_sub)
        betas.append(beta)
    return betas

class graph_clf_model(torch.nn.Module):
    def __init__(self, num_node_features, num_classes, num_layers, hidden_channels, dropout, conv_fn):
        super(graph_clf_model, self).__init__()
        self.layers = ModuleList([conv_fn(num_node_features, hidden_channels)])
        self.layers += [conv_fn(hidden_channels,hidden_channels)]*(num_layers-2)
        self.layers += [conv_fn(hidden_channels,num_classes)]
        self.dropout= dropout

    def forward(self, x, edge_index, batch=None):
        for layer in self.layers[:-1]:
            x = layer(x, edge_index)
            x = F.relu(x)
            x = torch.nn.Dropout(p=self.dropout)(x)

        x = self.layers[-1](x, edge_index)
        x = global_mean_pool(x, batch)
        return F.log_softmax(x, dim=1)

def collect_random_subgraphs_graph_clf(dataset, subgraph_size, num_graphs):
    """
    Generates random subgraphs for null distribution creation in significance testing.
    
    Args:
        dataset (list): Complete dataset of graphs
        subgraph_size (int): Number of nodes per subgraph
        num_graphs (int): How many random subgraphs to generate
    
    Returns:
        list: List of Data objects representing random subgraphs
    """
    subgraph_list = []
    for _ in range(num_graphs):
        graph_index = random.choice(range(len(dataset)))
        data = dataset[graph_index]
        num_nodes = data.x.shape[0]
        subgraph_node_idx = torch.tensor(random.sample(list(range(num_nodes)), min(subgraph_size, num_nodes)))
        sub_edge_index, _ = subgraph(subgraph_node_idx, data.edge_index, relabel_nodes=True, num_nodes=num_nodes)
        data_sub = Data(
            x=data.x[subgraph_node_idx] if data.x is not None else None,
            edge_index=sub_edge_index,
            y=data.y)
        subgraph_list.append(data_sub)
    return subgraph_list



def train_with_watermark(dataset_name,
                         num_node_features,
                         num_classes,
                         train_ratio,
                         val_ratio,
                         test_ratio,
                         num_layers = 4,
                         lr  = 0.001,
                         epochs  = 4000,
                         hidden_channels  = 64,
                         dropout  = 0,
                         proportion_features_to_watermark = 1,
                         num_subgraphs_per_collection = 5,
                         subgraph_size = 10,
                         num_collections = 5,
                         epsilon = 0.01,
                         regression_lambda = 0.01,
                         coefWmk  = 80,
                         batch_size = 10,
                         use_pcgrad  = True,
                         conv_fn  = GraphConv,
                         aggregate_method = 'average',
                         num_iter = 1000,
                         seed = 0):
    
    """
    Complete training pipeline for graph classification with integrated watermarking.
    
    Args:
        dataset_name (str): Name of the dataset to load
        num_node_features (int): Dimension of node feature vectors
        num_classes (int): Number of classes for graph classification
        train_ratio, val_ratio, test_ratio (float): Data split ratios
        **kwargs: Additional hyperparameters including:
            - num_layers, lr, epochs, hidden_channels, dropout: Model architecture
            - num_collections, num_subgraphs_per_collection, subgraph_size: Watermarking
            - epsilon, coefWmk: Watermark loss parameters
            - use_pcgrad: Whether to use PCGrad optimizer
    
    Returns:
        tuple: (model, subgraph_collection_dict, primary_loss_curve, watermark_loss_curve, mu, sig, p)
            - model: Trained graph classification model
            - subgraph_collection_dict: Watermark collections used
            - primary/watermark_loss_curve: Training loss histories
            - mu, sig: Natural distribution parameters for significance testing
            - p: P-value for watermark significance
    
    Process:
        1. Dataset Loading: Creates dataset from raw files and splits into train/val/test
        2. Model Setup: Initializes graph classification model with specified architecture
        3. Watermark Preparation: Builds subgraph collections and removes used graphs from training
        4. Training Loop: For each epoch:
           - Forward pass on training data (classification loss)
           - Forward pass on watermark collections (watermark loss)
           - Regression analysis on watermark predictions
           - Combined loss backpropagation
           - Validation accuracy computation
        5. Testing: Final model evaluation on test set
        6. Statistical Validation: 
           - Generates random subgraph baseline (null distribution)
           - Computes natural distribution parameters (mu, sigma)
           - Calculates p-value for watermark significance
        7. Persistence: Saves model, results, and statistical data
    
    Key Features:
        - Dual-loss optimization (classification + watermark)
        - Real-time watermark alignment monitoring
        - Statistical significance testing
        - Automatic result organization
    """
    
    if aggregate_method=='flatten':
        get_x = lambda sub: sub.x.flatten().tolist()
    elif aggregate_method=='average':
        get_x = lambda sub: torch.mean(sub.x, dim=0).tolist()
    elif aggregate_method=='sum':
        get_x = lambda sub: torch.sum(sub.x, dim=0).tolist()

    data_list = create_dataset_from_files(f"../data/{dataset_name}_raw", dataset_name)
    train_data, val_data, test_data = split_dataset(data_list, train_ratio=train_ratio, val_ratio=val_ratio, test_ratio=test_ratio)
    val_loader = DataLoader(train_data, batch_size=500, shuffle=True)
    test_loader = DataLoader(train_data, batch_size=500, shuffle=True)
    model =  graph_clf_model(num_node_features, num_classes, num_layers, hidden_channels, dropout, conv_fn)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    if use_pcgrad==True:
        optimizer = PCGrad(optimizer)

    subgraph_collection_dict, most_represented_indices, graphs_used = build_subgraph_collections(train_data, num_collections, num_subgraphs_per_collection, subgraph_size, proportion_features_to_watermark, aggregate_method, seed)
    train_data = [d for d in train_data if d.original_index not in graphs_used]
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    primary_loss_curve, watermark_loss_curve = [],[]
    observed_matches = []
    for e in range(epochs):
        optimizer.zero_grad()

        loss_primary, loss_watermark = torch.tensor(0.0), torch.tensor(0.0)
        train_accuracies, val_accuracies = torch.tensor([]), torch.tensor([])
        raw_betas = []
        watermark_alignment_rates = []
        model.train()
        for batch in train_loader:
            log_logits_train  = model(batch.x, batch.edge_index, batch=batch.batch)
            loss_primary  += F.nll_loss(log_logits_train, batch.y)
            train_accuracies = torch.cat((train_accuracies, accuracy(log_logits_train, batch.y).unsqueeze(dim=0)), dim=0)

        for k in subgraph_collection_dict.keys():
            this_collection = subgraph_collection_dict[k]
            this_watermark = this_collection['watermark']
            x_sub = this_collection['feature_matrix']
            y_sub = []
            for subgraph_ in this_collection['subgraphs'].values():
                y_sub.append(model(subgraph_.x, subgraph_.edge_index))
            y_sub = torch.vstack(y_sub)
            omit_indices,not_omit_indices = get_omit_indices(x_sub, this_watermark,ignore_zeros_from_subgraphs=False)
            this_raw_beta = solve_regression(x_sub, y_sub, regression_lambda)
            raw_betas.append(this_raw_beta)
            beta          = process_beta(this_raw_beta, omit_indices)
            B_x_W = (beta*this_watermark).clone()
            B_x_W = B_x_W[not_omit_indices]
            sign_beta_most_rep = torch.sign(beta)[most_represented_indices]
            wmk_most_rep = this_watermark[most_represented_indices]
            watermark_alignment_rates.append(torch.sum(torch.eq(sign_beta_most_rep, wmk_most_rep)).item()/len(most_represented_indices))
            beta_weights = torch.ones(len(not_omit_indices))
            loss_watermark += torch.mean(torch.clamp(epsilon-B_x_W, min=0)*beta_weights)

        model.eval()
        for batch in val_loader:
            log_logits_val  = model(batch.x, batch.edge_index, batch=batch.batch)
            val_accuracies = torch.cat((val_accuracies, accuracy(log_logits_val, batch.y).unsqueeze(dim=0)), dim=0)

        stacked_sign_betas= torch.sign(torch.vstack(raw_betas))
        match_count_with_zeros = count_matches(stacked_sign_betas, ignore_zeros=False)
        match_count_without_zeros = count_matches(stacked_sign_betas, ignore_zeros=True)

        observed_matches.append(match_count_without_zeros)

        acc_trn = torch.mean(train_accuracies).item()
        acc_val = torch.mean(val_accuracies).item()
        avg_wmk_alignment = np.mean(watermark_alignment_rates)
        loss = loss_primary + coefWmk*loss_watermark
        primary_loss_curve.append(loss_primary.item())
        watermark_loss_curve.append(coefWmk*loss_watermark.item())
        loss.backward()
        optimizer.step()
        if e%10==0:
            epoch_printout = f'Epoch: {e:3d}, L (clf/wmk) = {loss_primary:.3f}/{coefWmk*loss_watermark:.3f}, acc (trn/val)= {acc_trn:.3f}/{acc_val:.3f}, #_match_WMK w/wout 0s = {match_count_with_zeros}/{match_count_without_zeros}, alignment={avg_wmk_alignment:.3f}'
            print(epoch_printout)
    test_accuracies = torch.tensor([])
    for batch in test_loader:
        log_logits_test  = model(batch.x, batch.edge_index, batch=batch.batch)
        test_accuracies = torch.cat((test_accuracies, accuracy(log_logits_test, batch.y).unsqueeze(dim=0)), dim=0)
    epoch_printout = f'Epoch: {e:3d}, L (clf/wmk) = {loss_primary:.3f}/{coefWmk*loss_watermark:.3f}, acc (trn/val)= {acc_trn:.3f}/{acc_val:.3f}, #_match_WMK w/wout 0s = {match_count_with_zeros}/{match_count_without_zeros}, alignment={avg_wmk_alignment:.3f}'
    print(epoch_printout)
    print(f'Test accuracy: {torch.mean(test_accuracies).item():.3f}')

    final_observed_match_count = np.mean(observed_matches[-5:])
    data_list = create_dataset_from_files(f"../data/{dataset_name}_raw", dataset_name)
    random_subgraph_list = collect_random_subgraphs_graph_clf(data_list, subgraph_size, 200)
    all_match_counts = []
    for i in range(num_iter):
        print(f'{i}/{num_iter}',end='\r')
        betas = []
        for c in range(num_collections):
            subgraph_choices = random.sample(range(len(random_subgraph_list)),num_subgraphs_per_collection)
            subgraphs_ = [random_subgraph_list[idx] for idx in subgraph_choices]
            y_subs = torch.vstack([model(s.x, s.edge_index) for s in subgraphs_])
            x_subs = torch.tensor([get_x(s) for s in subgraphs_])
            this_raw_beta = solve_regression(x_subs, y_subs, regression_lambda)
            betas.append(torch.sign(this_raw_beta))
        betas = torch.vstack(betas)
        all_match_counts.append(count_matches(betas,ignore_zeros=True))
    mu = np.mean(all_match_counts)
    sig = np.std(all_match_counts)
    print('observed_matches[-5:]:',observed_matches[-5:])
    print('average final observed match:',final_observed_match_count)
    print(f'Naturally-occurring match counts (mu,sig): ({mu},{sig})')
    z = (final_observed_match_count-mu)/sig
    p = scipy.stats.norm.sf(z)
    print('p_value:',p)
    

    root_ = f'../{dataset_name}_results/{num_collections}_collections_{num_subgraphs_per_collection}_subgraphs_size_{subgraph_size}_coefWmk{coefWmk}_numLayers{num_layers}_hiddenDim{hidden_channels}_epsilon{epsilon}_dropout{dropout}_lr{lr}'
    try:
        os.mkdir(root_)
    except:
        pass
    root_ += f'/seed{seed}'
    try:
        os.mkdir(root_)
    except:
        pass
    with open(os.path.join(root_,'all_match_counts.pkl'),'wb') as f:
        pickle.dump(all_match_counts,f)
    with open(os.path.join(root_,f'model.pkl'),'wb') as f:
        pickle.dump(model,f)
    with open(os.path.join(root_,f'acc_trn.pkl'),'wb') as f:
        pickle.dump(acc_trn,f)
    with open(os.path.join(root_,f'acc_val.pkl'),'wb') as f:
        pickle.dump(acc_val,f)
    return model, subgraph_collection_dict, primary_loss_curve, watermark_loss_curve, mu, sig, p




In [None]:
# First: need to download the dataset. You can get many dataset here: https://chrsmrrs.github.io/datasets/docs/datasets/
# You will need to format the dataset into the correct format: see below for an example for how it needs to be.


dataset_name='MUTAG'


url = f"https://www.chrsmrrs.com/graphkerneldatasets/{dataset_name}.zip"
download_file(url, f"{dataset_name}.zip")

root = f"../data/{dataset_name}_raw"  # Path to raw files

num_collections = 5
num_subgraphs_per_collection = 5
subgraph_size = 10
train_ratio, test_ratio, val_ratio = 0.7, 0.15, 0.15
num_node_features = 7 #You may need to look this up depending on the dataset you're using. MUTAG has 7 node features.
num_classes=2





for subgraph_size in [10,20]:
    for num_collections in [2,3,4,5,6]:
        for seed in range(5):
            random.seed(seed)
            model, subgraph_collection_dict, primary_loss_curve, watermark_loss_curve, mu, sig, p = train_with_watermark(dataset_name, num_node_features, num_classes, train_ratio, val_ratio, test_ratio, subgraph_size=subgraph_size, num_collections=num_collections, seed=seed)
            # Right now using default for other values
            plt.plot(primary_loss_curve)
            plt.plot(watermark_loss_curve)
            plt.title(f'{num_collections} collections of {num_subgraphs_per_collection} subgraphs, size {subgraph_size}\nseed={seed}\nmu={mu},sig={sig},p={p}')
            plt.show()

