Loading Training Data

In [1]:
#Imports
import torch
from torch.utils.data import Dataset, DataLoader
import pickle
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from torch_geometric.data import Data, Batch
from torch_geometric.utils import from_networkx

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Helper functions for printing and plotting graphs
def print_graph_details(G):
    print("Nodes:")
    for node, data in G.nodes(data=True):
        print(f"{node}: {data}")
    
    print("\nEdges:")
    for u, v, data in G.edges(data=True):
        print(f"{u} -> {v}: {data}")

# Function to plot the graph
def plot_graph(G):
    pos = nx.spring_layout(G)  # Layout for visualization
    plt.figure(figsize=(12, 12))
    nx.draw(G, pos, with_labels=True, node_size=500, node_color="skyblue", font_size=10, font_weight="bold", edge_color="gray")
    plt.title('Attention Graph')
    plt.show()

In [3]:
# load data using pickle
with open("attention_dataset.pkl", "rb") as f:
    """
    attention_data: dimension (num_samples, iteration, layer, 
                                1 #batch during generation, 
                                num_heads, seq_len, seq_len)
    reward_data: list of rewards (num_samples)
    """
    attention_data, reward_data = pickle.load(f)


# Create the graph structure from the attention weights
def attention_to_graph(attention):
    # Get the number of nodes
    n = attention.shape[-1] # number of tokens

    # Create a directed graph
    G = nx.DiGraph()

    # Add nodes from attention
    for i in range(n):
        # TODO: weight dependend on the number of the tokens
        G.add_node(f'token_{i}', weight=attention[i, i])
    
    # Add edges from attention
    for i in range(n):
        for j in range(n):
            if (j < i): # attention masking
                G.add_edge(f'token_{i}', f'token_{j}', weight=attention[i, j])

    # TODO: Check for further aggregation for transformer input
    return G

# AttentionDataset class
class AttentionDataset:
    def __init__(self, attentions, rewards):
        self.attentions = attentions
        self.rewards = rewards
        self.dataset_size = len(rewards)

    def __len__(self):
        return self.dataset_size

    def __getitem__(self, idx):
        return self.attentions[idx], self.rewards[idx]
    
def attention_collate_fn(batch):
    # get data from file and convert to format for PyTorch Geometric
    attentions, rewards = zip(*batch)

    # rewards
    rewards = torch.tensor(rewards)  # Convert rewards to tensor

    # attention
    # data: (num_samples, iteration, layer, 1, num_heads, seq_len, seq_len), rewards: (num_samples)
    attention_data_batch = []
    for attention_iteration in attentions:
        attention_data_step = []
        for attention_layer in attention_iteration:
            # attention_layer: (layer, 1, num_heads, seq_len, seq_len)
            attention_data_layer = []
            for attention_batch in attention_layer: 
                for attention_head in attention_batch: # TODO: Change data structure to remove the singleton batch dimension
                    attention_data_head = []
                    for attention in attention_head:  
                        # attention_head: (num_heads, seq_len, seq_len)
                        G = attention_to_graph(attention)
                        pyg_graph = from_networkx(G, group_node_attrs=['weight'], group_edge_attrs=['weight'])
                        attention_data_head.append(pyg_graph) # Put G for NetworkX graph
                attention_data_layer.append(Batch.from_data_list(attention_data_head))
            attention_data_step.append(attention_data_layer)
        attention_data_batch.append(attention_data_step)

    return attention_data_batch, rewards

In [4]:
attention_dataset = AttentionDataset(attention_data, reward_data)
attention_loader = DataLoader(attention_dataset, batch_size=32, shuffle=False, collate_fn=attention_collate_fn)

In [5]:
# Check Data dimensions
for batch in attention_loader:
    attentions, rewards = batch
    print("Batch attentions:", len(attentions))  # Check the number of attention graphs in the batch
    print("Batch attentions:", len(attentions[0]))
    print("Batch attentions:", len(attentions[0][0]))
    print("Batch attentions:", len(attentions[0][0][0]))
    print("Batch attentions:", len(attentions[0][0][0][0]))
    break  # Just to test the first batch

Batch attentions: 32
Batch attentions: 3
Batch attentions: 12
Batch attentions: 12
Batch attentions: 3


In [6]:
# Test pytorch geometric graph
for batch in attention_loader:
    attentions, rewards = batch
    pyg_graph = attentions[0][0][0][2]  # Get the first graph
    print(pyg_graph)
    print("Node features (x):", pyg_graph.x)
    print("Edge index:", pyg_graph.edge_index)
    print("Edge attributes:", pyg_graph.edge_attr)
    print("Label (y):", getattr(pyg_graph, 'y', None))
    print("Number of nodes:", pyg_graph.num_nodes)
    print("Number of edges:", pyg_graph.num_edges)

    if pyg_graph.x is not None:
        print("x shape:", pyg_graph.x.shape)
    print("edge_index shape:", pyg_graph.edge_index.shape)

    break  # Just to test the first graph

Data(edge_index=[2, 6], x=[4, 1], edge_attr=[6, 1])
Node features (x): tensor([[1.0000],
        [0.2059],
        [0.2359],
        [0.0536]])
Edge index: tensor([[1, 2, 2, 3, 3, 3],
        [0, 0, 1, 0, 1, 2]])
Edge attributes: tensor([[0.7941],
        [0.5216],
        [0.2425],
        [0.4180],
        [0.1817],
        [0.3467]])
Label (y): None
Number of nodes: 4
Number of edges: 6
x shape: torch.Size([4, 1])
edge_index shape: torch.Size([2, 6])


Model

In [5]:
# imports for the GNN model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
import torch_geometric.utils as pyg_utils

In [6]:
"""
Define the GNN model for aggregating attention graphs

It extracts information through message passing and aggregation
which is passed to the rest of the model.
This is mostly used to extract information and reduce to linear dimensionality
"""
class AggregationNetwork(torch.nn.Module):
    def __init__(self, hidden_dim, embedding_dim=2, dropout=0.2, adj_dropout=0.2): # TODO: hyperparameter tuning for dropout
        super(AggregationNetwork, self).__init__()
        self.conv1 = GCNConv(1, hidden_dim) # each node has a single feature (weight)
        self.conv2 = GCNConv(hidden_dim, embedding_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.dropout = torch.nn.Dropout(dropout)
        self.adj_dropout = adj_dropout
    
    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr  
        edge_index, edge_mask = pyg_utils.dropout_edge(edge_index, p=self.adj_dropout, training=self.training)

        if edge_weight is not None:
            edge_weight = edge_weight[edge_mask]

        x = self.conv1(x, edge_index, edge_weight)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index, edge_weight)
        x = global_mean_pool(x, data.batch)
        return x

In [7]:
"""
Define the Compression Network to compress data from the AggregationNetwork

Compresses the embedding from the AggregationNetworks (Heads, Layers)
into a smaller representation. This is then passed to the adversarial transformer model.
"""
class CompressionNetwork(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, compressed_dim, dropout=0.1):
        super(CompressionNetwork, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, compressed_dim)
        )

    def forward(self, x):
        x = self.net(x) 
        return x

In [8]:
""" 
Define the AggregationEncoderTransformer used to link the different layers

Encoder only transformer that processes the aggregated embeddings, to link them togheter
and further compress a signle network state.
"""
class AggregationEncoderTransformer(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, num_heads, num_layers, dropout=0.1):
        super(AggregationEncoderTransformer, self).__init__()
        self.embedding = nn.Linear(input_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads),
            num_layers=num_layers
        )
        self.fc_out = nn.Linear(hidden_dim, output_dim)  # Output layer
        
        self.pos_encoder = nn.Parameter(torch.randn(1, 100, hidden_dim))  # Positional encoding

    def forward(self, x):
        x = self.embedding(x)
        x = self.dropout(x)
        x = x + self.pos_encoder[:, :x.size(1), :]  # Add positional encoding
        x = self.transformer_encoder(x)
        x = self.fc_out(x)
        return x

In [9]:
""" 
Define the AttentionToRewardEncoder used to predict rewards

This encoder processes the attention features and predicts a reward based on them.
Used to predict the reward over multiple iterations of the primary Network.
"""
class AttentionToRewardEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=512, num_head=8, num_layers=6, dim_feedforward=2048, dropout=0.1):
        super(AttentionToRewardEncoder, self).__init__()

        self.embedding = nn.Linear(input_dim, hidden_dim)  # Project attention features
        self.dropout = nn.Dropout(dropout)
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_head , 
                                       dim_feedforward=dim_feedforward, 
                                       dropout=dropout, batch_first=True),
            num_layers=num_layers
        )

        self.fc_out = nn.Linear(hidden_dim, 1)  # Predict reward

        self.pos_encoder = nn.Parameter(torch.randn(1, 100, hidden_dim))  # Positional encoding

    def forward(self, x):
        x = self.embedding(x)  
        x = self.dropout(x)
        x = x + self.pos_encoder[:, :x.size(1), :]  # Add positional info
        x = self.transformer_encoder(x)  # No causal mask needed
        x = x.mean(dim=1)  # Pool over token representations (global understanding)
        x = self.fc_out(x)  # Predict token logits  
        
        return x  # Output shape: (reward)

In [12]:
class FullAdversarialAlignmentModel(nn.Module):
    def __init__(
        self,
        num_iterations,
        num_layers,
        num_heads,

        gnn_hidden_dim,
        gnn_embedding_dim,

        compression_hidden_dim,
        compression_dim,

        agg_hidden_dim,
        agg_heads,
        agg_layers,
        
        reward_hidden_dim,
        reward_heads,
        reward_layers,
        reward_ff_dim,
        dropout=0.1
    ):
        super(FullAdversarialAlignmentModel, self).__init__()

        # Shared GNN across all heads, layers, iterations
        self.gnn = AggregationNetwork(
            hidden_dim=gnn_hidden_dim,
            embedding_dim=gnn_embedding_dim,
            dropout=dropout
        )

        # Separate compression for each layer (compress across heads)
        self.compressors = nn.ModuleList([
            CompressionNetwork(
                input_dim=gnn_embedding_dim * num_heads,
                hidden_dim=compression_hidden_dim,
                compressed_dim=compression_dim,
                dropout=dropout
            ) for _ in range(num_layers)
        ])

        # Shared AggregationEncoder for linking compressed layers inside each iteration
        self.aggregation_encoder = AggregationEncoderTransformer(
            input_dim=compression_dim * num_layers,  # Concatenate all compressed layers
            output_dim=agg_hidden_dim,  # Output dimension for each iteration summary
            hidden_dim=agg_hidden_dim,
            num_heads=agg_heads,
            num_layers=agg_layers,
            dropout=dropout
        )

        # Final reward predictor processing the sequence of iteration summaries
        self.reward_predictor = AttentionToRewardEncoder(
            input_dim=agg_hidden_dim * num_iterations,  # Concatenate all iteration summaries
            hidden_dim=reward_hidden_dim,
            num_head=reward_heads,
            num_layers=reward_layers,
            dim_feedforward=reward_ff_dim,
            dropout=dropout
        )

    def forward(self, attention_graph_batches):
        """
        Parameters
        ----------
        attention_graph_batches : list of lists of lists
            [iteration][layer][head] -> each element is a torch_geometric Batch

        Returns
        -------
        Tensor
            [batch_size, 1] reward predictions
        """

        iteration_embeddings = []
        for iteration_layers in attention_graph_batches:
            
            layer_embeddings = []
            for layer_idx, layer_heads in enumerate(iteration_layers):
                
                head_embeddings = []
                for head_graph in layer_heads:
                    # Shared GNN across all
                    gnn_out = self.gnn(head_graph)  # [batch_size, gnn_embedding_dim]
                    head_embeddings.append(gnn_out)

                # Concat all heads for this layer
                head_concat = torch.cat(head_embeddings, dim=-1)  # [batch_size, gnn_embedding_dim * num_heads]

                # Compress layer representation with layer-specific compressor
                compressed = self.compressors[layer_idx](head_concat)  # [batch_size, compression_dim]
                layer_embeddings.append(compressed)  # keep sequence dimension

            # Sequence of compressed layers for this iteration
            layer_seq = torch.cat(layer_embeddings, dim=1)  # [batch_size, num_layers, compression_dim]
            layer_seq_flat = layer_seq.view(layer_seq.size(0), -1)  # [batch_size, num_layers * compression_dim]
            iter_encoded = self.aggregation_encoder(layer_seq_flat.unsqueeze(1))  # [batch_size, 1, agg_hidden_dim]

            # Pool over layers (mean) to get single iteration summary
            iter_summary = iter_encoded.mean(dim=1).unsqueeze(1)  # [batch_size, 1, agg_hidden_dim]
            iteration_embeddings.append(iter_summary)

        
        # Stack all iteration summaries into a sequence
        iteration_seq = torch.cat(iteration_embeddings, dim=1)  # [batch_size, num_iterations, agg_hidden_dim]
        iteration_seq_flat = iteration_seq.view(iteration_seq.size(0), -1)  # [batch_size, num_iterations * agg_hidden_dim]
        # Final reward prediction
        reward = self.reward_predictor(iteration_seq_flat.unsqueeze(1))  # [batch_size, 1, input_dim]
        return reward

In [10]:
class FullAdversarialAlignmentModel_Padding(nn.Module):
    def __init__(
        self,
        num_iterations,
        num_layers,
        num_heads,

        gnn_hidden_dim,
        gnn_embedding_dim,

        compression_hidden_dim,
        compression_dim,

        agg_hidden_dim,
        agg_heads,
        agg_layers,
        
        reward_hidden_dim,
        reward_heads,
        reward_layers,
        reward_ff_dim,
        dropout=0.1
    ):
        super(FullAdversarialAlignmentModel_Padding, self).__init__()

        # needed variables
        self.num_iterations = num_iterations # used for padding

        # Shared GNN across all heads, layers, iterations
        self.gnn = AggregationNetwork(
            hidden_dim=gnn_hidden_dim,
            embedding_dim=gnn_embedding_dim,
            dropout=dropout
        )

        # Separate compression for each layer (compress across heads)
        self.compressors = nn.ModuleList([
            CompressionNetwork(
                input_dim=gnn_embedding_dim * num_heads,
                hidden_dim=compression_hidden_dim,
                compressed_dim=compression_dim,
                dropout=dropout
            ) for _ in range(num_layers)
        ])

        # Shared AggregationEncoder for linking compressed layers inside each iteration
        self.aggregation_encoder = AggregationEncoderTransformer(
            input_dim=compression_dim * num_layers,  # Concatenate all compressed layers
            output_dim=agg_hidden_dim,  # Output dimension for each iteration summary
            hidden_dim=agg_hidden_dim,
            num_heads=agg_heads,
            num_layers=agg_layers,
            dropout=dropout
        )

        # Final reward predictor processing the sequence of iteration summaries
        self.reward_predictor = AttentionToRewardEncoder(
            input_dim=agg_hidden_dim * num_iterations,  # Concatenate all iteration summaries
            hidden_dim=reward_hidden_dim,
            num_head=reward_heads,
            num_layers=reward_layers,
            dim_feedforward=reward_ff_dim,
            dropout=dropout
        )

    def forward(self, attention_graph_batches):
        """
        Parameters
        ----------
        attention_graph_batches : list of lists of lists
            [iteration][layer][head] -> each element is a torch_geometric Batch

        Returns
        -------
        Tensor
            [batch_size, 1] reward predictions
        """

        iteration_embeddings = []
        for iteration_layers in attention_graph_batches:
            layer_embeddings = []
            for layer_idx, layer_heads in enumerate(iteration_layers):
                
                head_embeddings = []
                for head_graph in layer_heads:
                    # Shared GNN across all
                    gnn_out = self.gnn(head_graph)  # [batch_size, gnn_embedding_dim]
                    head_embeddings.append(gnn_out)

                # Concat all heads for this layer
                head_concat = torch.cat(head_embeddings, dim=-1)  # [batch_size, gnn_embedding_dim * num_heads]
                #print("head_concat shape:", head_concat.shape)
                # Compress layer representation with layer-specific compressor
                compressed = self.compressors[layer_idx](head_concat)  # [batch_size, compression_dim]
                #print("compressed shape:", compressed.shape)
                layer_embeddings.append(compressed)  # keep sequence dimension

            # Sequence of compressed layers for this iteration
            #print("layer_embeddings[0] shape:", layer_embeddings[0].shape)
            #print("layer_embeddings length:", len(layer_embeddings))
            layer_seq = torch.stack(layer_embeddings, dim=0)  # [batch_size, num_layers, compression_dim]
            #print("layer_seq shape:", layer_seq.shape)
            layer_seq_flat = layer_seq.view(layer_seq.size(0), -1)  # [batch_size, num_layers * compression_dim]
            #print("layer_seq_flat shape:", layer_seq_flat.shape)
            iter_encoded = self.aggregation_encoder(layer_seq_flat.unsqueeze(1))  # [batch_size, 1, agg_hidden_dim]

            # Pool over layers (mean) to get single iteration summary
            iter_summary = iter_encoded.mean(dim=1).unsqueeze(1)  # [batch_size, 1, agg_hidden_dim]
            iteration_embeddings.append(iter_summary)

        # Padding
        dummy = torch.zeros((1, 1, iteration_embeddings[0].size(2)), device=iteration_embeddings[0].device)
        for i in range(len(iteration_embeddings)):
            num_missing = self.num_iterations - iteration_embeddings[i].size(0)
            #print("num_missing:", num_missing)
            if num_missing > 0:
                # Append num_missing copies to the tensor
                pad = dummy.repeat(num_missing, 1, 1)
                #print("pad shape:", pad.shape)
                iteration_embeddings[i] = torch.cat([pad, iteration_embeddings[i]], dim=0)
            
        # for i, iteration_embedding in enumerate(iteration_embeddings):
        #     print("Iteration_embeddings shape:", i, iteration_embedding.shape)
        
        # Stack all iteration summaries into a sequence
        #print("iteration_embeddings length:", len(iteration_embeddings))
        #print("iteration_embeddings[0] shape:", iteration_embeddings[0].shape)
        iteration_seq = torch.cat(iteration_embeddings, dim=1)  # [batch_size, num_iterations, agg_hidden_dim]
        #print("iteration_seq shape:", iteration_seq.shape)
        iteration_seq_flat = iteration_seq.view(iteration_seq.size(1), -1)  # [batch_size, num_iterations * agg_hidden_dim]
        #print("iteration_seq_flat shape:", iteration_seq_flat.shape)
        # Final reward prediction
        reward = self.reward_predictor(iteration_seq_flat.unsqueeze(1))  # [batch_size, 1, input_dim]
        return reward

In [14]:
import torch
from torch_geometric.data import Data, Batch

# === Dummy data generator ===
def generate_dummy_graph_batch(batch_size, num_nodes, feature_dim=1):
    # Create multiple small graphs in a single batch
    data_list = []
    for _ in range(batch_size):
        x = torch.rand((num_nodes, feature_dim))  # Node features
        edge_index = torch.randint(0, num_nodes, (2, num_nodes * 2))  # Random edges
        data_list.append(Data(x=x, edge_index=edge_index))
    return Batch.from_data_list(data_list)

# === Build dummy inputs matching your architecture ===
batch_size = 10
num_iterations = 2
num_layers = 2
num_heads = 2

# Build attention_graph_batches[iteration][layer][head]
attention_graph_batches = []
for _ in range(num_iterations):
    iteration_layers = []
    for _ in range(num_layers):
        layer_heads = []
        for _ in range(num_heads):
            head_batch = generate_dummy_graph_batch(batch_size, num_nodes=5)
            layer_heads.append(head_batch)
        iteration_layers.append(layer_heads)
    attention_graph_batches.append(iteration_layers)

# === Build the model ===
model = FullAdversarialAlignmentModel_Padding(
    num_iterations=5,
    num_layers=num_layers,
    num_heads=num_heads,
    gnn_hidden_dim=16,
    gnn_embedding_dim=8,
    compression_hidden_dim=32,
    compression_dim=4,
    agg_hidden_dim=4,
    agg_heads=2,
    agg_layers=2,
    reward_hidden_dim=24,
    reward_heads=2,
    reward_layers=2,
    reward_ff_dim=64,
    dropout=0.1
)

# === Forward pass ===
reward = model(attention_graph_batches)
print(f"\nFinal reward prediction shape: {reward.shape}")  # Expect [batch_size, 1]




RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x40 and 8x4)

In [17]:
for batch in attention_loader:
    attentions, rewards = batch
    pyg_graph = attentions[0][0][0][0]  # Get the first graph
    print(pyg_graph)
    print("Node features (x):", pyg_graph.x)
    print("Edge index:", pyg_graph.edge_index)
    print("Edge attributes:", pyg_graph.edge_attr)
    print("Label (y):", getattr(pyg_graph, 'y', None))
    print("Number of nodes:", pyg_graph.num_nodes)
    print("Number of edges:", pyg_graph.num_edges)

    if pyg_graph.x is not None:
        print("x shape:", pyg_graph.x.shape)
    print("edge_index shape:", pyg_graph.edge_index.shape)

    break  # Just to test the first graph

Data(edge_index=[2, 6], x=[4, 1], edge_attr=[6, 1])
Node features (x): tensor([[1.0000],
        [0.0703],
        [0.0172],
        [0.0466]])
Edge index: tensor([[1, 2, 2, 3, 3, 3],
        [0, 0, 1, 0, 1, 2]])
Edge attributes: tensor([[0.9297],
        [0.7762],
        [0.2066],
        [0.4831],
        [0.3801],
        [0.0902]])
Label (y): None
Number of nodes: 4
Number of edges: 6
x shape: torch.Size([4, 1])
edge_index shape: torch.Size([2, 6])


In [None]:
from torch_geometric.data import Batch

# === Build the model ===
model = FullAdversarialAlignmentModel_Padding(
    num_iterations=5,
    num_layers=12,
    num_heads=12,
    gnn_hidden_dim=16,
    gnn_embedding_dim=8,
    compression_hidden_dim=32,
    compression_dim=4,
    agg_hidden_dim=4,
    agg_heads=2,
    agg_layers=2,
    reward_hidden_dim=24,
    reward_heads=2,
    reward_layers=2,
    reward_ff_dim=64,
    dropout=0.1
)

# Get one batch from the attention_loader
for batch in attention_loader:
    attentions, rewards = batch  # attentions: [batch, iteration, layer, head, graph]
    # print("Batch_size", len(attentions))  # Check the number of attention graphs in the batch
    # print("num_iteration:", len(attentions[0]))
    # print("num_layers:", len(attentions[0][0]))
    # print("num_heads:", len(attentions[0][0][0]))
    # Forward pass through the model 
    reward_pred = model(attentions)  # Output: [batch_size, 1]
    print("Predicted reward shape:", reward_pred.shape)
    print("Predicted reward:", reward_pred.squeeze().detach().cpu().numpy())
    print("True reward:", rewards.detach().cpu().numpy())

iteration_embeddings length: 32
iteration_embeddings[0] shape: torch.Size([5, 1, 4])
iteration_seq shape: torch.Size([5, 32, 4])
iteration_seq_flat shape: torch.Size([32, 20])
Predicted reward shape: torch.Size([32, 1])
Predicted reward: [-0.29744077 -0.34963167 -0.52768946 -0.43813324 -0.33486015 -0.6731272
 -0.73075515 -0.7117985  -0.7345962  -0.60961884 -0.5868688  -0.28745365
 -0.49931788 -0.6043281  -0.6093149  -0.63385195 -0.48177934 -0.47170025
 -0.6744481  -0.44590962 -0.8785878  -0.6509539  -0.76292455 -0.603332
 -0.61975515 -0.75231326 -0.87337506 -0.6542589  -0.49191695 -0.84706193
 -0.72784865 -0.9169785 ]
True reward: [-3.7376697e+00 -6.8341088e+00  1.3815511e+01 -6.9314671e-01
 -8.5149908e+00 -6.1158919e+00  1.3815511e+01 -4.1588831e+00
 -3.8712010e+00  1.3815511e+01 -4.4426513e+00 -6.1136823e+00
  1.3815511e+01 -1.7917596e+00 -5.8051348e+00 -4.7874918e+00
 -2.9957323e+00 -5.7037826e+00 -4.3820267e+00 -6.8330317e+00
 -3.5263605e+00 -2.3025851e+00 -3.4011974e+00  1.3815511

KeyboardInterrupt: 

In [11]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm

# == Assume you already have these prepared ==
# dataset = AttentionDataset(...)
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=your_collate_fn)

# Your model
model = FullAdversarialAlignmentModel_Padding(
    num_iterations=5,
    num_layers=12,
    num_heads=12,
    gnn_hidden_dim=16,
    gnn_embedding_dim=8,
    compression_hidden_dim=32,
    compression_dim=4,
    agg_hidden_dim=4,
    agg_heads=2,
    agg_layers=2,
    reward_hidden_dim=24,
    reward_heads=2,
    reward_layers=2,
    reward_ff_dim=64,
    dropout=0.1
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Loss & optimizer
criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=1e-4)

# === Optionally start wandb ===
import wandb
wandb.init(project="adversarial-alignment")

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for batch in tqdm(attention_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        attention_batches, rewards = batch

        # Move rewards to device
        rewards = rewards.to(device).float()

        # Move graphs to device
        for iteration_layers in attention_batches:
            for layer_heads in iteration_layers:
                for i, head_graph in enumerate(layer_heads):
                    layer_heads[i] = head_graph.to(device)

        # Forward pass
        predictions = model(attention_batches).squeeze()

        # Loss
        loss = criterion(predictions, rewards)
        epoch_loss += loss.item()

        # Backward + optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # === Optionally log to wandb ===
        
        wandb.log({
                    "loss": loss.item(),
                    "predictions_mean": predictions.mean().item(),
                    "predictions_std": predictions.std().item(),
                    "targets_mean": rewards.mean().item(),
                    "targets_std": rewards.std().item()
                })
        
        total_norm = 0
        for p in model.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        wandb.log({"grad_norm": total_norm})

    avg_loss = epoch_loss / len(attention_loader)
    wandb.log({"epoch_avg_loss": avg_loss})
    print(f"Epoch {epoch+1} - Avg Loss: {avg_loss:.6f}")

# === Save final model ===
torch.save(model.state_dict(), "alignment_model_final.pt")

wandb: Currently logged in as: leonhard_waibl to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


Epoch 1/10: 100%|██████████| 4/4 [00:42<00:00, 10.60s/it]


Epoch 1 - Avg Loss: 51.752468


Epoch 2/10: 100%|██████████| 4/4 [00:42<00:00, 10.59s/it]


Epoch 2 - Avg Loss: 52.718700


Epoch 3/10: 100%|██████████| 4/4 [00:41<00:00, 10.29s/it]


Epoch 3 - Avg Loss: 52.193806


Epoch 4/10: 100%|██████████| 4/4 [00:39<00:00,  9.83s/it]


Epoch 4 - Avg Loss: 51.140471


Epoch 5/10: 100%|██████████| 4/4 [00:39<00:00,  9.81s/it]


Epoch 5 - Avg Loss: 51.110319


Epoch 6/10: 100%|██████████| 4/4 [00:38<00:00,  9.57s/it]


Epoch 6 - Avg Loss: 51.706482


Epoch 7/10: 100%|██████████| 4/4 [00:38<00:00,  9.59s/it]


Epoch 7 - Avg Loss: 51.479305


Epoch 8/10: 100%|██████████| 4/4 [00:38<00:00,  9.65s/it]


Epoch 8 - Avg Loss: 51.272855


Epoch 9/10: 100%|██████████| 4/4 [00:38<00:00,  9.61s/it]


Epoch 9 - Avg Loss: 51.916001


Epoch 10/10: 100%|██████████| 4/4 [00:35<00:00,  8.79s/it]


Epoch 10 - Avg Loss: 52.546176


In [None]:
def train():
    with wandb.init() as run:
        config = wandb.config

        model = FullAdversarialAlignmentModel_Padding(
            num_iterations=5,
            num_layers=12,
            num_heads=12,
            gnn_hidden_dim=16,
            gnn_embedding_dim=8,
            compression_hidden_dim=32,
            compression_dim=config.compression_dim,
            agg_hidden_dim=config.agg_hidden_dim,
            agg_heads=2,
            agg_layers=2,
            reward_hidden_dim=24,
            reward_heads=2,
            reward_layers=2,
            reward_ff_dim=64,
            dropout=config.dropout
        ).to(device)

        print("Starting run with config:", dict(config))

        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

        for epoch in range(num_epochs):
            model.train()
            epoch_loss = 0
            for batch in tqdm(attention_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
                attention_batches, rewards = batch

                # Move rewards to device
                rewards = rewards.to(device).float()

                # Move graphs to device
                for iteration_layers in attention_batches:
                    for layer_heads in iteration_layers:
                        for i, head_graph in enumerate(layer_heads):
                            layer_heads[i] = head_graph.to(device)

                # Forward pass
                predictions = model(attention_batches).squeeze()

                # Loss
                loss = criterion(predictions, rewards)
                epoch_loss += loss.item()

                # Backward + optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # === Optionally log to wandb ===
                total_norm = 0
                for p in model.parameters():
                    if p.grad is not None:
                        param_norm = p.grad.data.norm(2)
                        total_norm += param_norm.item() ** 2
                total_norm = total_norm ** 0.5

                wandb.log({
                            "loss": loss.item(),
                            "predictions_mean": predictions.mean().detach().item(),
                            "predictions_std": predictions.std().detach().item(),
                            "targets_mean": rewards.mean().detach().item(),
                            "targets_std": rewards.std().detach().item(),
                            "grad_norm": total_norm,
                        })
                
                
            avg_loss = epoch_loss / len(attention_loader)
            wandb.log({"epoch_avg_loss": avg_loss})
            print(f"Epoch {epoch+1} - Avg Loss: {avg_loss:.6f}")

# Define sweep
sweep_config = {
    'method': 'bayes',
    'metric': {'name': 'epoch_avg_loss', 'goal': 'minimize'},
    'parameters': {
        'lr': {'values': [1e-5, 1e-4, 1e-3]},
        'dropout': {'values': [0.1, 0.2, 0.3]},
        'compression_dim': {'values': [4, 8, 16]},
        'agg_hidden_dim': {'values': [4, 8, 16]}
    }
}

sweep_id = wandb.sweep(sweep_config, project="adversarial-alignment")
wandb.agent(sweep_id, function=train, count=20)

Create sweep with ID: mfr95nkg
Sweep URL: https://wandb.ai/leonhard_waibl/adversarial-alignment/sweeps/mfr95nkg


wandb: Agent Starting Run: w2yku2ru with config:
wandb: 	agg_hidden_dim: 8
wandb: 	compression_dim: 4
wandb: 	dropout: 0.3
wandb: 	lr: 0.0001




Starting run with config: {'agg_hidden_dim': 8, 'compression_dim': 4, 'dropout': 0.3, 'lr': 0.0001}


Epoch 1/10: 100%|██████████| 4/4 [00:36<00:00,  9.24s/it]


Epoch 1 - Avg Loss: 53.772911


Epoch 2/10: 100%|██████████| 4/4 [00:38<00:00,  9.65s/it]


Epoch 2 - Avg Loss: 52.901051


Epoch 3/10: 100%|██████████| 4/4 [00:38<00:00,  9.58s/it]


Epoch 3 - Avg Loss: 54.353225


Epoch 4/10: 100%|██████████| 4/4 [00:37<00:00,  9.42s/it]


Epoch 4 - Avg Loss: 53.483410


Epoch 5/10: 100%|██████████| 4/4 [00:38<00:00,  9.68s/it]


Epoch 5 - Avg Loss: 53.484097


Epoch 6/10: 100%|██████████| 4/4 [00:36<00:00,  9.18s/it]


Epoch 6 - Avg Loss: 51.977777


Epoch 7/10: 100%|██████████| 4/4 [00:36<00:00,  9.20s/it]


Epoch 7 - Avg Loss: 53.398236


Epoch 8/10: 100%|██████████| 4/4 [00:37<00:00,  9.31s/it]


Epoch 8 - Avg Loss: 53.467681


Epoch 9/10: 100%|██████████| 4/4 [00:39<00:00,  9.97s/it]


Epoch 9 - Avg Loss: 53.453395


Epoch 10/10: 100%|██████████| 4/4 [00:39<00:00,  9.95s/it]


Epoch 10 - Avg Loss: 53.933036


0,1
epoch_avg_loss,▆▄█▅▅▁▅▅▅▇
grad_norm,▂▂▃▇▂▂▂▅▂▁▃▆▁▁▂█▂▁▂█▁▁▂▅▂▁▂▆▁▁▂▅▁▁▂▅▁▁▂▆
loss,▄▄▂▇▅▃▁▆▄▃▁█▄▃▁▇▄▃▁▇▄▃▁▆▄▃▁▇▄▃▁▇▄▃▁▇▄▃▁█
predictions_mean,▇▇▇▅▇█▆▆▆▆▄▇▅▆▄▇▅▆▄▂▅▄▄▅▅▄▅▄▅▃▃▅▃▄▃▄▄▄▁▄
predictions_std,▂▄▄▂▄▄▄▁▄▄▄█▄▄▃▁▃▄▄▄▃▄▄▃▄▄▅▆▄▅▅▂▄▄▃▆▄▄▃▆
targets_mean,▅▇▁█▅▇▁█▅▇▁█▅▇▁█▅▇▁█▅▇▁█▅▇▁█▅▇▁█▅▇▁█▅▇▁█
targets_std,▄▃▁█▄▃▁█▄▃▁█▄▃▁█▄▃▁█▄▃▁█▄▃▁█▄▃▁█▄▃▁█▄▃▁█

0,1
epoch_avg_loss,53.93304
grad_norm,60.34888
loss,72.99413
predictions_mean,0.18906
predictions_std,0.46809
targets_mean,-0.26792
targets_std,9.42233


wandb: Agent Starting Run: 89nwh9oj with config:
wandb: 	agg_hidden_dim: 4
wandb: 	compression_dim: 16
wandb: 	dropout: 0.1
wandb: 	lr: 0.001




Starting run with config: {'agg_hidden_dim': 4, 'compression_dim': 16, 'dropout': 0.1, 'lr': 0.001}


Epoch 1/10: 100%|██████████| 4/4 [00:40<00:00, 10.16s/it]


Epoch 1 - Avg Loss: 52.401592


Epoch 2/10: 100%|██████████| 4/4 [00:35<00:00,  8.99s/it]


Epoch 2 - Avg Loss: 51.680532


Epoch 3/10: 100%|██████████| 4/4 [00:39<00:00,  9.94s/it]


Epoch 3 - Avg Loss: 50.568548


Epoch 4/10: 100%|██████████| 4/4 [00:42<00:00, 10.59s/it]


Epoch 4 - Avg Loss: 52.925820


Epoch 5/10: 100%|██████████| 4/4 [00:41<00:00, 10.28s/it]


Epoch 5 - Avg Loss: 52.702439


Epoch 6/10: 100%|██████████| 4/4 [00:40<00:00, 10.16s/it]


Epoch 6 - Avg Loss: 52.555544


Epoch 7/10: 100%|██████████| 4/4 [00:40<00:00, 10.19s/it]


Epoch 7 - Avg Loss: 51.769683


Epoch 8/10: 100%|██████████| 4/4 [00:38<00:00,  9.69s/it]


Epoch 8 - Avg Loss: 51.795659


Epoch 9/10: 100%|██████████| 4/4 [00:39<00:00,  9.86s/it]


Epoch 9 - Avg Loss: 52.433807


Epoch 10/10: 100%|██████████| 4/4 [00:41<00:00, 10.38s/it]


Epoch 10 - Avg Loss: 51.500283


0,1
epoch_avg_loss,▆▄▁█▇▇▅▅▇▄
grad_norm,▃▂▄█▂▂▃▇▁▂▃▇▂▃▂▇▂▃▃▇▁▂▃▆▁▂▂▅▁▂▃▄▂▂▃▆▁▁▃▅
loss,▅▄▁█▄▄▁▇▅▄▁▆▅▄▁█▅▄▁█▅▄▁█▄▄▁▇▅▄▁▇▅▄▁█▄▄▁▇
predictions_mean,█▆▆▅▄▄▄▃▃▃▃▄▃▃▂▃▂▂▃▁▂▃▂▂▂▃▂▂▂▂▃▃▃▃▄▂▃▃▃▄
predictions_std,▃▃▂▄▂▃▂▂▂▃▃▆▃▃▂█▃▃▄▃▃▃▃▂▃▃▃▄▂▂▂▁▃▃▃▂▂▃▃▄
targets_mean,▅▇▁█▅▇▁█▅▇▁█▅▇▁█▅▇▁█▅▇▁█▅▇▁█▅▇▁█▅▇▁█▅▇▁█
targets_std,▄▃▁█▄▃▁█▄▃▁█▄▃▁█▄▃▁█▄▃▁█▄▃▁█▄▃▁█▄▃▁█▄▃▁█

0,1
epoch_avg_loss,51.50028
grad_norm,21.93045
loss,65.01209
predictions_mean,-0.73879
predictions_std,0.25359
targets_mean,-0.26792
targets_std,9.42233


wandb: Agent Starting Run: ciwkxekx with config:
wandb: 	agg_hidden_dim: 16
wandb: 	compression_dim: 16
wandb: 	dropout: 0.3
wandb: 	lr: 1e-05




Starting run with config: {'agg_hidden_dim': 16, 'compression_dim': 16, 'dropout': 0.3, 'lr': 1e-05}


Epoch 1/10: 100%|██████████| 4/4 [00:39<00:00,  9.89s/it]


Epoch 1 - Avg Loss: 52.564932


Epoch 2/10: 100%|██████████| 4/4 [00:37<00:00,  9.43s/it]


Epoch 2 - Avg Loss: 52.149525


Epoch 3/10: 100%|██████████| 4/4 [00:34<00:00,  8.69s/it]


Epoch 3 - Avg Loss: 52.673622


Epoch 4/10: 100%|██████████| 4/4 [00:34<00:00,  8.71s/it]


Epoch 4 - Avg Loss: 52.643154


Epoch 5/10: 100%|██████████| 4/4 [00:34<00:00,  8.53s/it]


Epoch 5 - Avg Loss: 53.251510


Epoch 6/10: 100%|██████████| 4/4 [00:34<00:00,  8.71s/it]


Epoch 6 - Avg Loss: 52.963539


Epoch 7/10: 100%|██████████| 4/4 [00:34<00:00,  8.74s/it]


Epoch 7 - Avg Loss: 51.921824


Epoch 8/10:  25%|██▌       | 1/4 [00:11<00:33, 11.20s/it]