In [3]:
import torch
import torch.nn as nn
import torch_geometric
from torch_geometric.nn import GCNConv, global_mean_pool

class GraphKoopmanEncoder(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim=64):
        super(GraphKoopmanEncoder, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.linear = nn.Linear(hidden_dim, embedding_dim)
        
    def forward(self, x, edge_index, batch):
        # Graph convolutions
        h = torch.relu(self.conv1(x, edge_index))
        h = torch.relu(self.conv2(h, edge_index))
        # Global pooling
        h = global_mean_pool(h, batch)
        # Project to embedding space
        return self.linear(h)

class GraphKoopmanDecoder(nn.Module):
    def __init__(self, embedding_dim, output_dim, hidden_dim=64):
        super(GraphKoopmanDecoder, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, g):
        return self.net(g)

class KoopmanGraphLoss(nn.Module):
    def __init__(self, lambda_pred=1.0, lambda_metric=1.0):
        super(KoopmanGraphLoss, self).__init__()
        self.lambda_pred = lambda_pred
        self.lambda_metric = lambda_metric
        
    def forward(self, model, x_seq, u_seq, edge_index, batch):
        """
        Calculate all three losses for the Koopman GNN model
        
        Parameters:
        - x_seq: Sequence of states [batch_size, seq_len, state_dim]
        - u_seq: Sequence of controls [batch_size, seq_len-1, control_dim]
        - edge_index: Graph connectivity
        - batch: Batch indices for graphs
        
        Returns:
        - total_loss: Combined loss value
        - loss_dict: Dictionary containing individual loss components
        """
        T = x_seq.size(1)
        batch_size = x_seq.size(0)
        
        # 1. Auto-encoding Loss (Lae)
        # Encode and decode each state
        embeddings = []
        reconstructions = []
        for t in range(T):
            g_t = model.encoder(x_seq[:, t], edge_index, batch)
            x_recon_t = model.decoder(g_t)
            embeddings.append(g_t)
            reconstructions.append(x_recon_t)
            
        embeddings = torch.stack(embeddings, dim=1)  # [batch_size, T, embedding_dim]
        reconstructions = torch.stack(reconstructions, dim=1)  # [batch_size, T, state_dim]
        
        Lae = torch.mean(torch.norm(reconstructions - x_seq, dim=-1))
        
        # 2. Prediction Loss (Lpred)
        # Rollout in Koopman space
        g_hat = embeddings[:, 0]  # Initial embedding
        predicted_states = [model.decoder(g_hat)]
        
        for t in range(T-1):
            # Apply Koopman operator: g_hat_next = K*g_hat + L*u
            g_hat = model.koopman(g_hat) + torch.mm(u_seq[:, t], model.control_matrix)
            predicted_states.append(model.decoder(g_hat))
            
        predicted_states = torch.stack(predicted_states, dim=1)
        Lpred = torch.mean(torch.norm(predicted_states - x_seq, dim=-1))
        
        # 3. Metric Loss (Lmetric)
        # Compute pairwise distances in both spaces
        def compute_pairwise_distances(x):
            n = x.size(0)
            square = torch.sum(x**2, dim=-1, keepdim=True)
            distances = square - 2 * torch.matmul(x, x.transpose(-2, -1)) + square.transpose(-2, -1)
            return torch.sqrt(torch.clamp(distances, min=1e-12))
        
        Lmetric = 0
        for t in range(T):
            # Distances in state space
            state_distances = compute_pairwise_distances(x_seq[:, t])
            # Distances in Koopman space
            koopman_distances = compute_pairwise_distances(embeddings[:, t])
            # Compute metric loss
            Lmetric += torch.mean(torch.abs(koopman_distances - state_distances))
        
        Lmetric = Lmetric / T
        
        # Combine losses
        total_loss = Lae + self.lambda_pred * Lpred + self.lambda_metric * Lmetric
        
        loss_dict = {
            'total_loss': total_loss.item(),
            'autoencoding_loss': Lae.item(),
            'prediction_loss': Lpred.item(),
            'metric_loss': Lmetric.item()
        }
        
        return total_loss, loss_dict

def train_graph_koopman_model(model, train_loader, epochs=100, device='cpu'):
    """Train the Graph Koopman model with all three losses"""
    criterion = KoopmanGraphLoss(lambda_pred=1.0, lambda_metric=1.0)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(epochs):
        model.train()
        epoch_losses = []
        
        for batch_idx, (x_seq, u_seq, edge_index, batch) in enumerate(train_loader):
            x_seq = x_seq.to(device)
            u_seq = u_seq.to(device)
            edge_index = edge_index.to(device)
            batch = batch.to(device)
            
            # Calculate all losses
            total_loss, loss_components = criterion(model, x_seq, u_seq, edge_index, batch)
            
            # Backward pass
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            
            epoch_losses.append(loss_components)
        
        # Print epoch statistics
        if (epoch + 1) % 10 == 0:
            avg_losses = {k: sum(d[k] for d in epoch_losses) / len(epoch_losses)
                         for k in epoch_losses[0].keys()}
            print(f'Epoch [{epoch+1}/{epochs}]')
            for loss_name, loss_value in avg_losses.items():
                print(f'{loss_name}: {loss_value:.4f}')
            print('-' * 50)
    
    return model


class LorenzSystem:
    def __init__(self, sigma=10.0, rho=28.0, beta=8/3):
        self.sigma = sigma
        self.rho = rho
        self.beta = beta
    
    def derivatives(self, state):
        """Calculate Lorenz system derivatives"""
        x, y, z = state[..., 0], state[..., 1], state[..., 2]
        dx = self.sigma * (y - x)
        dy = x * (self.rho - z) - y
        dz = x * y - self.beta * z
        return torch.stack([dx, dy, dz], dim=-1)
    
    def step(self, state, dt=0.01):
        """Perform one step of RK4 integration"""
        k1 = self.derivatives(state)
        k2 = self.derivatives(state + dt * k1/2)
        k3 = self.derivatives(state + dt * k2/2)
        k4 = self.derivatives(state + dt * k3)
        return state + dt * (k1 + 2*k2 + 2*k3 + k4) / 6

def generate_lorenz_data(batch_size, seq_length, dt=0.01, device='cpu'):
    """Generate batch of Lorenz trajectories"""
    lorenz = LorenzSystem()
    
    # Random initial conditions
    states = torch.randn(batch_size, 3).to(device) * 0.1
    
    # Generate trajectories
    trajectories = [states]
    for t in range(seq_length - 1):
        states = lorenz.step(states, dt)
        trajectories.append(states)
    
    trajectories = torch.stack(trajectories, dim=1)  # [batch_size, seq_length, 3]
    
    # Create dummy control inputs (zero for autonomous system)
    controls = torch.zeros(batch_size, seq_length-1, 1).to(device)
    
    return trajectories, controls

class LorenzGraphDataset:
    def __init__(self, num_samples, seq_length, dt=0.01, num_nodes=10):
        self.num_samples = num_samples
        self.seq_length = seq_length
        self.dt = dt
        self.num_nodes = num_nodes
        
        # Create fixed graph structure (fully connected)
        self.edge_index = torch.combinations(torch.arange(num_nodes), r=2).t()
        # Add reverse edges
        self.edge_index = torch.cat([self.edge_index, self.edge_index.flip(0)], dim=1)
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        # Generate single Lorenz trajectory
        trajectory, controls = generate_lorenz_data(1, self.seq_length, self.dt)
        
        # Replicate the trajectory for each node (as node features)
        x_seq = trajectory[0].unsqueeze(1).repeat(1, self.num_nodes, 1)
        u_seq = controls[0]
        
        return x_seq, u_seq, self.edge_index

def create_lorenz_dataloader(dataset, batch_size):
    def collate_fn(batch):
        x_seqs, u_seqs, edge_indices = zip(*batch)
        
        # Stack sequences
        x_seqs = torch.stack(x_seqs)  # [batch_size, seq_len, num_nodes, 3]
        u_seqs = torch.stack(u_seqs)  # [batch_size, seq_len-1, control_dim]
        
        # Create batched graph
        batch_edge_index = []
        batch_indices = []
        offset = 0
        for i, edge_index in enumerate(edge_indices):
            batch_edge_index.append(edge_index + offset)
            batch_indices.extend([i] * dataset.num_nodes)
            offset += dataset.num_nodes
            
        batch_edge_index = torch.cat(batch_edge_index, dim=1)
        batch_indices = torch.tensor(batch_indices)
        
        return x_seqs, u_seqs, batch_edge_index, batch_indices
    
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=True
    )

class GraphKoopmanModel(nn.Module):
    def __init__(self, input_dim, embedding_dim, control_dim=1):
        super().__init__()
        self.encoder = GraphKoopmanEncoder(input_dim, embedding_dim)
        self.decoder = GraphKoopmanDecoder(embedding_dim, input_dim)
        self.koopman = nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.control_matrix = nn.Parameter(torch.randn(control_dim, embedding_dim))
        
def setup_lorenz_koopman(num_nodes=10, batch_size=32):
    input_dim = 3
    embedding_dim = 16
    model = GraphKoopmanModel(input_dim, embedding_dim)
    
    dataset = LorenzGraphDataset(
        num_samples=1000,
        seq_length=50,
        num_nodes=num_nodes
    )
    train_loader = create_lorenz_dataloader(dataset, batch_size=batch_size)
    
    return model, train_loader

if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model, train_loader = setup_lorenz_koopman()
    model = model.to(device)
    trained_model = train_graph_koopman_model(model, train_loader, device=device)

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
