# Graph Augmentations

## Library Import

In [None]:
import os
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import rootutils
import torch
import torch_geometric
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops, degree, to_networkx, to_undirected
from torch_sparse import SparseTensor

rootutils.setup_root(os.getcwd(), indicator=".project-root", pythonpath=True)

from src.utils.graph_augmentations_domain import (
    get_graph_augmentation,
    remove_directed_edges,
)

np.random.seed(44)
torch.manual_seed(44)
torch.cuda.manual_seed(44)

## Helper Functions

In [None]:
def visualize_graph(graph, color):
    G = to_networkx(graph, to_undirected=False)

    plt.figure(figsize=(7, 7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(
        G, pos=nx.spring_layout(G, seed=42), with_labels=True, node_color=color, cmap="Set2"
    )
    plt.show()

In [None]:
def create_domain_graph(
    graph_name: str,
    num_nodes: int,
    num_features: int,
    num_neighbors: int,
    num_classes: int,
    seed: int = 44,
    noise_scale: float = 0.05,
    position_noise_scale: float = 0.2,
) -> Data:
    """
    Create a graph like the domain identification graphs.
    """
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    group_features = torch.rand((num_classes, num_features), dtype=torch.float)
    group_assignments = torch.randint(0, num_classes, (num_nodes,))
    x = group_features[group_assignments] + noise_scale * torch.randn((num_nodes, num_features))

    group_positions = torch.rand((num_classes, 2), dtype=torch.float)
    positions = group_positions[group_assignments] + position_noise_scale * torch.randn(
        (num_nodes, 2)
    )

    nbrs = NearestNeighbors(n_neighbors=num_neighbors, algorithm="ball_tree").fit(positions)
    distances, indices = nbrs.kneighbors(positions)
    edge_index = []
    for i, neighbors in enumerate(indices):
        for neighbor in neighbors:
            if neighbor != i:
                edge_index.append([i, neighbor])

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_weight = torch.ones(edge_index.size(1), dtype=torch.float)

    graph = Data(x=x, edge_index=edge_index, edge_weight=edge_weight, position=positions)

    graph.edge_index = torch_geometric.utils.to_undirected(graph.edge_index)
    graph.edge_weight = torch.ones(graph.edge_index.size(1), dtype=torch.float)
    graph.sample_name = graph_name

    assert graph.is_undirected(), "Graph is not undirected!"
    assert not graph.has_self_loops(), "Graph has self-loops!"

    return graph

In [None]:
import time

import pandas as pd
import torch
from torch_geometric.transforms import Compose
from tqdm import tqdm

from src.utils.graph_augmentations_domain import (
    AddEdgesByFeatureSimilarity,
    DropEdges,
    DropFeatures,
    DropImportance,
    FeatureNoise,
    ShufflePositions,
    SpatialNoise,
)

# from copy import deepcopy



# --- Configuration ---
graph_sizes = [100, 500, 1000, 5000, 10000, 50000, 100000, 150000]
num_features = 50
num_neighbors = 10
num_classes = 10
num_runs = 3
seed = 44

augmentations = {
    "DropFeatures": DropFeatures(p=0.2),
    "DropEdges": DropEdges(p=0.2),
    "DropImportance": DropImportance(mu=0.2, p_lambda=0.5),
    "SpatialNoise": SpatialNoise(spatial_noise_std=10),
    "FeatureNoise": FeatureNoise(feature_noise_std=1),
    "ShufflePositions": ShufflePositions(p_shuffle=0.2),
    "AddEdgesByFeatureSimilarity": AddEdgesByFeatureSimilarity(p_add=0.2, k_add=2),
}

augmentation_combos = {
    "DropFeatures": [augmentations["DropFeatures"]],
    "DropEdges": [augmentations["DropEdges"]],
    "SpatialNoise": [augmentations["SpatialNoise"]],
    "FeatureNoise": [augmentations["FeatureNoise"]],
    "ShufflePositions": [augmentations["ShufflePositions"]],
    "AddEdgesByFeatureSimilarity": [augmentations["AddEdgesByFeatureSimilarity"]],
    "Baseline": [augmentations["DropEdges"], augmentations["DropFeatures"]],
    "Baseline + SpatialNoise + FeatureNoise": [
        augmentations["DropEdges"],
        augmentations["DropFeatures"],
        augmentations["SpatialNoise"],
        augmentations["FeatureNoise"],
    ],
    "DropImportance": [augmentations["DropImportance"]],
    "DropImportance + SpatialNoise + FeatureNoise": [
        augmentations["DropImportance"],
        augmentations["SpatialNoise"],
        augmentations["FeatureNoise"],
    ],
    "DropImportance + SpatialNoise + FeatureNoise + ShufflePositions": [
        augmentations["DropImportance"],
        augmentations["SpatialNoise"],
        augmentations["FeatureNoise"],
        augmentations["ShufflePositions"],
    ],
    "DropImportance + SpatialNoise + FeatureNoise + AddEdgesByFeatureSimilarity": [
        augmentations["DropImportance"],
        augmentations["SpatialNoise"],
        augmentations["FeatureNoise"],
        augmentations["AddEdgesByFeatureSimilarity"],
    ],
}

results = []

# --- Benchmarking Loop ---
for size in tqdm(graph_sizes, desc="Graph sizes"):
    graph = create_domain_graph(
        graph_name=f"graph_{size}",
        num_nodes=size,
        num_features=num_features,
        num_neighbors=num_neighbors,
        num_classes=num_classes,
        seed=seed,
    ).to("cuda" if torch.cuda.is_available() else "cpu")

    for name, transforms in augmentation_combos.items():
        compose = Compose(transforms)

        times = []
        memory_usages = []

        for _ in range(num_runs):
            data = deepcopy(graph)

            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.reset_peak_memory_stats()
                torch.cuda.synchronize()

            start = time.time()
            _ = compose(data)
            torch.cuda.synchronize() if torch.cuda.is_available() else None
            duration = time.time() - start
            times.append(duration)

            # Memory in MB
            if torch.cuda.is_available():
                mem = torch.cuda.max_memory_allocated() / 1024**2
                memory_usages.append(mem)
            else:
                memory_usages.append(0.0)

        results.append(
            {
                "augmentation": name,
                "num_nodes": size,
                "avg_time_s": sum(times) / len(times),
                "max_memory_mb": max(memory_usages),
            }
        )

# --- Results ---
df = pd.DataFrame(results)
pivot_time = df.pivot(index="num_nodes", columns="augmentation", values="avg_time_s")
pivot_mem = df.pivot(index="num_nodes", columns="augmentation", values="max_memory_mb")

print("=== Average Runtime (s) ===")
print(pivot_time.round(4))
print("\n=== Max Memory Usage (MB) ===")
print(pivot_mem.round(2))

# Optional: save results
# df.to_csv("augmentation_benchmark_results.csv", index=False)

In [None]:
df

In [None]:
def create_phenotype_graph(
    graph_name: str,
    num_nodes: int,
    num_neighbors: int,
    num_classes: int,
    seed: int = 44,
    position_noise_scale: float = 0.2,
) -> Data:
    """
    Create a graph like the phenotype prediction graphs.
    """
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    cell_types = torch.randint(0, num_classes, (num_nodes,))
    sizes = 0.2 + 0.6 * torch.rand(num_nodes)
    x = torch.stack([cell_types.float(), sizes], dim=1)

    group_positions = torch.rand((num_classes, 2))
    position_noise = position_noise_scale * torch.randn((num_nodes, 2))
    positions = group_positions[cell_types] + position_noise

    nbrs = NearestNeighbors(n_neighbors=num_neighbors, algorithm="ball_tree").fit(positions)
    distances, indices = nbrs.kneighbors(positions)

    edge_list = []
    edge_attrs = []

    for i, (dists, neighbors) in enumerate(zip(distances, indices)):
        for dist, j in zip(dists, neighbors):
            if i != j:
                edge_list.append([i, j])
                edge_type = 0 if dist < 0.5 else 1
                edge_attrs.append([edge_type, dist])

    edge_index = torch.tensor(edge_list, dtype=torch.long).T
    edge_attr = torch.tensor(edge_attrs, dtype=torch.float)

    edge_index, edge_attr = to_undirected(edge_index, edge_attr=edge_attr)
    edge_index, edge_attr = add_self_loops(edge_index, edge_attr=edge_attr, num_nodes=num_nodes)

    graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    graph.sample_name = graph_name

    assert graph.is_undirected(), "Graph is not undirected!"

    return graph

In [None]:
import time

import pandas as pd
import torch
from torch_geometric.transforms import Compose
from tqdm import tqdm

from src.utils.graph_augmentations_phenotype import (
    AddEdgesByCellType,
    DropEdges,
    DropFeatures,
    DropImportance,
    FeatureNoise,
    ShufflePositions,
)

# from copy import deepcopy



# --- Configuration ---
graph_sizes = [100, 500, 1000, 5000, 10000, 50000, 100000, 150000]
num_neighbors = 10
num_classes = 30
num_runs = 3
seed = 44

augmentations = {
    "DropFeatures": DropFeatures(p=0.2),
    "DropEdges": DropEdges(p=0.2),
    "DropImportance": DropImportance(mu=0.2, p_lambda=0.5),
    "FeatureNoise": FeatureNoise(feature_noise_std=1),
    "ShufflePositions": ShufflePositions(p_shuffle=0.2),
    "AddEdgesByCellType": AddEdgesByCellType(p_add=0.2, k_add=2),
}

augmentation_combos = {
    "DropFeatures": [augmentations["DropFeatures"]],
    "DropEdges": [augmentations["DropEdges"]],
    "FeatureNoise": [augmentations["FeatureNoise"]],
    "ShufflePositions": [augmentations["ShufflePositions"]],
    "AddEdgesByCellType": [augmentations["AddEdgesByCellType"]],
    "Baseline": [augmentations["DropEdges"], augmentations["DropFeatures"]],
    "Baseline + FeatureNoise": [
        augmentations["DropEdges"],
        augmentations["DropFeatures"],
        augmentations["FeatureNoise"],
    ],
    "DropImportance": [augmentations["DropImportance"]],
    "DropImportance + FeatureNoise": [
        augmentations["DropImportance"],
        augmentations["FeatureNoise"],
    ],
    "DropImportance + FeatureNoise + ShufflePositions": [
        augmentations["DropImportance"],
        augmentations["FeatureNoise"],
        augmentations["ShufflePositions"],
    ],
    "DropImportance + FeatureNoise + AddEdgesByCellType": [
        augmentations["DropImportance"],
        augmentations["FeatureNoise"],
        augmentations["AddEdgesByCellType"],
    ],
}

results = []

# --- Benchmarking Loop ---
for size in tqdm(graph_sizes, desc="Graph sizes"):
    graph = create_phenotype_graph(
        graph_name=f"graph_{size}",
        num_nodes=size,
        num_neighbors=num_neighbors,
        num_classes=num_classes,
        seed=seed,
    ).to("cuda" if torch.cuda.is_available() else "cpu")

    for name, transforms in augmentation_combos.items():
        compose = Compose(transforms)

        times = []
        memory_usages = []

        for _ in range(num_runs):
            data = deepcopy(graph)

            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.reset_peak_memory_stats()
                torch.cuda.synchronize()

            start = time.time()
            _ = compose(data)
            torch.cuda.synchronize() if torch.cuda.is_available() else None
            duration = time.time() - start
            times.append(duration)

            # Memory in MB
            if torch.cuda.is_available():
                mem = torch.cuda.max_memory_allocated() / 1024**2
                memory_usages.append(mem)
            else:
                memory_usages.append(0.0)

        results.append(
            {
                "augmentation": name,
                "num_nodes": size,
                "avg_time_s": sum(times) / len(times),
                "max_memory_mb": max(memory_usages),
            }
        )

# --- Results ---
df = pd.DataFrame(results)
pivot_time = df.pivot(index="num_nodes", columns="augmentation", values="avg_time_s")
pivot_mem = df.pivot(index="num_nodes", columns="augmentation", values="max_memory_mb")

print("=== Average Runtime (s) ===")
print(pivot_time.round(4))
print("\n=== Max Memory Usage (MB) ===")
print(pivot_mem.round(2))

# Optional: save results
# df.to_csv("augmentation_benchmark_results.csv", index=False)

In [None]:
df

## Example Graph

x = torch.tensor([[1, 5], [4, 2], [3, 6], [7, 8], [0, 3], [4, 7]], dtype=torch.float)
edge_index = torch.tensor([[0, 3, 2, 1, 4, 4, 4, 4], [1, 5, 3, 4, 5, 3, 1, 0]], dtype=torch.long)
edge_weight = torch.ones(edge_index.size(1), dtype=torch.float)
sample_name = "example_graph"
position = torch.tensor([[0, 0], [1, 0], [0, 1], [1, 1], [2, 0], [2, 1]], dtype=torch.float)

graph = Data(
    x=x,
    edge_index=edge_index,
    edge_weight=edge_weight,
    sample_name=sample_name,
    position=position,
)
graph.edge_index = torch_geometric.utils.to_undirected(graph.edge_index)
graph.edge_weight = torch.ones(graph.edge_index.size(1), dtype=torch.float)

graph

In [None]:
np.random.seed(44)
torch.manual_seed(44)
torch.cuda.manual_seed(44)

num_nodes = 200000
num_features = 50
k = 10
num_groups = 5
noise_scale = 0.05
position_noise_scale = 0.2

group_features = torch.rand((num_groups, num_features), dtype=torch.float)
group_assignments = torch.randint(0, num_groups, (num_nodes,))
x = group_features[group_assignments] + noise_scale * torch.randn((num_nodes, num_features))

group_positions = torch.rand((num_groups, 2), dtype=torch.float)
positions = group_positions[group_assignments] + position_noise_scale * torch.randn((num_nodes, 2))

nbrs = NearestNeighbors(n_neighbors=k, algorithm="ball_tree").fit(positions)
distances, indices = nbrs.kneighbors(positions)
edge_index = []
for i, neighbors in enumerate(indices):
    for neighbor in neighbors:
        if neighbor != i:
            edge_index.append([i, neighbor])

edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
edge_weight = torch.ones(edge_index.size(1), dtype=torch.float)

graph = Data(x=x, edge_index=edge_index, edge_weight=edge_weight, position=positions)

graph.edge_index = torch_geometric.utils.to_undirected(graph.edge_index)
graph.edge_weight = torch.ones(graph.edge_index.size(1), dtype=torch.float)
graph.sample_name = "example_graph"

graph

## Graph Augmentations

In [None]:
transform1 = get_graph_augmentation(
    augmentation_mode="advanced",
    augmentation_list=["DropImportance", "SpatialNoise", "AddEdgesByFeatureSimilarity"],
    drop_edge_p=0.3,
    drop_feat_p=0.3,
    mu=0.2,
    p_lambda=0.5,
    p_rewire=0.3,
    p_shuffle=0.1,
    spatial_noise_std=0.01,
    feature_noise_std=0.01,
    p_add=0.2,
    k_add=2,
)
transform1

In [None]:
transform2 = get_graph_augmentation(
    augmentation_mode="advanced",
    augmentation_list=["DropImportance"],
    drop_edge_p=0.3,
    drop_feat_p=0.3,
    mu=0.2,
    p_lambda=0.5,
    p_rewire=0.2,
    p_shuffle=0.2,
    spatial_noise_std=0.01,
    feature_noise_std=0.01,
    p_add=0.1,
    k_add=3,
)
transform2

In [None]:
aug1 = transform1(graph)
aug1

In [None]:
aug2 = transform2(graph)
aug2

In [None]:
aug1.is_undirected()

In [None]:
aug2.is_undirected()

In [None]:
for i in range(10):
    transform1 = get_graph_augmentation(
        augmentation_mode="advanced",
        augmentation_list=["DropImportance", "SpatialNoise"],
        drop_edge_p=0.3,
        drop_feat_p=0.3,
        mu=0.3,
        p_lambda=0.3,
        p_rewire=0.5,
        p_shuffle=0.1,
        spatial_noise_std=0.01,
        feature_noise_std=0.01,
        p_add=0.1,
        k_add=3,
    )
    data = transform1(graph)
    print(data)

In [None]:
visualize_graph(aug1, color="red")
visualize_graph(aug2, color="green")

## Proprotional Resampling

alpha = 0.0  # weighting factor: 0 = only spatial, 1 = only expression
n_clusters = 5
new_distribution = None

x = data.x.cpu().numpy()
pos = data.position.cpu().numpy()

x_scaled = StandardScaler().fit_transform(x)
pos_scaled = StandardScaler().fit_transform(pos)
combined_features = np.concatenate([alpha * x_scaled, (1 - alpha) * pos_scaled], axis=1)

kmeans = KMeans(n_clusters=n_clusters, random_state=44).fit(combined_features)
cluster_labels = kmeans.labels_
cluster_to_indices = {i: np.where(cluster_labels == i)[0] for i in range(n_clusters)}

if new_distribution is None:
    new_distribution = {i: 1/n_clusters for i in range(n_clusters)}

new_node_indices = []
for cluster_id, proportion in new_distribution.items():
    n = int(proportion * data.num_nodes)
    candidates = cluster_to_indices[cluster_id]
    chosen = np.random.choice(candidates, n, replace=len(candidates) < n)
    new_node_indices.extend(chosen)

new_node_indices = torch.tensor(new_node_indices, dtype=torch.long)

from torch_geometric.utils import subgraph
new_edge_index, new_edge_weight = subgraph(
    subset=new_node_indices, edge_index=data.edge_index, edge_attr=data.edge_weight, relabel_nodes=True
)
data = Data(
    x=data.x[new_node_indices],
    edge_index=new_edge_index,
    edge_weight=new_edge_weight,
    position=data.position[new_node_indices],
    sample_name=data.sample_name,
)

In [None]:
from src.utils.graph_augmentations_phenotype import get_graph_augmentation

In [None]:
graph = torch.load("../data/phenotype/nsclc/processed/175A_100.0.gpt", weights_only=False)
graph

In [None]:
transform = get_graph_augmentation(
    augmentation_mode="advanced",
    augmentation_list=["ShufflePositions"],
    drop_edge_p=0.3,
    drop_feat_p=0.2,
    mu=0.2,
    p_lambda=0.5,
    p_rewire=0.5,
    feature_noise_std=0.01,
    p_add=0.1,
    k_add=3,
    p_shuffle=0.1,
)
transform

In [None]:
data = transform(graph)
data

In [None]:
graph

In [None]:
(graph.edge_index != data.edge_index).sum()

In [None]:
data.edge_index