Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .github/bump_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ def get_current_version(pyproject_path: Path) -> str:

def infer_bump(changelog_dir: Path) -> str:
fragments = [
f
for f in changelog_dir.iterdir()
if f.is_file() and f.name != ".gitkeep"
f for f in changelog_dir.iterdir() if f.is_file() and f.name != ".gitkeep"
]
if not fragments:
print("No changelog fragments found", file=sys.stderr)
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ help:
@echo " make install Install package in editable mode"
@echo " make install-dev Install with development dependencies"
@echo " make test Run tests with coverage"
@echo " make format Format code with black"
@echo " make format Format code with ruff"
@echo " make type-check Run mypy type checker"
@echo " make changelog Update changelog and version"
@echo " make clean Remove build artifacts"
Expand All @@ -26,7 +26,7 @@ test:
pytest tests/ -v --cov=l0 --cov-report=term-missing

format:
black . -l 79
ruff format .

type-check:
mypy l0 --ignore-missing-imports
Expand Down
1 change: 1 addition & 0 deletions changelog.d/switch-to-ruff.changed
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Switched code formatter from Black to Ruff.
4 changes: 2 additions & 2 deletions docs/examples/advanced.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
"np.random.seed(42)\n",
"\n",
"# Set device\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(f'Using device: {device}')"
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Using device: {device}\")"
]
},
{
Expand Down
93 changes: 52 additions & 41 deletions docs/examples/basic_l0.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
"np.random.seed(42)\n",
"\n",
"# Set device\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(f'Using device: {device}')"
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Using device: {device}\")"
]
},
{
Expand Down Expand Up @@ -74,10 +74,7 @@
" def __init__(self, n_features, init_sparsity=0.9):\n",
" super().__init__()\n",
" self.l0_linear = L0Linear(\n",
" n_features, 1,\n",
" bias=False,\n",
" init_sparsity=init_sparsity,\n",
" temperature=0.5\n",
" n_features, 1, bias=False, init_sparsity=init_sparsity, temperature=0.5\n",
" )\n",
"\n",
" def forward(self, x):\n",
Expand All @@ -89,6 +86,7 @@
" def get_sparsity(self):\n",
" return self.l0_linear.get_sparsity()\n",
"\n",
"\n",
"model = L0LinearRegression(n_features, init_sparsity=0.5).to(device)\n",
"print(f\"Initial sparsity: {model.get_sparsity():.2%}\")"
]
Expand All @@ -110,13 +108,11 @@
"# Create a simple linear regression model with L0 regularization\n",
"class L0LinearRegression(nn.Module):\n",
" \"\"\"Linear regression with L0 regularization.\"\"\"\n",
"\n",
" def __init__(self, n_features, init_sparsity=0.9):\n",
" super().__init__()\n",
" self.l0_linear = L0Linear(\n",
" n_features, 1,\n",
" bias=False,\n",
" init_sparsity=init_sparsity,\n",
" temperature=0.5\n",
" n_features, 1, bias=False, init_sparsity=init_sparsity, temperature=0.5\n",
" )\n",
"\n",
" def forward(self, x):\n",
Expand Down Expand Up @@ -204,25 +200,29 @@
"fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n",
"\n",
"ax1.plot(losses)\n",
"ax1.set_xlabel('Epoch')\n",
"ax1.set_ylabel('Total Loss')\n",
"ax1.set_title('Training Loss')\n",
"ax1.set_xlabel(\"Epoch\")\n",
"ax1.set_ylabel(\"Total Loss\")\n",
"ax1.set_title(\"Training Loss\")\n",
"ax1.grid(True)\n",
"\n",
"ax2.plot(sparsities)\n",
"ax2.set_xlabel('Epoch')\n",
"ax2.set_ylabel('Sparsity')\n",
"ax2.set_title('Learned Sparsity')\n",
"ax2.set_xlabel(\"Epoch\")\n",
"ax2.set_ylabel(\"Sparsity\")\n",
"ax2.set_title(\"Learned Sparsity\")\n",
"ax2.grid(True)\n",
"ax2.axhline(y=(n_features-n_informative)/n_features, color='r', linestyle='--',\n",
" label=f'True sparsity: {(n_features-n_informative)/n_features:.1%}')\n",
"ax2.axhline(\n",
" y=(n_features - n_informative) / n_features,\n",
" color=\"r\",\n",
" linestyle=\"--\",\n",
" label=f\"True sparsity: {(n_features - n_informative) / n_features:.1%}\",\n",
")\n",
"ax2.legend()\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"print(f\"Final sparsity: {model.get_sparsity():.2%}\")\n",
"print(f\"True sparsity: {(n_features-n_informative)/n_features:.1%}\")"
"print(f\"True sparsity: {(n_features - n_informative) / n_features:.1%}\")"
]
},
{
Expand Down Expand Up @@ -319,44 +319,55 @@
" n_active = int((1 - final_sparsity) * n_features)\n",
" final_mse = F.mse_loss(model(X).squeeze(), y).item()\n",
"\n",
" results.append({\n",
" 'l0_lambda': l0_lambda,\n",
" 'sparsity': final_sparsity,\n",
" 'n_active': n_active,\n",
" 'mse': final_mse\n",
" })\n",
" results.append(\n",
" {\n",
" \"l0_lambda\": l0_lambda,\n",
" \"sparsity\": final_sparsity,\n",
" \"n_active\": n_active,\n",
" \"mse\": final_mse,\n",
" }\n",
" )\n",
"\n",
" print(f\"L0 penalty={l0_lambda:8.1e}: {n_active:2d} active features, \"\n",
" f\"sparsity={final_sparsity:.1%}, MSE={final_mse:.4f}\")\n",
" print(\n",
" f\"L0 penalty={l0_lambda:8.1e}: {n_active:2d} active features, \"\n",
" f\"sparsity={final_sparsity:.1%}, MSE={final_mse:.4f}\"\n",
" )\n",
"\n",
"# Visualize the relationship\n",
"fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n",
"\n",
"# Plot 1: Number of active features vs L0 penalty\n",
"l0_vals = [r['l0_lambda'] for r in results[1:]] # Skip 0\n",
"n_active_vals = [r['n_active'] for r in results[1:]]\n",
"ax1.semilogx(l0_vals, n_active_vals, 'o-', linewidth=2, markersize=8)\n",
"ax1.axhline(y=n_true_features, color='r', linestyle='--', label=f'True features ({n_true_features})')\n",
"ax1.set_xlabel('L0 Penalty (λ)')\n",
"ax1.set_ylabel('Number of Active Features')\n",
"ax1.set_title('L0 Penalty Controls Sparsity')\n",
"l0_vals = [r[\"l0_lambda\"] for r in results[1:]] # Skip 0\n",
"n_active_vals = [r[\"n_active\"] for r in results[1:]]\n",
"ax1.semilogx(l0_vals, n_active_vals, \"o-\", linewidth=2, markersize=8)\n",
"ax1.axhline(\n",
" y=n_true_features,\n",
" color=\"r\",\n",
" linestyle=\"--\",\n",
" label=f\"True features ({n_true_features})\",\n",
")\n",
"ax1.set_xlabel(\"L0 Penalty (λ)\")\n",
"ax1.set_ylabel(\"Number of Active Features\")\n",
"ax1.set_title(\"L0 Penalty Controls Sparsity\")\n",
"ax1.legend()\n",
"ax1.grid(True, alpha=0.3)\n",
"\n",
"# Plot 2: Trade-off between sparsity and accuracy\n",
"sparsity_vals = [r['sparsity'] for r in results]\n",
"mse_vals = [r['mse'] for r in results]\n",
"ax2.plot([100*s for s in sparsity_vals], mse_vals, 'o-', linewidth=2, markersize=8)\n",
"ax2.set_xlabel('Sparsity (%)')\n",
"ax2.set_ylabel('MSE Loss')\n",
"ax2.set_title('Sparsity vs Accuracy Trade-off')\n",
"sparsity_vals = [r[\"sparsity\"] for r in results]\n",
"mse_vals = [r[\"mse\"] for r in results]\n",
"ax2.plot([100 * s for s in sparsity_vals], mse_vals, \"o-\", linewidth=2, markersize=8)\n",
"ax2.set_xlabel(\"Sparsity (%)\")\n",
"ax2.set_ylabel(\"MSE Loss\")\n",
"ax2.set_title(\"Sparsity vs Accuracy Trade-off\")\n",
"ax2.grid(True, alpha=0.3)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"print(\"\\nKey insight: Higher L0 penalty → Fewer active features → Higher sparsity\")\n",
"print(f\"Ground truth had {n_true_features} important features out of {n_features} total\")"
"print(\n",
" f\"Ground truth had {n_true_features} important features out of {n_features} total\"\n",
")"
]
},
{
Expand Down
55 changes: 28 additions & 27 deletions docs/examples/comparison.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
"np.random.seed(42)\n",
"\n",
"# Set device\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(f'Using device: {device}')"
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Using device: {device}\")"
]
},
{
Expand Down Expand Up @@ -75,13 +75,11 @@
"# Create models for comparison\n",
"class L0LinearRegression(nn.Module):\n",
" \"\"\"Linear regression with L0 regularization.\"\"\"\n",
"\n",
" def __init__(self, n_features, init_sparsity=0.9):\n",
" super().__init__()\n",
" self.l0_linear = L0Linear(\n",
" n_features, 1,\n",
" bias=False,\n",
" init_sparsity=init_sparsity,\n",
" temperature=0.5\n",
" n_features, 1, bias=False, init_sparsity=init_sparsity, temperature=0.5\n",
" )\n",
"\n",
" def forward(self, x):\n",
Expand All @@ -93,8 +91,10 @@
" def get_sparsity(self):\n",
" return self.l0_linear.get_sparsity()\n",
"\n",
"\n",
"class L1Linear(nn.Module):\n",
" \"\"\"Standard linear layer with L1 regularization.\"\"\"\n",
"\n",
" def __init__(self, in_features, out_features):\n",
" super().__init__()\n",
" self.linear = nn.Linear(in_features, out_features, bias=False)\n",
Expand All @@ -108,6 +108,7 @@
" def get_sparsity(self, threshold=1e-3):\n",
" return (self.linear.weight.abs() < threshold).float().mean().item()\n",
"\n",
"\n",
"# Train both models on the same data\n",
"n_features = 100\n",
"X_train = torch.randn(500, n_features)\n",
Expand Down Expand Up @@ -195,33 +196,33 @@
"l1_weights = l1_model.linear.weight.squeeze().detach().cpu().numpy()\n",
"\n",
"# Plot weight distributions\n",
"axes[0, 0].hist(l0_effective, bins=50, edgecolor='black')\n",
"axes[0, 0].set_title('L0: Effective Weight Distribution')\n",
"axes[0, 0].set_xlabel('Weight Value')\n",
"axes[0, 0].set_ylabel('Count')\n",
"axes[0, 0].axvline(x=0, color='r', linestyle='--')\n",
"\n",
"axes[0, 1].hist(l1_weights, bins=50, edgecolor='black')\n",
"axes[0, 1].set_title('L1: Weight Distribution')\n",
"axes[0, 1].set_xlabel('Weight Value')\n",
"axes[0, 1].set_ylabel('Count')\n",
"axes[0, 1].axvline(x=0, color='r', linestyle='--')\n",
"axes[0, 0].hist(l0_effective, bins=50, edgecolor=\"black\")\n",
"axes[0, 0].set_title(\"L0: Effective Weight Distribution\")\n",
"axes[0, 0].set_xlabel(\"Weight Value\")\n",
"axes[0, 0].set_ylabel(\"Count\")\n",
"axes[0, 0].axvline(x=0, color=\"r\", linestyle=\"--\")\n",
"\n",
"axes[0, 1].hist(l1_weights, bins=50, edgecolor=\"black\")\n",
"axes[0, 1].set_title(\"L1: Weight Distribution\")\n",
"axes[0, 1].set_xlabel(\"Weight Value\")\n",
"axes[0, 1].set_ylabel(\"Count\")\n",
"axes[0, 1].axvline(x=0, color=\"r\", linestyle=\"--\")\n",
"\n",
"# Plot sparsity evolution\n",
"axes[1, 0].plot(l0_sparsities, label='L0', linewidth=2)\n",
"axes[1, 0].plot(l1_sparsities, label='L1', linewidth=2)\n",
"axes[1, 0].set_xlabel('Epoch')\n",
"axes[1, 0].set_ylabel('Sparsity')\n",
"axes[1, 0].set_title('Sparsity Evolution')\n",
"axes[1, 0].plot(l0_sparsities, label=\"L0\", linewidth=2)\n",
"axes[1, 0].plot(l1_sparsities, label=\"L1\", linewidth=2)\n",
"axes[1, 0].set_xlabel(\"Epoch\")\n",
"axes[1, 0].set_ylabel(\"Sparsity\")\n",
"axes[1, 0].set_title(\"Sparsity Evolution\")\n",
"axes[1, 0].legend()\n",
"axes[1, 0].grid(True)\n",
"\n",
"# Plot gate distribution for L0\n",
"axes[1, 1].hist(l0_gates, bins=50, edgecolor='black')\n",
"axes[1, 1].set_title('L0: Gate Value Distribution')\n",
"axes[1, 1].set_xlabel('Gate Value')\n",
"axes[1, 1].set_ylabel('Count')\n",
"axes[1, 1].axvline(x=0.5, color='r', linestyle='--', label='Threshold')\n",
"axes[1, 1].hist(l0_gates, bins=50, edgecolor=\"black\")\n",
"axes[1, 1].set_title(\"L0: Gate Value Distribution\")\n",
"axes[1, 1].set_xlabel(\"Gate Value\")\n",
"axes[1, 1].set_ylabel(\"Count\")\n",
"axes[1, 1].axvline(x=0.5, color=\"r\", linestyle=\"--\", label=\"Threshold\")\n",
"axes[1, 1].legend()\n",
"\n",
"plt.tight_layout()\n",
Expand Down
36 changes: 22 additions & 14 deletions docs/examples/feature_selection.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
"np.random.seed(42)\n",
"\n",
"# Set device\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(f'Using device: {device}')"
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Using device: {device}\")"
]
},
{
Expand Down Expand Up @@ -102,9 +102,10 @@
"feature_gate = FeatureGate(\n",
" n_features=n_features,\n",
" max_features=50, # Select at most 50 features\n",
" temperature=0.2\n",
" temperature=0.2,\n",
").to(device)\n",
"\n",
"\n",
"# Simple classifier on top of selected features\n",
"class FeatureSelectClassifier(nn.Module):\n",
" def __init__(self, feature_gate):\n",
Expand All @@ -118,6 +119,7 @@
" x_gated = x * gates\n",
" return self.classifier(x_gated).squeeze()\n",
"\n",
"\n",
"model = FeatureSelectClassifier(feature_gate).to(device)\n",
"optimizer = optim.Adam(model.parameters(), lr=0.01)\n",
"\n",
Expand All @@ -136,7 +138,9 @@
"\n",
" if (epoch + 1) % 25 == 0:\n",
" n_selected = len(feature_gate.get_active_indices())\n",
" print(f\"Epoch {epoch+1}: Loss={total_loss:.4f}, Selected features={n_selected}\")"
" print(\n",
" f\"Epoch {epoch + 1}: Loss={total_loss:.4f}, Selected features={n_selected}\"\n",
" )"
]
},
{
Expand Down Expand Up @@ -179,26 +183,30 @@
"top_30_indices = torch.topk(importance, 30).indices.cpu().numpy()\n",
"\n",
"print(\"\\nTop 30 selected features:\")\n",
"print(f\"True informative features in top 30: {sum(i < n_informative for i in top_30_indices)}/{n_informative}\")\n",
"print(\n",
" f\"True informative features in top 30: {sum(i < n_informative for i in top_30_indices)}/{n_informative}\"\n",
")\n",
"\n",
"# Visualize feature importance\n",
"plt.figure(figsize=(14, 5))\n",
"\n",
"plt.subplot(1, 2, 1)\n",
"plt.bar(range(100), importance[:100].cpu().numpy())\n",
"plt.axvline(x=n_informative-0.5, color='r', linestyle='--', label='True/Noise boundary')\n",
"plt.xlabel('Feature Index')\n",
"plt.ylabel('Importance (Gate Value)')\n",
"plt.title('Feature Importance (First 100 features)')\n",
"plt.axvline(\n",
" x=n_informative - 0.5, color=\"r\", linestyle=\"--\", label=\"True/Noise boundary\"\n",
")\n",
"plt.xlabel(\"Feature Index\")\n",
"plt.ylabel(\"Importance (Gate Value)\")\n",
"plt.title(\"Feature Importance (First 100 features)\")\n",
"plt.legend()\n",
"\n",
"plt.subplot(1, 2, 2)\n",
"# Show importance distribution\n",
"plt.hist(importance.cpu().numpy(), bins=50, edgecolor='black')\n",
"plt.xlabel('Importance Value')\n",
"plt.ylabel('Count')\n",
"plt.title('Distribution of Feature Importance')\n",
"plt.axvline(x=0.5, color='r', linestyle='--', label='Selection threshold')\n",
"plt.hist(importance.cpu().numpy(), bins=50, edgecolor=\"black\")\n",
"plt.xlabel(\"Importance Value\")\n",
"plt.ylabel(\"Count\")\n",
"plt.title(\"Distribution of Feature Importance\")\n",
"plt.axvline(x=0.5, color=\"r\", linestyle=\"--\", label=\"Selection threshold\")\n",
"plt.legend()\n",
"\n",
"plt.tight_layout()\n",
Expand Down
Loading