In [None]:
import torch
import torch.nn as nn
import torch.optim as o
from torch.utils.data import DataLoader

from GeomLayers.SPDLayers import *
from GeomLayers.Toy import ToyNormCovarianceDataset
from GeomLayers.metaoptimizer import MetaOptimizer

In [None]:
net = nn.Sequential(
    SPDBiMap(10, 10),
    SPDReEig(1e-2),
    SPDBiMap(10, 5),
    SPDReEig(1e-2),
    SPDBiMap(5, 3),
    SPDLogEig(),
    SPDVectorize(3),
    nn.BatchNorm1d(6),
    nn.Linear(6, 128),
    nn.ReLU(),
    nn.Linear(128, 20),
    nn.LogSoftmax(dim=-1)
)
optimizer = MetaOptimizer(net.parameters(), o.Adam)
loss_fn = nn.NLLLoss()

In [None]:
dataset = ToyNormCovarianceDataset()
val_dataset = ToyNormCovarianceDataset(n_item_per_class=20)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
val_dataloader = DataLoader(dataset, batch_size=8, shuffle=False)

In [None]:
for i in range(10):
    net.train()
    for features, labels in dataloader:
        out = net(features)
        loss = loss_fn(out, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    net.eval()
    eval_labels = []
    eval_logits = []
    for features, labels in val_dataloader:
        out = net(features)
        eval_labels.append(labels)
        eval_logits.append(out)

    eval_labels = torch.concat(eval_labels, dim=0)
    eval_logits = torch.concat(eval_logits, dim=0)
    eval_preds = torch.argmax(eval_logits, dim=-1)
    eval_loss = loss_fn(eval_logits, eval_preds)
    eval_acc = sum(eval_labels == eval_preds) / len(eval_preds) * 100
    print(eval_loss.detach(), eval_acc)