In [2]:
!pip -q install geoopt

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/90.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.1/90.1 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import geoopt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
np.random.seed(42)


def gen_lorenz(N=5000, dt=0.01, x0=(1.0, 1.0, 1.0), sigma=10.0, rho=28.0, beta=8 / 3):
    xs = np.zeros((N, 3), dtype=np.float32)
    x = np.array(x0, dtype=np.float32)
    for i in range(N):
        xs[i] = x

        def f(v):
            return np.array(
                [
                    sigma * (v[1] - v[0]),
                    v[0] * (rho - v[2]) - v[1],
                    v[0] * v[1] - beta * v[2],
                ],
                dtype=np.float32,
            )

        k1 = f(x)
        k2 = f(x + 0.5 * dt * k1)
        k3 = f(x + 0.5 * dt * k2)
        k4 = f(x + dt * k3)
        x = x + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6.0
    return xs


data = gen_lorenz(N=5000)
obs_dim = data.shape[1]


tau = 10
m = 8
embed_dim = m
L = 21
halfL = L // 2
eps_reg = 1e-3

start_idx = (m - 1) * tau + halfL
end_idx = data.shape[0] - halfL
T = end_idx - start_idx

X_embed = np.zeros((T, embed_dim), dtype=np.float32)
Y_embed = np.zeros((T, embed_dim), dtype=np.float32)

for idx, t in enumerate(range(start_idx, end_idx)):
    X_embed[idx] = [data[:, 0][t - j * tau] for j in range(m)]
    Y_embed[idx] = [data[:, 1][t - j * tau] for j in range(m)]

Cs1 = np.zeros((T, embed_dim, embed_dim), dtype=np.float32)
Cs2 = np.zeros((T, embed_dim, embed_dim), dtype=np.float32)
for i in range(T):
    win_idx = np.arange(max(0, i - halfL), min(T, i + halfL + 1))
    W = X_embed[win_idx]
    mu = W.mean(axis=0, keepdims=True)
    Wc = W - mu
    cov = (Wc.T @ Wc) / Wc.shape[0]
    cov += eps_reg * np.eye(embed_dim, dtype=np.float32)
    Cs1[i] = cov

for i in range(T):
    win_idx = np.arange(max(0, i - halfL), min(T, i + halfL + 1))
    W = Y_embed[win_idx]
    mu = W.mean(axis=0, keepdims=True)
    Wc = W - mu
    cov = (Wc.T @ Wc) / Wc.shape[0]
    cov += eps_reg * np.eye(embed_dim, dtype=np.float32)
    Cs2[i] = cov


class SPDPairDataset(Dataset):
    def __init__(self, Cs1, Cs2, eps=1e-3):
        self.Cs1 = torch.from_numpy(Cs1).float()
        self.Cs2 = torch.from_numpy(Cs2).float()

        self.eps = 1e-3

        diag_eps = (
            torch.eye(self.Cs1.shape[1], device=self.Cs1.device).unsqueeze(0) * self.eps
        )
        self.Cs1 += diag_eps
        self.Cs2 += diag_eps

    def __len__(self):
        return self.Cs1.shape[0]

    def __getitem__(self, idx):
        return self.Cs1[idx], self.Cs2[idx]


dataset = SPDPairDataset(Cs1, Cs2)
loader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)


class BiMap(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        if out_dim > in_dim:
            raise ValueError("For Stiefel manifold, out_dim <= in_dim")
        W = torch.randn(in_dim, out_dim)
        Q, R = torch.linalg.qr(W)
        self.W = geoopt.ManifoldParameter(Q * 0.01, manifold=geoopt.Stiefel())

    def forward(self, X):
        return self.W.T @ X @ self.W


class ReEig(nn.Module):
    def __init__(self, eps=1e-3):
        super().__init__()
        self.eps = eps

    def forward(self, X):

        eigvals, eigvecs = torch.linalg.eigh(X)
        eigvals = torch.clamp(eigvals, min=self.eps)
        return eigvecs @ torch.diag_embed(eigvals) @ eigvecs.transpose(-2, -1)


class LogEig(nn.Module):
    def forward(self, X):
        eigvals, eigvecs = torch.linalg.eigh(X)
        log_eigvals = torch.log(eigvals)
        return eigvecs @ torch.diag_embed(log_eigvals) @ eigvecs.transpose(-2, -1)


class ExpEig(nn.Module):
    def forward(self, X):
        eigvals, eigvecs = torch.linalg.eigh(X)
        exp_eigvals = torch.exp(eigvals)
        return eigvecs @ torch.diag_embed(exp_eigvals) @ eigvecs.transpose(-2, -1)


class SPDEncoder(nn.Module):
    def __init__(self, input_dim, k=6):
        super().__init__()
        self.bimap = BiMap(input_dim, k)
        self.reeig = ReEig()
        self.logeig = LogEig()
        self.k = k
        self.vech_dim = k * (k + 1) // 2

    def forward(self, X):
        X = self.bimap(X)
        X = self.reeig(X)
        X = self.logeig(X)
        idx = torch.triu_indices(self.k, self.k)
        vech = X[:, idx[0], idx[1]]
        return vech


class SPDDecoder(nn.Module):
    def __init__(self, k, output_dim):
        super().__init__()
        self.k = k
        self.output_dim = output_dim
        W = torch.randn(output_dim, k)
        Q, R = torch.linalg.qr(W)
        self.W = geoopt.ManifoldParameter(Q * 0.01, manifold=geoopt.Stiefel())
        self.reeig = ReEig()
        self.exp_eig = ExpEig()

    def forward(self, vech):
        """
        vech: [batch, k*(k+1)/2] - верхняя треугольная часть SPD матрицы после LogEig
        """
        batch_size = vech.shape[0]
        X = torch.zeros((batch_size, self.k, self.k), device=vech.device)
        idx = torch.triu_indices(self.k, self.k)
        X[:, idx[0], idx[1]] = vech
        X = (
            X
            + X.transpose(-2, -1)
            - torch.diag_embed(torch.diagonal(X, dim1=-2, dim2=-1))
        )
        X = self.exp_eig(X)
        X = self.W @ X @ self.W.T
        X = self.reeig(X)
        return X


mlp_link = nn.Sequential(nn.Linear(21, 21), nn.ReLU(), nn.Linear(21, 21)).to(device)

encoder1 = SPDEncoder(Cs1.shape[1]).to(device)
encoder2 = SPDEncoder(Cs2.shape[1]).to(device)
decoder1 = SPDDecoder(k=6, output_dim=Cs1.shape[1]).to(device)
decoder2 = SPDDecoder(k=6, output_dim=Cs2.shape[1]).to(device)

optimizer = geoopt.optim.RiemannianAdam(
    list(encoder1.parameters())
    + list(encoder2.parameters())
    + list(decoder1.parameters())
    + list(decoder2.parameters())
    + list(mlp_link.parameters()),
    lr=1e-3,
)


epochs = 100
mse_loss = nn.MSELoss()
logeig = LogEig()

for epoch in range(epochs):
    total_loss = 0
    for C1_batch, C2_batch in loader:
        C1_batch = C1_batch.to(device)
        C2_batch = C2_batch.to(device)
        optimizer.zero_grad()

        z1 = encoder1(C1_batch)
        z2 = encoder2(C2_batch)

        z1_to_z2 = mlp_link(z1)

        X1_recon = decoder1(z1)
        X2_recon = decoder2(z2)

        X1_recon_log = logeig(X1_recon)
        X2_recon_log = logeig(X2_recon)
        C1_batch_log = logeig(C1_batch)
        C2_batch_log = logeig(C2_batch)

        loss_recon = mse_loss(X1_recon_log, C1_batch_log) + mse_loss(
            X2_recon_log, C2_batch_log
        )
        loss_latent = mse_loss(z1_to_z2, z2)
        loss = loss_recon + loss_latent

        loss.backward()
        optimizer.step()
        total_loss += loss.item() * C1_batch.size(0)
        print(1)

    avg_loss = total_loss / len(dataset)
    print(f"Epoch {epoch+1}/{epochs}, Loss={avg_loss:.6e}")