Loading Training Data

In [1]:
#Imports
import torch
from torch.utils.data import Dataset, DataLoader
import pickle
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Helper functions for printing and plotting graphs
def print_graph_details(G):
    print("Nodes:")
    for node, data in G.nodes(data=True):
        print(f"{node}: {data}")
    
    print("\nEdges:")
    for u, v, data in G.edges(data=True):
        print(f"{u} -> {v}: {data}")

# Function to plot the graph
def plot_graph(G):
    pos = nx.spring_layout(G)  # Layout for visualization
    plt.figure(figsize=(12, 12))
    nx.draw(G, pos, with_labels=True, node_size=500, node_color="skyblue", font_size=10, font_weight="bold", edge_color="gray")
    plt.title('Attention Graph')
    plt.show()

In [73]:
# load data using pickle
with open("attention_dataset.pkl", "rb") as f:
    """
    attention_data: dimension (num_samples, iteration, layer, 
                                1 #batch during generation, 
                                num_heads, seq_len, seq_len)
    reward_data: list of rewards (num_samples)
    """
    attention_data, reward_data = pickle.load(f)


# Create the graph structure from the attention weights
def attention_to_graph(attention):
    # Get the number of nodes
    n = attention.shape[-1] # number of tokens

    # Create a directed graph
    G = nx.DiGraph()

    # Add nodes from attention
    for i in range(n):
        # TODO: weight dependend on the number of the tokens
        G.add_node(f'token_{i}', weight=attention[i, i])
    
    # Add edges from attention
    for i in range(n):
        for j in range(n):
            if (j < i): # attention masking
                G.add_edge(f'token_{i}', f'token_{j}', weight=attention[i, j])

    # TODO: Check for further aggregation for transformer input
    return G

# AttentionDataset class
class AttentionDataset:
    def __init__(self, attentions, rewards):
        self.attentions = attentions
        self.rewards = rewards
        self.dataset_size = len(rewards)

    def __len__(self):
        return self.dataset_size

    def __getitem__(self, idx):
        return self.attentions[idx], self.rewards[idx]
    
def attention_collate_fn(batch):
    # get data from file and convert to format for PyTorch Geometric
    attentions, rewards = zip(*batch)

    # rewards
    rewards = torch.tensor(rewards)  # Convert rewards to tensor

    # attention
    # data: (num_samples, iteration, layer, 1, num_heads, seq_len, seq_len), rewards: (num_samples)
    attention_data_batch = []
    for attention_iteration in attentions:
        attention_data_step = []
        for attention_layer in attention_iteration:
            # attention_layer: (layer, 1, num_heads, seq_len, seq_len)
            attention_data_layer = []
            for attention_batch in attention_layer: 
                for attention_head in attention_batch: # TODO: Change data structure to remove the singleton batch dimension
                    attention_data_head = []
                    for attention in attention_head:  
                        # attention_head: (num_heads, seq_len, seq_len)
                        G = attention_to_graph(attention)
                        pyg_graph = from_networkx(G, group_node_attrs=['weight'], group_edge_attrs=['weight'])
                        attention_data_head.append(pyg_graph) # Put G for NetworkX graph
                attention_data_layer.append(attention_data_head)
            attention_data_step.append(attention_data_layer)
        attention_data_batch.append(attention_data_step)

    return attention_data_batch, rewards

In [74]:
attention_dataset = AttentionDataset(attention_data, reward_data)
attention_loader = DataLoader(attention_dataset, batch_size=32, shuffle=False, collate_fn=attention_collate_fn)

In [75]:
# Check Data dimensions
for batch in attention_loader:
    attentions, rewards = batch
    print("Batch attentions:", len(attentions))  # Check the number of attention graphs in the batch
    print("Batch attentions:", len(attentions[0]))
    print("Batch attentions:", len(attentions[0][0]))
    print("Batch attentions:", len(attentions[0][0][0]))
    print("Batch attentions:", len(attentions[0][0][0][0]))
    break  # Just to test the first batch

Batch attentions: 32
Batch attentions: 3
Batch attentions: 12
Batch attentions: 12
Batch attentions: 3


In [77]:
# Test pytorch geometric graph
for batch in attention_loader:
    attentions, rewards = batch
    pyg_graph = attentions[0][0][0][2]  # Get the first graph
    print(pyg_graph)
    print("Node features (x):", pyg_graph.x)
    print("Edge index:", pyg_graph.edge_index)
    print("Edge attributes:", pyg_graph.edge_attr)
    print("Label (y):", getattr(pyg_graph, 'y', None))
    print("Number of nodes:", pyg_graph.num_nodes)
    print("Number of edges:", pyg_graph.num_edges)

    if pyg_graph.x is not None:
        print("x shape:", pyg_graph.x.shape)
    print("edge_index shape:", pyg_graph.edge_index.shape)

    break  # Just to test the first graph

Data(edge_index=[2, 6], x=[4, 1], edge_attr=[6, 1])
Node features (x): tensor([[1.0000],
        [0.2059],
        [0.2359],
        [0.0536]])
Edge index: tensor([[1, 2, 2, 3, 3, 3],
        [0, 0, 1, 0, 1, 2]])
Edge attributes: tensor([[0.7941],
        [0.5216],
        [0.2425],
        [0.4180],
        [0.1817],
        [0.3467]])
Label (y): None
Number of nodes: 4
Number of edges: 6
x shape: torch.Size([4, 1])
edge_index shape: torch.Size([2, 6])


Model

In [None]:
# imports for the GNN model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
import torch_geometric.utils as pyg_utils

In [78]:
"""
Define the GNN model for aggregating attention graphs

It extracts information through message passing and aggregation
which is passed to the rest of the model.
This is mostly used to extract information and reduce to linear dimensionality
"""
class AggregationNetwork(torch.nn.Module):
    def __init__(self, hidden_dim, embedding_dim=2, dropout=0.2, adj_dropout=0.2): # TODO: hyperparameter tuning for dropout
        super(AggregationNetwork, self).__init__()
        self.conv1 = GCNConv(1, hidden_dim) # each node has a single feature (weight)
        self.conv2 = GCNConv(hidden_dim, embedding_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.dropout = torch.nn.Dropout(dropout)
        self.adj_dropout = adj_dropout
    
    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr  
        edge_index, edge_mask = pyg_utils.dropout_edge(edge_index, p=self.adj_dropout, training=self.training)

        if edge_weight is not None:
            edge_weight = edge_weight[edge_mask]

        x = self.conv1(x, edge_index, edge_weight)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index, edge_weight)
        x = global_mean_pool(x, data.batch)
        return x

In [None]:
"""
Define the Compression Network to compress data from the AggregationNetwork

Compresses the embedding from the AggregationNetworks (Heads, Layers)
into a smaller representation. This is then passed to the adversarial transformer model.
"""
class CompressionNetwork(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, compressed_dim, dropout=0.1):
        super(CompressionNetwork, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, compressed_dim)
        )

    def forward(self, x):
        x = self.net(x) 
        return x

In [None]:
""" 
Define the AggregationEncoderTransformer used to link the different layers

Encoder only transformer that processes the aggregated embeddings, to link them togheter
and further compress a signle network state.
"""
class AggregationEncoderTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads, num_layers, dropout=0.1):
        super(AggregationEncoderTransformer, self).__init__()
        self.embedding = nn.Linear(input_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads),
            num_layers=num_layers
        )
        self.fc_out = nn.Linear(hidden_dim, input_dim)  # Output layer
        
        self.pos_encoder = nn.Parameter(torch.randn(1, 100, hidden_dim))  # Positional encoding

    def forward(self, x):
        x = self.embedding(x)
        x = self.dropout(x)
        x = x + self.pos_encoder[:, :x.size(1), :]  # Add positional encoding
        x = self.transformer_encoder(x)
        x = self.fc_out(x)
        return x

In [None]:
""" 
Define the AttentionToRewardEncoder used to predict rewards

This encoder processes the attention features and predicts a reward based on them.
Used to predict the reward over multiple iterations of the primary Network.
"""
class AttentionToRewardEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=512, num_head=8, num_layers=6, dim_feedforward=2048, dropout=0.1):
        super(AttentionToRewardEncoder, self).__init__()

        self.embedding = nn.Linear(input_dim, hidden_dim)  # Project attention features
        self.dropout = nn.Dropout(dropout)
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_head , 
                                       dim_feedforward=dim_feedforward, 
                                       dropout=dropout, batch_first=True),
            num_layers=num_layers
        )

        self.fc_out = nn.Linear(hidden_dim, 1)  # Predict reward

        self.pos_encoder = nn.Parameter(torch.randn(1, 100, hidden_dim))  # Positional encoding

    def forward(self, x):
        x = self.embedding(x)  
        x = self.dropout(x)
        x = x + self.pos_encoder[:, :x.size(1), :]  # Add positional info
        x = self.transformer_encoder(x)  # No causal mask needed
        x = x.mean(dim=1)  # Pool over token representations (global understanding)
        x = self.fc_out(x)  # Predict token logits  
        
        return x  # Output shape: (reward)

In [None]:
class FullAdversarialAlignmentModel(nn.Module):
    def __init__(
        self,
        num_iterations,
        num_layers,
        num_heads,
        gnn_hidden_dim,
        gnn_embedding_dim,
        compression_hidden_dim,
        compression_dim,
        agg_hidden_dim,
        agg_heads,
        agg_layers,
        reward_hidden_dim,
        reward_heads,
        reward_layers,
        reward_ff_dim,
        dropout=0.1
    ):
        super(FullAdversarialAlignmentModel, self).__init__()

        # Shared GNN across all heads, layers, iterations
        self.gnn = AggregationNetwork(
            hidden_dim=gnn_hidden_dim,
            embedding_dim=gnn_embedding_dim,
            dropout=dropout
        )

        # Separate compression for each layer (compress across heads)
        self.compressors = nn.ModuleList([
            CompressionNetwork(
                input_dim=gnn_embedding_dim * num_heads,
                hidden_dim=compression_hidden_dim,
                compressed_dim=compression_dim,
                dropout=dropout
            ) for _ in range(num_layers)
        ])

        # Shared AggregationEncoder for linking compressed layers inside each iteration
        self.aggregation_encoder = AggregationEncoderTransformer(
            input_dim=compression_dim,
            hidden_dim=agg_hidden_dim,
            num_heads=agg_heads,
            num_layers=agg_layers,
            dropout=dropout
        )

        # Final reward predictor processing the sequence of iteration summaries
        self.reward_predictor = AttentionToRewardEncoder(
            input_dim=agg_hidden_dim,
            hidden_dim=reward_hidden_dim,
            num_head=reward_heads,
            num_layers=reward_layers,
            dim_feedforward=reward_ff_dim,
            dropout=dropout
        )

    def forward(self, attention_graph_batches):
        """
        Parameters
        ----------
        attention_graph_batches : list of lists of lists
            [iteration][layer][head] -> each element is a torch_geometric Batch

        Returns
        -------
        Tensor
            [batch_size, 1] reward predictions
        """

        iteration_embeddings = []

        for iteration_layers in attention_graph_batches:
            layer_embeddings = []

            for layer_idx, layer_heads in enumerate(iteration_layers):
                head_embeddings = []

                for head_graph in layer_heads:
                    # Shared GNN across all
                    gnn_out = self.gnn(head_graph)  # [batch_size, gnn_embedding_dim]
                    head_embeddings.append(gnn_out)

                # Concat all heads for this layer
                head_concat = torch.cat(head_embeddings, dim=-1)  # [batch_size, gnn_embedding_dim * num_heads]
                
                # Compress layer representation with layer-specific compressor
                compressed = self.compressors[layer_idx](head_concat)  # [batch_size, compression_dim]
                layer_embeddings.append(compressed.unsqueeze(1))  # keep sequence dimension

            # Sequence of compressed layers for this iteration
            layer_seq = torch.cat(layer_embeddings, dim=1)  # [batch_size, num_layers, compression_dim]
            
            # Shared aggregation encoder across iterations
            iter_encoded = self.aggregation_encoder(layer_seq)  # [batch_size, num_layers, agg_hidden_dim]

            # Pool over layers (mean) to get single iteration summary
            iter_summary = iter_encoded.mean(dim=1)  # [batch_size, agg_hidden_dim]
            iteration_embeddings.append(iter_summary.unsqueeze(1))

        # Stack all iteration summaries into a sequence
        iteration_seq = torch.cat(iteration_embeddings, dim=1)  # [batch_size, num_iterations, agg_hidden_dim]

        # Final reward prediction
        reward = self.reward_predictor(iteration_seq)  # [batch_size, 1]
        return reward