The notebook runs, however, it requires the difflogic library which only works on certain colab instances... The best solution we found is to delete the runtime and start a new one if it fails to install.

In [None]:
#Install DiffLogic and correct CUDA version

!sudo apt-get install -y cmake ninja-build
!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121
!pip install difflogic

In [None]:
import copy
import time
from types import MethodType
from typing import List, Tuple, Callable

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from tqdm import tqdm

from difflogic import LogicLayer, GroupSum
from difflogic.packbitstensor import PackBitsTensor
from difflogic.difflogic import LogicLayerCudaFunction


In [None]:
seed = 42
if seed is not None:
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    print(f"[INFO] Using fixed seed: {seed}")


BITS_TO_TORCH_FLOATING_POINT_TYPE = {
    16: torch.float16,
    32: torch.float32,
    64: torch.float64
}
training_bit_count =  32


def bin_op(a, b, i):
    assert a[0].shape == b[0].shape, (a[0].shape, b[0].shape)
    if a.shape[0] > 1:
        assert a[1].shape == b[1].shape, (a[1].shape, b[1].shape)

    if i == 0:
        return torch.zeros_like(a)
    elif i == 1:
        return a * b
    elif i == 2:
        return a - a * b
    elif i == 3:
        return a
    elif i == 4:
        return b - a * b
    elif i == 5:
        return b
    elif i == 6:
        return a + b - 2 * a * b
    elif i == 7:
        return a + b - a * b
    elif i == 8:
        return 1 - (a + b - a * b)
    elif i == 9:
        return 1 - (a + b - 2 * a * b)
    elif i == 10:
        return 1 - b
    elif i == 11:
        return 1 - b + a * b
    elif i == 12:
        return 1 - a
    elif i == 13:
        return 1 - a + a * b
    elif i == 14:
        return 1 - a * b
    elif i == 15:
        return torch.ones_like(a)


def bin_op_s(a, b, i_s):
    r = torch.zeros_like(a)
    for i in range(16):
        u = bin_op(a, b, i)
        r = r + i_s[..., i] * u
    return r


def load_n(loader, n):
    i = 0
    while i < n:
        for x in loader:
            yield x
            i += 1
            if i == n:
                break

In [None]:
def train_step(model, x, y, loss_fn, optimizer):
    out = model(x)
    loss = loss_fn(out, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()


def train(model, train_loader, test_loader, loss_fn, optimizer, n_steps, print_freq, full_eval: bool) -> dict[str, list]:
    loss_values = 0
    results = {'training':[]}
    for i, (x, y) in tqdm(enumerate(load_n(train_loader, n_steps)), total=n_steps):
        x = x.to(BITS_TO_TORCH_FLOATING_POINT_TYPE[training_bit_count]).to('cuda')
        y = y.to('cuda')
        loss = train_step(model, x, y, loss_fn, optimizer)
        loss_values += loss
        if i%print_freq == (print_freq - 1):
            if full_eval:
                train_soft = eval_model(model, train_loader, train_mode=True)
                train_discrete = eval_model(model, train_loader, train_mode=False)
                test_soft = eval_model(model, test_loader, train_mode=True)
                test_discrete = eval_model(model, test_loader, train_mode=False)
            else:
                train_soft = -1
                train_discrete = -1
                test_soft = -1
                test_discrete = -1
            print(f"Step {i+1}/{n_steps} - Loss: {loss_values/print_freq:7.4f} - "
                  f"Train soft: {train_soft:4.2%} - "
                  f"Train discrete: {train_discrete:4.2%} - "
                  f"Test soft: {test_soft:4.2%} - "
                  f"Test discrete: {test_discrete:4.2%}")
            results['training'].append({
                'step': i,
                'loss': loss_values/print_freq,
                'time': time.time(),
                'train_soft': train_soft,
                'train_discrete': train_discrete,
                'test_soft': test_soft,
                'test_discrete': test_discrete
            })
            print()
            loss_values = 0

    return results


def eval_model(model, loader, train_mode):
    orig = model.training
    with torch.no_grad():
        model.train(mode=train_mode)
        accs = []
        for x, y in loader:
            pred = model(x.to('cuda').round()).argmax(-1)
            accs.append((pred == y.to('cuda')).float().mean().item())
        res = float(np.mean(accs))
    model.train(mode=orig)
    return res


def patch_logic_layer(model, default_tau=1.0, verbose=True):
    """
    Monkey-patch every LogicLayer in `model` so it uses Gumbel-Softmax
    and prints a confirmation the first time its forward is executed.
    """

    def _weights_to_ops(self, training):
        if training:
            return F.gumbel_softmax(
                self.weights,
                tau=getattr(self, "gumbel_tau", default_tau),
                hard=True,
                dim=-1,
            )
        else:
            return F.one_hot(self.weights.argmax(-1), 16).to(torch.float32)

    # ------------- python path -------------
    def forward_python_gumbel(self, x):
        if verbose and not getattr(self, "_print_done", False):
            print(f"[GUMBEL] LogicLayer(id={id(self)}) using forward_python_gumbel")
            self._print_done = True

        assert x.shape[-1] == self.in_dim
        if self.indices[0].dtype != torch.long:
            self.indices = self.indices[0].long(), self.indices[1].long()

        a, b = x[..., self.indices[0]], x[..., self.indices[1]]
        weights = _weights_to_ops(self, self.training)
        return bin_op_s(a, b, weights)

    # ------------- CUDA path ---------------
    def forward_cuda_gumbel(self, x):
        if verbose and not getattr(self, "_print_done", False):
            print(f"[GUMBEL] LogicLayer(id={id(self)}) using forward_cuda_gumbel")
            self._print_done = True

        assert x.ndim == 2
        assert x.device.type == "cuda", x.device
        x = x.transpose(0, 1).contiguous()

        a, b = self.indices
        w = _weights_to_ops(self, self.training).to(x.dtype)

        return LogicLayerCudaFunction.apply(
            x, a, b, w,
            self.given_x_indices_of_y_start,
            self.given_x_indices_of_y
        ).transpose(0, 1)

    # ------------- master forward ----------
    def forward_gumbel(self, x):
        # print once here too if you want:
        if verbose and not getattr(self, "_print_done", False):
            print(f"[GUMBEL] LogicLayer(id={id(self)}) master forward")
            # don't set _print_done here—let sub-forward do it

        if self.implementation == "cuda":
            if isinstance(x, PackBitsTensor):
                return self.forward_cuda_eval(x)
            return self.forward_cuda(x)
        elif self.implementation == "python":
            return self.forward_python(x)
        else:
            raise ValueError(self.implementation)

    # -------- apply to every layer ---------
    for layer in model.modules():
        if isinstance(layer, LogicLayer):
            layer.gumbel_tau = default_tau
            layer.forward_python = MethodType(forward_python_gumbel, layer)
            layer.forward_cuda   = MethodType(forward_cuda_gumbel,   layer)
            layer.forward        = MethodType(forward_gumbel,        layer)


def get_dataloaders() -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, int]:
    """
    Returns the CIFAR-10 dataset and the corresponding DataLoaders.
    """
    #Load CIFAR-10 dataset
    binarize = lambda x: torch.cat([(x > (i+1)/32).float() for i in range(31)], dim=0)
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Lambda(binarize)
    ])
    train_set_cifar = torchvision.datasets.CIFAR10('./data/cifar', train=True, download=True, transform=transforms)
    test_set_cifar  = torchvision.datasets.CIFAR10('./data/cifar', train=False, download=True, transform=transforms)

    in_dim_cifar    = 3 * 32 * 32 * 31 # Size of CIFAR-10

    train_set = train_set_cifar
    test_set  = test_set_cifar
    in_dim = in_dim_cifar

    # DataLoaders
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=128,
                                                shuffle=True, pin_memory=True,
                                                drop_last=True, num_workers=4)
    test_loader  = torch.utils.data.DataLoader(test_set,  batch_size=1024,
                                                shuffle=False, pin_memory=True,
                                                drop_last=False, num_workers=2)
    return train_loader, test_loader, in_dim


def get_model(in_dim: int, width: int=256_000, depth: int=12, gumbel_tau: float=0.2,
             group_sum_k: int=10, group_sum_tau: float=30, gumbel_model: bool=False):
    """
    Returns a model with the specified parameters.
    """
    model = torch.nn.Sequential(
        torch.nn.Flatten(),
        LogicLayer(in_dim, width),  # 1
        *[LogicLayer(width, width) for _ in range(depth-1)], # 2 to depth
        GroupSum(k=group_sum_k, tau=group_sum_tau)
    )
    model = model.to('cuda')
    if gumbel_model:
        patch_logic_layer(model, default_tau=gumbel_tau, verbose=False)
    return model

In [None]:
# Empirically, we observe that the free version of Google Collab is able to run ~4.5 iterations / second. For both Differentiable LGNs and Gumbel LGNs, we allow a compute budget of 2 hours, which equates to 4.5 x 2 x 60 x 60 = 32400 iterations.
def main(gumbel_model: bool, model_depth: int):

    n_steps = 3_000
    print_freq = 1_200
    full_eval=True
    loss_fn   = torch.nn.CrossEntropyLoss()


    # Load CIFAR-10 dataset
    train_loader, test_loader, in_dim = get_dataloaders()
    # Define model
    model = get_model(in_dim, width=256_000, depth=model_depth, gumbel_tau=1.00, group_sum_k=10, group_sum_tau=30, gumbel_model=gumbel_model)
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.1 if gumbel_model else 0.1)

    results = train(model, train_loader, test_loader, loss_fn, optimizer, n_steps=n_steps, print_freq=print_freq, full_eval=full_eval)

    results["final"] = {
        "train_soft": eval_model(model, train_loader, train_mode=True),
        "train_discrete": eval_model(model, train_loader, train_mode=False),
        "test_soft": eval_model(model, test_loader, train_mode=True),
        "test_discrete": eval_model(model, test_loader, train_mode=False)
    }
    results["model"] = model
    print("Final evaluation:")
    print("soft:    ", results["final"]["train_soft"])
    print("discrete:", results["final"]["train_discrete"])
    print("soft:    ", results["final"]["test_soft"])
    print("discrete:", results["final"]["test_discrete"])

    return results

In [None]:
gumbel_model = True
model_depth = 12

start_time = time.time()
gumbel_results = main(gumbel_model=gumbel_model, model_depth=model_depth)
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")

In [None]:
gumbel_model = False
model_depth = 12

start_time = time.time()
softmax_results = main(gumbel_model=gumbel_model, model_depth=model_depth)
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")

In [None]:
# Put the results in a dataframe and show the table
import pandas as pd
gumbel_df = pd.Series(gumbel_results['final'], name='Gumbel')
softmax_df = pd.Series(softmax_results['final'], name='Softmax')

# Create a dataframe from the series
df = pd.DataFrame({
    'Gumbel': gumbel_df,
    'Softmax': softmax_df
})

# Compute train and test dicretization gap for both models
# df.loc["Gumbel Train Discrete Gap"] = df.loc["Gumbel", "train_soft"] - df.loc["Gumbel", "train_discrete"]
# train discretization gap = train soft - train discrete
df.loc["train_discretization_gap"] = df.loc["train_soft"] - df.loc["train_discrete"]
df.loc["test_discretization_gap"] = df.loc["test_soft"] - df.loc["test_discrete"]
# Sort index
df = df.loc[["train_soft", "train_discrete", "train_discretization_gap", "test_soft", "test_discrete", "test_discretization_gap"]]

# Remove "_" from the index and capitalize the first letter
df.index = df.index.str.replace("_", " ").str.capitalize()

# Display the dataframe
print(df)

We see much lower (even negative) discretization gaps and also better performance in the discrete setting for our gumbel method.

## Create plots of the loss landscapes

We provide the following code for reference. However, the models have not converged within the 3k iterations used above to be within a fast runtime. For this, more than 50k iterations are needed.

In [None]:

def get_random_direction(model: torch.nn.Module) -> List[torch.Tensor]:
    """
    Sample a random direction in the parameter space of the model.

    Args:
        model: PyTorch model

    Returns:
        List of tensors representing a normalized random direction
    """
    direction = []
    for param in model.parameters():
        if param.requires_grad:
            direction.append(torch.randn_like(param))

    # Normalize the direction
    norm = torch.sqrt(sum(torch.sum(d * d) for d in direction))
    for i in range(len(direction)):
        direction[i] = direction[i] / norm

    return direction

def add_direction_to_model(model: torch.nn.Module, direction: List[torch.Tensor], scale: float) -> None:
    """
    Add a scaled direction to the model parameters.

    Args:
        model: PyTorch model
        direction: Direction in parameter space
        scale: Scaling factor
    """
    for param, d in zip([p for p in model.parameters() if p.requires_grad], direction):
        param.data.add_(scale * d)

def compute_loss_landscape(
    model: torch.nn.Module,
    train_loader: DataLoader,
    loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    alphas: np.ndarray,
    betas: np.ndarray,
    device: torch.device
) -> Tuple[np.ndarray, List[torch.Tensor], List[torch.Tensor]]:
    """
    Compute the loss landscape by sampling in two random directions.

    Args:
        model: PyTorch model
        train_loader: DataLoader for training data
        loss_fn: Loss function
        alphas: List of scaling factors for first direction
        betas: List of scaling factors for second direction
        device: Device to use for computation

    Returns:
        Tuple containing loss landscape matrix and the two random directions
    """
    # Store original parameters
    original_params = copy.deepcopy(list(p.data for p in model.parameters() if p.requires_grad))

    # Sample two random directions
    direction1 = get_random_direction(model)
    direction2 = get_random_direction(model)

    # Get two batches from the train loader
    train_iter = iter(train_loader)
    try:
        batch1 = next(train_iter)
        batch2 = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        batch1 = next(train_iter)
        batch2 = next(train_iter)

    # Move batches to device
    inputs1, targets1 = batch1
    inputs2, targets2 = batch2
    inputs1, targets1 = inputs1.to(device), targets1.to(device)
    inputs2, targets2 = inputs2.to(device), targets2.to(device)

    # Compute loss landscape
    loss_landscape = np.zeros((len(alphas), len(betas)))

    for i, alpha in tqdm(enumerate(alphas), total=len(alphas), desc="Computing loss landscape"):
        for j, beta in enumerate(betas):
            # Reset model parameters to original values
            for param, orig in zip([p for p in model.parameters() if p.requires_grad], original_params):
                param.data.copy_(orig)

            # Add scaled directions to model parameters
            add_direction_to_model(model, direction1, alpha)
            add_direction_to_model(model, direction2, beta)

            # Compute loss
            model.train()
            with torch.no_grad():
                outputs1 = model(inputs1)
                loss1 = loss_fn(outputs1, targets1)

                outputs2 = model(inputs2)
                loss2 = loss_fn(outputs2, targets2)

                # Average loss over two batches
                loss = (loss1 + loss2) / 2

                loss_landscape[i, j] = loss.item()

    # Reset model parameters to original values
    for param, orig in zip([p for p in model.parameters() if p.requires_grad], original_params):
        param.data.copy_(orig)

    return loss_landscape, direction1, direction2

def plot_loss_landscape(
    loss_landscape: np.ndarray,
    alphas: np.ndarray,
    betas: np.ndarray,
    title: str = "Loss Landscape"
) -> plt.Figure:
    """
    Plot the loss landscape.

    Args:
        loss_landscape: Loss landscape matrix
        alphas: List of scaling factors for first direction
        betas: List of scaling factors for second direction
        title: Plot title

    Returns:
        Matplotlib figure
    """
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')

    alpha_grid, beta_grid = np.meshgrid(alphas, betas)

    # Plot the surface
    surf = ax.plot_surface(
        alpha_grid, beta_grid, loss_landscape.T,
        cmap=plt.cm.viridis,
        linewidth=0,
        antialiased=True
    )

    # Add labels and title
    ax.set_xlabel('Direction 1')
    ax.set_ylabel('Direction 2')
    # ax.set_zlabel('Loss')
    ax.set_title(title)

    # Remove zticks
    ax.set_zticks([])
    ax.set_zticklabels([])

    # Add a color bar
    cbar = fig.colorbar(surf, ax=ax, shrink=0.75, aspect=15)
    # cbar.set_label('Loss')
    cbar.ax.set_yticklabels(cbar.ax.get_yticklabels(), rotation=0)
    cbar.ax.yaxis.set_label_position('left')
    cbar.ax.yaxis.set_label_coords(-0.1, 0.5)
    cbar.ax.yaxis.set_label_text('Loss')

    return fig


In [None]:
"""
Example usage of the loss landscape computation.
"""
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss_fn   = torch.nn.CrossEntropyLoss()

limit = 1.0
samples = 10

# Define range of scaling factors
alphas = np.linspace(-limit, limit, samples)
betas = np.linspace(-limit, limit, samples)

# Replace with your actual train_loader
# train_loader = ...

train_loader, test_loader, in_dim = get_dataloaders()

# Compute loss landscape
gumbel_loss_landscape, gumbel_direction1, gumbel_direction2 = compute_loss_landscape(
    model=gumbel_results["model"],
    train_loader=train_loader,
    loss_fn=loss_fn,
    alphas=alphas,
    betas=betas,
    device=device
)

softmax_loss_landscape, softmax_direction1, softmax_direction2 = compute_loss_landscape(
    model=softmax_results["model"],
    train_loader=train_loader,
    loss_fn=loss_fn,
    alphas=alphas,
    betas=betas,
    device=device
)


In [None]:
# Plot loss landscape
gumbel_fig = plot_loss_landscape(
    loss_landscape=gumbel_loss_landscape,
    alphas=alphas,
    betas=betas,
    title="Gumbel Loss Landscape"
)
softmax_fig = plot_loss_landscape(
    loss_landscape=softmax_loss_landscape,
    alphas=alphas,
    betas=betas,
    title="Softmax Loss Landscape"
)
# Show the plot
plt.show()

As stated earlier, the above plots of the loss landscapes are not representative, as the models have not been trained until convergence.