Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use concept-erasure implementation of LEACE and SAL #252

Merged
merged 8 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ repos:
hooks:
- id: codespell
# The promptsource templates spuriously get flagged without this
args: ["-L fpr", "--skip=*.yaml"]
args: ["-L fpr,leace", "--skip=*.yaml"]
6 changes: 3 additions & 3 deletions elk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from .extraction import Extract, extract_hiddens
from .training import EigenReporter, EigenReporterConfig
from .training import EigenFitter, EigenFitterConfig
from .truncated_eigh import truncated_eigh

__all__ = [
"EigenReporter",
"EigenReporterConfig",
"EigenFitter",
"EigenFitterConfig",
"extract_hiddens",
"Extract",
"truncated_eigh",
Expand Down
4 changes: 1 addition & 3 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from ..files import elk_reporter_dir
from ..metrics import evaluate_preds
from ..run import Run
from ..training import Reporter
from ..utils import Color


Expand Down Expand Up @@ -40,8 +39,7 @@ def apply_to_layer(
experiment_dir = elk_reporter_dir() / self.source

reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt"
reporter = Reporter.load(reporter_path, map_location=device)
reporter.eval()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the eval() here still neded?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No because CcsReporter doesn't actually have any submodules like nn.BatchNorm or nn.Dropout whose behavior changes due to eval()

reporter = torch.load(reporter_path, map_location=device)

row_bufs = defaultdict(list)
for ds_name, (val_h, val_gt, _) in val_output.items():
Expand Down
19 changes: 9 additions & 10 deletions elk/training/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from .ccs_reporter import CcsReporter, CcsReporterConfig
from .ccs_reporter import CcsConfig, CcsReporter
from .classifier import Classifier
from .concept_eraser import ConceptEraser
from .eigen_reporter import EigenReporter, EigenReporterConfig
from .reporter import Reporter, ReporterConfig
from .common import FitterConfig
from .eigen_reporter import EigenFitter, EigenFitterConfig
from .platt_scaling import PlattMixin

__all__ = [
"CcsReporter",
"CcsReporterConfig",
"CcsConfig",
"Classifier",
"ConceptEraser",
"EigenReporter",
"EigenReporterConfig",
"Reporter",
"ReporterConfig",
"EigenFitter",
"EigenFitterConfig",
"FitterConfig",
"PlattMixin",
]
100 changes: 42 additions & 58 deletions elk/training/ccs_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,71 +3,61 @@
import math
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal, Optional, cast

import torch
import torch.nn as nn
from concept_erasure import LeaceFitter
from torch import Tensor

from ..parsing import parse_loss
from ..utils.typing import assert_type
from .concept_eraser import ConceptEraser
from .common import FitterConfig
from .losses import LOSSES
from .reporter import Reporter, ReporterConfig
from .platt_scaling import PlattMixin


@dataclass
class CcsReporterConfig(ReporterConfig):
"""
Args:
activation: The activation function to use. Defaults to GELU.
bias: Whether to use a bias term in the linear layers. Defaults to True.
hidden_size: The number of hidden units in the MLP. Defaults to None.
By default, use an MLP expansion ratio of 4/3. This ratio is used by
Tucker et al. (2022) <https://arxiv.org/abs/2204.09722> in their 3-layer
MLP probes. We could also use a ratio of 4, imitating transformer FFNs,
but this seems to lead to excessively large MLPs when num_layers > 2.
init: The initialization scheme to use. Defaults to "zero".
loss: The loss function to use. list of strings, each of the form
"coef*name", where coef is a float and name is one of the keys in
`elk.training.losses.LOSSES`.
Example: --loss 1.0*consistency_squared 0.5*prompt_var
corresponds to the loss function 1.0*consistency_squared + 0.5*prompt_var.
Defaults to the loss "ccs_squared_loss".
normalization: The kind of normalization to apply to the hidden states.
num_layers: The number of layers in the MLP. Defaults to 1.
pre_ln: Whether to include a LayerNorm module before the first linear
layer. Defaults to False.
supervised_weight: The weight of the supervised loss. Defaults to 0.0.

lr: The learning rate to use. Ignored when `optimizer` is `"lbfgs"`.
Defaults to 1e-2.
num_epochs: The number of epochs to train for. Defaults to 1000.
num_tries: The number of times to try training the reporter. Defaults to 10.
optimizer: The optimizer to use. Defaults to "adam".
weight_decay: The weight decay or L2 penalty to use. Defaults to 0.01.
"""

class CcsConfig(FitterConfig):
activation: Literal["gelu", "relu", "swish"] = "gelu"
"""The activation function to use."""
bias: bool = True
"""Whether to use a bias term in the linear layers."""
hidden_size: Optional[int] = None
"""
The number of hidden units in the MLP. Defaults to None. By default, use an MLP
expansion ratio of 4/3. This ratio is used by Tucker et al. (2022)
<https://arxiv.org/abs/2204.09722> in their 3-layer MLP probes. We could also use
a ratio of 4, imitating transformer FFNs, but this seems to lead to excessively
large MLPs when num_layers > 2.
"""
init: Literal["default", "pca", "spherical", "zero"] = "default"
"""The initialization scheme to use."""
loss: list[str] = field(default_factory=lambda: ["ccs"])
"""
The loss function to use. list of strings, each of the form "coef*name", where coef
is a float and name is one of the keys in `elk.training.losses.LOSSES`.
Example: `--loss 1.0*consistency_squared 0.5*prompt_var` corresponds to the loss
function 1.0*consistency_squared + 0.5*prompt_var.
"""
loss_dict: dict[str, float] = field(default_factory=dict, init=False)
num_layers: int = 1
"""The number of layers in the MLP."""
pre_ln: bool = False
"""Whether to include a LayerNorm module before the first linear layer."""
supervised_weight: float = 0.0
"""The weight of the supervised loss."""

lr: float = 1e-2
"""The learning rate to use. Ignored when `optimizer` is `"lbfgs"`."""
num_epochs: int = 1000
"""The number of epochs to train for."""
num_tries: int = 10
"""The number of times to try training the reporter."""
optimizer: Literal["adam", "lbfgs"] = "lbfgs"
"""The optimizer to use."""
weight_decay: float = 0.01

@classmethod
def reporter_class(cls) -> type[Reporter]:
return CcsReporter
"""The weight decay or L2 penalty to use."""

def __post_init__(self):
self.loss_dict = parse_loss(self.loss)
Expand All @@ -76,19 +66,19 @@ def __post_init__(self):
self.loss = [f"{coef}*{name}" for name, coef in self.loss_dict.items()]


class CcsReporter(Reporter):
class CcsReporter(nn.Module, PlattMixin):
"""CCS reporter network.

Args:
in_features: The number of input features.
cfg: The reporter configuration.
"""

config: CcsReporterConfig
config: CcsConfig

def __init__(
self,
cfg: CcsReporterConfig,
cfg: CcsConfig,
in_features: int,
*,
device: str | torch.device | None = None,
Expand All @@ -106,12 +96,7 @@ def __init__(

hidden_size = cfg.hidden_size or 4 * in_features // 3

self.norm = ConceptEraser(
in_features,
2 * num_variants,
device=device,
dtype=dtype,
)
self.norm = None
self.probe = nn.Sequential(
nn.Linear(
in_features,
Expand Down Expand Up @@ -175,6 +160,8 @@ def reset_parameters(self):

def forward(self, x: Tensor) -> Tensor:
"""Return the credence assigned to the hidden state `x`."""
assert self.norm is not None, "Must call fit() before forward()"

raw_scores = self.probe(self.norm(x)).squeeze(-1)
return raw_scores.mul(self.scale).add(self.bias).squeeze(-1)

Expand Down Expand Up @@ -203,19 +190,22 @@ def fit(self, hiddens: Tensor) -> float:
x_neg, x_pos = hiddens.unbind(2)

# One-hot indicators for each prompt template
n, v, _ = x_neg.shape
n, v, d = x_neg.shape
prompt_ids = torch.eye(v, device=x_neg.device).expand(n, -1, -1)

self.norm.update(
fitter = LeaceFitter(d, 2 * v, dtype=x_neg.dtype, device=x_neg.device)
fitter.update(
x=x_neg,
# Independent indicator for each (template, pseudo-label) pair
y=torch.cat([torch.zeros_like(prompt_ids), prompt_ids], dim=-1),
z=torch.cat([torch.zeros_like(prompt_ids), prompt_ids], dim=-1),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed; replaced y with z for ccs

)
self.norm.update(
fitter.update(
x=x_pos,
# Independent indicator for each (template, pseudo-label) pair
y=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1),
z=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1),
)
self.norm = fitter.eraser

x_neg, x_pos = self.norm(x_neg), self.norm(x_pos)

# Record the best acc, loss, and params found so far
Expand Down Expand Up @@ -299,9 +289,3 @@ def closure():

optimizer.step(closure)
return float(loss)

def save(self, path: Path | str) -> None:
"""Save the reporter to a file."""
state = {k: v.cpu() for k, v in self.state_dict().items()}
state.update(in_features=self.in_features, num_variants=self.num_variants)
torch.save(state, path)
31 changes: 31 additions & 0 deletions elk/training/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""An ELK reporter network."""

from dataclasses import dataclass

from concept_erasure import LeaceEraser
from simple_parsing.helpers import Serializable
from torch import Tensor, nn

from .platt_scaling import PlattMixin


@dataclass
class FitterConfig(Serializable, decode_into_subclasses=True):
seed: int = 42
"""The random seed to use."""


@dataclass
class Reporter(PlattMixin):
weight: Tensor
eraser: LeaceEraser

def __post_init__(self):
# Platt scaling parameters
self.bias = nn.Parameter(self.weight.new_zeros(1))
self.scale = nn.Parameter(self.weight.new_ones(1))

def __call__(self, hiddens: Tensor) -> Tensor:
"""Return the predicted log odds on input `x`."""
raw_scores = self.eraser(hiddens) @ self.weight.mT
return raw_scores.mul(self.scale).add(self.bias).squeeze(-1)
113 changes: 0 additions & 113 deletions elk/training/concept_eraser.py

This file was deleted.

Loading