# Graph Augmentations

## Library Import

In [None]:
import os
from copy import deepcopy

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import rootutils
import torch
import torch_geometric
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
from torch_geometric.data import Data
from torch_geometric.utils import degree, to_networkx

rootutils.setup_root(os.getcwd(), indicator=".project-root", pythonpath=True)

from src.utils.graph_augmentations import get_graph_augmentation

np.random.seed(44)
torch.manual_seed(44)
torch.cuda.manual_seed(44)

## Helper Functions

In [None]:
def visualize_graph(graph, color):
    G = to_networkx(graph, to_undirected=False)

    plt.figure(figsize=(7, 7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(
        G, pos=nx.spring_layout(G, seed=42), with_labels=True, node_color=color, cmap="Set2"
    )
    plt.show()

## Example Graph

x = torch.tensor([[1, 5], [4, 2], [3, 6], [7, 8], [0, 3], [4, 7]], dtype=torch.float)
edge_index = torch.tensor([[0, 3, 2, 1, 4, 4, 4, 4], [1, 5, 3, 4, 5, 3, 1, 0]], dtype=torch.long)
edge_weight = torch.ones(edge_index.size(1), dtype=torch.float)
sample_name = "example_graph"
position = torch.tensor([[0, 0], [1, 0], [0, 1], [1, 1], [2, 0], [2, 1]], dtype=torch.float)

graph = Data(
    x=x,
    edge_index=edge_index,
    edge_weight=edge_weight,
    sample_name=sample_name,
    position=position,
)
graph.edge_index = torch_geometric.utils.to_undirected(graph.edge_index)
graph.edge_weight = torch.ones(graph.edge_index.size(1), dtype=torch.float)

graph

In [None]:
np.random.seed(44)
torch.manual_seed(44)
torch.cuda.manual_seed(44)

num_nodes = 10000
num_features = 50
k = 10
num_groups = 5
noise_scale = 0.05
position_noise_scale = 0.2

group_features = torch.rand((num_groups, num_features), dtype=torch.float)
group_assignments = torch.randint(0, num_groups, (num_nodes,))
x = group_features[group_assignments] + noise_scale * torch.randn((num_nodes, num_features))

group_positions = torch.rand((num_groups, 2), dtype=torch.float)
positions = group_positions[group_assignments] + position_noise_scale * torch.randn((num_nodes, 2))

nbrs = NearestNeighbors(n_neighbors=k, algorithm="ball_tree").fit(positions)
distances, indices = nbrs.kneighbors(positions)
edge_index = []
for i, neighbors in enumerate(indices):
    for neighbor in neighbors:
        if neighbor != i:
            edge_index.append([i, neighbor])

edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
edge_weight = torch.ones(edge_index.size(1), dtype=torch.float)

graph = Data(x=x, edge_index=edge_index, edge_weight=edge_weight, position=positions)

graph.edge_index = torch_geometric.utils.to_undirected(graph.edge_index)
graph.edge_weight = torch.ones(graph.edge_index.size(1), dtype=torch.float)
graph.sample_name = "example_graph"

graph

## Graph Augmentations

In [None]:
transform1 = get_graph_augmentation(
    augmentation_mode="advanced",
    drop_edge_p=0.3,
    drop_feat_p=0.3,
    mu=0.3,
    p_lambda=0.3,
    p_rewire=0.1,
    p_shuffle=0.1,
)
transform1

In [None]:
transform2 = get_graph_augmentation(
    augmentation_mode="baseline",
    drop_edge_p=0.3,
    drop_feat_p=0.3,
    mu=0.1,
    p_lambda=0.1,
    p_rewire=0.2,
    p_shuffle=0.2,
)
transform2

In [None]:
aug1 = transform1(graph)
aug1

In [None]:
aug2 = transform2(graph)
aug2

In [None]:
aug1.is_undirected()

In [None]:
visualize_graph(aug1, color="red")
visualize_graph(aug2, color="green")

## Proprotional Resampling

alpha = 0.0  # weighting factor: 0 = only spatial, 1 = only expression
n_clusters = 5
new_distribution = None

x = data.x.cpu().numpy()
pos = data.position.cpu().numpy()

x_scaled = StandardScaler().fit_transform(x)
pos_scaled = StandardScaler().fit_transform(pos)
combined_features = np.concatenate([alpha * x_scaled, (1 - alpha) * pos_scaled], axis=1)

kmeans = KMeans(n_clusters=n_clusters, random_state=44).fit(combined_features)
cluster_labels = kmeans.labels_
cluster_to_indices = {i: np.where(cluster_labels == i)[0] for i in range(n_clusters)}

if new_distribution is None:
    new_distribution = {i: 1/n_clusters for i in range(n_clusters)}

new_node_indices = []
for cluster_id, proportion in new_distribution.items():
    n = int(proportion * data.num_nodes)
    candidates = cluster_to_indices[cluster_id]
    chosen = np.random.choice(candidates, n, replace=len(candidates) < n)
    new_node_indices.extend(chosen)

new_node_indices = torch.tensor(new_node_indices, dtype=torch.long)

from torch_geometric.utils import subgraph
new_edge_index, new_edge_weight = subgraph(
    subset=new_node_indices, edge_index=data.edge_index, edge_attr=data.edge_weight, relabel_nodes=True
)
data = Data(
    x=data.x[new_node_indices],
    edge_index=new_edge_index,
    edge_weight=new_edge_weight,
    position=data.position[new_node_indices],
    sample_name=data.sample_name,
)