# Various Loss Functions

> Contain various loss functions used for optimization.

In [None]:
#| default_exp losses.sigreg

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
#| export
import torch
class SIGReg(torch.nn.Module):
    def __init__(self, knots=17):
        super().__init__()
        t = torch.linspace(0, 3, knots, dtype=torch.float32)
        dt = 3 / (knots - 1)
        weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
        weights[[0, -1]] = dt
        window = torch.exp(-t.square() / 2.0)
        self.register_buffer("t", t)
        self.register_buffer("phi", window)
        self.register_buffer("weights", weights * window)

    def forward(self, proj, global_step= None, mask=None):
        A = torch.randn(proj.size(-1), 256, device=proj.device)
        A = A.div_(A.norm(p=2, dim=0))
        x_t = (proj @ A).unsqueeze(-1) * self.t
        err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square()
        statistic = (err @ self.weights) * proj.size(-2)
        return statistic.mean()


### CPU version for local testing

In [None]:
#| hide
import torch
class MySIGReg(torch.nn.Module):
    def __init__(self, knots=17):
        super().__init__()
        t = torch.linspace(0, 3, knots, dtype=torch.float32)
        dt = 3 / (knots - 1)
        weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
        weights[[0, -1]] = dt
        window = torch.exp(-t.square() / 2.0)
        self.register_buffer("t", t)
        self.register_buffer("phi", window)
        self.register_buffer("weights", weights * window)

    def forward(self, proj, global_step= None, mask=None):
        A = torch.randn(proj.size(-1), 256, device="cpu")
        A = A.div_(A.norm(p=2, dim=0))
        x_t = (proj @ A).unsqueeze(-1) * self.t
        err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square()
        statistic = (err @ self.weights) * proj.size(-2)
        return statistic.mean()


In [None]:
#| hide
T = 8
B = 16
d = 128
z_proj = torch.randn(T-1, B, d)
c_proj = torch.randn(T-1, B, d)
loss = MySIGReg()
loss(z_proj)

tensor(1.0757)

### Distributed Sigreg

In [None]:
#| export
import torch.distributed as dist
import torch
class FullGatherLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
        dist.all_gather(output, x)
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        all_gradients = torch.stack(grads)
        # Sum gradients across all GPUs so each rank gets the global signal
        dist.all_reduce(all_gradients, op=dist.ReduceOp.SUM)
        return all_gradients[dist.get_rank()]
    

In [None]:
#| export
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.distributed.nn import ReduceOp
from torch.distributed.nn import all_reduce as functional_all_reduce

def all_reduce_differentiable(x, op="AVG"):
    if dist.is_available() and dist.is_initialized():
        x_new = x.clone()
        op_type = ReduceOp.__dict__[op.upper()]
        return functional_all_reduce(x_new, op_type)
    else:
        return x

class SIGRegFunctional(torch.nn.Module):
    def __init__(self, knots=17, num_slices=256, scale_factor=1.0):
        super().__init__()
        self.num_slices = num_slices
        self.scale_factor = scale_factor
        
        # Standard integration setup (-5 to 5)
        t = torch.linspace(0, 3, knots, dtype=torch.float32)
        dt = 3 / (knots - 1)
        weights = torch.full((knots,), dt, dtype=torch.float32)
        weights[[0, -1]] = dt
        phi = torch.exp(-t.square() / 2.0)
        
        self.register_buffer("t", t)
        self.register_buffer("phi", phi)
        self.register_buffer("weights", weights * phi)


    def forward(self, x, global_step, mask):
        """
        x: [T, B, d]
        mask: [T, B] (1 for valid, 0 for padding)
        """
        T_dim, B_dim, d_dim = x.shape
        # x = F.normalize(x, p=2, dim=-1)
        
        # 1. Sync projection matrix A
        with torch.no_grad():
            g = torch.Generator(device=x.device).manual_seed(int(global_step))
            A = torch.randn(d_dim, self.num_slices, generator=g, device=x.device)
            A = A / A.norm(p=2, dim=0).clamp(min=1e-8)

        # 2. Project & ECF Terms
        projections = x @ A # [T, B, Slices]
        # x_t shape: [T, B, Slices, Knots]
        x_t = projections.unsqueeze(-1) * self.t 
        
        # 3. Apply Mask & Compute Local Sums
        # Expand mask to [T, B, 1, 1] to broadcast across Slices and Knots
        m = mask.view(T_dim, B_dim, 1, 1)
        
        # Sum of ECF values for valid instances
        local_sum_real = (x_t.cos() * m).sum(dim=1) # [T, Slices, Knots]
        local_sum_imag = (x_t.sin() * m).sum(dim=1)
        local_count = m.sum(dim=1) # [T, 1, 1]
        
        # 4. Global Differentiable Sync (SUM instead of AVG)
        # We sum all numerators and all denominators across GPUs
        global_sum_real = all_reduce_differentiable(local_sum_real.clone(), op="SUM")
        global_sum_imag = all_reduce_differentiable(local_sum_imag.clone(), op="SUM")
        global_count = all_reduce_differentiable(local_count.clone(), op="SUM")
        
        # 5. Global Weighted Mean
        # Avoid division by zero with clamp
        global_count_safe = global_count.clone()
        ecf_real = global_sum_real / global_count_safe.clamp(min=1e-6)
        ecf_imag = global_sum_imag / global_count_safe.clamp(min=1e-6)

        # 6. Loss Calculation
        err = (ecf_real - self.phi).square() + ecf_imag.square()
        
        # Scale by Global N (The total number of valid samples across all GPUs)
        # statistic shape: [T, Slices]
        statistic = (err @ self.weights) * global_count.squeeze(-1)
        
        # Average over Time and Slices
        return statistic.mean() * self.scale_factor

In [None]:
#| hide
T = 8
B = 16
d = 512
z_proj = torch.randn(T, B, d)
c_proj = torch.randn(T, B, d)
mask = torch.ones(T, B)
dist.init_process_group(backend="gloo", init_method="tcp://localhost:12345", rank=0, world_size=1)
loss = SIGRegFunctional()
print(loss(z_proj, 0, mask))
dist.destroy_process_group()


[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
tensor(0.5251)


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()