Skip to content

Commit

Permalink
Address misc comments in PR #260
Browse files Browse the repository at this point in the history
  • Loading branch information
danbraunai-apollo committed Jan 18, 2024
1 parent d198ff2 commit 6f919e4
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
8 changes: 4 additions & 4 deletions rib/ablations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import time
from pathlib import Path
from typing import Callable, Literal, Optional, Union

import numpy as np
Expand All @@ -23,6 +22,7 @@
from rib.log import logger
from rib.models import MLP, SequentialTransformer
from rib.rib_builder import RibBuildResults
from rib.settings import REPO_ROOT
from rib.types import TORCH_DTYPES, RootPath, StrDtype
from rib.utils import (
check_outfile_overwrite,
Expand All @@ -32,8 +32,8 @@
set_seed,
)

BasisVecs = Union[Float[Tensor, "orig orig_trunc"], Float[Tensor, "orig orig"]]
BasisVecsPinv = Union[Float[Tensor, "orig_trunc orig"], Float[Tensor, "orig orig"]]
BasisVecs = Union[Float[Tensor, "orig rib"], Float[Tensor, "orig orig"]]
BasisVecsPinv = Union[Float[Tensor, "rib orig"], Float[Tensor, "orig orig"]]
AblationAccuracies = dict[str, dict[int, float]]


Expand Down Expand Up @@ -173,7 +173,7 @@ class AblationConfig(BaseModel):
model_config = ConfigDict(extra="forbid", frozen=True)
exp_name: str
out_dir: Optional[RootPath] = Field(
Path(__file__).parent / "out",
REPO_ROOT / "rib_scripts/ablations/out",
description="Directory for the output files. Defaults to `./out/`. If None, no output "
"is written. If a relative path, it is relative to the root of the rib repo.",
)
Expand Down
8 changes: 5 additions & 3 deletions rib/rib_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,13 @@ def verify_n_stochastic_sources(self) -> "RibBuildConfig":
class RibBuildResults(BaseModel):
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
exp_name: str = Field(..., description="The name of the experiment")
gram_matrices: dict[str, torch.Tensor] = Field(description="Gram matrices at each node layer.")
gram_matrices: dict[str, torch.Tensor] = Field(
description="Gram matrices at each node layer.", repr=False
)
interaction_rotations: list[InteractionRotation] = Field(
description="Interaction rotation matrices (e.g. Cs, Us) at each node layer."
description="Interaction rotation matrices (e.g. Cs, Us) at each node layer.", repr=False
)
edges: list[Edges] = Field(description="The edges between each node layer.")
edges: list[Edges] = Field(description="The edges between each node layer.", repr=False)
dist_info: DistributedInfo = Field(
description="Information about the parallelisation setup used for the run."
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_build_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def get_means(results: RibBuildResults, atol: float, batch_size=16):
)


@pytest.mark.slow
# @pytest.mark.slow
@pytest.mark.parametrize(
"basis_formula, edge_formula",
[
Expand Down

0 comments on commit 6f919e4

Please sign in to comment.