# Setup

In [None]:
# Selection of GPUs
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"   

In [None]:
import argparse
import os
import logging
from typing import List, Tuple, Dict, Iterator
import torch
from dpr.options import *

In [None]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if (logger.hasHandlers()):
    logger.handlers.clear()
console = logging.StreamHandler()
logger.addHandler(console)

In [None]:
parser = argparse.ArgumentParser()
from dpr.options import init_base_args
args = init_base_args(parser)

# main


In [None]:
import collections
from dpr.utils.model_utils import load_states_from_checkpoint, setup_for_distributed_mode, get_model_obj
from dpr.models import init_biencoder_components
#Intermediate dump of a models internal state
CheckpointState = collections.namedtuple("CheckpointState",['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset', 'epoch','encoder_params'])

In [None]:
saved_state = load_states_from_checkpoint(args.model_file)
set_encoder_params_from_state(saved_state)


In [None]:
tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only=True)
encoder = encoder.question_model

#                                       Model, Optimiser, Device,   GPU count,  Rank,            Half-precision fp
encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu, args.local_rank, args.fp16)
encoder.eval()

## Loading Model Weights

In [28]:
from dpr.indexer.faiss_indexers import DenseHNSWFlatIndexer, DenseFlatIndexer
from dense_retriever import DenseRetriever

In [None]:
# Load weights from model file
model_to_load = get_model_obj(encoder)
logger.info('Reading saved model from %s', args.model_file)

prefix_len = len('question_model.')
question_encoder_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if
                            key.startswith('question_model.')}
model_to_load.load_state_dict(question_encoder_state, strict=False)


In [None]:
vector_size = model_to_load.get_out_size()
logger.info('Encoder vector_size=%d', vector_size)
ndex_buffer_sz = args.index_buffer
if args.hnsw_index:
    index = DenseHNSWFlatIndexer(vector_size)
    index_buffer_sz = -1  # encode all at once
else:
    index = DenseFlatIndexer(vector_size)

retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index)

## Index passages

In [27]:
import glob

In [None]:
ctx_files_pattern = args.encoded_ctx_file
input_paths = glob.glob(ctx_files_pattern)

if args.remove_lang is not None:
    final_fps = []

    for path in input_paths:
        basename = os.path.basename(path)
        to_be_removed = False
        for lang in args.remove_lang:
            if lang in basename:
                to_be_removed = True
        if to_be_removed is False:
            final_fps.append(path)
    input_paths = final_fps
    print("lang {} are removed from retrieval target".format(input_paths))
    index_path = "_".join(input_paths[0].split("_")[:-1])

    