In [1]:
import sys,os,argparse
sys.path.append('atlas/')

# Imports
import pandas as pd
import duckdb
import pickle
import os
import time
from collections import defaultdict
import numpy as np
import torch
import torch.cuda
import torch.distributed as dist
from src import dist_utils, slurm, util
from src.index_io import load_or_initialize_index, save_embeddings_and_index
from src.model_io import create_checkpoint_directories, load_or_initialize_atlas_model
from src.options import get_options
from src.tasks import get_task
from evaluate import run_retrieval_only, _get_eval_data_iterator
import types

os.environ["TOKENIZERS_PARALLELISM"] = "true"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Paths
DATA_DIR = './atlas_data'
MODEL_PATH = '/home/tkolb/data/models/atlas/base'
ATLAS_INDEX_PATH = '/home/tkolb/data/indices/atlas/wiki/base'
EVAL_PATH = '/home/tkolb/data/datasets/queries/qs.jsonl'
FAISS_INDEX_PATH = '/home/tkolb/data/faiss_index.index'
PASSAGES_PATH = f'/home/tkolb/data/wiki_passages.pkl'
DUCKDB_FAISS_EXT = 'faiss/build/release/repository/v1.0.0/linux_amd64/faiss.duckdb_extension'

In [7]:
EXPERIMENT_NAME='my-nq-64-shot-example'
# READER_MODEL = 'google/t5-base-lm-adapt '
READER_MODEL = 't5-small'
port=15000

args = f'''--name 'my-nq-64-shot-example-evaluation' 
--generation_max_length 16 
--gold_score_mode "pdist" 
--precision fp32 
--reader_model_type {READER_MODEL} 
--text_maxlength 512 
--model_path {MODEL_PATH} 
--eval_data {EVAL_PATH} 
--per_gpu_batch_size 1 
--n_context 40 --retriever_n_context 40 
--checkpoint_dir {MODEL_PATH} 
--main_port {port} 
--index_mode "flat" 
--task "qa" 
--load_index_path {FAISS_INDEX_PATH} 
--write_results'''

args = ['/home/tkolb/RAG-Demo/atlas/evaluate.py'] + args.replace('\'', '').replace('\"', '').replace('\n', '').split(' ')

In [8]:
# Load options
sys.argv = args
options = get_options()
opt = options.parse()
torch.manual_seed(opt.seed)
slurm.init_distributed_mode(opt)
slurm.init_signal_handler()

# Load ATLAS
model, _, _, _, _, opt, step = load_or_initialize_atlas_model(opt, eval_only=True)
model.eval()
unwrapped_model = util.get_unwrapped_model_if_wrapped(model)

# Create data iterator from QA jsonl file
reader_tokenizer = unwrapped_model.reader_tokenizer
task = get_task(opt, reader_tokenizer)
data_iterator = _get_eval_data_iterator(opt, opt.eval_data[0], task)

# Get query encoding
# TODO make it possible to do it for one query
batch = data_iterator[0]
query = batch.get("query", [""])
answers = batch.get("target", [""])
batch_metadata = batch.get("metadata")
query_enc, labels, decoder_input_ids = unwrapped_model.tokenize(query, answers, None)
query_ids_retriever = query_enc["input_ids"].cuda()
query_mask_retriever = query_enc["attention_mask"].cuda()

# Get query embedding from model
unwrapped_model.retriever.eval()
query_emb = unwrapped_model.retriever(query_ids_retriever, query_mask_retriever, is_passages=False)

Downloading: 100%|██████████| 1.18k/1.18k [00:00<00:00, 4.67MB/s]
Downloading: 100%|██████████| 231M/231M [00:04<00:00, 58.3MB/s] 
Downloading: 100%|██████████| 2.27k/2.27k [00:00<00:00, 10.1MB/s]
Downloading: 100%|██████████| 773k/773k [00:00<00:00, 2.37MB/s]
Downloading: 100%|██████████| 1.32M/1.32M [00:00<00:00, 5.43MB/s]
Some weights of the model checkpoint at facebook/contriever were not used when initializing Contriever: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing Contriever from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Contriever from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


RuntimeError: Error(s) in loading state_dict for Atlas:
	Missing key(s) in state_dict: "reader.encoder.block.0.layer.1.DenseReluDense.wi.weight", "reader.encoder.block.1.layer.1.DenseReluDense.wi.weight", "reader.encoder.block.2.layer.1.DenseReluDense.wi.weight", "reader.encoder.block.3.layer.1.DenseReluDense.wi.weight", "reader.encoder.block.4.layer.1.DenseReluDense.wi.weight", "reader.encoder.block.5.layer.1.DenseReluDense.wi.weight", "reader.decoder.block.0.layer.2.DenseReluDense.wi.weight", "reader.decoder.block.1.layer.2.DenseReluDense.wi.weight", "reader.decoder.block.2.layer.2.DenseReluDense.wi.weight", "reader.decoder.block.3.layer.2.DenseReluDense.wi.weight", "reader.decoder.block.4.layer.2.DenseReluDense.wi.weight", "reader.decoder.block.5.layer.2.DenseReluDense.wi.weight". 
	Unexpected key(s) in state_dict: "reader.encoder.block.6.layer.0.SelfAttention.q.weight", "reader.encoder.block.6.layer.0.SelfAttention.k.weight", "reader.encoder.block.6.layer.0.SelfAttention.v.weight", "reader.encoder.block.6.layer.0.SelfAttention.o.weight", "reader.encoder.block.6.layer.0.layer_norm.weight", "reader.encoder.block.6.layer.1.DenseReluDense.wi_0.weight", "reader.encoder.block.6.layer.1.DenseReluDense.wi_1.weight", "reader.encoder.block.6.layer.1.DenseReluDense.wo.weight", "reader.encoder.block.6.layer.1.layer_norm.weight", "reader.encoder.block.7.layer.0.SelfAttention.q.weight", "reader.encoder.block.7.layer.0.SelfAttention.k.weight", "reader.encoder.block.7.layer.0.SelfAttention.v.weight", "reader.encoder.block.7.layer.0.SelfAttention.o.weight", "reader.encoder.block.7.layer.0.layer_norm.weight", "reader.encoder.block.7.layer.1.DenseReluDense.wi_0.weight", "reader.encoder.block.7.layer.1.DenseReluDense.wi_1.weight", "reader.encoder.block.7.layer.1.DenseReluDense.wo.weight", "reader.encoder.block.7.layer.1.layer_norm.weight", "reader.encoder.block.8.layer.0.SelfAttention.q.weight", "reader.encoder.block.8.layer.0.SelfAttention.k.weight", "reader.encoder.block.8.layer.0.SelfAttention.v.weight", "reader.encoder.block.8.layer.0.SelfAttention.o.weight", "reader.encoder.block.8.layer.0.layer_norm.weight", "reader.encoder.block.8.layer.1.DenseReluDense.wi_0.weight", "reader.encoder.block.8.layer.1.DenseReluDense.wi_1.weight", "reader.encoder.block.8.layer.1.DenseReluDense.wo.weight", "reader.encoder.block.8.layer.1.layer_norm.weight", "reader.encoder.block.9.layer.0.SelfAttention.q.weight", "reader.encoder.block.9.layer.0.SelfAttention.k.weight", "reader.encoder.block.9.layer.0.SelfAttention.v.weight", "reader.encoder.block.9.layer.0.SelfAttention.o.weight", "reader.encoder.block.9.layer.0.layer_norm.weight", "reader.encoder.block.9.layer.1.DenseReluDense.wi_0.weight", "reader.encoder.block.9.layer.1.DenseReluDense.wi_1.weight", "reader.encoder.block.9.layer.1.DenseReluDense.wo.weight", "reader.encoder.block.9.layer.1.layer_norm.weight", "reader.encoder.block.10.layer.0.SelfAttention.q.weight", "reader.encoder.block.10.layer.0.SelfAttention.k.weight", "reader.encoder.block.10.layer.0.SelfAttention.v.weight", "reader.encoder.block.10.layer.0.SelfAttention.o.weight", "reader.encoder.block.10.layer.0.layer_norm.weight", "reader.encoder.block.10.layer.1.DenseReluDense.wi_0.weight", "reader.encoder.block.10.layer.1.DenseReluDense.wi_1.weight", "reader.encoder.block.10.layer.1.DenseReluDense.wo.weight", "reader.encoder.block.10.layer.1.layer_norm.weight", "reader.encoder.block.11.layer.0.SelfAttention.q.weight", "reader.encoder.block.11.layer.0.SelfAttention.k.weight", "reader.encoder.block.11.layer.0.SelfAttention.v.weight", "reader.encoder.block.11.layer.0.SelfAttention.o.weight", "reader.encoder.block.11.layer.0.layer_norm.weight", "reader.encoder.block.11.layer.1.DenseReluDense.wi_0.weight", "reader.encoder.block.11.layer.1.DenseReluDense.wi_1.weight", "reader.encoder.block.11.layer.1.DenseReluDense.wo.weight", "reader.encoder.block.11.layer.1.layer_norm.weight", "reader.encoder.block.0.layer.1.DenseReluDense.wi_0.weight", "reader.encoder.block.0.layer.1.DenseReluDense.wi_1.weight", "reader.encoder.block.1.layer.1.DenseReluDense.wi_0.weight", "reader.encoder.block.1.layer.1.DenseReluDense.wi_1.weight", "reader.encoder.block.2.layer.1.DenseReluDense.wi_0.weight", "reader.encoder.block.2.layer.1.DenseReluDense.wi_1.weight", "reader.encoder.block.3.layer.1.DenseReluDense.wi_0.weight", "reader.encoder.block.3.layer.1.DenseReluDense.wi_1.weight", "reader.encoder.block.4.layer.1.DenseReluDense.wi_0.weight", "reader.encoder.block.4.layer.1.DenseReluDense.wi_1.weight", "reader.encoder.block.5.layer.1.DenseReluDense.wi_0.weight", "reader.encoder.block.5.layer.1.DenseReluDense.wi_1.weight", "reader.decoder.block.6.layer.0.SelfAttention.q.weight", "reader.decoder.block.6.layer.0.SelfAttention.k.weight", "reader.decoder.block.6.layer.0.SelfAttention.v.weight", "reader.decoder.block.6.layer.0.SelfAttention.o.weight", "reader.decoder.block.6.layer.0.layer_norm.weight", "reader.decoder.block.6.layer.1.EncDecAttention.q.weight", "reader.decoder.block.6.layer.1.EncDecAttention.k.weight", "reader.decoder.block.6.layer.1.EncDecAttention.v.weight", "reader.decoder.block.6.layer.1.EncDecAttention.o.weight", "reader.decoder.block.6.layer.1.layer_norm.weight", "reader.decoder.block.6.layer.2.DenseReluDense.wi_0.weight", "reader.decoder.block.6.layer.2.DenseReluDense.wi_1.weight", "reader.decoder.block.6.layer.2.DenseReluDense.wo.weight", "reader.decoder.block.6.layer.2.layer_norm.weight", "reader.decoder.block.7.layer.0.SelfAttention.q.weight", "reader.decoder.block.7.layer.0.SelfAttention.k.weight", "reader.decoder.block.7.layer.0.SelfAttention.v.weight", "reader.decoder.block.7.layer.0.SelfAttention.o.weight", "reader.decoder.block.7.layer.0.layer_norm.weight", "reader.decoder.block.7.layer.1.EncDecAttention.q.weight", "reader.decoder.block.7.layer.1.EncDecAttention.k.weight", "reader.decoder.block.7.layer.1.EncDecAttention.v.weight", "reader.decoder.block.7.layer.1.EncDecAttention.o.weight", "reader.decoder.block.7.layer.1.layer_norm.weight", "reader.decoder.block.7.layer.2.DenseReluDense.wi_0.weight", "reader.decoder.block.7.layer.2.DenseReluDense.wi_1.weight", "reader.decoder.block.7.layer.2.DenseReluDense.wo.weight", "reader.decoder.block.7.layer.2.layer_norm.weight", "reader.decoder.block.8.layer.0.SelfAttention.q.weight", "reader.decoder.block.8.layer.0.SelfAttention.k.weight", "reader.decoder.block.8.layer.0.SelfAttention.v.weight", "reader.decoder.block.8.layer.0.SelfAttention.o.weight", "reader.decoder.block.8.layer.0.layer_norm.weight", "reader.decoder.block.8.layer.1.EncDecAttention.q.weight", "reader.decoder.block.8.layer.1.EncDecAttention.k.weight", "reader.decoder.block.8.layer.1.EncDecAttention.v.weight", "reader.decoder.block.8.layer.1.EncDecAttention.o.weight", "reader.decoder.block.8.layer.1.layer_norm.weight", "reader.decoder.block.8.layer.2.DenseReluDense.wi_0.weight", "reader.decoder.block.8.layer.2.DenseReluDense.wi_1.weight", "reader.decoder.block.8.layer.2.DenseReluDense.wo.weight", "reader.decoder.block.8.layer.2.layer_norm.weight", "reader.decoder.block.9.layer.0.SelfAttention.q.weight", "reader.decoder.block.9.layer.0.SelfAttention.k.weight", "reader.decoder.block.9.layer.0.SelfAttention.v.weight", "reader.decoder.block.9.layer.0.SelfAttention.o.weight", "reader.decoder.block.9.layer.0.layer_norm.weight", "reader.decoder.block.9.layer.1.EncDecAttention.q.weight", "reader.decoder.block.9.layer.1.EncDecAttention.k.weight", "reader.decoder.block.9.layer.1.EncDecAttention.v.weight", "reader.decoder.block.9.layer.1.EncDecAttention.o.weight", "reader.decoder.block.9.layer.1.layer_norm.weight", "reader.decoder.block.9.layer.2.DenseReluDense.wi_0.weight", "reader.decoder.block.9.layer.2.DenseReluDense.wi_1.weight", "reader.decoder.block.9.layer.2.DenseReluDense.wo.weight", "reader.decoder.block.9.layer.2.layer_norm.weight", "reader.decoder.block.10.layer.0.SelfAttention.q.weight", "reader.decoder.block.10.layer.0.SelfAttention.k.weight", "reader.decoder.block.10.layer.0.SelfAttention.v.weight", "reader.decoder.block.10.layer.0.SelfAttention.o.weight", "reader.decoder.block.10.layer.0.layer_norm.weight", "reader.decoder.block.10.layer.1.EncDecAttention.q.weight", "reader.decoder.block.10.layer.1.EncDecAttention.k.weight", "reader.decoder.block.10.layer.1.EncDecAttention.v.weight", "reader.decoder.block.10.layer.1.EncDecAttention.o.weight", "reader.decoder.block.10.layer.1.layer_norm.weight", "reader.decoder.block.10.layer.2.DenseReluDense.wi_0.weight", "reader.decoder.block.10.layer.2.DenseReluDense.wi_1.weight", "reader.decoder.block.10.layer.2.DenseReluDense.wo.weight", "reader.decoder.block.10.layer.2.layer_norm.weight", "reader.decoder.block.11.layer.0.SelfAttention.q.weight", "reader.decoder.block.11.layer.0.SelfAttention.k.weight", "reader.decoder.block.11.layer.0.SelfAttention.v.weight", "reader.decoder.block.11.layer.0.SelfAttention.o.weight", "reader.decoder.block.11.layer.0.layer_norm.weight", "reader.decoder.block.11.layer.1.EncDecAttention.q.weight", "reader.decoder.block.11.layer.1.EncDecAttention.k.weight", "reader.decoder.block.11.layer.1.EncDecAttention.v.weight", "reader.decoder.block.11.layer.1.EncDecAttention.o.weight", "reader.decoder.block.11.layer.1.layer_norm.weight", "reader.decoder.block.11.layer.2.DenseReluDense.wi_0.weight", "reader.decoder.block.11.layer.2.DenseReluDense.wi_1.weight", "reader.decoder.block.11.layer.2.DenseReluDense.wo.weight", "reader.decoder.block.11.layer.2.layer_norm.weight", "reader.decoder.block.0.layer.2.DenseReluDense.wi_0.weight", "reader.decoder.block.0.layer.2.DenseReluDense.wi_1.weight", "reader.decoder.block.1.layer.2.DenseReluDense.wi_0.weight", "reader.decoder.block.1.layer.2.DenseReluDense.wi_1.weight", "reader.decoder.block.2.layer.2.DenseReluDense.wi_0.weight", "reader.decoder.block.2.layer.2.DenseReluDense.wi_1.weight", "reader.decoder.block.3.layer.2.DenseReluDense.wi_0.weight", "reader.decoder.block.3.layer.2.DenseReluDense.wi_1.weight", "reader.decoder.block.4.layer.2.DenseReluDense.wi_0.weight", "reader.decoder.block.4.layer.2.DenseReluDense.wi_1.weight", "reader.decoder.block.5.layer.2.DenseReluDense.wi_0.weight", "reader.decoder.block.5.layer.2.DenseReluDense.wi_1.weight". 
	size mismatch for reader.shared.weight: copying a param with shape torch.Size([32128, 768]) from checkpoint, the shape in current model is torch.Size([32128, 512]).
	size mismatch for reader.encoder.embed_tokens.weight: copying a param with shape torch.Size([32128, 768]) from checkpoint, the shape in current model is torch.Size([32128, 512]).
	size mismatch for reader.encoder.block.0.layer.0.SelfAttention.q.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.0.layer.0.SelfAttention.k.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.0.layer.0.SelfAttention.v.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.0.layer.0.SelfAttention.o.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight: copying a param with shape torch.Size([32, 12]) from checkpoint, the shape in current model is torch.Size([32, 8]).
	size mismatch for reader.encoder.block.0.layer.0.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.encoder.block.0.layer.1.DenseReluDense.wo.weight: copying a param with shape torch.Size([768, 2048]) from checkpoint, the shape in current model is torch.Size([512, 2048]).
	size mismatch for reader.encoder.block.0.layer.1.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.encoder.block.1.layer.0.SelfAttention.q.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.1.layer.0.SelfAttention.k.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.1.layer.0.SelfAttention.v.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.1.layer.0.SelfAttention.o.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.1.layer.0.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.encoder.block.1.layer.1.DenseReluDense.wo.weight: copying a param with shape torch.Size([768, 2048]) from checkpoint, the shape in current model is torch.Size([512, 2048]).
	size mismatch for reader.encoder.block.1.layer.1.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.encoder.block.2.layer.0.SelfAttention.q.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.2.layer.0.SelfAttention.k.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.2.layer.0.SelfAttention.v.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.2.layer.0.SelfAttention.o.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.2.layer.0.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.encoder.block.2.layer.1.DenseReluDense.wo.weight: copying a param with shape torch.Size([768, 2048]) from checkpoint, the shape in current model is torch.Size([512, 2048]).
	size mismatch for reader.encoder.block.2.layer.1.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.encoder.block.3.layer.0.SelfAttention.q.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.3.layer.0.SelfAttention.k.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.3.layer.0.SelfAttention.v.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.3.layer.0.SelfAttention.o.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.3.layer.0.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.encoder.block.3.layer.1.DenseReluDense.wo.weight: copying a param with shape torch.Size([768, 2048]) from checkpoint, the shape in current model is torch.Size([512, 2048]).
	size mismatch for reader.encoder.block.3.layer.1.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.encoder.block.4.layer.0.SelfAttention.q.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.4.layer.0.SelfAttention.k.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.4.layer.0.SelfAttention.v.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.4.layer.0.SelfAttention.o.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.4.layer.0.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.encoder.block.4.layer.1.DenseReluDense.wo.weight: copying a param with shape torch.Size([768, 2048]) from checkpoint, the shape in current model is torch.Size([512, 2048]).
	size mismatch for reader.encoder.block.4.layer.1.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.encoder.block.5.layer.0.SelfAttention.q.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.5.layer.0.SelfAttention.k.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.5.layer.0.SelfAttention.v.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.5.layer.0.SelfAttention.o.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.encoder.block.5.layer.0.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.encoder.block.5.layer.1.DenseReluDense.wo.weight: copying a param with shape torch.Size([768, 2048]) from checkpoint, the shape in current model is torch.Size([512, 2048]).
	size mismatch for reader.encoder.block.5.layer.1.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.encoder.final_layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.decoder.embed_tokens.weight: copying a param with shape torch.Size([32128, 768]) from checkpoint, the shape in current model is torch.Size([32128, 512]).
	size mismatch for reader.decoder.block.0.layer.0.SelfAttention.q.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.0.layer.0.SelfAttention.k.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.0.layer.0.SelfAttention.v.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.0.layer.0.SelfAttention.o.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight: copying a param with shape torch.Size([32, 12]) from checkpoint, the shape in current model is torch.Size([32, 8]).
	size mismatch for reader.decoder.block.0.layer.0.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.decoder.block.0.layer.1.EncDecAttention.q.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.0.layer.1.EncDecAttention.k.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.0.layer.1.EncDecAttention.v.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.0.layer.1.EncDecAttention.o.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.0.layer.1.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.decoder.block.0.layer.2.DenseReluDense.wo.weight: copying a param with shape torch.Size([768, 2048]) from checkpoint, the shape in current model is torch.Size([512, 2048]).
	size mismatch for reader.decoder.block.0.layer.2.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.decoder.block.1.layer.0.SelfAttention.q.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.1.layer.0.SelfAttention.k.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.1.layer.0.SelfAttention.v.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.1.layer.0.SelfAttention.o.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.1.layer.0.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.decoder.block.1.layer.1.EncDecAttention.q.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.1.layer.1.EncDecAttention.k.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.1.layer.1.EncDecAttention.v.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.1.layer.1.EncDecAttention.o.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.1.layer.1.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.decoder.block.1.layer.2.DenseReluDense.wo.weight: copying a param with shape torch.Size([768, 2048]) from checkpoint, the shape in current model is torch.Size([512, 2048]).
	size mismatch for reader.decoder.block.1.layer.2.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.decoder.block.2.layer.0.SelfAttention.q.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.2.layer.0.SelfAttention.k.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.2.layer.0.SelfAttention.v.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.2.layer.0.SelfAttention.o.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.2.layer.0.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.decoder.block.2.layer.1.EncDecAttention.q.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.2.layer.1.EncDecAttention.k.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.2.layer.1.EncDecAttention.v.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.2.layer.1.EncDecAttention.o.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.2.layer.1.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.decoder.block.2.layer.2.DenseReluDense.wo.weight: copying a param with shape torch.Size([768, 2048]) from checkpoint, the shape in current model is torch.Size([512, 2048]).
	size mismatch for reader.decoder.block.2.layer.2.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.decoder.block.3.layer.0.SelfAttention.q.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.3.layer.0.SelfAttention.k.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.3.layer.0.SelfAttention.v.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.3.layer.0.SelfAttention.o.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.3.layer.0.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.decoder.block.3.layer.1.EncDecAttention.q.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.3.layer.1.EncDecAttention.k.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.3.layer.1.EncDecAttention.v.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.3.layer.1.EncDecAttention.o.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.3.layer.1.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.decoder.block.3.layer.2.DenseReluDense.wo.weight: copying a param with shape torch.Size([768, 2048]) from checkpoint, the shape in current model is torch.Size([512, 2048]).
	size mismatch for reader.decoder.block.3.layer.2.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.decoder.block.4.layer.0.SelfAttention.q.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.4.layer.0.SelfAttention.k.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.4.layer.0.SelfAttention.v.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.4.layer.0.SelfAttention.o.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.4.layer.0.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.decoder.block.4.layer.1.EncDecAttention.q.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.4.layer.1.EncDecAttention.k.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.4.layer.1.EncDecAttention.v.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.4.layer.1.EncDecAttention.o.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.4.layer.1.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.decoder.block.4.layer.2.DenseReluDense.wo.weight: copying a param with shape torch.Size([768, 2048]) from checkpoint, the shape in current model is torch.Size([512, 2048]).
	size mismatch for reader.decoder.block.4.layer.2.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.decoder.block.5.layer.0.SelfAttention.q.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.5.layer.0.SelfAttention.k.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.5.layer.0.SelfAttention.v.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.5.layer.0.SelfAttention.o.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.5.layer.0.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.decoder.block.5.layer.1.EncDecAttention.q.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.5.layer.1.EncDecAttention.k.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.5.layer.1.EncDecAttention.v.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.5.layer.1.EncDecAttention.o.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for reader.decoder.block.5.layer.1.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.decoder.block.5.layer.2.DenseReluDense.wo.weight: copying a param with shape torch.Size([768, 2048]) from checkpoint, the shape in current model is torch.Size([512, 2048]).
	size mismatch for reader.decoder.block.5.layer.2.layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.decoder.final_layer_norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for reader.lm_head.weight: copying a param with shape torch.Size([32128, 768]) from checkpoint, the shape in current model is torch.Size([32128, 512]).

In [5]:
query, answers, labels.shape, decoder_input_ids.shape

(['question: who got the first nobel prize in physics answer: <extra_id_0>'],
 ['<extra_id_0> Wilhelm Conrad Röntgen'],
 torch.Size([1, 512]),
 torch.Size([1, 512]))

In [6]:
query_emb.shape

torch.Size([1, 768])

In [20]:
# DuckDB session
con = duckdb.connect(config = {'allow_unsigned_extensions': 'true'})

# Load FAISS extension
con.sql(f"LOAD '{DUCKDB_FAISS_EXT}'")

# Load Wikipedia index (21GB)
con.sql(f"CALL faiss_load('index', '{FAISS_INDEX_PATH}');")  

┌─────────┐
│ Success │
│ boolean │
├─────────┤
│ 0 rows  │
└─────────┘

In [27]:
with open(PASSAGES_PATH, 'rb') as f:
    passages = pickle.load(f)

In [25]:
query_df = pd.DataFrame({"query": query_emb.tolist()})
con.sql("SELECT * FROM query_df")

┌──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐
│                                                        query                                                         │
│                                                       double[]                                                       │
├──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ [-0.14450009167194366, -0.019435692578554153, 0.041694700717926025, -0.08630520105361938, -0.03823167458176613, -0…  │
└──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘

In [26]:
topk_df = con.sql(f"SELECT UNNEST(faiss_search('index', 40, query)) FROM query_df").to_df()
top_k = [row[0] for _, row in topk_df.iterrows()]

  top_k = [row[0] for _, row in topk_df.iterrows()]


In [31]:
top_k_passages = []
for dict_item in top_k:
    id = int(dict_item['label'])
    top_k_passages.append(passages[id])

[{'id': '27997088',
  'title': '2011–12 Coppa Italia: =First round',
  'section': '=First round',
  'text': ' ='},
 {'id': '27997092',
  'title': '2011–12 Coppa Italia: =First round',
  'section': '=First round',
  'text': ' ='},
 {'id': '27997089',
  'title': '2011–12 Coppa Italia: =Second round',
  'section': '=Second round',
  'text': ' ='},
 {'id': '27997093',
  'title': '2011–12 Coppa Italia: =Second round',
  'section': '=Second round',
  'text': ' ='},
 {'id': '20207299',
  'title': '2010–11 Coppa Italia: =First round',
  'section': '=First round',
  'text': ' ='},
 {'id': '8972033',
  'title': '2009–10 Coppa Italia: =First round',
  'section': '=First round',
  'text': ' ='},
 {'id': '8972037',
  'title': '2009–10 Coppa Italia: =First round',
  'section': '=First round',
  'text': ' ='},
 {'id': '1612355',
  'title': 'List of Swedish scientists',
  'section': '',
  'text': ' This is a list of Swedish scientists.'},
 {'id': '22764196',
  'title': '1994 MTV Video Music Awards: Li

In [7]:
# Search top k relevant documents for query embedding
# Return top k document passages + scores
def search_relevant_docs(query_emb, index_path, passages_path, k=5):
    # DuckDB session
    con = duckdb.connect(config = {'allow_unsigned_extensions': 'true'})

    # Load FAISS extension
    con.sql(f"LOAD '{DUCKDB_FAISS_EXT}'")

    # Load Wikipedia index (21GB)
    con.sql(f"CALL faiss_load('index', '{FAISS_INDEX_PATH}');")
    
    # Load Wikipedia passages
    with open(passages_path, 'rb') as f:
        passages = pickle.load(f)
    
    # Create query df and read into duckdb
    query_df = pd.DataFrame({"query": query_emb.tolist()})
    
    # top k search
    topk_df = con.sql(f"SELECT UNNEST(faiss_search('index', {k}, query)) FROM query_df").to_df()
    top_k = [row[0] for _, row in topk_df.iterrows()]
    
    # Extend top k dictionaries with corresponding passages
    top_k_passages = []
    for dict_item in top_k:
        id = int(dict_item['label'])
        top_k_passages.append(passages[id])
        
    return top_k_passages

In [8]:
top_k_passages = search_relevant_docs(query_emb, FAISS_INDEX_PATH, PASSAGES_PATH, 40)

  top_k = [row[0] for _, row in topk_df.iterrows()]


In [9]:
# Move reader to cuda:1
READER_DEVICE = 'cuda:1'
unwrapped_model.reader.to(READER_DEVICE)

# Create tokens for reader
reader_tokens, retriever_tokens = unwrapped_model.tokenize_passages(query, [top_k_passages])
reader_ids = reader_tokens["input_ids"].to(READER_DEVICE)
reader_mask = reader_tokens["attention_mask"].bool().to(READER_DEVICE)
n_context_training = min(unwrapped_model.opt.n_context, reader_ids.size(1))

# Get reader config
cfg = unwrapped_model.reader.encoder.config
cfg.bsz = reader_ids.size(0)
cfg.n_context = n_context_training

# Reshape reader ids
reader_ids_training = reader_ids[:, :n_context_training].contiguous()
reader_mask_training = reader_mask[:, :n_context_training].contiguous()
reader_ids_training = reader_ids_training.view(reader_ids.size(0), -1)
reader_mask_training = reader_mask_training.view(reader_mask.size(0), -1)

# Inference with reader ids
reader_output = unwrapped_model.reader(
    input_ids=reader_ids_training,
    attention_mask=reader_mask_training,
    decoder_input_ids=decoder_input_ids,
    labels=labels,
    use_cache=False,
)

RuntimeError: CUDA out of memory. Tried to allocate 60.00 MiB (GPU 1; 15.77 GiB total capacity; 14.35 GiB already allocated; 5.12 MiB free; 14.41 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

: 