In [1]:
from probes import LRProbe
from utils import DataManager
import torch as th



In [2]:
label_names = [
    "has_alice",
    "has_not",
    "label",
    "has_alice xor has_not",
    "has_alice xor label",
    "has_not xor label",
    "has_alice xor has_not xor label",
]
DEVICE = "auto"
if DEVICE == "auto":
    DEVICE = "cuda" if th.cuda.is_available() else "cpu"
model_base = "state-spaces/mamba-"
sizes = ["1.4b","2.8b"]
all_accs = {}
for i in list(range(3)):
    model = model_base+sizes[i]
    revision = ""
    accs = {}
    for label_name in label_names:
        dm = DataManager()
        for dataset in ["cities_alice", "neg_cities_alice"]:
            dm.add_dataset(
                dataset,
                model,
                4,
                label=label_name,
                center=False,
                split=0.8,
                device=DEVICE,
                revision=revision,
            )
        acts, labels = dm.get("train")
        probe = LRProbe.from_data(acts, labels, bias=True, device=DEVICE)
        acts, labels = dm.get("val")
        acc = (probe(acts).round() == labels).float().mean()
        accs[label_name] = acc
    all_accs[model] = accs

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading model state-spaces/mamba-1.2b...


OSError: state-spaces/mamba-1.2b is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`

In [12]:
import plotly.express as px

all_accs_f = {k: {k2: v2.item() for k2, v2 in v.items()} for k, v in all_accs.items()}

fig = px.bar(all_accs_f, barmode='group')
fig.update_layout(xaxis_title="Feature", yaxis_title="Accuracy", legend_title="Revision")
fig.show()

In [4]:
dm = DataManager()
for dataset in ['cities']:
    dm.add_dataset(dataset, 'state-spaces/mamba-370m', 14, label='label', center=False, split=0.8)
acts, labels = dm.get('train')
probe = LRProbe.from_data(acts, labels)
acts, labels = dm.get('val')
val_acc = (probe(acts).round() == labels).float().mean()

for dataset in ['neg_cities']:
    dm.add_dataset(dataset, 'state-spaces/mamba-370m', 14, label='label', center=False, split=None)
acts, labels = dm.get('all')
test_acc = (probe(acts).round() == labels).float().mean()

px.bar(x=['On cities', 'On neg_cities'], y=[val_acc, test_acc], text_auto=True).update_xaxes(title='Dataset').update_yaxes(title='Accuracy')

Loading model llama-2-13b...


OSError: llama-2-13b is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`

In [4]:
dm = DataManager()
for dataset in ['cities_alice', 'neg_cities_alice']:
    dm.add_dataset(dataset, 'state-spaces/mamba-370m', 14, label='has_alice', center=False, split=0.8)

In [5]:
acts, labels = dm.get('train')
probe = LRProbe.from_data(acts, labels)

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
acts, labels = dm.get('val')
(probe(acts).round() == labels).float().mean()

tensor(1.)

In [109]:
dm = DataManager()
for dataset in ['cities_alice_only', 'neg_cities_alice_only']:
    dm.add_dataset(dataset, 'state-spaces/mamba-370m', 14, label='xor', split=0.8)

In [110]:
acts, labels = dm.get('all')
(probe(acts).round() == labels).float().mean()

tensor(0.8068)

In [10]:
import pandas as pd

df = pd.read_csv('datasets/cities_alice.csv')
df['has_alice xor has_not xor label'] = df['has_alice'] ^ df['has_not'] ^ (df['label'] == 1)
df.to_csv('datasets/cities_alice.csv', index=False)

In [45]:
df.to_csv('datasets/neg_cities_distractor.csv', index=False)

In [119]:
df

Unnamed: 0,statement,label,city,country,correct_country,has_alice,has_not,has_alice xor has_not,has_alice xor label
0,Bob: The city of Krasnodar is in Russia.,1,Krasnodar,Russia,Russia,False,False,False,True
1,Alice: The city of Krasnodar is in South Africa.,0,Krasnodar,South Africa,Russia,True,False,True,True
2,Alice: The city of Lodz is in Poland.,1,Lodz,Poland,Poland,True,False,True,False
3,Bob: The city of Lodz is in the Dominican Repu...,0,Lodz,the Dominican Republic,Poland,False,False,False,False
4,Bob: The city of Maracay is in Venezuela.,1,Maracay,Venezuela,Venezuela,False,False,False,True
...,...,...,...,...,...,...,...,...,...
1491,Alice: The city of Kirkuk is in China.,0,Kirkuk,China,Iraq,True,False,True,True
1492,Bob: The city of Al Mawsil al Jadidah is in Iraq.,1,Al Mawsil al Jadidah,Iraq,Iraq,False,False,False,True
1493,Bob: The city of Al Mawsil al Jadidah is in In...,0,Al Mawsil al Jadidah,India,Iraq,False,False,False,False
1494,Bob: The city of Tangerang is in Indonesia.,1,Tangerang,Indonesia,Indonesia,False,False,False,True
