# SIGReg

> Single and Multi-GPU implementations of SIGReg.

In [None]:
#| default_exp losses.sigreg

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F

## Single-GPU Implementation

In [None]:
#| export
import torch
class SIGRegSingle(torch.nn.Module):
    def __init__(self, knots=17, num_slices=256):
        super().__init__()
        self.num_slices = num_slices
        t = torch.linspace(0, 3, knots, dtype=torch.float32)
        dt = 3 / (knots - 1)
        weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
        weights[[0, -1]] = dt
        window = torch.exp(-t.square() / 2.0)
        self.register_buffer("t", t)
        self.register_buffer("phi", window)
        self.register_buffer("weights", weights * window)

    def forward(self, proj, global_step= None, across_dim= 0):
        """
        proj: Tensor of shape [T, B, D]
        global_step: [Optional] integer for seeding random projections.
        across_dim: Int or Tuple of Ints determining the dimensions to compute the loss across.
        Returns:
            statistic: Scalar tensor representing the SIGReg loss across the dimension determined by `across_time`.
        """
        T, B, D = proj.shape

        sample_size = T if across_dim == 0 else B
        sample_size = T * B if isinstance(across_dim, tuple) and set(across_dim) == {0, 1} else sample_size
        print("Sample Size:", sample_size)
        if global_step:
            g = torch.Generator(device= proj.device).manual_seed(int(global_step))
            A = torch.randn(D, self.num_slices, generator= g, device= proj.device)
        else: 
            A = torch.randn(D, self.num_slices, device= proj.device)
        
        A = A / A.norm(p=2, dim=0).clamp(min= 1e-8)
        x_t = (proj @ A).unsqueeze(-1) * self.t # [B, T, num_slices, knots]
        mean_cos = x_t.cos().mean(dim=across_dim) 
        mean_sin = x_t.sin().mean(dim=across_dim)
        err = (mean_cos - self.phi).square() + mean_sin.square()
        statistic = (err @ self.weights) * sample_size
        return statistic.mean()


In [None]:
#| hide
import torch
B = 32
T = 91
D = 32*15*15
z = torch.randn(T, B, D)
sigreg = SIGRegSingle()
loss = sigreg(z, global_step= 42, across_dim= 0)
print(loss)
loss = sigreg(z, global_step= 42, across_dim= 1)
print(loss)
loss = sigreg(z, global_step= 42, across_dim= (0,1))
print(loss)

Sample Size: 91
tensor(1.0511)
Sample Size: 32
tensor(1.0590)
Sample Size: 2912
tensor(1.0073)


## Distributed Implementation of SIGReg

In [None]:
#| export
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.distributed.nn import ReduceOp
from torch.distributed.nn import all_reduce as functional_all_reduce

def all_reduce_differentiable(x, op="AVG"):
    if dist.is_available() and dist.is_initialized():
        x_new = x.clone()
        op_type = ReduceOp.__dict__[op.upper()]
        return functional_all_reduce(x_new, op_type)
    else:
        return x
    
class SIGReg(torch.nn.Module):
    def __init__(self, knots=17, num_slices=256):
        super().__init__()
        self.num_slices = num_slices
        t = torch.linspace(0, 3, knots, dtype=torch.float32)
        dt = 3 / (knots - 1)
        weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
        weights[[0, -1]] = dt
        window = torch.exp(-t.square() / 2.0)
        self.register_buffer("t", t)
        self.register_buffer("phi", window)
        self.register_buffer("weights", weights * window)

    def forward(self, proj, global_step=None, across_dim=0, distributed=False):
        """
        proj: [T, B, D]
        global_step: [Optional] integer for seeding random projections.
        across_dim: [Optional] `Int` or `Tuple` of Ints determining the dimensions to compute the loss across.
        distributed: [Optional] `Bool` Set to True to compute statistics across all GPUs.
        Returns:
            statistic: Scalar tensor representing the SIGReg loss across the dimension determined by `across_time`.
        """
        T, B, D = proj.shape
        device = proj.device

        if isinstance(across_dim, tuple):
            local_sample_size = 1
            for d in across_dim:
                local_sample_size *= proj.shape[d]
        else:
            local_sample_size = proj.shape[across_dim]

        world_size = dist.get_world_size() if (distributed and dist.is_initialized()) else 1
        global_sample_size = local_sample_size * world_size

        if global_step is not None:
            g = torch.Generator(device=device).manual_seed(int(global_step))
            A = torch.randn(D, self.num_slices, generator= g, device= device)
        else:
            A = torch.randn(D, self.num_slices, device=device)
        
        A = A / A.norm(p= 2, dim= 0).clamp(min= 1e-8) # [D, num_slices]
        x_t = (proj @ A).unsqueeze(-1) * self.t # [T, B, num_slices, knots]
        
        mean_cos = x_t.cos().mean(dim=across_dim) 
        mean_sin = x_t.sin().mean(dim=across_dim)

        if distributed and dist.is_initialized():
            mean_cos = all_reduce_differentiable(mean_cos, op= "AVG")
            mean_sin = all_reduce_differentiable(mean_sin, op= "AVG")

        err = (mean_cos - self.phi).square() + mean_sin.square()
        statistic = (err @ self.weights) * global_sample_size
        
        return statistic.mean()

In [None]:
import torch

sigreg = SIGReg()
loss = sigreg(z, global_step= 42, across_dim= 0)
print(loss)
loss = sigreg(z, global_step= 42, across_dim= 1)
print(loss)
loss = sigreg(z, global_step= 42, across_dim= (0,1))
print(loss)

tensor(1.0511)
tensor(1.0590)
tensor(1.0073)


### Other trials

In [None]:
#| export
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.distributed.nn import ReduceOp
from torch.distributed.nn import all_reduce as functional_all_reduce

def all_reduce_differentiable(x, op="AVG"):
    if dist.is_available() and dist.is_initialized():
        x_new = x.clone()
        op_type = ReduceOp.__dict__[op.upper()]
        return functional_all_reduce(x_new, op_type)
    else:
        return x

class SIGRegFunctional(torch.nn.Module):
    def __init__(self, knots=17, num_slices=256, scale_factor=1.0):
        super().__init__()
        self.num_slices = num_slices
        self.scale_factor = scale_factor
        
        t = torch.linspace(0, 3, knots, dtype=torch.float32)
        dt = 3 / (knots - 1)
        weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
        weights[[0, -1]] = dt
        window = torch.exp(-t.square() / 2.0)
        self.register_buffer("t", t)
        self.register_buffer("phi", window)
        self.register_buffer("weights", weights * window)

    def forward(self, proj, global_step, mask):
        """
        proj: Tensor of shape [T, B, d]
        mask: [T, B] (1 for valid, 0 for padding)
        """
        T_dim, B_dim, d_dim = proj.shape
        
        with torch.no_grad():
            g = torch.Generator(device=proj.device).manual_seed(int(global_step))
            A = torch.randn(d_dim, self.num_slices, generator=g, device=proj.device)
            A = A / A.norm(p=2, dim=0).clamp(min=1e-8)

        projections = proj @ A # [T, B, Slices]
        x_t = projections.unsqueeze(-1) * self.t 
        
        # Apply Mask & Compute Local Sums
        # Expand mask to [T, B, 1, 1] to broadcast across Slices and Knots
        m = mask.view(T_dim, B_dim, 1, 1)
        
        local_sum_real = (x_t.cos() * m).sum(dim=1) # [T, Slices, Knots]
        local_sum_imag = (x_t.sin() * m).sum(dim=1)
        local_count = m.sum(dim=1) # [T, 1, 1]
        
        # Global Differentiable Sync (SUM instead of AVG)
        # We sum all numerators and all denominators across GPUs
        global_sum_real = all_reduce_differentiable(local_sum_real.clone(), op="SUM")
        global_sum_imag = all_reduce_differentiable(local_sum_imag.clone(), op="SUM")
        global_count = all_reduce_differentiable(local_count.clone(), op="SUM")
        
        # 5. Global Weighted Mean
        # Avoid division by zero with clamp
        global_count_safe = global_count.clone()
        ecf_real = global_sum_real / global_count_safe.clamp(min=1e-6)
        ecf_imag = global_sum_imag / global_count_safe.clamp(min=1e-6)

        # 6. Loss Calculation
        err = (ecf_real - self.phi).square() + ecf_imag.square()
        
        # Scale by Global N (The total number of valid samples across all GPUs)
        # statistic shape: [T, Slices]
        statistic = (err @ self.weights) * global_count.squeeze(-1)
        
        # Average over Time and Slices
        return statistic.mean() * self.scale_factor

In [None]:
#| hide
T = 8
B = 16
d = 512
global_step = 100
z_proj = torch.randn(T, B, d)
c_proj = torch.randn(T, B, d)
mask = torch.ones(T, B)
dist.init_process_group(backend="gloo", init_method="tcp://localhost:12345", rank=0, world_size=1)
loss = SIGRegFunctional()
print(loss(z_proj, 0, mask))
dist.destroy_process_group()
# loss = SIGReg()
# print(loss(z_proj, 0, mask))

[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
tensor(1.0431)


### Time-version of SIGReg

In [None]:
#| export
class TemporalSIGReg(torch.nn.Module):
    def __init__(self, knots=17, num_slices=256):
        super().__init__()
        # Standard integration setup (matching your aligned SIGReg)
        t = torch.linspace(0, 3, knots, dtype=torch.float32)
        dt = 3 / (knots - 1)
        weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
        weights[[0, -1]] = dt
        phi = torch.exp(-t.square() / 2.0)
        
        self.register_buffer("t", t)
        self.register_buffer("phi", phi)
        self.register_buffer("weights", weights * phi)
        self.num_slices = num_slices

    def forward(self, x, global_step, mask):
        """
        x: [T, B, d] -> Time, Batch, Dim
        mask: [T, B] -> 1 for valid, 0 for padding
        """
        T_dim, B_dim, d_dim = x.shape
        
        # 1. Deterministic Projection
        with torch.no_grad():
            g = torch.Generator(device=x.device).manual_seed(int(global_step))
            A = torch.randn(d_dim, self.num_slices, generator=g, device=x.device)
            A = A / A.norm(p=2, dim=0).clamp(min=1e-8)

        # 2. Project & Prepare grid
        projections = x @ A # [T, B, Slices]
        x_t = projections.unsqueeze(-1) * self.t # [T, B, Slices, Knots]
        
        # 3. Apply Mask & Compute Sums across TIME (dim 0)
        # This replaces the Batch reduction with Time reduction
        m = mask.view(T_dim, B_dim, 1, 1)
        
        # We reduce over dim=0 to check the distribution of a sequence's life-cycle
        local_sum_real = (x_t.cos() * m).sum(dim=0) # [B, Slices, Knots]
        local_sum_imag = (x_t.sin() * m).sum(dim=0)
        local_count = m.sum(dim=0) # [B, 1, 1]
        
        # 4. Global ECF per Batch Sample
        # No all_reduce needed! Each GPU has the full time-sequence for its samples.
        ecf_real = local_sum_real / local_count.clamp(min=1e-6)
        ecf_imag = local_sum_imag / local_count.clamp(min=1e-6)

        # 5. Loss Calculation
        err = (ecf_real - self.phi).square() + ecf_imag.square()
        
        # Scale by Local sequence length (matching the VC loss logic)
        statistic = (err @ self.weights) * local_count.squeeze(-1)
        
        # Average over Batch and Slices
        return statistic.mean()

In [None]:
#| hide
T = 8
B = 16
d = 512
global_step = 100
z_proj = torch.randn(T, B, d)
c_proj = torch.randn(T, B, d)
mask = torch.ones(T, B)
loss = TemporalSIGReg()
print(loss(z_proj, 0, mask))


torch.Size([16, 256, 17])
tensor(1.0589)


## Testing latents

In [None]:
# #| hide
# from mawm.models import init_models
# from omegaconf import OmegaConf
# cfg = OmegaConf.load("../cfgs/findgoal/mawm/ablations/datasize/mawm_ds_200k.yaml")
# cfg.distributed = False
# model = init_models(cfg, device= "cpu", distributed= cfg.distributed)


INFO:root:JEPA Parameters: 98560
INFO:root:CommModule Parameters: 56005
INFO:root:MSgEncoder Parameters: 32608
INFO:root:Projector Parameters: 2241536
INFO:root:--------------------------------------------------
INFO:root:Total Parameters: 2462245


In [None]:
# #| hide
# import torch
# ckpt = torch.load("./models/good.pt", map_location="cpu")

In [None]:
# jepa = model['rec']['jepa']
# jepa.load_state_dict(ckpt["jepa"])

<All keys matched successfully>

In [None]:
# from mawm.data.utils import init_data
# cfg = OmegaConf.load("../cfgs/findgoal/mawm/ablations/datasize/mawm_ds_200k.yaml")
# cfg.distributed = False
# cfg.data.data_size = 0
# cfg.data.batch_size = 10
# dl, _ = init_data(cfg, distributed= cfg.distributed)


Data path found for hostname: local
Using all 10 rollouts in dataset.


In [None]:
# len(dl)

1

In [None]:
# batch = next(iter(dl))
# batch.keys()

dict_keys(['agent_0', 'agent_1'])

In [None]:
# batch['agent_0'].keys()

dict_keys(['obs', 'pos', 'msg', 'msg_target', 'act', 'next_obs', 'done'])

In [None]:
# batch['agent_0']['obs'].shape

torch.Size([10, 40, 3, 42, 42])

In [None]:
# with torch.no_grad():
#     z = jepa.backbone(batch['agent_0']['obs'].to("cpu"), batch['agent_0']['pos'].to("cpu"))
#     z2 = jepa.backbone(batch['agent_1']['obs'].to("cpu"), batch['agent_1']['pos'].to("cpu"))
    

In [None]:
# z.shape, z2.shape

(torch.Size([10, 40, 32, 15, 15]), torch.Size([10, 40, 32, 15, 15]))

In [None]:
# vectors = z.flatten(0, 1).view(-1, 32*15*15)
# vectors.shape

# vectors2 = z2.flatten(0, 1).view(-1, 32*15*15)
# vectors2.shape

torch.Size([400, 7200])

In [None]:
# v = vectors[0]
# v.shape

torch.Size([7200])

In [None]:
# B = vectors.shape[0]
# vectors = vectors - vectors.mean(dim=0, keepdim=True)
# cov_mat = (vectors.T @ vectors) / (B - 1)
# cov_mat.shape
# cov_mat2 = torch.cov(vectors2.T)
# cov_mat2.shape

torch.Size([7200, 7200])

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()