In [1]:
import numpy as np
import torch

# Import L0 modules

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cpu


In [None]:
import numpy as np
import torch

# Import L0 modules

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Visualize structured vs unstructured sparsity patterns
def visualize_conv_sparsity(model, layer_name='conv2'):
    """Visualize sparsity pattern in a conv layer."""
    for name, module in model.named_modules():
        if name == layer_name and isinstance(module, L0Conv2d):
            if hasattr(module, 'channel_gates'):
                # Structured sparsity - show channel-wise gates
                gates = module.channel_gates.get_gates().detach().cpu().numpy()
                plt.bar(range(len(gates)), gates)
                plt.xlabel('Channel Index')
                plt.ylabel('Gate Value')
                plt.title(f'{layer_name} Channel Gates (Structured)')
            else:
                # Unstructured sparsity - show weight-level sparsity as heatmap
                gates = module.weight_gates.get_gates().detach().cpu()
                # Reshape to show per-channel sparsity
                n_out = module.out_channels
                n_in = module.in_channels
                kernel_size = module.kernel_size[0] * module.kernel_size[1]
                gates_reshaped = gates.view(n_out, n_in * kernel_size)
                
                # Show average sparsity per output channel
                channel_sparsity = (gates_reshaped > 0.5).float().mean(dim=1).numpy()
                plt.bar(range(n_out), channel_sparsity)
                plt.xlabel('Output Channel Index')
                plt.ylabel('Fraction Active')
                plt.title(f'{layer_name} Weight Sparsity (Unstructured)')
            plt.show()
            break

# Create example CNNs with structured and unstructured L0
class StructuredCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = L0Conv2d(3, 32, 3, structured=True, init_sparsity=0.3)
        self.conv2 = L0Conv2d(32, 64, 3, structured=True, init_sparsity=0.5)
        self.fc = L0Linear(64 * 6 * 6, 10, init_sparsity=0.7)
    
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(x.size(0), -1)
        return self.fc(x)

class UnstructuredCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = L0Conv2d(3, 32, 3, structured=False, init_sparsity=0.3)
        self.conv2 = L0Conv2d(32, 64, 3, structured=False, init_sparsity=0.5)
        self.fc = L0Linear(64 * 6 * 6, 10, init_sparsity=0.7)
    
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(x.size(0), -1)
        return self.fc(x)

# Create models
structured_cnn = StructuredCNN().to(device)
unstructured_cnn = UnstructuredCNN().to(device)

print("Structured Sparsity Pattern:")
visualize_conv_sparsity(structured_cnn, 'conv2')

print("\nUnstructured Sparsity Pattern:")
visualize_conv_sparsity(unstructured_cnn, 'conv2')