# Various Loss Functions

> Contain various loss functions used for optimization.

In [None]:
#| default_exp losses.sigreg

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
#| export
import torch
class SIGReg(torch.nn.Module):
    def __init__(self, knots=17):
        super().__init__()
        t = torch.linspace(0, 3, knots, dtype=torch.float32)
        dt = 3 / (knots - 1)
        weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
        weights[[0, -1]] = dt
        window = torch.exp(-t.square() / 2.0)
        self.register_buffer("t", t)
        self.register_buffer("phi", window)
        self.register_buffer("weights", weights * window)

    def forward(self, proj):
        A = torch.randn(proj.size(-1), 256, device=proj.device)
        A = A.div_(A.norm(p=2, dim=0))
        x_t = (proj @ A).unsqueeze(-1) * self.t
        err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square()
        statistic = (err @ self.weights) * proj.size(-2)
        return statistic.mean()


### CPU version for local testing

In [None]:
#| hide
import torch
class MySIGReg(torch.nn.Module):
    def __init__(self, knots=17):
        super().__init__()
        t = torch.linspace(0, 3, knots, dtype=torch.float32)
        dt = 3 / (knots - 1)
        weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
        weights[[0, -1]] = dt
        window = torch.exp(-t.square() / 2.0)
        self.register_buffer("t", t)
        self.register_buffer("phi", window)
        self.register_buffer("weights", weights * window)

    def forward(self, proj):
        A = torch.randn(proj.size(-1), 256, device="cpu")
        A = A.div_(A.norm(p=2, dim=0))
        x_t = (proj @ A).unsqueeze(-1) * self.t
        err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square()
        statistic = (err @ self.weights) * proj.size(-2)
        return statistic.mean()


In [None]:
#| hide
T = 8
B = 16
d = 128
z_proj = torch.randn(T-1, B, d)
c_proj = torch.randn(T-1, B, d)
loss = MySIGReg()
loss(z_proj)

tensor(1.0757)

### Distributed Sigreg

In [None]:
#| export
import torch.distributed as dist
import torch
class FullGatherLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
        dist.all_gather(output, x)
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        all_gradients = torch.stack(grads)
        # Sum gradients across all GPUs so each rank gets the global signal
        dist.all_reduce(all_gradients, op=dist.ReduceOp.SUM)
        return all_gradients[dist.get_rank()]
    

In [None]:
#| export
import torch
import torch.distributed as dist
from einops import rearrange

class SIGRegDistributed(torch.nn.Module):
    def __init__(self, knots=17, num_slices=256):
        super().__init__()
        self.num_slices = num_slices
        # Integration points (t) and Gaussian window
        t = torch.linspace(0, 3, knots, dtype=torch.float32)
        dt = 3 / (knots - 1)
        
        # Integration weights (Trapezoidal rule approximation)
        weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
        weights[[0, -1]] = dt
        window = torch.exp(-t.square() / 2.0)
        
        self.register_buffer("t", t)
        self.register_buffer("phi", window)
        self.register_buffer("weights", weights * window)

    # def forward(self, proj, global_step):
        # """
        # proj: [B, D] where B is local batch size
        # global_step: Used to seed the projection matrix so all GPUs match
        # """
        # device = proj.device
        
        # # 1. Generate Synchronized Projection Matrix A
        # # We seed the generator with the global_step so every GPU creates the SAME A
        # g = torch.Generator(device=device)
        # g.manual_seed(int(global_step))
        
        # # A shape: [Feature_Dim, Num_Slices]
        # A = torch.randn(proj.size(-1), self.num_slices, generator=g, device=device)
        # A = A.div_(A.norm(p=2, dim=0))
        
        # # 2. Compute Local Empirical Characteristic Function (ECF)
        # # x_t shape: [B, M, T] (Batch, Slices, Knots)
        # x_t = (proj @ A).unsqueeze(-1) * self.t # [7, 16, 256, 17]
        
        # # We compute the real and imaginary components locally
        # ecf_real_local = x_t.cos().mean(dim=0) # [7, 16, 256, 17] => [16, 256, 17]
        # ecf_imag_local = x_t.sin().mean(dim=0) # [7, 16, 256, 17] => [16, 256, 17]
        
        # # 3. Synchronize across all GPUs (All-Reduce AVG)
        # # If DDP is initialized, we average the characteristic function across the world
        # if dist.is_initialized():
        #     dist.all_reduce(ecf_real_local, op=dist.ReduceOp.AVG)
        #     dist.all_reduce(ecf_imag_local, op=dist.ReduceOp.AVG)
        #     world_size = dist.get_world_size()
        # else:
        #     world_size = 1

        # # 4. Compute Squared Distance to Gaussian CF
        # # err = |ecf_global - phi|^2
        # err = (ecf_real_local - self.phi).square() + ecf_imag_local.square()# [16, 256, 17]
        # print(err.shape)
        
        # # 5. Integration and Scaling
        # # We multiply by the Global Batch Size (N * world_size)
        # global_n = proj.size(0) * world_size
        # statistic = (err @ self.weights) * global_n
        
        # return statistic.mean()
    

In [None]:
#| export
import torch
import torch.distributed as dist
@patch
def forward(self: SIGRegDistributed, proj, global_step):
    device = proj.device

    if not proj.is_contiguous():
        proj = proj.contiguous()
    
    # 1. Sync Projections Across GPUs
    if dist.is_initialized():   
        # This makes the gather operation part of the autograd graph
        # It converts [B, D] on 1 GPU -> [World_Size, B, D]
        gathered = FullGatherLayer.apply(proj)
        proj = torch.cat(gathered, dim=0) 

        print("proj.shape across GPUs: ", proj.shape)
    
    world_size = dist.get_world_size() if dist.is_initialized() else 1

    # 2. Generate Synchronized Projection Matrix A
    g = torch.Generator(device=device)
    g.manual_seed(int(global_step))
    A = torch.randn(proj.size(-1), self.num_slices, generator=g, device=device)
    A = A / A.norm(p=2, dim=0).clamp(min=1e-8) # Use out-of-place division
    
    # 3. Compute ECF on the Global Batch
    # x_t shape: [Global_B, Num_Slices, Knots]
    x_t = (proj @ A).unsqueeze(-1) * self.t 
    
    # Mean across the GLOBAL batch
    ecf_real = x_t.cos().mean(dim=0) 
    ecf_imag = x_t.sin().mean(dim=0)
    
    # 4. Compute Squared Distance to Gaussian CF
    err = (ecf_real - self.phi).square() + ecf_imag.square()
    
    # 5. Scale by Global Batch Size
    global_n = proj.size(0) # This is already local_n * world_size now
    statistic = (err @ self.weights) * global_n
    
    return statistic.mean()


In [None]:
#| hide
T = 8
B = 16
d = 128
z_proj = torch.randn(T-1, B, d)
c_proj = torch.randn(T-1, B, d)
loss = SIGRegDistributed()
loss(z_proj, 0)

torch.Size([7, 16, 128])


tensor(1.0517)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()