In [None]:
import torch.nn as nn
import torch

from tqdm.notebook import tqdm

from typing import List

In [None]:
num_epochs = 500

Landscape Distance Loss:
$$\mathcal{L}_{Geometric} = \sum_{i,j} (||q(x_i)-q(x_j)||_2 - ||(x_i,L(x_i))-(x_j,L(x_j))||_2)^2$$

In [None]:
def landscape_dist_loss(Q, X, L):
    """
    Q: Embeddings
    X: Parameters
    L: Loss
    """

    # Z = (X, L)
    Z = torch.cat([X, L], dim=1)

    # Pairwise dists
    Dq = torch.cdist(Q, Q, p=2)
    Dz = torch.cdist(Z, Z, p=2)

    # Total dists
    loss = (Dq - Dz).pow(2).sum()

    return loss

In [None]:
def train_autoencoder(model, loader, optimizer, num_epochs, loss_type, lambda_geometric, device):

    loss_fn = nn.MSELoss()

    train_loss, recon_loss, geoemtric_loss = 0.0, 0.0, 0.0

    for epoch in tqdm(range(num_epochs)):
        model.train()

        for params, loss_vals in loader:

            params, loss_vals = params.to(device), loss_vals.to(device)

            optimizer.zero_grad()

            embeddings = model.encode(params)
            reconstructions = model.decode(embeddings)

            recon_loss = loss_fn(reconstructions, params)

            if loss_type == "landscape_dist":
                geometric_loss = landscape_dist_loss(embeddings, params, loss_vals)
            elif loss_type == "smooth_preserving":
                pass
            else:
                raise Exception

            loss = recon_loss + geometric_loss*lambda_geometric

            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            recon_loss += recon_loss.item()
            geometric_loss += geometric_loss.item()

        print(f"[Epoch {epoch}]\tTrain loss: {train_loss}\tRecon Loss: {recon_loss}\tGeometric Loss: {geometric_loss}")


