In [1]:
try:
    import google.colab
    IN_COLAB = True
    from tqdm.notebook import tqdm, trange

    from google.colab import drive
    drive.mount("/content/gdrive", force_remount=True)
    %cd /content/gdrive/MyDrive/feature-circuits
    %pip install -r requirements.txt
    !git submodule update --init
except:
    IN_COLAB = False
    from tqdm import tqdm, trange

import os

import torch
from nnsight.models.UnifiedTransformer import UnifiedTransformer
from nnsight import LanguageModel

from circuit_stop_at_layer import get_circuit, save_circuit
from circuit_plotting import plot_circuit
from dictionary_learning import AutoEncoder

from transformer_lens import HookedTransformer

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("DEVICE :", DEVICE)

print("IN_COLAB :", IN_COLAB)

  from .autonotebook import tqdm as notebook_tqdm


DEVICE : cuda
IN_COLAB : False


In [2]:
pythia70m = UnifiedTransformer("EleutherAI/pythia-70m-deduped", device=DEVICE, processing=False)

pythia70m_embed = pythia70m.embed

pythia70m_resids= []
pythia70m_attns = []
pythia70m_mlps = []
for layer in range(len(pythia70m.blocks)):
    pythia70m_resids.append(pythia70m.blocks[layer])
    pythia70m_attns.append(pythia70m.blocks[layer].attn)
    pythia70m_mlps.append(pythia70m.blocks[layer].mlp)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer


In [3]:
if IN_COLAB:
    base = "/content/gdrive/MyDrive/feature-circuits/"
else:
    base = "C:/Users/Grégoire/Documents/ENS/stages/AttentionGraph/Marks/feature-circuits/"
path = base + "dictionary_learning/dictionaires/pythia-70m-deduped/"

if not os.path.exists(path):
    if IN_COLAB:
        # go to base / dictionary_learning :
        %cd /content/gdrive/MyDrive/feature-circuits/dictionary_learning
        !apt-get update
        !apt-get install dos2unix
        !dos2unix pretrained_dictionary_downloader.sh
        !chmod +x pretrained_dictionary_downloader.sh
        !./pretrained_dictionary_downloader.sh
        %cd /content/gdrive/MyDrive/feature-circuits
    else:
        %cd C:/Users/Grégoire/Documents/ENS/stages/AttentionGraph/Marks/feature-circuits/dictionary_learning
        %run ./pretrained_dictionary_downloader.sh
        %cd C:/Users/Grégoire/Documents/ENS/stages/AttentionGraph/Marks/feature-circuits

dictionaries = {}

d_model = 512
dict_size = 32768

ae = AutoEncoder(d_model, dict_size)
ae.load_state_dict(torch.load(path + f"embed/ae.pt", map_location='cpu'))
dictionaries[pythia70m_embed] = ae.to(DEVICE)


for layer in range(len(pythia70m.blocks)):
    ae = AutoEncoder(d_model, dict_size)
    ae.load_state_dict(torch.load(path + f"resid_out_layer{layer}/ae.pt", map_location='cpu'))
    dictionaries[pythia70m_resids[layer]] = ae.to(DEVICE)

    # ae = AutoEncoder(d_model, dict_size)
    # ae.load_state_dict(torch.load(path + f"attn_out_layer{layer}/ae.pt", map_location='cpu'))
    # dictionaries[pythia70m_attns[layer]] = ae.to(DEVICE)

    # ae = AutoEncoder(d_model, dict_size)
    # ae.load_state_dict(torch.load(path + f"mlp_out_layer{layer}/ae.pt", map_location='cpu'))
    # dictionaries[pythia70m_mlps[layer]] = ae.to(DEVICE)

In [4]:
def metric_fn_v1(model, trg=None):
    """
    default : return the logit
    """
    if trg is None:
        raise ValueError("trg must be provided")
    logits = model.unembed.output[:,-1,:]
    return logits[torch.arange(trg.numel()), trg]

def metric_fn_v2(model, trg=None):
    """
    Return -log probability for the expected target.

    trg : torch.Tensor, contains idxs of the target tokens (between 0 and d_vocab_out)

    /!\ here we assume that all last tokens are indeed in the last position (if padding, it must happen in front of the sequence, not after)
    """
    if trg is None:
        raise ValueError("trg must be provided")
    logits = model.unembed.output[:,-1,:]
    return (
         -1 * torch.gather(
             torch.nn.functional.log_softmax(model.unembed.output[:,-1,:], dim=-1),
             dim=-1, index=trg.view(-1, 1)
         ).squeeze(-1)
    )

In [5]:
batch_size = 1

clean = [
    "When Mary and John went to the store, John gave a drink to"
    for _ in range(batch_size)
]
patch = None

trg = " Mary"
trg_idx = torch.tensor([pythia70m.tokenizer.encode(trg)[0]] * batch_size, device=DEVICE)
print(trg_idx)

tensor([6393])


In [6]:
d = pythia70m.cfg.device
print(d.type)
print(d.index)
print(d)

cpu
None
cpu


In [9]:
circuit = get_circuit(
    clean, patch,
    pythia70m,
    pythia70m_embed, pythia70m_resids,
    dictionaries,
    metric_fn_v1, {"trg": trg_idx},
    edge_threshold=0.1
)

Layer 5 : 3.246279716491699 seconds
Now processing layer 5 with 61 features
[426292, 427617, 428036, 428635, 429814, 430366, 430713, 431013, 431236, 431270, 432770, 432875, 433336, 433538, 433553, 434039, 435004, 435103, 435248, 436423, 437083, 437582, 437956, 438170, 438445, 438546, 440085, 440208, 441947, 442157, 442265, 442819, 442905, 443826, 444575, 444843, 444940, 445946, 446073, 446824, 447449, 447889, 449138, 449642, 449663, 449831, 450998, 451296, 451439, 451503, 452071, 452850, 453938, 454108, 454500, 455574, 455720, 457443, 458717, 458737, 458765]
Layer 4 : 212.16219425201416 seconds
Now processing layer 4 with 75 features
[426292, 427330, 427617, 428036, 428569, 428635, 429814, 430366, 430713, 431013, 431236, 431270, 432770, 432875, 433336, 433386, 433538, 433553, 434039, 434198, 434561, 435004, 435103, 435248, 436164, 436423, 437083, 437582, 437605, 437956, 438170, 438445, 438546, 439139, 439805, 440085, 440208, 441947, 442074, 442157, 442265, 442819, 442905, 443467, 44382

IndexError: tensors used as indices must be long, int, byte or bool tensors

In [None]:
t = torch.randn(1, 512)
print(t.shape)
print(t[0].shape)
print(t[:, 0:10].shape)
print(t[..., 0:10].shape)

torch.Size([1, 512])
torch.Size([10])
torch.Size([1, 10])
torch.Size([1, 10])


- cpu :
    - 1 : 2m47
    - 2 : /
    - 10: Stop at 68m+

- gpu :
    - 1 : 42s
    - 2 : 1m32

In [None]:
submod_1 = "resid_0"
submod_2 = "resid_1"

weights = circuit[1][submod_1][submod_2]
weights = weights.values()

from matplotlib import pyplot as plt

alive_downstream = circuit[1][submod_1][submod_2].indices()[0]
set_downstream = list(set([alive_downstream_.item() for alive_downstream_ in alive_downstream]))

ss = []
abss = []
nb_k = []

from tqdm import tqdm

for k in tqdm(set_downstream):
    weights = []
    for i, idx in enumerate(alive_downstream):
        if idx == k:
            weights.append(circuit[1][submod_1][submod_2].values()[i])
    weights = torch.stack(weights)

    perm = torch.argsort(weights.abs(), descending=True)
    weights = weights[perm]
    tot = sum(weights)
    s = 0
    for i in range(len(weights)):
        s += weights[i]
        if i < len(ss):
            ss[i] += (s / tot).item()
        else:
            ss.append((s / tot).item())
        if i < len(abss):
            abss[i] += weights[i].abs().item()
        else:
            abss.append(weights[i].abs().item())
        if i < len(nb_k):
            nb_k[i] += 1
        else:
            nb_k.append(1)
        # print("i :", i)
        # print("weight :", embed_weights[i].item())
        # print("% of total :", s.item() / tot.item() * 100)

ss = [ss[i] / nb_k[i] for i in range(len(ss))]
abss = [abss[i] / nb_k[i] for i in range(len(abss))]

"""
plot ss and abss on two different axis with the same x-axis on the same plot
"""
fig, ax1 = plt.subplots()

color = 'tab:red'
ax1.set_xlabel('weight index')
ax1.set_ylabel('cumulative % of total', color=color)
ax1.plot(ss, color=color)
ax1.tick_params(axis='y', labelcolor=color)

ax2 = ax1.twinx()
color = 'tab:blue'
ax2.set_ylabel('weight', color=color)
ax2.plot(abss, color=color)
ax2.tick_params(axis='y', labelcolor=color)

plt.show()

In [None]:
max_weights = 100

fig, ax1 = plt.subplots()

color = 'tab:red'
ax1.set_xlabel('weight index')
ax1.set_ylabel('cumulative % of total', color=color)
ax1.plot(ss[:max_weights], color=color)
ax1.tick_params(axis='y', labelcolor=color)

ax2 = ax1.twinx()
color = 'tab:blue'
ax2.set_ylabel('weight', color=color)
ax2.plot(abss[:max_weights], color=color)
ax2.tick_params(axis='y', labelcolor=color)

plt.show()

In [None]:
import importlib
import circuit_plotting
importlib.reload(circuit_plotting)
circuit_plotting.plot_circuit(circuit[0], circuit[1], save_dir='./circuit/cpu_2_')

In [None]:
all_weights = []
for key, value in circuit[1].items():
    for k, v in value.items():
        all_weights.append(v.values())
        
all_weights = torch.cat(all_weights, dim=0)
print(all_weights.shape)
print(all_weights.abs().mean())

plt.hist(all_weights[all_weights.abs() > 0.01].detach().cpu().numpy(), bins=100)
plt.show()

In [None]:
A = torch.randn(1, 10, 50)
B = torch.randn(1, 10, 50)

print((A * B).shape)
print(A @ B)

In [None]:
import torch

dummy_2d_sparse_idx = torch.tensor([[0, 99, 27], [1, 2, 199]])
dummy_2d_sparse_values = torch.randn(2, 3)

dummy_2d_sparse = torch.sparse_coo_tensor(
    dummy_2d_sparse_idx,
    dummy_2d_sparse_values,
    size=(100, 200)
)

print(dummy_2d_sparse.to_dense())