In [36]:
from transformers import AutoModelForCausalLM, AutoTokenizer
# from tuned_lens.plotting import plot_logit_lens

model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-13b-deduped", torch_dtype="auto").to("cuda:1")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-13b-deduped")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [39]:
import torch as th

pythia_70m_metrics = th.load(
    "/mnt/ssd-1/nora/real-lenses/pythia/19m-deduped/affine/eval/aggregate_metrics.pt",
    map_location="cpu"
)
pythia_160m_metrics = th.load(
    "/mnt/ssd-1/nora/real-lenses/pythia/125m-deduped/affine/eval/aggregate_metrics.pt",
    map_location="cpu"
)
pythia_410m_metrics = th.load(
    "/mnt/ssd-1/nora/real-lenses/pythia/350m-deduped/affine/eval/aggregate_metrics.pt",
    map_location="cpu"
)
pythia_1_4b_metrics = th.load(
    "/mnt/ssd-1/nora/real-lenses/pythia/1.3b-deduped/affine/eval/aggregate_metrics.pt",
    map_location="cpu"
)
pythia_12b_metrics = th.load(
    "/mnt/ssd-1/nora/real-lenses/pythia/13b-deduped/affine/eval/aggregate_metrics.pt",
    map_location="cpu"
)
neox_20b_metrics = th.load(
    "/mnt/ssd-1/nora/real-lenses/gpt-neox-20b/classic/eval/aggregate_metrics.pt",
    map_location="cpu"
)

In [3]:
import torch as th

bloom_metrics = th.load(
    "/mnt/ssd-1/nora/real-lenses/bloom/560m/affine/eval/aggregate_metrics.pt",
    map_location="cpu"
)
bloom_extra_metrics = th.load(
    "/mnt/ssd-1/nora/real-lenses/bloom/560m/extra-layer/eval/aggregate_metrics.pt",
    map_location="cpu"
)
neo_metrics = th.load(
    "/mnt/ssd-1/nora/real-lenses/gpt-neo/125M/extra-layer/eval/aggregate_metrics.pt",
    map_location="cpu"
)
neo_1_3b_metrics = th.load(
    "/mnt/ssd-1/nora/real-lenses/gpt-neo/1.3B/extra-layer/eval/aggregate_metrics.pt",
    map_location="cpu"
)
# neo2_7b_metrics = th.load(
#     "/mnt/ssd-1/nora/real-lenses/gpt-neo/2.7B/eval/aggregate_metrics.pt",
#     map_location="cpu"
# )
neo2_7b_extra_metrics = th.load(
    "/mnt/ssd-1/nora/real-lenses/gpt-neo/2.7B/extra-layer/eval/aggregate_metrics.pt",
    map_location="cpu"
)
opt_125m_metrics = th.load(
    "/mnt/ssd-1/nora/real-lenses/opt/125m/affine/eval/aggregate_metrics.pt",
    map_location="cpu"
)
# opt_1_3b_metrics = th.load(
#     "/mnt/ssd-1/nora/real-lenses/opt/1.3b/affine/eval/aggregate_metrics.pt",
#     map_location="cpu"
# )
# opt_6_7b_metrics = th.load(
#     "/mnt/ssd-1/nora/real-lenses/opt/6.7b/affine/eval/aggregate_metrics.pt",
#     map_location="cpu"
# )
pythia_metrics = th.load(
    "/mnt/ssd-1/nora/real-lenses/pythia/125m-deduped/affine/eval/aggregate_metrics.pt",
    map_location="cpu"
)
pythia_1_3b_metrics = th.load(
    "/mnt/ssd-1/nora/real-lenses/pythia/1.3b-deduped/affine/eval/aggregate_metrics.pt",
    map_location="cpu"
)
neox_metrics = th.load(
    "/mnt/ssd-1/nora/real-lenses/gpt-neox-20b/classic/eval/aggregate_metrics.pt",
    map_location="cpu"
)

In [None]:
import plotly.graph_objects as go
import plotly.express as px

X = [
    ("70M", pythia_70m_metrics),
    ("160M", pythia_160m_metrics),
    # ("410M", pythia_410m_metrics),
    ("1.4B", pythia_1_4b_metrics),
    ("12B", pythia_12b_metrics),
    # ("20B<br>(NeoX)", neox_20b_metrics),
]

fig = go.Figure([
    go.Scatter(
        marker_color=color,
        mode="lines+markers",
        name=name,
        y=list(metrics['lens_ce'].values())
    )
    for color, (name, metrics) in zip(
        px.colors.sequential.Plasma[::2], X
    )
]).update_layout(
    hovermode="y unified",
    legend_title="Model size",
    title="Tuned lens perplexity across depth & scale (Pythia)",
).update_xaxes(
    dtick=5,
    title="Layer"
).update_yaxes(
    title="bits per byte",
    type="log"
)
fig

In [141]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

vanilla_ll = list(bloom_metrics['baseline_ce'].values())
extra_ll = list(bloom_extra_metrics['baseline_ce'].values())

TEXT_WIDTH = 6.75 * 96 # 487.8225

master = make_subplots(
    cols=2,
    horizontal_spacing=0.05,
    subplot_titles=("w/o final layer", "w/ final layer"),
    x_title="Layer",
    y_title="bits per byte",
).add_traces(
    cols=1, rows=1,
    data=[
        go.Scatter(
            y=vanilla_ll[:-1],
            marker_color="red",
            marker_symbol="square",
            mode="lines+markers",
            name="Logit lens",
            showlegend=False,
        ),
        go.Scatter(
            y=list(bloom_metrics['lens_ce'].values()),
            marker_color="blue",
            marker_symbol="circle",
            mode="lines+markers",
            name="Tuned lens",
            showlegend=False,
        ),
        go.Scatter(
            x=(0, len(extra_ll) - 1),
            y=(vanilla_ll[-1], extra_ll[-1]),
            line_color="black",
            line_dash="dash",
            mode="lines",
            name="Final logits",
            showlegend=False,
        ),
    ]
).add_traces(
    cols=2, rows=1,
    data=[
        go.Scatter(
            y=extra_ll[:-1],
            marker_color="red",
            marker_symbol="square",
            mode="lines+markers",
            name="Logit lens",
        ),
        go.Scatter(
            y=list(bloom_extra_metrics['lens_ce'].values()),
            marker_color="blue",
            mode="lines+markers",
            name="Tuned lens",
        ),
        go.Scatter(
            x=(0, len(extra_ll) - 1),
            y=(vanilla_ll[-1], extra_ll[-1]),
            line_color="black",
            line_dash="dash",
            mode="lines",
            name="Final logits",
        ),
    ]
).update_annotations(
    font=dict(color="black", size=20),
).update_layout(
    font=dict(color="black", size=16),
    height=TEXT_WIDTH / 2,
    width=TEXT_WIDTH,
    xaxis1=dict(
        tickangle=-20,
        tickvals=list(range(0, len(vanilla_ll) - 1, 5)),
        ticktext=["input"] + [str(i * 5) for i in range(1, len(vanilla_ll) - 1)],
    ),
    xaxis2=dict(
        tickangle=-20,
        tickvals=list(range(0, len(extra_ll) - 1, 5)),
        ticktext=["input"] + [str(i * 5) for i in range(1, len(extra_ll) - 1)],
    ),
    yaxis2=dict(
        range=[0, 8.5]
    ),
    legend=dict(
        x=0.24,
    ),
    margin_l=70,
    margin_r=30,
    margin_t=50,
    margin_b=70,
)
master

In [None]:
# Constants factors
marginInches = 1/18
ppi = 96
width_inches = 6.771654 / 2
height_inches = 4

In [78]:
TEXT_WIDTH / 72.27

6.75

In [142]:
master.write_image(
    "/mnt/ssd-1/nora/perplexity-bloom.pdf",
)

In [35]:
vanilla_ll[-1], extra_ll[-1]

(tensor(0.9786), tensor(0.9585))

In [41]:
import plotly.graph_objects as go
import plotly.express as px

X = [
    ("70M", pythia_70m_metrics),
    ("160M", pythia_160m_metrics),
    # ("410M", pythia_410m_metrics),
    ("1.4B", pythia_1_4b_metrics),
    ("12B", pythia_12b_metrics),
    # ("20B<br>(NeoX)", neox_20b_metrics),
]

fig = go.Figure([
    go.Scatter(
        marker_color=color,
        mode="lines+markers",
        name=name,
        y=list(metrics['lens_ce'].values())
    )
    for color, (name, metrics) in zip(
        px.colors.sequential.Plasma[::2], X
    )
]).update_layout(
    hovermode="y unified",
    legend_title="Model size",
    title="Tuned lens perplexity across depth & scale (Pythia)",
).update_xaxes(
    dtick=5,
    title="Layer"
).update_yaxes(
    title="bits per byte",
    type="log"
)
fig

In [44]:
fig.write_image("/mnt/ssd-1/nora/perplexity-simple.pdf", width=700, height=500)

In [17]:
OURS = """\
We present evidence that transformers build their \
predictions through a latent iterative process
"""

In [3]:
from pathlib import Path
from tuned_lens.utils import pytree_stack
import torch as th


paths = Path("/mnt/ssd-1/nora/real-lenses/bloom/560m/affine/eval/").rglob("batch_*.pt")
batches = [
    th.load(p, map_location="cpu")
    for p in paths
]

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer

neo = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")

In [2]:
from tuned_lens import TunedLens

lens = TunedLens.load("/mnt/ssd-1/nora/real-lenses/gpt-neo/2.7B/extra-layer")

In [1]:
VASWANI = """\
The dominant sequence transduction models are based on complex recurrent or \
convolutional neural networks that include an encoder and a decoder. The best \
performing models also connect the encoder and decoder through an attention \
mechanism. We propose a new simple network architecture, the Transformer\
"""

In [4]:
from tuned_lens.plotting import plot_logit_lens

tl_trace = plot_logit_lens(
    neo, tokenizer, text=VASWANI, tuned_lens=lens,
    start_pos=47,
    end_pos=55,
    layer_stride=3,
    metric="prob",
    # min_prob=0.05,
    newline_replacement="",
    whitespace_replacement="",
    raw=True
)

In [5]:
from tuned_lens.plotting import plot_logit_lens

ll_trace = plot_logit_lens(
    neo, tokenizer, text=VASWANI,
    start_pos=47,
    end_pos=55,
    extra_decoder_layers=1,
    layer_stride=3,
    metric="prob",
    # min_prob=0.05,
    newline_replacement="",
    whitespace_replacement="",
    raw=True
)

In [75]:
6.75 * 0.1 / 2

0.3375

In [71]:
from plotly.subplots import make_subplots

COL_WIDTH = 6.75 * 96

master = make_subplots(
    rows=2,
    subplot_titles=("Logit Lens (theirs)", "Tuned Lens (ours)"),
    vertical_spacing=0.1,
    x_title="Input token",
    y_title="Layer"
).add_trace(
    ll_trace, row=1, col=1
).add_trace(
    tl_trace, row=2, col=1
).update_traces(
    colorbar=dict(
        orientation="h",
        thickness=15,
        y=-0.18,
    ),
    textfont_size=12,
).update_layout(
    font=dict(color="black"),
    height=800,
    margin_b=70,
    margin_l=60,
    margin_r=0,
    margin_t=25,
    # yaxis_tickangle=-20,
    # yaxis2_tickangle=-20,
    width=COL_WIDTH * 1.1,
)
master

In [72]:
master.write_image("/mnt/ssd-1/nora/front-page.pdf")

In [4]:
from tuned_lens.nn import TunedLens

lens = TunedLens.load("/mnt/ssd-1/nora/real-lenses/pythia/13b-deduped/affine", map_location="cuda:1").to("cuda:1")

In [8]:
fig = plot_logit_lens(
    model, tokenizer,
    text=OURS,
    start_pos=4,
    end_pos=15,
    layer_stride=4,
    tuned_lens=lens,
    whitespace_replacement="",
).update_layout(
    # autosize=False,
    font=dict(size=14),
    title_text=None,
    # title_x=0.5,
    # height=500,
    # width=1100,
    xaxis_title=None,
    yaxis_title=None,
).update_traces(
    textfont_size=14,
)
fig

In [7]:
py.plot(fig, filename="figure1", auto_open=False)

'https://plotly.com/~norabelrose/23/'

In [19]:
lens = TunedLens.load("/mnt/ssd-1/nora/real-lenses/bloom/560m/affine/")

In [18]:
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")

Downloading:   0%|          | 0.00/688 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/222 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/85.0 [00:00<?, ?B/s]