In [9]:
!pip install torch_geometric



In [10]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch.optim import Adam

import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch.optim import Adam

def synthesize_graph_for_class(model, device, target_class, num_samples, graph_data,
                               lr=0.1, steps=800, regularization_weight=1e-4, debug=False):
    model.eval()

    # Extract the original graph data
    node_features = graph_data.x.to(device)  # Node features from the input graph
    edge_index = graph_data.edge_index.to(device)  # Edge indices from the input graph

    # Initialize the optimizer for node features and adjacency matrix
    node_features.requires_grad = True
    adj_matrix = torch.randn(num_samples, num_samples, device=device, requires_grad=True)
    adj_matrix.data = torch.sigmoid(adj_matrix.data)

    optimizer = Adam([node_features, adj_matrix], lr=lr)
    done_mask = torch.zeros(num_samples, dtype=torch.bool, device=device)

    for step in range(steps):
        if done_mask.all():
            if debug:
                print(f"All samples reached target class by step {step}. Stopping early.")
            break

        optimizer.zero_grad()

        # Create graph data structure for PyG model
        edge_index = adj_matrix.nonzero(as_tuple=False).t()  # Update the graph edges from the learned adjacency matrix
        data = Data(x=node_features, edge_index=edge_index)

        # Forward pass through the model
        outputs = model(data)
        class_scores = outputs[:, target_class]
        print(f"Step {step+1}: Class scores for target class: {class_scores.mean().item()}")

        loss = -class_scores.mean() + regularization_weight * (node_features**2).mean()
        loss.backward()
        optimizer.step()

        # Clamp features between [0, 1]
        node_features.data.clamp_(0, 1)

        # Check if the node is classified as target class
        preds = torch.argmax(outputs, dim=1)
        previously_done_count = done_mask.sum().item()
        done_mask = done_mask | (preds == target_class)
        currently_done_count = done_mask.sum().item()

        if debug:
            newly_done = currently_done_count - previously_done_count
            print(f"Step {step+1}/{steps}: {currently_done_count}/{num_samples} nodes done (+{newly_done} this step).")

    if debug and not done_mask.all():
        print(f"Reached max steps without all nodes classified as target class. "
              f"{done_mask.sum().item()} out of {num_samples} done.")

    return node_features.detach(), adj_matrix.detach()



In [11]:
import torch
from torch_geometric.data import Data

def get_class_samples_from_noise_for_graph(model, device, unlearn_class, data):
    num_classes = 7
    num_samples = data.num_nodes

    # Synthesize graphs for the unlearn_class
    forget_node_features, forget_adj_matrix = synthesize_graph_for_class(
            model=model,
            device=device,
            target_class=unlearn_class,
            num_samples=num_samples,
            graph_data=data,
            lr=0.1,
            steps=500,
            regularization_weight=1e-4,
            debug=True
            )

    # Create Data object for the forget data
    forget_data = Data(x=forget_node_features, edge_index=forget_adj_matrix, y=torch.full((num_samples,), unlearn_class, dtype=torch.long))

    retain_node_features_list = []
    retain_labels_list = []

    # Synthesize graphs for other classes (retain data)
    for c in range(num_classes):
        if c == unlearn_class:
            continue

        # Generate synthetic graph for current class
        class_node_features, class_adj_matrix = synthesize_graph_for_class(
            model=model,
            device=device,
            target_class=c,  # Use class c for retention
            num_samples=num_samples,
            graph_data=data,
            lr=0.1,
            steps=500,
            regularization_weight=1e-4,
            debug=True
        )

        retain_node_features_list.append(class_node_features)

    retain_node_features = torch.cat(retain_node_features_list, dim=0)
    retain_labels = torch.tensor(retain_labels_list, dtype=torch.long)

    retain_data = Data(x=retain_node_features, edge_index=forget_adj_matrix, y=retain_labels)

    return forget_data, retain_data


In [13]:
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv

# Load the CORA dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]  # Extract the first graph
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
data = data.to(device)
# Example GCN model
class GCN(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, 16)
        self.conv2 = GCNConv(16, out_channels)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return x

# Initialize model and move to device (CPU/GPU)
model = GCN(in_channels=data.num_features, out_channels=dataset.num_classes).to(device)
model.eval()

# Synthesize a graph for a specific class
target_class = 0  # Example target class
num_samples = data.num_nodes  # Using the number of nodes in the original graph
forget_data,retain_data = get_class_samples_from_noise_for_graph(model, device, target_class, data)

Step 1: Class scores for target class: 0.0013795519480481744
Step 1/500: 0/2708 nodes done (+0 this step).
Step 2: Class scores for target class: 1.0357791185379028
Step 2/500: 2708/2708 nodes done (+2708 this step).
All samples reached target class by step 2. Stopping early.
Step 1: Class scores for target class: 1.5318753719329834
Step 1/500: 0/2708 nodes done (+0 this step).
Step 2: Class scores for target class: 3.4000704288482666
Step 2/500: 0/2708 nodes done (+0 this step).
Step 3: Class scores for target class: 5.019963264465332
Step 3/500: 2708/2708 nodes done (+2708 this step).
All samples reached target class by step 3. Stopping early.
Step 1: Class scores for target class: -1.3370870351791382
Step 1/500: 0/2708 nodes done (+0 this step).
Step 2: Class scores for target class: 0.6938981413841248
Step 2/500: 0/2708 nodes done (+0 this step).
Step 3: Class scores for target class: 1.7264173030853271
Step 3/500: 0/2708 nodes done (+0 this step).
Step 4: Class scores for target c

In [14]:
print(retain_data)

Data(x=[16248, 1433], edge_index=[2708, 2708], y=[0])


In [16]:
print(forget_data)

Data(x=[2708, 1433], edge_index=[2708, 2708], y=[2708])
