Skip to content

Commit

Permalink
Merge pull request #20 from OpenMOSS/dev
Browse files Browse the repository at this point in the history
Comply with Strict Mypy Typing
  • Loading branch information
dest1n1s committed Jun 10, 2024
2 parents 05a5973 + 60236ff commit e62876a
Show file tree
Hide file tree
Showing 19 changed files with 200 additions and 546 deletions.
52 changes: 52 additions & 0 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
name: Checks

on:
push:
branches:
- main
- dev
paths:
- "**" # Include all files by default
- "!.devcontainer/**"
- "!.vscode/**"
- "!.git*"
- "!*.md"
- "!.github/**"
- ".github/workflows/checks.yml" # Still include current workflow
pull_request:
branches:
- main
- dev
paths:
- "**"
- "!.devcontainer/**"
- "!.vscode/**"
- "!.git*"
- "!*.md"
- "!.github/**"
- ".github/workflows/checks.yml"
# Allow this workflow to be called from other workflows
workflow_call:
inputs:
# Requires at least one input to be valid, but in practice we don't need any
dummy:
type: string
required: false

permissions:
actions: write
contents: write

jobs:
code-checks:
name: Code Checks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Setup PDM
uses: pdm-project/setup-pdm@v4
# You are now able to use PDM in your workflow
- name: Install dependencies
run: pdm install
- name: Type check
run: pdm run mypy .
112 changes: 72 additions & 40 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,12 @@ license = {text = "MIT"}
[tool.pdm.dev-dependencies]
dev = [
"-e file:///${PROJECT_ROOT}/TransformerLens#egg=transformer-lens",
"mypy>=1.10.0",
]

[tool.mypy]
check_untyped_defs=true
exclude=[".venv/", "examples", "TransformerLens", "tests", "exp"]
ignore_missing_imports=true
allow_redefinition=true
implicit_optional=true
32 changes: 7 additions & 25 deletions server/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Any, cast

import numpy as np
import torch
Expand Down Expand Up @@ -100,7 +101,7 @@ def list_dictionaries():


@app.get("/dictionaries/{dictionary_name}/features/{feature_index}")
def get_feature(dictionary_name: str, feature_index: str):
def get_feature(dictionary_name: str, feature_index: str | int):
tokenizer = get_model(dictionary_name).tokenizer
if isinstance(feature_index, str):
if feature_index == "random":
Expand All @@ -113,7 +114,8 @@ def get_feature(dictionary_name: str, feature_index: str):
content=f"Feature index {feature_index} is not a valid integer",
status_code=400,
)
feature = client.get_feature(dictionary_name, feature_index, dictionary_series=dictionary_series)
if isinstance(feature_index, int):
feature = client.get_feature(dictionary_name, feature_index, dictionary_series=dictionary_series)

if feature is None:
return Response(
Expand Down Expand Up @@ -341,13 +343,13 @@ def feature_interpretation(
interpretation = feature["interpretation"] if "interpretation" in feature else None
if interpretation is None:
return Response(content="Feature interpretation not found", status_code=404)
validation = interpretation["validation"]
validation = cast(Any, interpretation["validation"])
if not any(v["method"] == "activation" for v in validation):
validation_result = check_description(
model,
cfg,
feature_index,
interpretation["text"],
cast(str, interpretation["text"]),
False,
feature_activation=feature["analysis"][0],
)
Expand All @@ -363,7 +365,7 @@ def feature_interpretation(
model,
cfg,
feature_index,
interpretation["text"],
cast(str, interpretation["text"]),
True,
sae=get_sae(dictionary_name),
)
Expand All @@ -383,26 +385,6 @@ def feature_interpretation(
return interpretation


@app.get("/attn_heads/{layer}/{head}")
def get_attn_head(layer: int, head: int):
attn_head = client.get_attn_head(layer, head, dictionary_series=dictionary_series)
if attn_head is None:
return Response(
content=f"Attention head {layer}/{head} not found", status_code=404
)
attn_scores = [{
"dictionary1_name": v["dictionary1"]["name"],
"dictionary2_name": v["dictionary2"]["name"],
"top_attn_scores": v["top_attn_scores"],
} for v in attn_head["attn_scores"]]

return {
"layer": layer,
"head": head,
"attn_scores": attn_scores,
}


app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
Expand Down
4 changes: 2 additions & 2 deletions src/lm_saes/activation/activation_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ def next_tokens(self, batch_size: int) -> torch.Tensor | None:
class CachedActivationSource(ActivationSource):
def __init__(self, cfg: ActivationStoreConfig):
self.cfg = cfg
assert cfg.use_cached_activations and cfg.cached_activations_path is not None
assert cfg.use_cached_activations and len(cfg.cached_activations_path) == 1
assert len(cfg.hook_points) == 1, "CachedActivationSource only supports one hook point"
self.hook_point = cfg.hook_points[0]
self.chunk_paths = list_activation_chunks(cfg.cached_activations_path, self.hook_point)
self.chunk_paths = list_activation_chunks(cfg.cached_activations_path[0], self.hook_point)
if cfg.use_ddp:
self.chunk_paths = [p for i, p in enumerate(self.chunk_paths) if i % cfg.world_size == cfg.rank]
random.shuffle(self.chunk_paths)
Expand Down
1 change: 1 addition & 0 deletions src/lm_saes/activation/activation_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def next_tokens(self, batch_size: int) -> torch.Tensor | None:

@staticmethod
def from_config(model: HookedTransformer, cfg: ActivationStoreConfig):
act_source: ActivationSource
if cfg.use_cached_activations:
act_source=CachedActivationSource(cfg=cfg)
else:
Expand Down
Loading

0 comments on commit e62876a

Please sign in to comment.