In [14]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
torch.cuda.empty_cache()
print(f"Using device: {device}")

Using device: cuda


In [16]:
def simulate_lorenz(initial_state, steps = 500, dt= 0.01, sigma = 10.0, rho = 28.0, beta = 8.0/3.0):
    def lorenz_system(state):
        x, y, z = state
        dxdt = sigma * (y - x)
        dydt = x * (rho - z) - y
        dzdt = x * y - beta * z
        return np.array([dxdt, dydt, dzdt])
    state = np.array(initial_state)
    trajectory = []
    for _ in range(steps):
        trajectory.append(state.copy())
        k1 = lorenz_system(state)
        k2 = lorenz_system(state + 0.5 * dt * k1)
        k3 = lorenz_system(state + 0.5 * dt * k2)
        k4 = lorenz_system(state + dt * k3)
        state += (dt / 6.0) * (k1 + 2*k2 + 2*k3 + k4)
    return np.array(trajectory)
        

In [17]:
def create_graph_lorenz(trajectory):
    edge_index = []
    for i in range(len(trajectory)-1):
        edge_index.append([i, i+1])
    edge_index = torch.tensor(edge_index, dtype = torch.long).t().contiguous()
    x = torch.tensor(trajectory, dtype = torch.float)

    data = Data(x=x, edge_index = edge_index)
    return data

In [20]:
class GNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNN, self).__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(output_dim)

    def create_full_graph(self, x):
        n = x.size(0)
        edge_index = torch.stack([
            torch.repeat_interleave(torch.arange(n), n),
            torch.tile(torch.arange(n), (n,))
        ]).to(x.device)
        return edge_index
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = x.float()
        edge_index = self.create_full_graph(x)
        gnn_x = self.conv1(x, edge_index).relu()
        gnn_x = self.norm1(gnn_x)
        gnn_x = self.conv2(gnn_x, edge_index)
        gnn_x = self.norm2(gnn_x)

        fc_x = self.fc1(x).relu()
        fc_x = self.fc3(fc_x).relu()
        fc_x = self.fc3(fc_x)

        return (gnn_x+ fc_x)/2

In [22]:
class KoopmanModel(torch.nn.Module):
    def __init__(self, input_dim, koopman_dim):
        super(KoopmanModel, self).__init__()
        self.encoder = GNN(input_dim, koopman_dim, koopman_dim)
        self.koopman_matrix = torch.nn.Parameter(torch.eye(koopman_dim))
        self.decoder = GNN(koopman_dim, koopman_dim, input_dim)

    def forward(self, data):
        koopman_space = self.encoder(data)
        next_koopman_space = koopman_space @ self.koopman_matrix
        new_data = Data(x=next_koopman_space, edge_index = data.edge_index)
        new_state = self.decoder(new_data)
        return new_state

In [24]:
def auto_encoding_loss(decoded, original_states):
    return F.mse_loss(decoded, original_states)

def prediction_loss(model, koopman_space, data):
    g_t = koopman_space
    T = len(data.x)
    total_pred_loss = 0
    for t in range(T):
        decoded_state = model.decoder(Data(x=g_t, edge_index = data.edge_index))
        total_pred_loss += F.mse_loss(decoded_state, data.x)
        if(t < T-1):
            g_t = g_t @ model.koopman_matrix
    
    return total_pred_loss / T

def metric_loss(koopman_space, original_space):
    distances_koopman = torch.cdist(koopman_space, koopman_space, p=2)
    distances_original = torch.cdist(original_space, original_space, p=2)
    return F.l1_loss(distances_koopman, distances_original)


def total_loss(model, data, lambda1=1.0, lambda2=1.0):
    koopman_space = model.encoder(data)
    decoded = model.decoder(Data(x=koopman_space, edge_index=data.edge_index))
    
    ae_loss = auto_encoding_loss(decoded, data.x)
    pred_loss = prediction_loss(model, koopman_space, data)
    m_loss = metric_loss(koopman_space, data.x)
    print(f"AE Loss: {ae_loss}, Prediction Loss {pred_loss}, Total Loss {m_loss}")

    return ae_loss + lambda1 * pred_loss + lambda2 * m_loss 

In [26]:
def train_model(model, dataset, epochs=10, lambda1=1.0, lambda2=1.0, batch_size=8):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)
    
    model = model.to(device)
    train_losses = []
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        batch_count = 0

        optimizer.zero_grad()
        for batch in loader:
            batch = batch.to(device)
            
            loss = total_loss(model, batch, lambda1=lambda1, lambda2=lambda2)
            if not torch.isnan(loss): 
                loss.backward()  
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  
                optimizer.step()

                epoch_loss += loss.item()
                batch_count += 1
        if batch_count > 0:
            avg_loss = epoch_loss / batch_count
            train_losses.append(avg_loss)
            print(f"Epoch {epoch + 1}, Loss: {avg_loss:.6f}")
            scheduler.step(avg_loss)  
        else:
            print(f"Warning: Epoch {epoch + 1} had no valid batches!")

    plt.figure(figsize=(10, 6))
    plt.plot(range(1, epochs + 1), train_losses)
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Training Loss Over Time")
    plt.grid(True)
    plt.show()


In [28]:
initial_state = [1.0, 0.0, 0.0]
lorenz_trajectory = simulate_lorenz(initial_state)
dataset = [create_graph_lorenz(lorenz_trajectory) for _ in range(100)]

In [30]:
model = KoopmanModel(input_dim = 3, koopman_dim = 3).to(device)

In [None]:
train_model(model, dataset, epochs = 10)