In [None]:
class SpacetimeEmbedder(nn.Module):
    def __init__(self, input_dim=4, hidden_dim=128, num_coupling_layers=6):
        super().__init__()
        self.inn_phi = INNPhi(input_dim, hidden_dim, num_coupling_layers)

    def forward(self, x, reverse=False):
        return self.inn_phi(x, reverse=reverse)

    def pullback_metric(self, x):
        return pullback_metric(x, self.inn_phi, eta_E=torch.diag(torch.tensor([-1.0, 1.0, 1.0, 1.0])))


In [None]:
class GeodesicPolicy(nn.Module):
    def __init__(self, state_dim, goal_dim, hidden_dim=256):
        super().__init__()
        self.subgoal_net = nn.Sequential(
            nn.Linear(state_dim + goal_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, state_dim)
        )
        self.action_net = nn.Sequential(
            nn.Linear(state_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, state_dim)
        )

    def forward(self, state, goal):
        subgoal = self.subgoal_net(torch.cat([state, goal], dim=-1))
        action = self.action_net(torch.cat([state, subgoal], dim=-1))
        return action, subgoal


In [None]:
class AttentionODE(nn.Module):
    def __init__(self, dim, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim + 1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim)
        )

    def forward(self, t, a):
        t_vec = t.expand(a.size(0), 1)
        return self.net(torch.cat([a, t_vec], dim=1))


In [None]:
def geodesic_loss(child, parent, margin=0.001, epsilon=1e-5):
    dt = child[:, 0] - parent[:, 0]
    dx2 = ((child[:, 1:] - parent[:, 1:]) ** 2).sum(dim=1)
    interval_sq = -(dt ** 2) + dx2

    causal_violation = torch.clamp(margin - dt, min=0) ** 2
    interval_violation = (interval_sq + epsilon).pow(2)
    return causal_violation.mean(), interval_violation.mean()


In [None]:
class GeodesicRLTrainer:
    def __init__(self, embedder, policy, ode_func, optimizer, gamma=0.99):
        self.embedder = embedder
        self.policy = policy
        self.ode_func = ode_func
        self.optimizer = optimizer
        self.gamma = gamma

    def train_step(self, batch):
        state, goal = batch["state"], batch["goal"]
        e_state = self.embedder(state)
        e_goal = self.embedder(goal)

        # Policy prediction
        action, subgoal = self.policy(e_state, e_goal)
        evolved = odeint(self.ode_func, e_state, torch.tensor([0.0, 1.0]).to(state.device))[-1]

        # Losses
        causal_loss, interval_loss = geodesic_loss(evolved, e_goal)
        imitation_loss = F.mse_loss(evolved, action.detach())
        total_loss = imitation_loss + causal_loss + interval_loss

        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()
        return {
            "total_loss": total_loss.item(),
            "causal_loss": causal_loss.item(),
            "interval_loss": interval_loss.item(),
            "imitation_loss": imitation_loss.item()
        }
