# MechIR Main Demo

## Setup

In [52]:
%pip install ..
%pip install transformer_lens

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Processing /Users/cchen207/git/brown/research/MechIR
  Preparing metadata (setup.py) ... [?25ldone
[?25hBuilding wheels for collected packages: mechir
  Building wheel for mechir (setup.py) ... [?25ldone
[?25h  Created wheel for mechir: filename=mechir-0.0.1-py3-none-any.whl size=58998 sha256=378097fc0d016d9062f14a9fe1f04a43181835113bfd09e523053cf398fc7758
  Stored in directory: /private/var/folders/w8/j9c1qwbx3cn8hf10x5nz_xtr0000gp/T/pip-ephem-wheel-cache-wugpaw_q/wheels/39/37/7c/9f04c1e8f880bc1e666f79cde17d9e585bcff18fdf2b5a9b0d
Successfully built mechir
Installing collected packages: mechir
  Attempting uninstall: mechir
    Found existing installation: mechir 0.0.1
    Uninstalling mechir-0.0.1:
      Successfully uninstalled mechir-0.0.1
Successfully installed mechir-0.0.1
Note: you may need to restart the kernel to use updated packages.


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Note: you may need to restart the kernel to use updated packages.


In [1]:
from mechir import Dot, MechIRDataset, DotDataCollator

from torch.utils.data import DataLoader
# import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


## Load Model

In [2]:
model = Dot("sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco")

If using BERT for interpretability research, keep in mind that BERT has some significant architectural differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning that the last LayerNorm in a block cannot be folded.


Moving model to device:  mps
Loaded pretrained model sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco into HookedEncoder


## Load Dataset

In [3]:
# Load smallest dataset for quick testing
dataset = MechIRDataset("vaswani")
dataset.pairs.head()

Unnamed: 0,query_id,doc_id,relevance,iteration
0,1,1239,1,0
1,1,1502,1,0
2,1,4462,1,0
3,1,4569,1,0
4,1,5472,1,0


In [4]:
# Print examples of queries
print("Total queries in dataset:", len(dataset.queries.items()))
print("\n----------- Examples of queries: -----------\n")
example_queries = list(dataset.queries.values())[:3]
for query in example_queries:
    print(query)

Total queries in dataset: 93

----------- Examples of queries: -----------

MEASUREMENT OF DIELECTRIC CONSTANT OF LIQUIDS BY THE USE OF MICROWAVE TECHNIQUES

MATHEMATICAL ANALYSIS AND DESIGN DETAILS OF WAVEGUIDE FED MICROWAVE RADIATIONS

USE OF DIGITAL COMPUTERS IN THE DESIGN OF BAND PASS FILTERS HAVING GIVEN PHASE AND ATTENUATION CHARACTERISTICS



In [5]:
# Calculate document stats
doc_lengths = [len(doc.split()) for doc in dataset.docs.values()]

# Print examples of documents
print("Total documents in dataset:", len(dataset.docs.items()))
print(f"Minimum Length (in words): {min(doc_lengths)}")
print(f"Maximum Length (in words): {max(doc_lengths)}")
print(f"Average Length (in words): {(sum(doc_lengths) / len(doc_lengths) if doc_lengths else 0):.2f}")
print("\n----------- Examples of documents: -----------\n")
example_docs = list(dataset.docs.values())[:3]
for doc in example_docs:
    print(doc)

Total documents in dataset: 11429
Minimum Length (in words): 2
Maximum Length (in words): 269
Average Length (in words): 41.93

----------- Examples of documents: -----------

compact memories have flexible capacities  a digital data storage
system with capacity up to bits and random and or sequential access
is described

an electronic analogue computer for solving systems of linear equations
mathematical derivation of the operating principle and stability
conditions for a computer consisting of amplifiers

electronic coordinate transformer  circuit details are given for
the construction of an electronic calculating unit which enables
the polar coordinates of a vector modulus and cosine or sine of the
argument to be derived from those of a rectangular system of axes



In [6]:
# Print example of one query and relevant documents
query_rel_doc_ex_df = dataset.pairs.head(3)[["query_id", "doc_id"]]
query_id = query_rel_doc_ex_df["query_id"].unique()[0]
doc_ids = query_rel_doc_ex_df["doc_id"]
print(f"Query: {dataset.queries[query_id]}")
print("-------------------------------")
for doc_id in doc_ids:
    print(f"Document:\n{dataset.docs[doc_id]}")

Query: MEASUREMENT OF DIELECTRIC CONSTANT OF LIQUIDS BY THE USE OF MICROWAVE TECHNIQUES

-------------------------------
Document:
broadband millimetre wave paramagnetic resonance spectrometer  the
specimen and waveguide which can be cooled by means of a cryostat
are placed between close pole pieces giving high uniform magnetic
fields  design details and some measurements on zero field splittings
are given

Document:
microwave measurements of dielectric absorption in dilute solutions

Document:
dielectric properties of ice at very low frequencies and the influence
of a polarizing field  measurements at frequencies down to are reported
the loss factor passes through a low frequency maximum which is distinguishable
from that associated with the dipole dispersion by its different
temperature dependence  the effect of impurities is to shift the
maximum towards higher frequencies  application of a unidirectional
field does not affect the permittivity of the pure crystals but eliminates
the 

### Simple Preprocessing

In [7]:
# strip newlines?

### Create Paired Dataset

In [24]:
# Define perturbation (for now just test simple append static term, later test TF/IDF)
# from mechir.perturb import perturbation # can't get import to work rn
def append_term(text, query=None):
    return text + "apple banana orange"

In [10]:
model.tokenizer.add_special_tokens({"additional_special_tokens": ["[X]"]})
print(model.tokenizer.special_tokens_map)
print(model.tokenizer.convert_tokens_to_ids('[X]'))

{'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]', 'additional_special_tokens': ['[X]']}
30522


In [25]:
data_collator = DotDataCollator(model.tokenizer, append_term)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=data_collator)

here


In [18]:
# Helper function just to print
def pretty_print_triplets(batch, tokenizer):
    """
    Pretty prints triplets of queries, documents, and their corresponding perturbed documents from a batch.

    Args:
        batch (dict): A dictionary containing 'queries', 'documents', and 'perturbed_documents' from a DataLoader.
        tokenizer: The tokenizer used to decode the input IDs.
    """
    # Get the queries, documents, and perturbed documents from the batch
    queries = batch["queries"]
    documents = batch["documents"]
    perturbed_documents = batch["perturbed_documents"]

    # Loop through the batch size
    for doc_id in range(len(documents["input_ids"])):
        # Get the input IDs
        query_ids = queries["input_ids"][doc_id]
        original_doc_ids = documents["input_ids"][doc_id]
        perturbed_doc_ids = perturbed_documents["input_ids"][doc_id]

        # Decode the input IDs to text
        query_decoded = tokenizer.decode(query_ids.tolist(), skip_special_tokens=False)
        original_doc_decoded = tokenizer.decode(original_doc_ids.tolist(), skip_special_tokens=False)
        perturbed_doc_decoded = tokenizer.decode(perturbed_doc_ids.tolist(), skip_special_tokens=False)

        # Pretty print
        print(f"Triplet {doc_id + 1}:")
        print("Query (Decoded):", query_decoded)
        print("Original Document (Decoded):", original_doc_decoded)
        print("Perturbed Document (Decoded):", perturbed_doc_decoded)
        print("=" * 50)  # Separator for clarity

In [29]:
for i, batch in enumerate(dataloader):
    pretty_print_triplets(batch, model.tokenizer)

    # stop after 2 batches
    if i == 3:
        break

Triplet 1:
Query (Decoded): [CLS] measurement of dielectric constant of liquids by the use of microwave techniques [SEP]
Original Document (Decoded): [CLS] broadband millimetre wave paramagnetic resonance spectrometer the specimen and waveguide which can be cooled by means of a cryostat are placed between close pole pieces giving high uniform magnetic fields design details and some measurements on zero field splittings are given [X] [X] [X] [SEP]
Perturbed Document (Decoded): [CLS] broadband millimetre wave paramagnetic resonance spectrometer the specimen and waveguide which can be cooled by means of a cryostat are placed between close pole pieces giving high uniform magnetic fields design details and some measurements on zero field splittings are given apple banana orange [SEP]
Triplet 1:
Query (Decoded): [CLS] measurement of dielectric constant of liquids by the use of microwave techniques [SEP]
Original Document (Decoded): [CLS] microwave measurements of dielectric absorption in dil



## Verify Difference in Performance on Perturbed Pairs

## Activation Patching

### Blocks

#### Residual Stream

#### Attention Layers

#### MLP Layers

### Individual Attention Heads

### Individual Tokens (by position)

## Analyzing Attention Patterns