# Various Loss Functions

> Contain various loss functions used for optimization.

In [None]:
#| default_exp losses.custom_loss

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
#| export
def _reward_loss(reward_dist, rewards):
    reward_loss = -torch.mean(reward_dist.log_prob(rewards))
    return reward_loss

In [None]:
#| export
def _vae_loss(recon_x, x, mu, logsigma):
    """ VAE loss function """
    BCE = F.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + 2 * logsigma - mu.pow(2) - (2 * logsigma).exp())
    return BCE + KLD

In [None]:
#| export
def _l1_loss(z, h, loss_exp=1.0):
    return torch.mean(torch.abs(z - h) ** loss_exp) / loss_exp

In [None]:
#| export
import torch
class SIGReg(torch.nn.Module):
    def __init__(self, knots=17):
        super().__init__()
        t = torch.linspace(0, 3, knots, dtype=torch.float32)
        dt = 3 / (knots - 1)
        weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
        weights[[0, -1]] = dt
        window = torch.exp(-t.square() / 2.0)
        self.register_buffer("t", t)
        self.register_buffer("phi", window)
        self.register_buffer("weights", weights * window)

    def forward(self, proj):
        A = torch.randn(proj.size(-1), 256, device="cuda")
        A = A.div_(A.norm(p=2, dim=0))
        x_t = (proj @ A).unsqueeze(-1) * self.t
        err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square()
        statistic = (err @ self.weights) * proj.size(-2)
        return statistic.mean()


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()