-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use concept-erasure implementation of LEACE and SAL #252
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
dc2cc49
Use concept-erasure implementation of LEACE and SAL
norabelrose 0a70094
fix parameter name in ccs
lauritowal 280343c
Fix test failures
norabelrose 703844c
Merge branch 'leace' of github.com:EleutherAI/elk into leace
norabelrose fac6247
Be picky about the concept-erasure version
norabelrose 0f6f120
Merge remote-tracking branch 'origin/main' into leace
norabelrose 0f8d0a1
Refactor to support concept-erasure v0.1
norabelrose 3db2cc8
Fix test failure
norabelrose File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,15 @@ | ||
from .ccs_reporter import CcsReporter, CcsReporterConfig | ||
from .ccs_reporter import CcsConfig, CcsReporter | ||
from .classifier import Classifier | ||
from .concept_eraser import ConceptEraser | ||
from .eigen_reporter import EigenReporter, EigenReporterConfig | ||
from .reporter import Reporter, ReporterConfig | ||
from .common import FitterConfig | ||
from .eigen_reporter import EigenFitter, EigenFitterConfig | ||
from .platt_scaling import PlattMixin | ||
|
||
__all__ = [ | ||
"CcsReporter", | ||
"CcsReporterConfig", | ||
"CcsConfig", | ||
"Classifier", | ||
"ConceptEraser", | ||
"EigenReporter", | ||
"EigenReporterConfig", | ||
"Reporter", | ||
"ReporterConfig", | ||
"EigenFitter", | ||
"EigenFitterConfig", | ||
"FitterConfig", | ||
"PlattMixin", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,71 +3,61 @@ | |
import math | ||
from copy import deepcopy | ||
from dataclasses import dataclass, field | ||
from pathlib import Path | ||
from typing import Literal, Optional, cast | ||
|
||
import torch | ||
import torch.nn as nn | ||
from concept_erasure import LeaceFitter | ||
from torch import Tensor | ||
|
||
from ..parsing import parse_loss | ||
from ..utils.typing import assert_type | ||
from .concept_eraser import ConceptEraser | ||
from .common import FitterConfig | ||
from .losses import LOSSES | ||
from .reporter import Reporter, ReporterConfig | ||
from .platt_scaling import PlattMixin | ||
|
||
|
||
@dataclass | ||
class CcsReporterConfig(ReporterConfig): | ||
""" | ||
Args: | ||
activation: The activation function to use. Defaults to GELU. | ||
bias: Whether to use a bias term in the linear layers. Defaults to True. | ||
hidden_size: The number of hidden units in the MLP. Defaults to None. | ||
By default, use an MLP expansion ratio of 4/3. This ratio is used by | ||
Tucker et al. (2022) <https://arxiv.org/abs/2204.09722> in their 3-layer | ||
MLP probes. We could also use a ratio of 4, imitating transformer FFNs, | ||
but this seems to lead to excessively large MLPs when num_layers > 2. | ||
init: The initialization scheme to use. Defaults to "zero". | ||
loss: The loss function to use. list of strings, each of the form | ||
"coef*name", where coef is a float and name is one of the keys in | ||
`elk.training.losses.LOSSES`. | ||
Example: --loss 1.0*consistency_squared 0.5*prompt_var | ||
corresponds to the loss function 1.0*consistency_squared + 0.5*prompt_var. | ||
Defaults to the loss "ccs_squared_loss". | ||
normalization: The kind of normalization to apply to the hidden states. | ||
num_layers: The number of layers in the MLP. Defaults to 1. | ||
pre_ln: Whether to include a LayerNorm module before the first linear | ||
layer. Defaults to False. | ||
supervised_weight: The weight of the supervised loss. Defaults to 0.0. | ||
|
||
lr: The learning rate to use. Ignored when `optimizer` is `"lbfgs"`. | ||
Defaults to 1e-2. | ||
num_epochs: The number of epochs to train for. Defaults to 1000. | ||
num_tries: The number of times to try training the reporter. Defaults to 10. | ||
optimizer: The optimizer to use. Defaults to "adam". | ||
weight_decay: The weight decay or L2 penalty to use. Defaults to 0.01. | ||
""" | ||
|
||
class CcsConfig(FitterConfig): | ||
activation: Literal["gelu", "relu", "swish"] = "gelu" | ||
"""The activation function to use.""" | ||
bias: bool = True | ||
"""Whether to use a bias term in the linear layers.""" | ||
hidden_size: Optional[int] = None | ||
""" | ||
The number of hidden units in the MLP. Defaults to None. By default, use an MLP | ||
expansion ratio of 4/3. This ratio is used by Tucker et al. (2022) | ||
<https://arxiv.org/abs/2204.09722> in their 3-layer MLP probes. We could also use | ||
a ratio of 4, imitating transformer FFNs, but this seems to lead to excessively | ||
large MLPs when num_layers > 2. | ||
""" | ||
init: Literal["default", "pca", "spherical", "zero"] = "default" | ||
"""The initialization scheme to use.""" | ||
loss: list[str] = field(default_factory=lambda: ["ccs"]) | ||
""" | ||
The loss function to use. list of strings, each of the form "coef*name", where coef | ||
is a float and name is one of the keys in `elk.training.losses.LOSSES`. | ||
Example: `--loss 1.0*consistency_squared 0.5*prompt_var` corresponds to the loss | ||
function 1.0*consistency_squared + 0.5*prompt_var. | ||
""" | ||
loss_dict: dict[str, float] = field(default_factory=dict, init=False) | ||
num_layers: int = 1 | ||
"""The number of layers in the MLP.""" | ||
pre_ln: bool = False | ||
"""Whether to include a LayerNorm module before the first linear layer.""" | ||
supervised_weight: float = 0.0 | ||
"""The weight of the supervised loss.""" | ||
|
||
lr: float = 1e-2 | ||
"""The learning rate to use. Ignored when `optimizer` is `"lbfgs"`.""" | ||
num_epochs: int = 1000 | ||
"""The number of epochs to train for.""" | ||
num_tries: int = 10 | ||
"""The number of times to try training the reporter.""" | ||
optimizer: Literal["adam", "lbfgs"] = "lbfgs" | ||
"""The optimizer to use.""" | ||
weight_decay: float = 0.01 | ||
|
||
@classmethod | ||
def reporter_class(cls) -> type[Reporter]: | ||
return CcsReporter | ||
"""The weight decay or L2 penalty to use.""" | ||
|
||
def __post_init__(self): | ||
self.loss_dict = parse_loss(self.loss) | ||
|
@@ -76,19 +66,19 @@ def __post_init__(self): | |
self.loss = [f"{coef}*{name}" for name, coef in self.loss_dict.items()] | ||
|
||
|
||
class CcsReporter(Reporter): | ||
class CcsReporter(nn.Module, PlattMixin): | ||
"""CCS reporter network. | ||
|
||
Args: | ||
in_features: The number of input features. | ||
cfg: The reporter configuration. | ||
""" | ||
|
||
config: CcsReporterConfig | ||
config: CcsConfig | ||
|
||
def __init__( | ||
self, | ||
cfg: CcsReporterConfig, | ||
cfg: CcsConfig, | ||
in_features: int, | ||
*, | ||
device: str | torch.device | None = None, | ||
|
@@ -106,12 +96,7 @@ def __init__( | |
|
||
hidden_size = cfg.hidden_size or 4 * in_features // 3 | ||
|
||
self.norm = ConceptEraser( | ||
in_features, | ||
2 * num_variants, | ||
device=device, | ||
dtype=dtype, | ||
) | ||
self.norm = None | ||
self.probe = nn.Sequential( | ||
nn.Linear( | ||
in_features, | ||
|
@@ -175,6 +160,8 @@ def reset_parameters(self): | |
|
||
def forward(self, x: Tensor) -> Tensor: | ||
"""Return the credence assigned to the hidden state `x`.""" | ||
assert self.norm is not None, "Must call fit() before forward()" | ||
|
||
raw_scores = self.probe(self.norm(x)).squeeze(-1) | ||
return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) | ||
|
||
|
@@ -203,19 +190,22 @@ def fit(self, hiddens: Tensor) -> float: | |
x_neg, x_pos = hiddens.unbind(2) | ||
|
||
# One-hot indicators for each prompt template | ||
n, v, _ = x_neg.shape | ||
n, v, d = x_neg.shape | ||
prompt_ids = torch.eye(v, device=x_neg.device).expand(n, -1, -1) | ||
|
||
self.norm.update( | ||
fitter = LeaceFitter(d, 2 * v, dtype=x_neg.dtype, device=x_neg.device) | ||
fitter.update( | ||
x=x_neg, | ||
# Independent indicator for each (template, pseudo-label) pair | ||
y=torch.cat([torch.zeros_like(prompt_ids), prompt_ids], dim=-1), | ||
z=torch.cat([torch.zeros_like(prompt_ids), prompt_ids], dim=-1), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed; replaced y with z for ccs |
||
) | ||
self.norm.update( | ||
fitter.update( | ||
x=x_pos, | ||
# Independent indicator for each (template, pseudo-label) pair | ||
y=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1), | ||
z=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1), | ||
) | ||
self.norm = fitter.eraser | ||
|
||
x_neg, x_pos = self.norm(x_neg), self.norm(x_pos) | ||
|
||
# Record the best acc, loss, and params found so far | ||
|
@@ -299,9 +289,3 @@ def closure(): | |
|
||
optimizer.step(closure) | ||
return float(loss) | ||
|
||
def save(self, path: Path | str) -> None: | ||
"""Save the reporter to a file.""" | ||
state = {k: v.cpu() for k, v in self.state_dict().items()} | ||
state.update(in_features=self.in_features, num_variants=self.num_variants) | ||
torch.save(state, path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
"""An ELK reporter network.""" | ||
|
||
from dataclasses import dataclass | ||
|
||
from concept_erasure import LeaceEraser | ||
from simple_parsing.helpers import Serializable | ||
from torch import Tensor, nn | ||
|
||
from .platt_scaling import PlattMixin | ||
|
||
|
||
@dataclass | ||
class FitterConfig(Serializable, decode_into_subclasses=True): | ||
seed: int = 42 | ||
"""The random seed to use.""" | ||
|
||
|
||
@dataclass | ||
class Reporter(PlattMixin): | ||
weight: Tensor | ||
eraser: LeaceEraser | ||
|
||
def __post_init__(self): | ||
# Platt scaling parameters | ||
self.bias = nn.Parameter(self.weight.new_zeros(1)) | ||
self.scale = nn.Parameter(self.weight.new_ones(1)) | ||
|
||
def __call__(self, hiddens: Tensor) -> Tensor: | ||
"""Return the predicted log odds on input `x`.""" | ||
raw_scores = self.eraser(hiddens) @ self.weight.mT | ||
return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't the eval() here still neded?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No because
CcsReporter
doesn't actually have any submodules likenn.BatchNorm
ornn.Dropout
whose behavior changes due toeval()